1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestAttributes.h"
11 #include "TestInterfaces.h"
12 #include "TestTypes.h"
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/DLTI/DLTI.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/DialectImplementation.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/Reducer/ReductionPatternInterface.h"
22 #include "mlir/Transforms/FoldUtils.h"
23 #include "mlir/Transforms/InliningUtils.h"
24 #include "llvm/ADT/StringSwitch.h"
25 
26 // Include this before the using namespace lines below to
27 // test that we don't have namespace dependencies.
28 #include "TestOpsDialect.cpp.inc"
29 
30 using namespace mlir;
31 using namespace test;
32 
33 void test::registerTestDialect(DialectRegistry &registry) {
34   registry.insert<TestDialect>();
35 }
36 
37 //===----------------------------------------------------------------------===//
38 // TestDialect Interfaces
39 //===----------------------------------------------------------------------===//
40 
41 namespace {
42 
43 /// Testing the correctness of some traits.
44 static_assert(
45     llvm::is_detected<OpTrait::has_implicit_terminator_t,
46                       SingleBlockImplicitTerminatorOp>::value,
47     "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
48 static_assert(OpTrait::hasSingleBlockImplicitTerminator<
49                   SingleBlockImplicitTerminatorOp>::value,
50               "hasSingleBlockImplicitTerminator does not match "
51               "SingleBlockImplicitTerminatorOp");
52 
53 // Test support for interacting with the AsmPrinter.
54 struct TestOpAsmInterface : public OpAsmDialectInterface {
55   using OpAsmDialectInterface::OpAsmDialectInterface;
56 
57   AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
58     StringAttr strAttr = attr.dyn_cast<StringAttr>();
59     if (!strAttr)
60       return AliasResult::NoAlias;
61 
62     // Check the contents of the string attribute to see what the test alias
63     // should be named.
64     Optional<StringRef> aliasName =
65         StringSwitch<Optional<StringRef>>(strAttr.getValue())
66             .Case("alias_test:dot_in_name", StringRef("test.alias"))
67             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
68             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
69             .Case("alias_test:sanitize_conflict_a",
70                   StringRef("test_alias_conflict0"))
71             .Case("alias_test:sanitize_conflict_b",
72                   StringRef("test_alias_conflict0_"))
73             .Case("alias_test:tensor_encoding", StringRef("test_encoding"))
74             .Default(llvm::None);
75     if (!aliasName)
76       return AliasResult::NoAlias;
77 
78     os << *aliasName;
79     return AliasResult::FinalAlias;
80   }
81 
82   AliasResult getAlias(Type type, raw_ostream &os) const final {
83     if (auto tupleType = type.dyn_cast<TupleType>()) {
84       if (tupleType.size() > 0 &&
85           llvm::all_of(tupleType.getTypes(), [](Type elemType) {
86             return elemType.isa<SimpleAType>();
87           })) {
88         os << "test_tuple";
89         return AliasResult::FinalAlias;
90       }
91     }
92     if (auto intType = type.dyn_cast<TestIntegerType>()) {
93       if (intType.getSignedness() ==
94               TestIntegerType::SignednessSemantics::Unsigned &&
95           intType.getWidth() == 8) {
96         os << "test_ui8";
97         return AliasResult::FinalAlias;
98       }
99     }
100     return AliasResult::NoAlias;
101   }
102 
103   void getAsmResultNames(Operation *op,
104                          OpAsmSetValueNameFn setNameFn) const final {
105     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
106       setNameFn(asmOp, "result");
107   }
108 };
109 
110 struct TestDialectFoldInterface : public DialectFoldInterface {
111   using DialectFoldInterface::DialectFoldInterface;
112 
113   /// Registered hook to check if the given region, which is attached to an
114   /// operation that is *not* isolated from above, should be used when
115   /// materializing constants.
116   bool shouldMaterializeInto(Region *region) const final {
117     // If this is a one region operation, then insert into it.
118     return isa<OneRegionOp>(region->getParentOp());
119   }
120 };
121 
122 /// This class defines the interface for handling inlining with standard
123 /// operations.
124 struct TestInlinerInterface : public DialectInlinerInterface {
125   using DialectInlinerInterface::DialectInlinerInterface;
126 
127   //===--------------------------------------------------------------------===//
128   // Analysis Hooks
129   //===--------------------------------------------------------------------===//
130 
131   bool isLegalToInline(Operation *call, Operation *callable,
132                        bool wouldBeCloned) const final {
133     // Don't allow inlining calls that are marked `noinline`.
134     return !call->hasAttr("noinline");
135   }
136   bool isLegalToInline(Region *, Region *, bool,
137                        BlockAndValueMapping &) const final {
138     // Inlining into test dialect regions is legal.
139     return true;
140   }
141   bool isLegalToInline(Operation *, Region *, bool,
142                        BlockAndValueMapping &) const final {
143     return true;
144   }
145 
146   bool shouldAnalyzeRecursively(Operation *op) const final {
147     // Analyze recursively if this is not a functional region operation, it
148     // froms a separate functional scope.
149     return !isa<FunctionalRegionOp>(op);
150   }
151 
152   //===--------------------------------------------------------------------===//
153   // Transformation Hooks
154   //===--------------------------------------------------------------------===//
155 
156   /// Handle the given inlined terminator by replacing it with a new operation
157   /// as necessary.
158   void handleTerminator(Operation *op,
159                         ArrayRef<Value> valuesToRepl) const final {
160     // Only handle "test.return" here.
161     auto returnOp = dyn_cast<TestReturnOp>(op);
162     if (!returnOp)
163       return;
164 
165     // Replace the values directly with the return operands.
166     assert(returnOp.getNumOperands() == valuesToRepl.size());
167     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
168       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
169   }
170 
171   /// Attempt to materialize a conversion for a type mismatch between a call
172   /// from this dialect, and a callable region. This method should generate an
173   /// operation that takes 'input' as the only operand, and produces a single
174   /// result of 'resultType'. If a conversion can not be generated, nullptr
175   /// should be returned.
176   Operation *materializeCallConversion(OpBuilder &builder, Value input,
177                                        Type resultType,
178                                        Location conversionLoc) const final {
179     // Only allow conversion for i16/i32 types.
180     if (!(resultType.isSignlessInteger(16) ||
181           resultType.isSignlessInteger(32)) ||
182         !(input.getType().isSignlessInteger(16) ||
183           input.getType().isSignlessInteger(32)))
184       return nullptr;
185     return builder.create<TestCastOp>(conversionLoc, resultType, input);
186   }
187 
188   void processInlinedCallBlocks(
189       Operation *call,
190       iterator_range<Region::iterator> inlinedBlocks) const final {
191     if (!isa<ConversionCallOp>(call))
192       return;
193 
194     // Set attributed on all ops in the inlined blocks.
195     for (Block &block : inlinedBlocks) {
196       block.walk([&](Operation *op) {
197         op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
198       });
199     }
200   }
201 };
202 
203 struct TestReductionPatternInterface : public DialectReductionPatternInterface {
204 public:
205   TestReductionPatternInterface(Dialect *dialect)
206       : DialectReductionPatternInterface(dialect) {}
207 
208   void populateReductionPatterns(RewritePatternSet &patterns) const final {
209     populateTestReductionPatterns(patterns);
210   }
211 };
212 
213 } // namespace
214 
215 //===----------------------------------------------------------------------===//
216 // TestDialect
217 //===----------------------------------------------------------------------===//
218 
219 static void testSideEffectOpGetEffect(
220     Operation *op,
221     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects);
222 
223 // This is the implementation of a dialect fallback for `TestEffectOpInterface`.
224 struct TestOpEffectInterfaceFallback
225     : public TestEffectOpInterface::FallbackModel<
226           TestOpEffectInterfaceFallback> {
227   static bool classof(Operation *op) {
228     bool isSupportedOp =
229         op->getName().getStringRef() == "test.unregistered_side_effect_op";
230     assert(isSupportedOp && "Unexpected dispatch");
231     return isSupportedOp;
232   }
233 
234   void
235   getEffects(Operation *op,
236              SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
237                  &effects) const {
238     testSideEffectOpGetEffect(op, effects);
239   }
240 };
241 
242 void TestDialect::initialize() {
243   registerAttributes();
244   registerTypes();
245   addOperations<
246 #define GET_OP_LIST
247 #include "TestOps.cpp.inc"
248       >();
249   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
250                 TestInlinerInterface, TestReductionPatternInterface>();
251   allowUnknownOperations();
252 
253   // Instantiate our fallback op interface that we'll use on specific
254   // unregistered op.
255   fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
256 }
257 TestDialect::~TestDialect() {
258   delete static_cast<TestOpEffectInterfaceFallback *>(
259       fallbackEffectOpInterfaces);
260 }
261 
262 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
263                                             Type type, Location loc) {
264   return builder.create<TestOpConstant>(loc, type, value);
265 }
266 
267 ::mlir::LogicalResult FormatInferType2Op::inferReturnTypes(
268     ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location,
269     ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
270     ::mlir::RegionRange regions,
271     ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
272   inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)});
273   return ::mlir::success();
274 }
275 
276 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
277                                                OperationName opName) {
278   if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
279       typeID == TypeID::get<TestEffectOpInterface>())
280     return fallbackEffectOpInterfaces;
281   return nullptr;
282 }
283 
284 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
285                                                     NamedAttribute namedAttr) {
286   if (namedAttr.getName() == "test.invalid_attr")
287     return op->emitError() << "invalid to use 'test.invalid_attr'";
288   return success();
289 }
290 
291 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
292                                                     unsigned regionIndex,
293                                                     unsigned argIndex,
294                                                     NamedAttribute namedAttr) {
295   if (namedAttr.getName() == "test.invalid_attr")
296     return op->emitError() << "invalid to use 'test.invalid_attr'";
297   return success();
298 }
299 
300 LogicalResult
301 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
302                                          unsigned resultIndex,
303                                          NamedAttribute namedAttr) {
304   if (namedAttr.getName() == "test.invalid_attr")
305     return op->emitError() << "invalid to use 'test.invalid_attr'";
306   return success();
307 }
308 
309 Optional<Dialect::ParseOpHook>
310 TestDialect::getParseOperationHook(StringRef opName) const {
311   if (opName == "test.dialect_custom_printer") {
312     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
313       return parser.parseKeyword("custom_format");
314     }};
315   }
316   if (opName == "test.dialect_custom_format_fallback") {
317     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
318       return parser.parseKeyword("custom_format_fallback");
319     }};
320   }
321   return None;
322 }
323 
324 llvm::unique_function<void(Operation *, OpAsmPrinter &)>
325 TestDialect::getOperationPrinter(Operation *op) const {
326   StringRef opName = op->getName().getStringRef();
327   if (opName == "test.dialect_custom_printer") {
328     return [](Operation *op, OpAsmPrinter &printer) {
329       printer.getStream() << " custom_format";
330     };
331   }
332   if (opName == "test.dialect_custom_format_fallback") {
333     return [](Operation *op, OpAsmPrinter &printer) {
334       printer.getStream() << " custom_format_fallback";
335     };
336   }
337   return {};
338 }
339 
340 //===----------------------------------------------------------------------===//
341 // TestBranchOp
342 //===----------------------------------------------------------------------===//
343 
344 Optional<MutableOperandRange>
345 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
346   assert(index == 0 && "invalid successor index");
347   return getTargetOperandsMutable();
348 }
349 
350 //===----------------------------------------------------------------------===//
351 // TestDialectCanonicalizerOp
352 //===----------------------------------------------------------------------===//
353 
354 static LogicalResult
355 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
356                                PatternRewriter &rewriter) {
357   rewriter.replaceOpWithNewOp<arith::ConstantOp>(
358       op, rewriter.getI32IntegerAttr(42));
359   return success();
360 }
361 
362 void TestDialect::getCanonicalizationPatterns(
363     RewritePatternSet &results) const {
364   results.add(&dialectCanonicalizationPattern);
365 }
366 
367 //===----------------------------------------------------------------------===//
368 // TestFoldToCallOp
369 //===----------------------------------------------------------------------===//
370 
371 namespace {
372 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
373   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
374 
375   LogicalResult matchAndRewrite(FoldToCallOp op,
376                                 PatternRewriter &rewriter) const override {
377     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.getCalleeAttr(),
378                                         ValueRange());
379     return success();
380   }
381 };
382 } // namespace
383 
384 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
385                                                MLIRContext *context) {
386   results.add<FoldToCallOpPattern>(context);
387 }
388 
389 //===----------------------------------------------------------------------===//
390 // Test Format* operations
391 //===----------------------------------------------------------------------===//
392 
393 //===----------------------------------------------------------------------===//
394 // Parsing
395 
396 static ParseResult parseCustomDirectiveOperands(
397     OpAsmParser &parser, OpAsmParser::OperandType &operand,
398     Optional<OpAsmParser::OperandType> &optOperand,
399     SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
400   if (parser.parseOperand(operand))
401     return failure();
402   if (succeeded(parser.parseOptionalComma())) {
403     optOperand.emplace();
404     if (parser.parseOperand(*optOperand))
405       return failure();
406   }
407   if (parser.parseArrow() || parser.parseLParen() ||
408       parser.parseOperandList(varOperands) || parser.parseRParen())
409     return failure();
410   return success();
411 }
412 static ParseResult
413 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
414                             Type &optOperandType,
415                             SmallVectorImpl<Type> &varOperandTypes) {
416   if (parser.parseColon())
417     return failure();
418 
419   if (parser.parseType(operandType))
420     return failure();
421   if (succeeded(parser.parseOptionalComma())) {
422     if (parser.parseType(optOperandType))
423       return failure();
424   }
425   if (parser.parseArrow() || parser.parseLParen() ||
426       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
427     return failure();
428   return success();
429 }
430 static ParseResult
431 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
432                                  Type optOperandType,
433                                  const SmallVectorImpl<Type> &varOperandTypes) {
434   if (parser.parseKeyword("type_refs_capture"))
435     return failure();
436 
437   Type operandType2, optOperandType2;
438   SmallVector<Type, 1> varOperandTypes2;
439   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
440                                   varOperandTypes2))
441     return failure();
442 
443   if (operandType != operandType2 || optOperandType != optOperandType2 ||
444       varOperandTypes != varOperandTypes2)
445     return failure();
446 
447   return success();
448 }
449 static ParseResult parseCustomDirectiveOperandsAndTypes(
450     OpAsmParser &parser, OpAsmParser::OperandType &operand,
451     Optional<OpAsmParser::OperandType> &optOperand,
452     SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
453     Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
454   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
455       parseCustomDirectiveResults(parser, operandType, optOperandType,
456                                   varOperandTypes))
457     return failure();
458   return success();
459 }
460 static ParseResult parseCustomDirectiveRegions(
461     OpAsmParser &parser, Region &region,
462     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
463   if (parser.parseRegion(region))
464     return failure();
465   if (failed(parser.parseOptionalComma()))
466     return success();
467   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
468   if (parser.parseRegion(*varRegion))
469     return failure();
470   varRegions.emplace_back(std::move(varRegion));
471   return success();
472 }
473 static ParseResult
474 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
475                                SmallVectorImpl<Block *> &varSuccessors) {
476   if (parser.parseSuccessor(successor))
477     return failure();
478   if (failed(parser.parseOptionalComma()))
479     return success();
480   Block *varSuccessor;
481   if (parser.parseSuccessor(varSuccessor))
482     return failure();
483   varSuccessors.append(2, varSuccessor);
484   return success();
485 }
486 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
487                                                   IntegerAttr &attr,
488                                                   IntegerAttr &optAttr) {
489   if (parser.parseAttribute(attr))
490     return failure();
491   if (succeeded(parser.parseOptionalComma())) {
492     if (parser.parseAttribute(optAttr))
493       return failure();
494   }
495   return success();
496 }
497 
498 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
499                                                 NamedAttrList &attrs) {
500   return parser.parseOptionalAttrDict(attrs);
501 }
502 static ParseResult parseCustomDirectiveOptionalOperandRef(
503     OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) {
504   int64_t operandCount = 0;
505   if (parser.parseInteger(operandCount))
506     return failure();
507   bool expectedOptionalOperand = operandCount == 0;
508   return success(expectedOptionalOperand != optOperand.hasValue());
509 }
510 
511 //===----------------------------------------------------------------------===//
512 // Printing
513 
514 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
515                                          Value operand, Value optOperand,
516                                          OperandRange varOperands) {
517   printer << operand;
518   if (optOperand)
519     printer << ", " << optOperand;
520   printer << " -> (" << varOperands << ")";
521 }
522 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
523                                         Type operandType, Type optOperandType,
524                                         TypeRange varOperandTypes) {
525   printer << " : " << operandType;
526   if (optOperandType)
527     printer << ", " << optOperandType;
528   printer << " -> (" << varOperandTypes << ")";
529 }
530 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
531                                              Operation *op, Type operandType,
532                                              Type optOperandType,
533                                              TypeRange varOperandTypes) {
534   printer << " type_refs_capture ";
535   printCustomDirectiveResults(printer, op, operandType, optOperandType,
536                               varOperandTypes);
537 }
538 static void printCustomDirectiveOperandsAndTypes(
539     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
540     OperandRange varOperands, Type operandType, Type optOperandType,
541     TypeRange varOperandTypes) {
542   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
543   printCustomDirectiveResults(printer, op, operandType, optOperandType,
544                               varOperandTypes);
545 }
546 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
547                                         Region &region,
548                                         MutableArrayRef<Region> varRegions) {
549   printer.printRegion(region);
550   if (!varRegions.empty()) {
551     printer << ", ";
552     for (Region &region : varRegions)
553       printer.printRegion(region);
554   }
555 }
556 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
557                                            Block *successor,
558                                            SuccessorRange varSuccessors) {
559   printer << successor;
560   if (!varSuccessors.empty())
561     printer << ", " << varSuccessors.front();
562 }
563 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
564                                            Attribute attribute,
565                                            Attribute optAttribute) {
566   printer << attribute;
567   if (optAttribute)
568     printer << ", " << optAttribute;
569 }
570 
571 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
572                                          DictionaryAttr attrs) {
573   printer.printOptionalAttrDict(attrs.getValue());
574 }
575 
576 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
577                                                    Operation *op,
578                                                    Value optOperand) {
579   printer << (optOperand ? "1" : "0");
580 }
581 
582 //===----------------------------------------------------------------------===//
583 // Test IsolatedRegionOp - parse passthrough region arguments.
584 //===----------------------------------------------------------------------===//
585 
586 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
587                                          OperationState &result) {
588   OpAsmParser::OperandType argInfo;
589   Type argType = parser.getBuilder().getIndexType();
590 
591   // Parse the input operand.
592   if (parser.parseOperand(argInfo) ||
593       parser.resolveOperand(argInfo, argType, result.operands))
594     return failure();
595 
596   // Parse the body region, and reuse the operand info as the argument info.
597   Region *body = result.addRegion();
598   return parser.parseRegion(*body, argInfo, argType, /*argLocations=*/{},
599                             /*enableNameShadowing=*/true);
600 }
601 
602 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
603   p << "test.isolated_region ";
604   p.printOperand(op.getOperand());
605   p.shadowRegionArgs(op.getRegion(), op.getOperand());
606   p << ' ';
607   p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false);
608 }
609 
610 //===----------------------------------------------------------------------===//
611 // Test SSACFGRegionOp
612 //===----------------------------------------------------------------------===//
613 
614 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
615   return RegionKind::SSACFG;
616 }
617 
618 //===----------------------------------------------------------------------===//
619 // Test GraphRegionOp
620 //===----------------------------------------------------------------------===//
621 
622 static ParseResult parseGraphRegionOp(OpAsmParser &parser,
623                                       OperationState &result) {
624   // Parse the body region, and reuse the operand info as the argument info.
625   Region *body = result.addRegion();
626   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
627 }
628 
629 static void print(OpAsmPrinter &p, GraphRegionOp op) {
630   p << "test.graph_region ";
631   p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false);
632 }
633 
634 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
635   return RegionKind::Graph;
636 }
637 
638 //===----------------------------------------------------------------------===//
639 // Test AffineScopeOp
640 //===----------------------------------------------------------------------===//
641 
642 static ParseResult parseAffineScopeOp(OpAsmParser &parser,
643                                       OperationState &result) {
644   // Parse the body region, and reuse the operand info as the argument info.
645   Region *body = result.addRegion();
646   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
647 }
648 
649 static void print(OpAsmPrinter &p, AffineScopeOp op) {
650   p << "test.affine_scope ";
651   p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false);
652 }
653 
654 //===----------------------------------------------------------------------===//
655 // Test parser.
656 //===----------------------------------------------------------------------===//
657 
658 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser,
659                                               OperationState &result) {
660   if (parser.parseOptionalColon())
661     return success();
662   uint64_t numResults;
663   if (parser.parseInteger(numResults))
664     return failure();
665 
666   IndexType type = parser.getBuilder().getIndexType();
667   for (unsigned i = 0; i < numResults; ++i)
668     result.addTypes(type);
669   return success();
670 }
671 
672 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) {
673   if (unsigned numResults = op->getNumResults())
674     p << " : " << numResults;
675 }
676 
677 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser,
678                                               OperationState &result) {
679   StringRef keyword;
680   if (parser.parseKeyword(&keyword))
681     return failure();
682   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
683   return success();
684 }
685 
686 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) {
687   p << " " << op.getKeyword();
688 }
689 
690 //===----------------------------------------------------------------------===//
691 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
692 
693 static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
694                                          OperationState &result) {
695   if (parser.parseKeyword("wraps"))
696     return failure();
697 
698   // Parse the wrapped op in a region
699   Region &body = *result.addRegion();
700   body.push_back(new Block);
701   Block &block = body.back();
702   Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
703   if (!wrappedOp)
704     return failure();
705 
706   // Create a return terminator in the inner region, pass as operand to the
707   // terminator the returned values from the wrapped operation.
708   SmallVector<Value, 8> returnOperands(wrappedOp->getResults());
709   OpBuilder builder(parser.getContext());
710   builder.setInsertionPointToEnd(&block);
711   builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands);
712 
713   // Get the results type for the wrapping op from the terminator operands.
714   Operation &returnOp = body.back().back();
715   result.types.append(returnOp.operand_type_begin(),
716                       returnOp.operand_type_end());
717 
718   // Use the location of the wrapped op for the "test.wrapping_region" op.
719   result.location = wrappedOp->getLoc();
720 
721   return success();
722 }
723 
724 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
725   p << " wraps ";
726   p.printGenericOp(&op.getRegion().front().front());
727 }
728 
729 //===----------------------------------------------------------------------===//
730 // Test PrettyPrintedRegionOp -  exercising the following parser APIs
731 //   parseGenericOperationAfterOpName
732 //   parseCustomOperationName
733 //===----------------------------------------------------------------------===//
734 
735 static ParseResult parsePrettyPrintedRegionOp(OpAsmParser &parser,
736                                               OperationState &result) {
737 
738   llvm::SMLoc loc = parser.getCurrentLocation();
739   Location currLocation = parser.getEncodedSourceLoc(loc);
740 
741   // Parse the operands.
742   SmallVector<OpAsmParser::OperandType, 2> operands;
743   if (parser.parseOperandList(operands))
744     return failure();
745 
746   // Check if we are parsing the pretty-printed version
747   //  test.pretty_printed_region start <inner-op> end : <functional-type>
748   // Else fallback to parsing the "non pretty-printed" version.
749   if (!succeeded(parser.parseOptionalKeyword("start")))
750     return parser.parseGenericOperationAfterOpName(
751         result, llvm::makeArrayRef(operands));
752 
753   FailureOr<OperationName> parseOpNameInfo = parser.parseCustomOperationName();
754   if (failed(parseOpNameInfo))
755     return failure();
756 
757   StringRef innerOpName = parseOpNameInfo->getStringRef();
758 
759   FunctionType opFntype;
760   Optional<Location> explicitLoc;
761   if (parser.parseKeyword("end") || parser.parseColon() ||
762       parser.parseType(opFntype) ||
763       parser.parseOptionalLocationSpecifier(explicitLoc))
764     return failure();
765 
766   // If location of the op is explicitly provided, then use it; Else use
767   // the parser's current location.
768   Location opLoc = explicitLoc.getValueOr(currLocation);
769 
770   // Derive the SSA-values for op's operands.
771   if (parser.resolveOperands(operands, opFntype.getInputs(), loc,
772                              result.operands))
773     return failure();
774 
775   // Add a region for op.
776   Region &region = *result.addRegion();
777 
778   // Create a basic-block inside op's region.
779   Block &block = region.emplaceBlock();
780 
781   // Create and insert an "inner-op" operation in the block.
782   // Just for testing purposes, we can assume that inner op is a binary op with
783   // result and operand types all same as the test-op's first operand.
784   Type innerOpType = opFntype.getInput(0);
785   Value lhs = block.addArgument(innerOpType, opLoc);
786   Value rhs = block.addArgument(innerOpType, opLoc);
787 
788   OpBuilder builder(parser.getBuilder().getContext());
789   builder.setInsertionPointToStart(&block);
790 
791   OperationState innerOpState(opLoc, innerOpName);
792   innerOpState.operands.push_back(lhs);
793   innerOpState.operands.push_back(rhs);
794   innerOpState.addTypes(innerOpType);
795 
796   Operation *innerOp = builder.createOperation(innerOpState);
797 
798   // Insert a return statement in the block returning the inner-op's result.
799   builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults());
800 
801   // Populate the op operation-state with result-type and location.
802   result.addTypes(opFntype.getResults());
803   result.location = innerOp->getLoc();
804 
805   return success();
806 }
807 
808 static void print(OpAsmPrinter &p, PrettyPrintedRegionOp op) {
809   p << ' ';
810   p.printOperands(op.getOperands());
811 
812   Operation &innerOp = op.getRegion().front().front();
813   // Assuming that region has a single non-terminator inner-op, if the inner-op
814   // meets some criteria (which in this case is a simple one  based on the name
815   // of inner-op), then we can print the entire region in a succinct way.
816   // Here we assume that the prototype of "special.op" can be trivially derived
817   // while parsing it back.
818   if (innerOp.getName().getStringRef().equals("special.op")) {
819     p << " start special.op end";
820   } else {
821     p << " (";
822     p.printRegion(op.getRegion());
823     p << ")";
824   }
825 
826   p << " : ";
827   p.printFunctionalType(op);
828 }
829 
830 //===----------------------------------------------------------------------===//
831 // Test PolyForOp - parse list of region arguments.
832 //===----------------------------------------------------------------------===//
833 
834 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
835   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
836   // Parse list of region arguments without a delimiter.
837   if (parser.parseRegionArgumentList(ivsInfo))
838     return failure();
839 
840   // Parse the body region.
841   Region *body = result.addRegion();
842   auto &builder = parser.getBuilder();
843   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
844   return parser.parseRegion(*body, ivsInfo, argTypes);
845 }
846 
847 void PolyForOp::getAsmBlockArgumentNames(Region &region,
848                                          OpAsmSetValueNameFn setNameFn) {
849   auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names");
850   if (!arrayAttr)
851     return;
852   auto args = getRegion().front().getArguments();
853   auto e = std::min(arrayAttr.size(), args.size());
854   for (unsigned i = 0; i < e; ++i) {
855     if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
856       setNameFn(args[i], strAttr.getValue());
857   }
858 }
859 
860 //===----------------------------------------------------------------------===//
861 // Test removing op with inner ops.
862 //===----------------------------------------------------------------------===//
863 
864 namespace {
865 struct TestRemoveOpWithInnerOps
866     : public OpRewritePattern<TestOpWithRegionPattern> {
867   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
868 
869   void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
870 
871   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
872                                 PatternRewriter &rewriter) const override {
873     rewriter.eraseOp(op);
874     return success();
875   }
876 };
877 } // namespace
878 
879 void TestOpWithRegionPattern::getCanonicalizationPatterns(
880     RewritePatternSet &results, MLIRContext *context) {
881   results.add<TestRemoveOpWithInnerOps>(context);
882 }
883 
884 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
885   return getOperand();
886 }
887 
888 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
889   return getValue();
890 }
891 
892 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
893     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
894   for (Value input : this->getOperands()) {
895     results.push_back(input);
896   }
897   return success();
898 }
899 
900 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
901   assert(operands.size() == 1);
902   if (operands.front()) {
903     (*this)->setAttr("attr", operands.front());
904     return getResult();
905   }
906   return {};
907 }
908 
909 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
910   return getOperand();
911 }
912 
913 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
914     MLIRContext *, Optional<Location> location, ValueRange operands,
915     DictionaryAttr attributes, RegionRange regions,
916     SmallVectorImpl<Type> &inferredReturnTypes) {
917   if (operands[0].getType() != operands[1].getType()) {
918     return emitOptionalError(location, "operand type mismatch ",
919                              operands[0].getType(), " vs ",
920                              operands[1].getType());
921   }
922   inferredReturnTypes.assign({operands[0].getType()});
923   return success();
924 }
925 
926 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
927     MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
928     DictionaryAttr attributes, RegionRange regions,
929     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
930   // Create return type consisting of the last element of the first operand.
931   auto operandType = operands.front().getType();
932   auto sval = operandType.dyn_cast<ShapedType>();
933   if (!sval) {
934     return emitOptionalError(location, "only shaped type operands allowed");
935   }
936   int64_t dim =
937       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
938   auto type = IntegerType::get(context, 17);
939   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
940   return success();
941 }
942 
943 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
944     OpBuilder &builder, ValueRange operands,
945     llvm::SmallVectorImpl<Value> &shapes) {
946   shapes = SmallVector<Value, 1>{
947       builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
948   return success();
949 }
950 
951 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
952     OpBuilder &builder, ValueRange operands,
953     llvm::SmallVectorImpl<Value> &shapes) {
954   Location loc = getLoc();
955   shapes.reserve(operands.size());
956   for (Value operand : llvm::reverse(operands)) {
957     auto rank = operand.getType().cast<RankedTensorType>().getRank();
958     auto currShape = llvm::to_vector<4>(
959         llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value {
960           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
961         }));
962     shapes.push_back(builder.create<tensor::FromElementsOp>(
963         getLoc(), RankedTensorType::get({rank}, builder.getIndexType()),
964         currShape));
965   }
966   return success();
967 }
968 
969 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
970     OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
971   Location loc = getLoc();
972   shapes.reserve(getNumOperands());
973   for (Value operand : llvm::reverse(getOperands())) {
974     auto currShape = llvm::to_vector<4>(llvm::map_range(
975         llvm::seq<int64_t>(
976             0, operand.getType().cast<RankedTensorType>().getRank()),
977         [&](int64_t dim) -> Value {
978           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
979         }));
980     shapes.emplace_back(std::move(currShape));
981   }
982   return success();
983 }
984 
985 //===----------------------------------------------------------------------===//
986 // Test SideEffect interfaces
987 //===----------------------------------------------------------------------===//
988 
989 namespace {
990 /// A test resource for side effects.
991 struct TestResource : public SideEffects::Resource::Base<TestResource> {
992   StringRef getName() final { return "<Test>"; }
993 };
994 } // namespace
995 
996 static void testSideEffectOpGetEffect(
997     Operation *op,
998     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
999         &effects) {
1000   auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
1001   if (!effectsAttr)
1002     return;
1003 
1004   effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
1005 }
1006 
1007 void SideEffectOp::getEffects(
1008     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1009   // Check for an effects attribute on the op instance.
1010   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
1011   if (!effectsAttr)
1012     return;
1013 
1014   // If there is one, it is an array of dictionary attributes that hold
1015   // information on the effects of this operation.
1016   for (Attribute element : effectsAttr) {
1017     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
1018 
1019     // Get the specific memory effect.
1020     MemoryEffects::Effect *effect =
1021         StringSwitch<MemoryEffects::Effect *>(
1022             effectElement.get("effect").cast<StringAttr>().getValue())
1023             .Case("allocate", MemoryEffects::Allocate::get())
1024             .Case("free", MemoryEffects::Free::get())
1025             .Case("read", MemoryEffects::Read::get())
1026             .Case("write", MemoryEffects::Write::get());
1027 
1028     // Check for a non-default resource to use.
1029     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
1030     if (effectElement.get("test_resource"))
1031       resource = TestResource::get();
1032 
1033     // Check for a result to affect.
1034     if (effectElement.get("on_result"))
1035       effects.emplace_back(effect, getResult(), resource);
1036     else if (Attribute ref = effectElement.get("on_reference"))
1037       effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
1038     else
1039       effects.emplace_back(effect, resource);
1040   }
1041 }
1042 
1043 void SideEffectOp::getEffects(
1044     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
1045   testSideEffectOpGetEffect(getOperation(), effects);
1046 }
1047 
1048 //===----------------------------------------------------------------------===//
1049 // StringAttrPrettyNameOp
1050 //===----------------------------------------------------------------------===//
1051 
1052 // This op has fancy handling of its SSA result name.
1053 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
1054                                                OperationState &result) {
1055   // Add the result types.
1056   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
1057     result.addTypes(parser.getBuilder().getIntegerType(32));
1058 
1059   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1060     return failure();
1061 
1062   // If the attribute dictionary contains no 'names' attribute, infer it from
1063   // the SSA name (if specified).
1064   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
1065     return attr.getName() == "names";
1066   });
1067 
1068   // If there was no name specified, check to see if there was a useful name
1069   // specified in the asm file.
1070   if (hadNames || parser.getNumResults() == 0)
1071     return success();
1072 
1073   SmallVector<StringRef, 4> names;
1074   auto *context = result.getContext();
1075 
1076   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
1077     auto resultName = parser.getResultName(i);
1078     StringRef nameStr;
1079     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
1080       nameStr = resultName.first;
1081 
1082     names.push_back(nameStr);
1083   }
1084 
1085   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
1086   result.attributes.push_back({StringAttr::get(context, "names"), namesAttr});
1087   return success();
1088 }
1089 
1090 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
1091   // Note that we only need to print the "name" attribute if the asmprinter
1092   // result name disagrees with it.  This can happen in strange cases, e.g.
1093   // when there are conflicts.
1094   bool namesDisagree = op.getNames().size() != op.getNumResults();
1095 
1096   SmallString<32> resultNameStr;
1097   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
1098     resultNameStr.clear();
1099     llvm::raw_svector_ostream tmpStream(resultNameStr);
1100     p.printOperand(op.getResult(i), tmpStream);
1101 
1102     auto expectedName = op.getNames()[i].dyn_cast<StringAttr>();
1103     if (!expectedName ||
1104         tmpStream.str().drop_front() != expectedName.getValue()) {
1105       namesDisagree = true;
1106     }
1107   }
1108 
1109   if (namesDisagree)
1110     p.printOptionalAttrDictWithKeyword(op->getAttrs());
1111   else
1112     p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"});
1113 }
1114 
1115 // We set the SSA name in the asm syntax to the contents of the name
1116 // attribute.
1117 void StringAttrPrettyNameOp::getAsmResultNames(
1118     function_ref<void(Value, StringRef)> setNameFn) {
1119 
1120   auto value = getNames();
1121   for (size_t i = 0, e = value.size(); i != e; ++i)
1122     if (auto str = value[i].dyn_cast<StringAttr>())
1123       if (!str.getValue().empty())
1124         setNameFn(getResult(i), str.getValue());
1125 }
1126 
1127 //===----------------------------------------------------------------------===//
1128 // RegionIfOp
1129 //===----------------------------------------------------------------------===//
1130 
1131 static void print(OpAsmPrinter &p, RegionIfOp op) {
1132   p << " ";
1133   p.printOperands(op.getOperands());
1134   p << ": " << op.getOperandTypes();
1135   p.printArrowTypeList(op.getResultTypes());
1136   p << " then";
1137   p.printRegion(op.getThenRegion(),
1138                 /*printEntryBlockArgs=*/true,
1139                 /*printBlockTerminators=*/true);
1140   p << " else";
1141   p.printRegion(op.getElseRegion(),
1142                 /*printEntryBlockArgs=*/true,
1143                 /*printBlockTerminators=*/true);
1144   p << " join";
1145   p.printRegion(op.getJoinRegion(),
1146                 /*printEntryBlockArgs=*/true,
1147                 /*printBlockTerminators=*/true);
1148 }
1149 
1150 static ParseResult parseRegionIfOp(OpAsmParser &parser,
1151                                    OperationState &result) {
1152   SmallVector<OpAsmParser::OperandType, 2> operandInfos;
1153   SmallVector<Type, 2> operandTypes;
1154 
1155   result.regions.reserve(3);
1156   Region *thenRegion = result.addRegion();
1157   Region *elseRegion = result.addRegion();
1158   Region *joinRegion = result.addRegion();
1159 
1160   // Parse operand, type and arrow type lists.
1161   if (parser.parseOperandList(operandInfos) ||
1162       parser.parseColonTypeList(operandTypes) ||
1163       parser.parseArrowTypeList(result.types))
1164     return failure();
1165 
1166   // Parse all attached regions.
1167   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
1168       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
1169       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
1170     return failure();
1171 
1172   return parser.resolveOperands(operandInfos, operandTypes,
1173                                 parser.getCurrentLocation(), result.operands);
1174 }
1175 
1176 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
1177   assert(index < 2 && "invalid region index");
1178   return getOperands();
1179 }
1180 
1181 void RegionIfOp::getSuccessorRegions(
1182     Optional<unsigned> index, ArrayRef<Attribute> operands,
1183     SmallVectorImpl<RegionSuccessor> &regions) {
1184   // We always branch to the join region.
1185   if (index.hasValue()) {
1186     if (index.getValue() < 2)
1187       regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
1188     else
1189       regions.push_back(RegionSuccessor(getResults()));
1190     return;
1191   }
1192 
1193   // The then and else regions are the entry regions of this op.
1194   regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs()));
1195   regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs()));
1196 }
1197 
1198 //===----------------------------------------------------------------------===//
1199 // SingleNoTerminatorCustomAsmOp
1200 //===----------------------------------------------------------------------===//
1201 
1202 static ParseResult parseSingleNoTerminatorCustomAsmOp(OpAsmParser &parser,
1203                                                       OperationState &state) {
1204   Region *body = state.addRegion();
1205   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
1206     return failure();
1207   return success();
1208 }
1209 
1210 static void print(SingleNoTerminatorCustomAsmOp op, OpAsmPrinter &printer) {
1211   printer.printRegion(
1212       op.getRegion(), /*printEntryBlockArgs=*/false,
1213       // This op has a single block without terminators. But explicitly mark
1214       // as not printing block terminators for testing.
1215       /*printBlockTerminators=*/false);
1216 }
1217 
1218 #include "TestOpEnums.cpp.inc"
1219 #include "TestOpInterfaces.cpp.inc"
1220 #include "TestOpStructs.cpp.inc"
1221 #include "TestTypeInterfaces.cpp.inc"
1222 
1223 #define GET_OP_CLASSES
1224 #include "TestOps.cpp.inc"
1225