1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestAttributes.h"
11 #include "TestInterfaces.h"
12 #include "TestTypes.h"
13 #include "mlir/Dialect/DLTI/DLTI.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/DialectImplementation.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Reducer/ReductionPatternInterface.h"
21 #include "mlir/Transforms/FoldUtils.h"
22 #include "mlir/Transforms/InliningUtils.h"
23 #include "llvm/ADT/StringSwitch.h"
24 
25 // Include this before the using namespace lines below to
26 // test that we don't have namespace dependencies.
27 #include "TestOpsDialect.cpp.inc"
28 
29 using namespace mlir;
30 using namespace test;
31 
32 void test::registerTestDialect(DialectRegistry &registry) {
33   registry.insert<TestDialect>();
34 }
35 
36 //===----------------------------------------------------------------------===//
37 // TestDialect Interfaces
38 //===----------------------------------------------------------------------===//
39 
40 namespace {
41 
42 /// Testing the correctness of some traits.
43 static_assert(
44     llvm::is_detected<OpTrait::has_implicit_terminator_t,
45                       SingleBlockImplicitTerminatorOp>::value,
46     "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
47 static_assert(OpTrait::hasSingleBlockImplicitTerminator<
48                   SingleBlockImplicitTerminatorOp>::value,
49               "hasSingleBlockImplicitTerminator does not match "
50               "SingleBlockImplicitTerminatorOp");
51 
52 // Test support for interacting with the AsmPrinter.
53 struct TestOpAsmInterface : public OpAsmDialectInterface {
54   using OpAsmDialectInterface::OpAsmDialectInterface;
55 
56   AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
57     StringAttr strAttr = attr.dyn_cast<StringAttr>();
58     if (!strAttr)
59       return AliasResult::NoAlias;
60 
61     // Check the contents of the string attribute to see what the test alias
62     // should be named.
63     Optional<StringRef> aliasName =
64         StringSwitch<Optional<StringRef>>(strAttr.getValue())
65             .Case("alias_test:dot_in_name", StringRef("test.alias"))
66             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
67             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
68             .Case("alias_test:sanitize_conflict_a",
69                   StringRef("test_alias_conflict0"))
70             .Case("alias_test:sanitize_conflict_b",
71                   StringRef("test_alias_conflict0_"))
72             .Case("alias_test:tensor_encoding", StringRef("test_encoding"))
73             .Default(llvm::None);
74     if (!aliasName)
75       return AliasResult::NoAlias;
76 
77     os << *aliasName;
78     return AliasResult::FinalAlias;
79   }
80 
81   AliasResult getAlias(Type type, raw_ostream &os) const final {
82     if (auto tupleType = type.dyn_cast<TupleType>()) {
83       if (tupleType.size() > 0 &&
84           llvm::all_of(tupleType.getTypes(), [](Type elemType) {
85             return elemType.isa<SimpleAType>();
86           })) {
87         os << "test_tuple";
88         return AliasResult::FinalAlias;
89       }
90     }
91     return AliasResult::NoAlias;
92   }
93 
94   void getAsmResultNames(Operation *op,
95                          OpAsmSetValueNameFn setNameFn) const final {
96     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
97       setNameFn(asmOp, "result");
98   }
99 
100   void getAsmBlockArgumentNames(Block *block,
101                                 OpAsmSetValueNameFn setNameFn) const final {
102     auto op = block->getParentOp();
103     auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
104     if (!arrayAttr)
105       return;
106     auto args = block->getArguments();
107     auto e = std::min(arrayAttr.size(), args.size());
108     for (unsigned i = 0; i < e; ++i) {
109       if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
110         setNameFn(args[i], strAttr.getValue());
111     }
112   }
113 };
114 
115 struct TestDialectFoldInterface : public DialectFoldInterface {
116   using DialectFoldInterface::DialectFoldInterface;
117 
118   /// Registered hook to check if the given region, which is attached to an
119   /// operation that is *not* isolated from above, should be used when
120   /// materializing constants.
121   bool shouldMaterializeInto(Region *region) const final {
122     // If this is a one region operation, then insert into it.
123     return isa<OneRegionOp>(region->getParentOp());
124   }
125 };
126 
127 /// This class defines the interface for handling inlining with standard
128 /// operations.
129 struct TestInlinerInterface : public DialectInlinerInterface {
130   using DialectInlinerInterface::DialectInlinerInterface;
131 
132   //===--------------------------------------------------------------------===//
133   // Analysis Hooks
134   //===--------------------------------------------------------------------===//
135 
136   bool isLegalToInline(Operation *call, Operation *callable,
137                        bool wouldBeCloned) const final {
138     // Don't allow inlining calls that are marked `noinline`.
139     return !call->hasAttr("noinline");
140   }
141   bool isLegalToInline(Region *, Region *, bool,
142                        BlockAndValueMapping &) const final {
143     // Inlining into test dialect regions is legal.
144     return true;
145   }
146   bool isLegalToInline(Operation *, Region *, bool,
147                        BlockAndValueMapping &) const final {
148     return true;
149   }
150 
151   bool shouldAnalyzeRecursively(Operation *op) const final {
152     // Analyze recursively if this is not a functional region operation, it
153     // froms a separate functional scope.
154     return !isa<FunctionalRegionOp>(op);
155   }
156 
157   //===--------------------------------------------------------------------===//
158   // Transformation Hooks
159   //===--------------------------------------------------------------------===//
160 
161   /// Handle the given inlined terminator by replacing it with a new operation
162   /// as necessary.
163   void handleTerminator(Operation *op,
164                         ArrayRef<Value> valuesToRepl) const final {
165     // Only handle "test.return" here.
166     auto returnOp = dyn_cast<TestReturnOp>(op);
167     if (!returnOp)
168       return;
169 
170     // Replace the values directly with the return operands.
171     assert(returnOp.getNumOperands() == valuesToRepl.size());
172     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
173       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
174   }
175 
176   /// Attempt to materialize a conversion for a type mismatch between a call
177   /// from this dialect, and a callable region. This method should generate an
178   /// operation that takes 'input' as the only operand, and produces a single
179   /// result of 'resultType'. If a conversion can not be generated, nullptr
180   /// should be returned.
181   Operation *materializeCallConversion(OpBuilder &builder, Value input,
182                                        Type resultType,
183                                        Location conversionLoc) const final {
184     // Only allow conversion for i16/i32 types.
185     if (!(resultType.isSignlessInteger(16) ||
186           resultType.isSignlessInteger(32)) ||
187         !(input.getType().isSignlessInteger(16) ||
188           input.getType().isSignlessInteger(32)))
189       return nullptr;
190     return builder.create<TestCastOp>(conversionLoc, resultType, input);
191   }
192 
193   void processInlinedCallBlocks(
194       Operation *call,
195       iterator_range<Region::iterator> inlinedBlocks) const final {
196     if (!isa<ConversionCallOp>(call))
197       return;
198 
199     // Set attributed on all ops in the inlined blocks.
200     for (Block &block : inlinedBlocks) {
201       block.walk([&](Operation *op) {
202         op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
203       });
204     }
205   }
206 };
207 
208 struct TestReductionPatternInterface : public DialectReductionPatternInterface {
209 public:
210   TestReductionPatternInterface(Dialect *dialect)
211       : DialectReductionPatternInterface(dialect) {}
212 
213   virtual void
214   populateReductionPatterns(RewritePatternSet &patterns) const final {
215     populateTestReductionPatterns(patterns);
216   }
217 };
218 
219 } // end anonymous namespace
220 
221 //===----------------------------------------------------------------------===//
222 // TestDialect
223 //===----------------------------------------------------------------------===//
224 
225 static void testSideEffectOpGetEffect(
226     Operation *op,
227     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects);
228 
229 // This is the implementation of a dialect fallback for `TestEffectOpInterface`.
230 struct TestOpEffectInterfaceFallback
231     : public TestEffectOpInterface::FallbackModel<
232           TestOpEffectInterfaceFallback> {
233   static bool classof(Operation *op) {
234     bool isSupportedOp =
235         op->getName().getStringRef() == "test.unregistered_side_effect_op";
236     assert(isSupportedOp && "Unexpected dispatch");
237     return isSupportedOp;
238   }
239 
240   void
241   getEffects(Operation *op,
242              SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
243                  &effects) const {
244     testSideEffectOpGetEffect(op, effects);
245   }
246 };
247 
248 void TestDialect::initialize() {
249   registerAttributes();
250   registerTypes();
251   addOperations<
252 #define GET_OP_LIST
253 #include "TestOps.cpp.inc"
254       >();
255   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
256                 TestInlinerInterface, TestReductionPatternInterface>();
257   allowUnknownOperations();
258 
259   // Instantiate our fallback op interface that we'll use on specific
260   // unregistered op.
261   fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
262 }
263 TestDialect::~TestDialect() {
264   delete static_cast<TestOpEffectInterfaceFallback *>(
265       fallbackEffectOpInterfaces);
266 }
267 
268 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
269                                             Type type, Location loc) {
270   return builder.create<TestOpConstant>(loc, type, value);
271 }
272 
273 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
274                                                OperationName opName) {
275   if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
276       typeID == TypeID::get<TestEffectOpInterface>())
277     return fallbackEffectOpInterfaces;
278   return nullptr;
279 }
280 
281 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
282                                                     NamedAttribute namedAttr) {
283   if (namedAttr.first == "test.invalid_attr")
284     return op->emitError() << "invalid to use 'test.invalid_attr'";
285   return success();
286 }
287 
288 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
289                                                     unsigned regionIndex,
290                                                     unsigned argIndex,
291                                                     NamedAttribute namedAttr) {
292   if (namedAttr.first == "test.invalid_attr")
293     return op->emitError() << "invalid to use 'test.invalid_attr'";
294   return success();
295 }
296 
297 LogicalResult
298 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
299                                          unsigned resultIndex,
300                                          NamedAttribute namedAttr) {
301   if (namedAttr.first == "test.invalid_attr")
302     return op->emitError() << "invalid to use 'test.invalid_attr'";
303   return success();
304 }
305 
306 Optional<Dialect::ParseOpHook>
307 TestDialect::getParseOperationHook(StringRef opName) const {
308   if (opName == "test.dialect_custom_printer") {
309     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
310       return parser.parseKeyword("custom_format");
311     }};
312   }
313   return None;
314 }
315 
316 llvm::unique_function<void(Operation *, OpAsmPrinter &)>
317 TestDialect::getOperationPrinter(Operation *op) const {
318   StringRef opName = op->getName().getStringRef();
319   if (opName == "test.dialect_custom_printer") {
320     return [](Operation *op, OpAsmPrinter &printer) {
321       printer.getStream() << " custom_format";
322     };
323   }
324   return {};
325 }
326 
327 //===----------------------------------------------------------------------===//
328 // TestBranchOp
329 //===----------------------------------------------------------------------===//
330 
331 Optional<MutableOperandRange>
332 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
333   assert(index == 0 && "invalid successor index");
334   return targetOperandsMutable();
335 }
336 
337 //===----------------------------------------------------------------------===//
338 // TestDialectCanonicalizerOp
339 //===----------------------------------------------------------------------===//
340 
341 static LogicalResult
342 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
343                                PatternRewriter &rewriter) {
344   rewriter.replaceOpWithNewOp<ConstantOp>(op, rewriter.getI32Type(),
345                                           rewriter.getI32IntegerAttr(42));
346   return success();
347 }
348 
349 void TestDialect::getCanonicalizationPatterns(
350     RewritePatternSet &results) const {
351   results.add(&dialectCanonicalizationPattern);
352 }
353 
354 //===----------------------------------------------------------------------===//
355 // TestFoldToCallOp
356 //===----------------------------------------------------------------------===//
357 
358 namespace {
359 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
360   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
361 
362   LogicalResult matchAndRewrite(FoldToCallOp op,
363                                 PatternRewriter &rewriter) const override {
364     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(),
365                                         ValueRange());
366     return success();
367   }
368 };
369 } // end anonymous namespace
370 
371 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
372                                                MLIRContext *context) {
373   results.add<FoldToCallOpPattern>(context);
374 }
375 
376 //===----------------------------------------------------------------------===//
377 // Test Format* operations
378 //===----------------------------------------------------------------------===//
379 
380 //===----------------------------------------------------------------------===//
381 // Parsing
382 
383 static ParseResult parseCustomDirectiveOperands(
384     OpAsmParser &parser, OpAsmParser::OperandType &operand,
385     Optional<OpAsmParser::OperandType> &optOperand,
386     SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
387   if (parser.parseOperand(operand))
388     return failure();
389   if (succeeded(parser.parseOptionalComma())) {
390     optOperand.emplace();
391     if (parser.parseOperand(*optOperand))
392       return failure();
393   }
394   if (parser.parseArrow() || parser.parseLParen() ||
395       parser.parseOperandList(varOperands) || parser.parseRParen())
396     return failure();
397   return success();
398 }
399 static ParseResult
400 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
401                             Type &optOperandType,
402                             SmallVectorImpl<Type> &varOperandTypes) {
403   if (parser.parseColon())
404     return failure();
405 
406   if (parser.parseType(operandType))
407     return failure();
408   if (succeeded(parser.parseOptionalComma())) {
409     if (parser.parseType(optOperandType))
410       return failure();
411   }
412   if (parser.parseArrow() || parser.parseLParen() ||
413       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
414     return failure();
415   return success();
416 }
417 static ParseResult
418 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
419                                  Type optOperandType,
420                                  const SmallVectorImpl<Type> &varOperandTypes) {
421   if (parser.parseKeyword("type_refs_capture"))
422     return failure();
423 
424   Type operandType2, optOperandType2;
425   SmallVector<Type, 1> varOperandTypes2;
426   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
427                                   varOperandTypes2))
428     return failure();
429 
430   if (operandType != operandType2 || optOperandType != optOperandType2 ||
431       varOperandTypes != varOperandTypes2)
432     return failure();
433 
434   return success();
435 }
436 static ParseResult parseCustomDirectiveOperandsAndTypes(
437     OpAsmParser &parser, OpAsmParser::OperandType &operand,
438     Optional<OpAsmParser::OperandType> &optOperand,
439     SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
440     Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
441   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
442       parseCustomDirectiveResults(parser, operandType, optOperandType,
443                                   varOperandTypes))
444     return failure();
445   return success();
446 }
447 static ParseResult parseCustomDirectiveRegions(
448     OpAsmParser &parser, Region &region,
449     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
450   if (parser.parseRegion(region))
451     return failure();
452   if (failed(parser.parseOptionalComma()))
453     return success();
454   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
455   if (parser.parseRegion(*varRegion))
456     return failure();
457   varRegions.emplace_back(std::move(varRegion));
458   return success();
459 }
460 static ParseResult
461 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
462                                SmallVectorImpl<Block *> &varSuccessors) {
463   if (parser.parseSuccessor(successor))
464     return failure();
465   if (failed(parser.parseOptionalComma()))
466     return success();
467   Block *varSuccessor;
468   if (parser.parseSuccessor(varSuccessor))
469     return failure();
470   varSuccessors.append(2, varSuccessor);
471   return success();
472 }
473 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
474                                                   IntegerAttr &attr,
475                                                   IntegerAttr &optAttr) {
476   if (parser.parseAttribute(attr))
477     return failure();
478   if (succeeded(parser.parseOptionalComma())) {
479     if (parser.parseAttribute(optAttr))
480       return failure();
481   }
482   return success();
483 }
484 
485 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
486                                                 NamedAttrList &attrs) {
487   return parser.parseOptionalAttrDict(attrs);
488 }
489 static ParseResult parseCustomDirectiveOptionalOperandRef(
490     OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) {
491   int64_t operandCount = 0;
492   if (parser.parseInteger(operandCount))
493     return failure();
494   bool expectedOptionalOperand = operandCount == 0;
495   return success(expectedOptionalOperand != optOperand.hasValue());
496 }
497 
498 //===----------------------------------------------------------------------===//
499 // Printing
500 
501 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
502                                          Value operand, Value optOperand,
503                                          OperandRange varOperands) {
504   printer << operand;
505   if (optOperand)
506     printer << ", " << optOperand;
507   printer << " -> (" << varOperands << ")";
508 }
509 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
510                                         Type operandType, Type optOperandType,
511                                         TypeRange varOperandTypes) {
512   printer << " : " << operandType;
513   if (optOperandType)
514     printer << ", " << optOperandType;
515   printer << " -> (" << varOperandTypes << ")";
516 }
517 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
518                                              Operation *op, Type operandType,
519                                              Type optOperandType,
520                                              TypeRange varOperandTypes) {
521   printer << " type_refs_capture ";
522   printCustomDirectiveResults(printer, op, operandType, optOperandType,
523                               varOperandTypes);
524 }
525 static void printCustomDirectiveOperandsAndTypes(
526     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
527     OperandRange varOperands, Type operandType, Type optOperandType,
528     TypeRange varOperandTypes) {
529   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
530   printCustomDirectiveResults(printer, op, operandType, optOperandType,
531                               varOperandTypes);
532 }
533 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
534                                         Region &region,
535                                         MutableArrayRef<Region> varRegions) {
536   printer.printRegion(region);
537   if (!varRegions.empty()) {
538     printer << ", ";
539     for (Region &region : varRegions)
540       printer.printRegion(region);
541   }
542 }
543 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
544                                            Block *successor,
545                                            SuccessorRange varSuccessors) {
546   printer << successor;
547   if (!varSuccessors.empty())
548     printer << ", " << varSuccessors.front();
549 }
550 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
551                                            Attribute attribute,
552                                            Attribute optAttribute) {
553   printer << attribute;
554   if (optAttribute)
555     printer << ", " << optAttribute;
556 }
557 
558 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
559                                          DictionaryAttr attrs) {
560   printer.printOptionalAttrDict(attrs.getValue());
561 }
562 
563 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
564                                                    Operation *op,
565                                                    Value optOperand) {
566   printer << (optOperand ? "1" : "0");
567 }
568 
569 //===----------------------------------------------------------------------===//
570 // Test IsolatedRegionOp - parse passthrough region arguments.
571 //===----------------------------------------------------------------------===//
572 
573 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
574                                          OperationState &result) {
575   OpAsmParser::OperandType argInfo;
576   Type argType = parser.getBuilder().getIndexType();
577 
578   // Parse the input operand.
579   if (parser.parseOperand(argInfo) ||
580       parser.resolveOperand(argInfo, argType, result.operands))
581     return failure();
582 
583   // Parse the body region, and reuse the operand info as the argument info.
584   Region *body = result.addRegion();
585   return parser.parseRegion(*body, argInfo, argType,
586                             /*enableNameShadowing=*/true);
587 }
588 
589 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
590   p << "test.isolated_region ";
591   p.printOperand(op.getOperand());
592   p.shadowRegionArgs(op.region(), op.getOperand());
593   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
594 }
595 
596 //===----------------------------------------------------------------------===//
597 // Test SSACFGRegionOp
598 //===----------------------------------------------------------------------===//
599 
600 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
601   return RegionKind::SSACFG;
602 }
603 
604 //===----------------------------------------------------------------------===//
605 // Test GraphRegionOp
606 //===----------------------------------------------------------------------===//
607 
608 static ParseResult parseGraphRegionOp(OpAsmParser &parser,
609                                       OperationState &result) {
610   // Parse the body region, and reuse the operand info as the argument info.
611   Region *body = result.addRegion();
612   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
613 }
614 
615 static void print(OpAsmPrinter &p, GraphRegionOp op) {
616   p << "test.graph_region ";
617   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
618 }
619 
620 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
621   return RegionKind::Graph;
622 }
623 
624 //===----------------------------------------------------------------------===//
625 // Test AffineScopeOp
626 //===----------------------------------------------------------------------===//
627 
628 static ParseResult parseAffineScopeOp(OpAsmParser &parser,
629                                       OperationState &result) {
630   // Parse the body region, and reuse the operand info as the argument info.
631   Region *body = result.addRegion();
632   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
633 }
634 
635 static void print(OpAsmPrinter &p, AffineScopeOp op) {
636   p << "test.affine_scope ";
637   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
638 }
639 
640 //===----------------------------------------------------------------------===//
641 // Test parser.
642 //===----------------------------------------------------------------------===//
643 
644 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser,
645                                               OperationState &result) {
646   if (parser.parseOptionalColon())
647     return success();
648   uint64_t numResults;
649   if (parser.parseInteger(numResults))
650     return failure();
651 
652   IndexType type = parser.getBuilder().getIndexType();
653   for (unsigned i = 0; i < numResults; ++i)
654     result.addTypes(type);
655   return success();
656 }
657 
658 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) {
659   if (unsigned numResults = op->getNumResults())
660     p << " : " << numResults;
661 }
662 
663 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser,
664                                               OperationState &result) {
665   StringRef keyword;
666   if (parser.parseKeyword(&keyword))
667     return failure();
668   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
669   return success();
670 }
671 
672 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) {
673   p << " " << op.keyword();
674 }
675 
676 //===----------------------------------------------------------------------===//
677 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
678 
679 static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
680                                          OperationState &result) {
681   if (parser.parseKeyword("wraps"))
682     return failure();
683 
684   // Parse the wrapped op in a region
685   Region &body = *result.addRegion();
686   body.push_back(new Block);
687   Block &block = body.back();
688   Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
689   if (!wrapped_op)
690     return failure();
691 
692   // Create a return terminator in the inner region, pass as operand to the
693   // terminator the returned values from the wrapped operation.
694   SmallVector<Value, 8> return_operands(wrapped_op->getResults());
695   OpBuilder builder(parser.getBuilder().getContext());
696   builder.setInsertionPointToEnd(&block);
697   builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
698 
699   // Get the results type for the wrapping op from the terminator operands.
700   Operation &return_op = body.back().back();
701   result.types.append(return_op.operand_type_begin(),
702                       return_op.operand_type_end());
703 
704   // Use the location of the wrapped op for the "test.wrapping_region" op.
705   result.location = wrapped_op->getLoc();
706 
707   return success();
708 }
709 
710 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
711   p << " wraps ";
712   p.printGenericOp(&op.region().front().front());
713 }
714 
715 //===----------------------------------------------------------------------===//
716 // Test PolyForOp - parse list of region arguments.
717 //===----------------------------------------------------------------------===//
718 
719 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
720   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
721   // Parse list of region arguments without a delimiter.
722   if (parser.parseRegionArgumentList(ivsInfo))
723     return failure();
724 
725   // Parse the body region.
726   Region *body = result.addRegion();
727   auto &builder = parser.getBuilder();
728   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
729   return parser.parseRegion(*body, ivsInfo, argTypes);
730 }
731 
732 //===----------------------------------------------------------------------===//
733 // Test removing op with inner ops.
734 //===----------------------------------------------------------------------===//
735 
736 namespace {
737 struct TestRemoveOpWithInnerOps
738     : public OpRewritePattern<TestOpWithRegionPattern> {
739   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
740 
741   void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
742 
743   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
744                                 PatternRewriter &rewriter) const override {
745     rewriter.eraseOp(op);
746     return success();
747   }
748 };
749 } // end anonymous namespace
750 
751 void TestOpWithRegionPattern::getCanonicalizationPatterns(
752     RewritePatternSet &results, MLIRContext *context) {
753   results.add<TestRemoveOpWithInnerOps>(context);
754 }
755 
756 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
757   return operand();
758 }
759 
760 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
761   return getValue();
762 }
763 
764 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
765     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
766   for (Value input : this->operands()) {
767     results.push_back(input);
768   }
769   return success();
770 }
771 
772 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
773   assert(operands.size() == 1);
774   if (operands.front()) {
775     (*this)->setAttr("attr", operands.front());
776     return getResult();
777   }
778   return {};
779 }
780 
781 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
782   return getOperand();
783 }
784 
785 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
786     MLIRContext *, Optional<Location> location, ValueRange operands,
787     DictionaryAttr attributes, RegionRange regions,
788     SmallVectorImpl<Type> &inferredReturnTypes) {
789   if (operands[0].getType() != operands[1].getType()) {
790     return emitOptionalError(location, "operand type mismatch ",
791                              operands[0].getType(), " vs ",
792                              operands[1].getType());
793   }
794   inferredReturnTypes.assign({operands[0].getType()});
795   return success();
796 }
797 
798 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
799     MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
800     DictionaryAttr attributes, RegionRange regions,
801     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
802   // Create return type consisting of the last element of the first operand.
803   auto operandType = operands.front().getType();
804   auto sval = operandType.dyn_cast<ShapedType>();
805   if (!sval) {
806     return emitOptionalError(location, "only shaped type operands allowed");
807   }
808   int64_t dim =
809       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
810   auto type = IntegerType::get(context, 17);
811   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
812   return success();
813 }
814 
815 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
816     OpBuilder &builder, ValueRange operands,
817     llvm::SmallVectorImpl<Value> &shapes) {
818   shapes = SmallVector<Value, 1>{
819       builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
820   return success();
821 }
822 
823 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
824     OpBuilder &builder, ValueRange operands,
825     llvm::SmallVectorImpl<Value> &shapes) {
826   Location loc = getLoc();
827   shapes.reserve(operands.size());
828   for (Value operand : llvm::reverse(operands)) {
829     auto currShape = llvm::to_vector<4>(llvm::map_range(
830         llvm::seq<int64_t>(
831             0, operand.getType().cast<RankedTensorType>().getRank()),
832         [&](int64_t dim) -> Value {
833           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
834         }));
835     shapes.push_back(builder.create<tensor::FromElementsOp>(
836         getLoc(), builder.getIndexType(), currShape));
837   }
838   return success();
839 }
840 
841 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
842     OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
843   Location loc = getLoc();
844   shapes.reserve(getNumOperands());
845   for (Value operand : llvm::reverse(getOperands())) {
846     auto currShape = llvm::to_vector<4>(llvm::map_range(
847         llvm::seq<int64_t>(
848             0, operand.getType().cast<RankedTensorType>().getRank()),
849         [&](int64_t dim) -> Value {
850           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
851         }));
852     shapes.emplace_back(std::move(currShape));
853   }
854   return success();
855 }
856 
857 //===----------------------------------------------------------------------===//
858 // Test SideEffect interfaces
859 //===----------------------------------------------------------------------===//
860 
861 namespace {
862 /// A test resource for side effects.
863 struct TestResource : public SideEffects::Resource::Base<TestResource> {
864   StringRef getName() final { return "<Test>"; }
865 };
866 } // end anonymous namespace
867 
868 static void testSideEffectOpGetEffect(
869     Operation *op,
870     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
871         &effects) {
872   auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
873   if (!effectsAttr)
874     return;
875 
876   effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
877 }
878 
879 void SideEffectOp::getEffects(
880     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
881   // Check for an effects attribute on the op instance.
882   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
883   if (!effectsAttr)
884     return;
885 
886   // If there is one, it is an array of dictionary attributes that hold
887   // information on the effects of this operation.
888   for (Attribute element : effectsAttr) {
889     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
890 
891     // Get the specific memory effect.
892     MemoryEffects::Effect *effect =
893         StringSwitch<MemoryEffects::Effect *>(
894             effectElement.get("effect").cast<StringAttr>().getValue())
895             .Case("allocate", MemoryEffects::Allocate::get())
896             .Case("free", MemoryEffects::Free::get())
897             .Case("read", MemoryEffects::Read::get())
898             .Case("write", MemoryEffects::Write::get());
899 
900     // Check for a non-default resource to use.
901     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
902     if (effectElement.get("test_resource"))
903       resource = TestResource::get();
904 
905     // Check for a result to affect.
906     if (effectElement.get("on_result"))
907       effects.emplace_back(effect, getResult(), resource);
908     else if (Attribute ref = effectElement.get("on_reference"))
909       effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
910     else
911       effects.emplace_back(effect, resource);
912   }
913 }
914 
915 void SideEffectOp::getEffects(
916     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
917   testSideEffectOpGetEffect(getOperation(), effects);
918 }
919 
920 //===----------------------------------------------------------------------===//
921 // StringAttrPrettyNameOp
922 //===----------------------------------------------------------------------===//
923 
924 // This op has fancy handling of its SSA result name.
925 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
926                                                OperationState &result) {
927   // Add the result types.
928   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
929     result.addTypes(parser.getBuilder().getIntegerType(32));
930 
931   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
932     return failure();
933 
934   // If the attribute dictionary contains no 'names' attribute, infer it from
935   // the SSA name (if specified).
936   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
937     return attr.first == "names";
938   });
939 
940   // If there was no name specified, check to see if there was a useful name
941   // specified in the asm file.
942   if (hadNames || parser.getNumResults() == 0)
943     return success();
944 
945   SmallVector<StringRef, 4> names;
946   auto *context = result.getContext();
947 
948   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
949     auto resultName = parser.getResultName(i);
950     StringRef nameStr;
951     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
952       nameStr = resultName.first;
953 
954     names.push_back(nameStr);
955   }
956 
957   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
958   result.attributes.push_back({Identifier::get("names", context), namesAttr});
959   return success();
960 }
961 
962 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
963   // Note that we only need to print the "name" attribute if the asmprinter
964   // result name disagrees with it.  This can happen in strange cases, e.g.
965   // when there are conflicts.
966   bool namesDisagree = op.names().size() != op.getNumResults();
967 
968   SmallString<32> resultNameStr;
969   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
970     resultNameStr.clear();
971     llvm::raw_svector_ostream tmpStream(resultNameStr);
972     p.printOperand(op.getResult(i), tmpStream);
973 
974     auto expectedName = op.names()[i].dyn_cast<StringAttr>();
975     if (!expectedName ||
976         tmpStream.str().drop_front() != expectedName.getValue()) {
977       namesDisagree = true;
978     }
979   }
980 
981   if (namesDisagree)
982     p.printOptionalAttrDictWithKeyword(op->getAttrs());
983   else
984     p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"});
985 }
986 
987 // We set the SSA name in the asm syntax to the contents of the name
988 // attribute.
989 void StringAttrPrettyNameOp::getAsmResultNames(
990     function_ref<void(Value, StringRef)> setNameFn) {
991 
992   auto value = names();
993   for (size_t i = 0, e = value.size(); i != e; ++i)
994     if (auto str = value[i].dyn_cast<StringAttr>())
995       if (!str.getValue().empty())
996         setNameFn(getResult(i), str.getValue());
997 }
998 
999 //===----------------------------------------------------------------------===//
1000 // RegionIfOp
1001 //===----------------------------------------------------------------------===//
1002 
1003 static void print(OpAsmPrinter &p, RegionIfOp op) {
1004   p << " ";
1005   p.printOperands(op.getOperands());
1006   p << ": " << op.getOperandTypes();
1007   p.printArrowTypeList(op.getResultTypes());
1008   p << " then";
1009   p.printRegion(op.thenRegion(),
1010                 /*printEntryBlockArgs=*/true,
1011                 /*printBlockTerminators=*/true);
1012   p << " else";
1013   p.printRegion(op.elseRegion(),
1014                 /*printEntryBlockArgs=*/true,
1015                 /*printBlockTerminators=*/true);
1016   p << " join";
1017   p.printRegion(op.joinRegion(),
1018                 /*printEntryBlockArgs=*/true,
1019                 /*printBlockTerminators=*/true);
1020 }
1021 
1022 static ParseResult parseRegionIfOp(OpAsmParser &parser,
1023                                    OperationState &result) {
1024   SmallVector<OpAsmParser::OperandType, 2> operandInfos;
1025   SmallVector<Type, 2> operandTypes;
1026 
1027   result.regions.reserve(3);
1028   Region *thenRegion = result.addRegion();
1029   Region *elseRegion = result.addRegion();
1030   Region *joinRegion = result.addRegion();
1031 
1032   // Parse operand, type and arrow type lists.
1033   if (parser.parseOperandList(operandInfos) ||
1034       parser.parseColonTypeList(operandTypes) ||
1035       parser.parseArrowTypeList(result.types))
1036     return failure();
1037 
1038   // Parse all attached regions.
1039   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
1040       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
1041       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
1042     return failure();
1043 
1044   return parser.resolveOperands(operandInfos, operandTypes,
1045                                 parser.getCurrentLocation(), result.operands);
1046 }
1047 
1048 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
1049   assert(index < 2 && "invalid region index");
1050   return getOperands();
1051 }
1052 
1053 void RegionIfOp::getSuccessorRegions(
1054     Optional<unsigned> index, ArrayRef<Attribute> operands,
1055     SmallVectorImpl<RegionSuccessor> &regions) {
1056   // We always branch to the join region.
1057   if (index.hasValue()) {
1058     if (index.getValue() < 2)
1059       regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs()));
1060     else
1061       regions.push_back(RegionSuccessor(getResults()));
1062     return;
1063   }
1064 
1065   // The then and else regions are the entry regions of this op.
1066   regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs()));
1067   regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs()));
1068 }
1069 
1070 //===----------------------------------------------------------------------===//
1071 // SingleNoTerminatorCustomAsmOp
1072 //===----------------------------------------------------------------------===//
1073 
1074 static ParseResult parseSingleNoTerminatorCustomAsmOp(OpAsmParser &parser,
1075                                                       OperationState &state) {
1076   Region *body = state.addRegion();
1077   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
1078     return failure();
1079   return success();
1080 }
1081 
1082 static void print(SingleNoTerminatorCustomAsmOp op, OpAsmPrinter &printer) {
1083   printer.printRegion(
1084       op.getRegion(), /*printEntryBlockArgs=*/false,
1085       // This op has a single block without terminators. But explicitly mark
1086       // as not printing block terminators for testing.
1087       /*printBlockTerminators=*/false);
1088 }
1089 
1090 #include "TestOpEnums.cpp.inc"
1091 #include "TestOpInterfaces.cpp.inc"
1092 #include "TestOpStructs.cpp.inc"
1093 #include "TestTypeInterfaces.cpp.inc"
1094 
1095 #define GET_OP_CLASSES
1096 #include "TestOps.cpp.inc"
1097