1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestAttributes.h"
11 #include "TestInterfaces.h"
12 #include "TestTypes.h"
13 #include "mlir/Dialect/DLTI/DLTI.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/DialectImplementation.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Reducer/ReductionPatternInterface.h"
21 #include "mlir/Transforms/FoldUtils.h"
22 #include "mlir/Transforms/InliningUtils.h"
23 #include "llvm/ADT/StringSwitch.h"
24 
25 // Include this before the using namespace lines below to
26 // test that we don't have namespace dependencies.
27 #include "TestOpsDialect.cpp.inc"
28 
29 using namespace mlir;
30 using namespace test;
31 
32 void test::registerTestDialect(DialectRegistry &registry) {
33   registry.insert<TestDialect>();
34 }
35 
36 //===----------------------------------------------------------------------===//
37 // TestDialect Interfaces
38 //===----------------------------------------------------------------------===//
39 
40 namespace {
41 
42 /// Testing the correctness of some traits.
43 static_assert(
44     llvm::is_detected<OpTrait::has_implicit_terminator_t,
45                       SingleBlockImplicitTerminatorOp>::value,
46     "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
47 static_assert(OpTrait::hasSingleBlockImplicitTerminator<
48                   SingleBlockImplicitTerminatorOp>::value,
49               "hasSingleBlockImplicitTerminator does not match "
50               "SingleBlockImplicitTerminatorOp");
51 
52 // Test support for interacting with the AsmPrinter.
53 struct TestOpAsmInterface : public OpAsmDialectInterface {
54   using OpAsmDialectInterface::OpAsmDialectInterface;
55 
56   AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
57     StringAttr strAttr = attr.dyn_cast<StringAttr>();
58     if (!strAttr)
59       return AliasResult::NoAlias;
60 
61     // Check the contents of the string attribute to see what the test alias
62     // should be named.
63     Optional<StringRef> aliasName =
64         StringSwitch<Optional<StringRef>>(strAttr.getValue())
65             .Case("alias_test:dot_in_name", StringRef("test.alias"))
66             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
67             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
68             .Case("alias_test:sanitize_conflict_a",
69                   StringRef("test_alias_conflict0"))
70             .Case("alias_test:sanitize_conflict_b",
71                   StringRef("test_alias_conflict0_"))
72             .Case("alias_test:tensor_encoding", StringRef("test_encoding"))
73             .Default(llvm::None);
74     if (!aliasName)
75       return AliasResult::NoAlias;
76 
77     os << *aliasName;
78     return AliasResult::FinalAlias;
79   }
80 
81   AliasResult getAlias(Type type, raw_ostream &os) const final {
82     if (auto tupleType = type.dyn_cast<TupleType>()) {
83       if (tupleType.size() > 0 &&
84           llvm::all_of(tupleType.getTypes(), [](Type elemType) {
85             return elemType.isa<SimpleAType>();
86           })) {
87         os << "test_tuple";
88         return AliasResult::FinalAlias;
89       }
90     }
91     return AliasResult::NoAlias;
92   }
93 
94   void getAsmResultNames(Operation *op,
95                          OpAsmSetValueNameFn setNameFn) const final {
96     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
97       setNameFn(asmOp, "result");
98   }
99 
100   void getAsmBlockArgumentNames(Block *block,
101                                 OpAsmSetValueNameFn setNameFn) const final {
102     auto op = block->getParentOp();
103     auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
104     if (!arrayAttr)
105       return;
106     auto args = block->getArguments();
107     auto e = std::min(arrayAttr.size(), args.size());
108     for (unsigned i = 0; i < e; ++i) {
109       if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
110         setNameFn(args[i], strAttr.getValue());
111     }
112   }
113 };
114 
115 struct TestDialectFoldInterface : public DialectFoldInterface {
116   using DialectFoldInterface::DialectFoldInterface;
117 
118   /// Registered hook to check if the given region, which is attached to an
119   /// operation that is *not* isolated from above, should be used when
120   /// materializing constants.
121   bool shouldMaterializeInto(Region *region) const final {
122     // If this is a one region operation, then insert into it.
123     return isa<OneRegionOp>(region->getParentOp());
124   }
125 };
126 
127 /// This class defines the interface for handling inlining with standard
128 /// operations.
129 struct TestInlinerInterface : public DialectInlinerInterface {
130   using DialectInlinerInterface::DialectInlinerInterface;
131 
132   //===--------------------------------------------------------------------===//
133   // Analysis Hooks
134   //===--------------------------------------------------------------------===//
135 
136   bool isLegalToInline(Operation *call, Operation *callable,
137                        bool wouldBeCloned) const final {
138     // Don't allow inlining calls that are marked `noinline`.
139     return !call->hasAttr("noinline");
140   }
141   bool isLegalToInline(Region *, Region *, bool,
142                        BlockAndValueMapping &) const final {
143     // Inlining into test dialect regions is legal.
144     return true;
145   }
146   bool isLegalToInline(Operation *, Region *, bool,
147                        BlockAndValueMapping &) const final {
148     return true;
149   }
150 
151   bool shouldAnalyzeRecursively(Operation *op) const final {
152     // Analyze recursively if this is not a functional region operation, it
153     // froms a separate functional scope.
154     return !isa<FunctionalRegionOp>(op);
155   }
156 
157   //===--------------------------------------------------------------------===//
158   // Transformation Hooks
159   //===--------------------------------------------------------------------===//
160 
161   /// Handle the given inlined terminator by replacing it with a new operation
162   /// as necessary.
163   void handleTerminator(Operation *op,
164                         ArrayRef<Value> valuesToRepl) const final {
165     // Only handle "test.return" here.
166     auto returnOp = dyn_cast<TestReturnOp>(op);
167     if (!returnOp)
168       return;
169 
170     // Replace the values directly with the return operands.
171     assert(returnOp.getNumOperands() == valuesToRepl.size());
172     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
173       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
174   }
175 
176   /// Attempt to materialize a conversion for a type mismatch between a call
177   /// from this dialect, and a callable region. This method should generate an
178   /// operation that takes 'input' as the only operand, and produces a single
179   /// result of 'resultType'. If a conversion can not be generated, nullptr
180   /// should be returned.
181   Operation *materializeCallConversion(OpBuilder &builder, Value input,
182                                        Type resultType,
183                                        Location conversionLoc) const final {
184     // Only allow conversion for i16/i32 types.
185     if (!(resultType.isSignlessInteger(16) ||
186           resultType.isSignlessInteger(32)) ||
187         !(input.getType().isSignlessInteger(16) ||
188           input.getType().isSignlessInteger(32)))
189       return nullptr;
190     return builder.create<TestCastOp>(conversionLoc, resultType, input);
191   }
192 
193   void processInlinedCallBlocks(
194       Operation *call,
195       iterator_range<Region::iterator> inlinedBlocks) const final {
196     if (!isa<ConversionCallOp>(call))
197       return;
198 
199     // Set attributed on all ops in the inlined blocks.
200     for (Block &block : inlinedBlocks) {
201       block.walk([&](Operation *op) {
202         op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
203       });
204     }
205   }
206 };
207 
208 struct TestReductionPatternInterface : public DialectReductionPatternInterface {
209 public:
210   TestReductionPatternInterface(Dialect *dialect)
211       : DialectReductionPatternInterface(dialect) {}
212 
213   void populateReductionPatterns(RewritePatternSet &patterns) const final {
214     populateTestReductionPatterns(patterns);
215   }
216 };
217 
218 } // end anonymous namespace
219 
220 //===----------------------------------------------------------------------===//
221 // TestDialect
222 //===----------------------------------------------------------------------===//
223 
224 static void testSideEffectOpGetEffect(
225     Operation *op,
226     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects);
227 
228 // This is the implementation of a dialect fallback for `TestEffectOpInterface`.
229 struct TestOpEffectInterfaceFallback
230     : public TestEffectOpInterface::FallbackModel<
231           TestOpEffectInterfaceFallback> {
232   static bool classof(Operation *op) {
233     bool isSupportedOp =
234         op->getName().getStringRef() == "test.unregistered_side_effect_op";
235     assert(isSupportedOp && "Unexpected dispatch");
236     return isSupportedOp;
237   }
238 
239   void
240   getEffects(Operation *op,
241              SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
242                  &effects) const {
243     testSideEffectOpGetEffect(op, effects);
244   }
245 };
246 
247 void TestDialect::initialize() {
248   registerAttributes();
249   registerTypes();
250   addOperations<
251 #define GET_OP_LIST
252 #include "TestOps.cpp.inc"
253       >();
254   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
255                 TestInlinerInterface, TestReductionPatternInterface>();
256   allowUnknownOperations();
257 
258   // Instantiate our fallback op interface that we'll use on specific
259   // unregistered op.
260   fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
261 }
262 TestDialect::~TestDialect() {
263   delete static_cast<TestOpEffectInterfaceFallback *>(
264       fallbackEffectOpInterfaces);
265 }
266 
267 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
268                                             Type type, Location loc) {
269   return builder.create<TestOpConstant>(loc, type, value);
270 }
271 
272 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
273                                                OperationName opName) {
274   if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
275       typeID == TypeID::get<TestEffectOpInterface>())
276     return fallbackEffectOpInterfaces;
277   return nullptr;
278 }
279 
280 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
281                                                     NamedAttribute namedAttr) {
282   if (namedAttr.first == "test.invalid_attr")
283     return op->emitError() << "invalid to use 'test.invalid_attr'";
284   return success();
285 }
286 
287 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
288                                                     unsigned regionIndex,
289                                                     unsigned argIndex,
290                                                     NamedAttribute namedAttr) {
291   if (namedAttr.first == "test.invalid_attr")
292     return op->emitError() << "invalid to use 'test.invalid_attr'";
293   return success();
294 }
295 
296 LogicalResult
297 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
298                                          unsigned resultIndex,
299                                          NamedAttribute namedAttr) {
300   if (namedAttr.first == "test.invalid_attr")
301     return op->emitError() << "invalid to use 'test.invalid_attr'";
302   return success();
303 }
304 
305 Optional<Dialect::ParseOpHook>
306 TestDialect::getParseOperationHook(StringRef opName) const {
307   if (opName == "test.dialect_custom_printer") {
308     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
309       return parser.parseKeyword("custom_format");
310     }};
311   }
312   return None;
313 }
314 
315 llvm::unique_function<void(Operation *, OpAsmPrinter &)>
316 TestDialect::getOperationPrinter(Operation *op) const {
317   StringRef opName = op->getName().getStringRef();
318   if (opName == "test.dialect_custom_printer") {
319     return [](Operation *op, OpAsmPrinter &printer) {
320       printer.getStream() << " custom_format";
321     };
322   }
323   return {};
324 }
325 
326 //===----------------------------------------------------------------------===//
327 // TestBranchOp
328 //===----------------------------------------------------------------------===//
329 
330 Optional<MutableOperandRange>
331 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
332   assert(index == 0 && "invalid successor index");
333   return targetOperandsMutable();
334 }
335 
336 //===----------------------------------------------------------------------===//
337 // TestDialectCanonicalizerOp
338 //===----------------------------------------------------------------------===//
339 
340 static LogicalResult
341 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
342                                PatternRewriter &rewriter) {
343   rewriter.replaceOpWithNewOp<ConstantOp>(op, rewriter.getI32Type(),
344                                           rewriter.getI32IntegerAttr(42));
345   return success();
346 }
347 
348 void TestDialect::getCanonicalizationPatterns(
349     RewritePatternSet &results) const {
350   results.add(&dialectCanonicalizationPattern);
351 }
352 
353 //===----------------------------------------------------------------------===//
354 // TestFoldToCallOp
355 //===----------------------------------------------------------------------===//
356 
357 namespace {
358 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
359   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
360 
361   LogicalResult matchAndRewrite(FoldToCallOp op,
362                                 PatternRewriter &rewriter) const override {
363     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(),
364                                         ValueRange());
365     return success();
366   }
367 };
368 } // end anonymous namespace
369 
370 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
371                                                MLIRContext *context) {
372   results.add<FoldToCallOpPattern>(context);
373 }
374 
375 //===----------------------------------------------------------------------===//
376 // Test Format* operations
377 //===----------------------------------------------------------------------===//
378 
379 //===----------------------------------------------------------------------===//
380 // Parsing
381 
382 static ParseResult parseCustomDirectiveOperands(
383     OpAsmParser &parser, OpAsmParser::OperandType &operand,
384     Optional<OpAsmParser::OperandType> &optOperand,
385     SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
386   if (parser.parseOperand(operand))
387     return failure();
388   if (succeeded(parser.parseOptionalComma())) {
389     optOperand.emplace();
390     if (parser.parseOperand(*optOperand))
391       return failure();
392   }
393   if (parser.parseArrow() || parser.parseLParen() ||
394       parser.parseOperandList(varOperands) || parser.parseRParen())
395     return failure();
396   return success();
397 }
398 static ParseResult
399 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
400                             Type &optOperandType,
401                             SmallVectorImpl<Type> &varOperandTypes) {
402   if (parser.parseColon())
403     return failure();
404 
405   if (parser.parseType(operandType))
406     return failure();
407   if (succeeded(parser.parseOptionalComma())) {
408     if (parser.parseType(optOperandType))
409       return failure();
410   }
411   if (parser.parseArrow() || parser.parseLParen() ||
412       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
413     return failure();
414   return success();
415 }
416 static ParseResult
417 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
418                                  Type optOperandType,
419                                  const SmallVectorImpl<Type> &varOperandTypes) {
420   if (parser.parseKeyword("type_refs_capture"))
421     return failure();
422 
423   Type operandType2, optOperandType2;
424   SmallVector<Type, 1> varOperandTypes2;
425   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
426                                   varOperandTypes2))
427     return failure();
428 
429   if (operandType != operandType2 || optOperandType != optOperandType2 ||
430       varOperandTypes != varOperandTypes2)
431     return failure();
432 
433   return success();
434 }
435 static ParseResult parseCustomDirectiveOperandsAndTypes(
436     OpAsmParser &parser, OpAsmParser::OperandType &operand,
437     Optional<OpAsmParser::OperandType> &optOperand,
438     SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
439     Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
440   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
441       parseCustomDirectiveResults(parser, operandType, optOperandType,
442                                   varOperandTypes))
443     return failure();
444   return success();
445 }
446 static ParseResult parseCustomDirectiveRegions(
447     OpAsmParser &parser, Region &region,
448     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
449   if (parser.parseRegion(region))
450     return failure();
451   if (failed(parser.parseOptionalComma()))
452     return success();
453   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
454   if (parser.parseRegion(*varRegion))
455     return failure();
456   varRegions.emplace_back(std::move(varRegion));
457   return success();
458 }
459 static ParseResult
460 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
461                                SmallVectorImpl<Block *> &varSuccessors) {
462   if (parser.parseSuccessor(successor))
463     return failure();
464   if (failed(parser.parseOptionalComma()))
465     return success();
466   Block *varSuccessor;
467   if (parser.parseSuccessor(varSuccessor))
468     return failure();
469   varSuccessors.append(2, varSuccessor);
470   return success();
471 }
472 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
473                                                   IntegerAttr &attr,
474                                                   IntegerAttr &optAttr) {
475   if (parser.parseAttribute(attr))
476     return failure();
477   if (succeeded(parser.parseOptionalComma())) {
478     if (parser.parseAttribute(optAttr))
479       return failure();
480   }
481   return success();
482 }
483 
484 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
485                                                 NamedAttrList &attrs) {
486   return parser.parseOptionalAttrDict(attrs);
487 }
488 static ParseResult parseCustomDirectiveOptionalOperandRef(
489     OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) {
490   int64_t operandCount = 0;
491   if (parser.parseInteger(operandCount))
492     return failure();
493   bool expectedOptionalOperand = operandCount == 0;
494   return success(expectedOptionalOperand != optOperand.hasValue());
495 }
496 
497 //===----------------------------------------------------------------------===//
498 // Printing
499 
500 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
501                                          Value operand, Value optOperand,
502                                          OperandRange varOperands) {
503   printer << operand;
504   if (optOperand)
505     printer << ", " << optOperand;
506   printer << " -> (" << varOperands << ")";
507 }
508 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
509                                         Type operandType, Type optOperandType,
510                                         TypeRange varOperandTypes) {
511   printer << " : " << operandType;
512   if (optOperandType)
513     printer << ", " << optOperandType;
514   printer << " -> (" << varOperandTypes << ")";
515 }
516 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
517                                              Operation *op, Type operandType,
518                                              Type optOperandType,
519                                              TypeRange varOperandTypes) {
520   printer << " type_refs_capture ";
521   printCustomDirectiveResults(printer, op, operandType, optOperandType,
522                               varOperandTypes);
523 }
524 static void printCustomDirectiveOperandsAndTypes(
525     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
526     OperandRange varOperands, Type operandType, Type optOperandType,
527     TypeRange varOperandTypes) {
528   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
529   printCustomDirectiveResults(printer, op, operandType, optOperandType,
530                               varOperandTypes);
531 }
532 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
533                                         Region &region,
534                                         MutableArrayRef<Region> varRegions) {
535   printer.printRegion(region);
536   if (!varRegions.empty()) {
537     printer << ", ";
538     for (Region &region : varRegions)
539       printer.printRegion(region);
540   }
541 }
542 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
543                                            Block *successor,
544                                            SuccessorRange varSuccessors) {
545   printer << successor;
546   if (!varSuccessors.empty())
547     printer << ", " << varSuccessors.front();
548 }
549 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
550                                            Attribute attribute,
551                                            Attribute optAttribute) {
552   printer << attribute;
553   if (optAttribute)
554     printer << ", " << optAttribute;
555 }
556 
557 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
558                                          DictionaryAttr attrs) {
559   printer.printOptionalAttrDict(attrs.getValue());
560 }
561 
562 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
563                                                    Operation *op,
564                                                    Value optOperand) {
565   printer << (optOperand ? "1" : "0");
566 }
567 
568 //===----------------------------------------------------------------------===//
569 // Test IsolatedRegionOp - parse passthrough region arguments.
570 //===----------------------------------------------------------------------===//
571 
572 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
573                                          OperationState &result) {
574   OpAsmParser::OperandType argInfo;
575   Type argType = parser.getBuilder().getIndexType();
576 
577   // Parse the input operand.
578   if (parser.parseOperand(argInfo) ||
579       parser.resolveOperand(argInfo, argType, result.operands))
580     return failure();
581 
582   // Parse the body region, and reuse the operand info as the argument info.
583   Region *body = result.addRegion();
584   return parser.parseRegion(*body, argInfo, argType,
585                             /*enableNameShadowing=*/true);
586 }
587 
588 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
589   p << "test.isolated_region ";
590   p.printOperand(op.getOperand());
591   p.shadowRegionArgs(op.region(), op.getOperand());
592   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
593 }
594 
595 //===----------------------------------------------------------------------===//
596 // Test SSACFGRegionOp
597 //===----------------------------------------------------------------------===//
598 
599 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
600   return RegionKind::SSACFG;
601 }
602 
603 //===----------------------------------------------------------------------===//
604 // Test GraphRegionOp
605 //===----------------------------------------------------------------------===//
606 
607 static ParseResult parseGraphRegionOp(OpAsmParser &parser,
608                                       OperationState &result) {
609   // Parse the body region, and reuse the operand info as the argument info.
610   Region *body = result.addRegion();
611   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
612 }
613 
614 static void print(OpAsmPrinter &p, GraphRegionOp op) {
615   p << "test.graph_region ";
616   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
617 }
618 
619 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
620   return RegionKind::Graph;
621 }
622 
623 //===----------------------------------------------------------------------===//
624 // Test AffineScopeOp
625 //===----------------------------------------------------------------------===//
626 
627 static ParseResult parseAffineScopeOp(OpAsmParser &parser,
628                                       OperationState &result) {
629   // Parse the body region, and reuse the operand info as the argument info.
630   Region *body = result.addRegion();
631   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
632 }
633 
634 static void print(OpAsmPrinter &p, AffineScopeOp op) {
635   p << "test.affine_scope ";
636   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
637 }
638 
639 //===----------------------------------------------------------------------===//
640 // Test parser.
641 //===----------------------------------------------------------------------===//
642 
643 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser,
644                                               OperationState &result) {
645   if (parser.parseOptionalColon())
646     return success();
647   uint64_t numResults;
648   if (parser.parseInteger(numResults))
649     return failure();
650 
651   IndexType type = parser.getBuilder().getIndexType();
652   for (unsigned i = 0; i < numResults; ++i)
653     result.addTypes(type);
654   return success();
655 }
656 
657 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) {
658   if (unsigned numResults = op->getNumResults())
659     p << " : " << numResults;
660 }
661 
662 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser,
663                                               OperationState &result) {
664   StringRef keyword;
665   if (parser.parseKeyword(&keyword))
666     return failure();
667   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
668   return success();
669 }
670 
671 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) {
672   p << " " << op.keyword();
673 }
674 
675 //===----------------------------------------------------------------------===//
676 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
677 
678 static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
679                                          OperationState &result) {
680   if (parser.parseKeyword("wraps"))
681     return failure();
682 
683   // Parse the wrapped op in a region
684   Region &body = *result.addRegion();
685   body.push_back(new Block);
686   Block &block = body.back();
687   Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
688   if (!wrapped_op)
689     return failure();
690 
691   // Create a return terminator in the inner region, pass as operand to the
692   // terminator the returned values from the wrapped operation.
693   SmallVector<Value, 8> return_operands(wrapped_op->getResults());
694   OpBuilder builder(parser.getContext());
695   builder.setInsertionPointToEnd(&block);
696   builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
697 
698   // Get the results type for the wrapping op from the terminator operands.
699   Operation &return_op = body.back().back();
700   result.types.append(return_op.operand_type_begin(),
701                       return_op.operand_type_end());
702 
703   // Use the location of the wrapped op for the "test.wrapping_region" op.
704   result.location = wrapped_op->getLoc();
705 
706   return success();
707 }
708 
709 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
710   p << " wraps ";
711   p.printGenericOp(&op.region().front().front());
712 }
713 
714 //===----------------------------------------------------------------------===//
715 // Test PolyForOp - parse list of region arguments.
716 //===----------------------------------------------------------------------===//
717 
718 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
719   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
720   // Parse list of region arguments without a delimiter.
721   if (parser.parseRegionArgumentList(ivsInfo))
722     return failure();
723 
724   // Parse the body region.
725   Region *body = result.addRegion();
726   auto &builder = parser.getBuilder();
727   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
728   return parser.parseRegion(*body, ivsInfo, argTypes);
729 }
730 
731 //===----------------------------------------------------------------------===//
732 // Test removing op with inner ops.
733 //===----------------------------------------------------------------------===//
734 
735 namespace {
736 struct TestRemoveOpWithInnerOps
737     : public OpRewritePattern<TestOpWithRegionPattern> {
738   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
739 
740   void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
741 
742   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
743                                 PatternRewriter &rewriter) const override {
744     rewriter.eraseOp(op);
745     return success();
746   }
747 };
748 } // end anonymous namespace
749 
750 void TestOpWithRegionPattern::getCanonicalizationPatterns(
751     RewritePatternSet &results, MLIRContext *context) {
752   results.add<TestRemoveOpWithInnerOps>(context);
753 }
754 
755 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
756   return operand();
757 }
758 
759 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
760   return getValue();
761 }
762 
763 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
764     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
765   for (Value input : this->operands()) {
766     results.push_back(input);
767   }
768   return success();
769 }
770 
771 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
772   assert(operands.size() == 1);
773   if (operands.front()) {
774     (*this)->setAttr("attr", operands.front());
775     return getResult();
776   }
777   return {};
778 }
779 
780 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
781   return getOperand();
782 }
783 
784 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
785     MLIRContext *, Optional<Location> location, ValueRange operands,
786     DictionaryAttr attributes, RegionRange regions,
787     SmallVectorImpl<Type> &inferredReturnTypes) {
788   if (operands[0].getType() != operands[1].getType()) {
789     return emitOptionalError(location, "operand type mismatch ",
790                              operands[0].getType(), " vs ",
791                              operands[1].getType());
792   }
793   inferredReturnTypes.assign({operands[0].getType()});
794   return success();
795 }
796 
797 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
798     MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
799     DictionaryAttr attributes, RegionRange regions,
800     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
801   // Create return type consisting of the last element of the first operand.
802   auto operandType = operands.front().getType();
803   auto sval = operandType.dyn_cast<ShapedType>();
804   if (!sval) {
805     return emitOptionalError(location, "only shaped type operands allowed");
806   }
807   int64_t dim =
808       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
809   auto type = IntegerType::get(context, 17);
810   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
811   return success();
812 }
813 
814 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
815     OpBuilder &builder, ValueRange operands,
816     llvm::SmallVectorImpl<Value> &shapes) {
817   shapes = SmallVector<Value, 1>{
818       builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
819   return success();
820 }
821 
822 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
823     OpBuilder &builder, ValueRange operands,
824     llvm::SmallVectorImpl<Value> &shapes) {
825   Location loc = getLoc();
826   shapes.reserve(operands.size());
827   for (Value operand : llvm::reverse(operands)) {
828     auto currShape = llvm::to_vector<4>(llvm::map_range(
829         llvm::seq<int64_t>(
830             0, operand.getType().cast<RankedTensorType>().getRank()),
831         [&](int64_t dim) -> Value {
832           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
833         }));
834     shapes.push_back(builder.create<tensor::FromElementsOp>(
835         getLoc(), builder.getIndexType(), currShape));
836   }
837   return success();
838 }
839 
840 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
841     OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
842   Location loc = getLoc();
843   shapes.reserve(getNumOperands());
844   for (Value operand : llvm::reverse(getOperands())) {
845     auto currShape = llvm::to_vector<4>(llvm::map_range(
846         llvm::seq<int64_t>(
847             0, operand.getType().cast<RankedTensorType>().getRank()),
848         [&](int64_t dim) -> Value {
849           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
850         }));
851     shapes.emplace_back(std::move(currShape));
852   }
853   return success();
854 }
855 
856 //===----------------------------------------------------------------------===//
857 // Test SideEffect interfaces
858 //===----------------------------------------------------------------------===//
859 
860 namespace {
861 /// A test resource for side effects.
862 struct TestResource : public SideEffects::Resource::Base<TestResource> {
863   StringRef getName() final { return "<Test>"; }
864 };
865 } // end anonymous namespace
866 
867 static void testSideEffectOpGetEffect(
868     Operation *op,
869     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
870         &effects) {
871   auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
872   if (!effectsAttr)
873     return;
874 
875   effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
876 }
877 
878 void SideEffectOp::getEffects(
879     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
880   // Check for an effects attribute on the op instance.
881   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
882   if (!effectsAttr)
883     return;
884 
885   // If there is one, it is an array of dictionary attributes that hold
886   // information on the effects of this operation.
887   for (Attribute element : effectsAttr) {
888     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
889 
890     // Get the specific memory effect.
891     MemoryEffects::Effect *effect =
892         StringSwitch<MemoryEffects::Effect *>(
893             effectElement.get("effect").cast<StringAttr>().getValue())
894             .Case("allocate", MemoryEffects::Allocate::get())
895             .Case("free", MemoryEffects::Free::get())
896             .Case("read", MemoryEffects::Read::get())
897             .Case("write", MemoryEffects::Write::get());
898 
899     // Check for a non-default resource to use.
900     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
901     if (effectElement.get("test_resource"))
902       resource = TestResource::get();
903 
904     // Check for a result to affect.
905     if (effectElement.get("on_result"))
906       effects.emplace_back(effect, getResult(), resource);
907     else if (Attribute ref = effectElement.get("on_reference"))
908       effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
909     else
910       effects.emplace_back(effect, resource);
911   }
912 }
913 
914 void SideEffectOp::getEffects(
915     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
916   testSideEffectOpGetEffect(getOperation(), effects);
917 }
918 
919 //===----------------------------------------------------------------------===//
920 // StringAttrPrettyNameOp
921 //===----------------------------------------------------------------------===//
922 
923 // This op has fancy handling of its SSA result name.
924 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
925                                                OperationState &result) {
926   // Add the result types.
927   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
928     result.addTypes(parser.getBuilder().getIntegerType(32));
929 
930   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
931     return failure();
932 
933   // If the attribute dictionary contains no 'names' attribute, infer it from
934   // the SSA name (if specified).
935   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
936     return attr.first == "names";
937   });
938 
939   // If there was no name specified, check to see if there was a useful name
940   // specified in the asm file.
941   if (hadNames || parser.getNumResults() == 0)
942     return success();
943 
944   SmallVector<StringRef, 4> names;
945   auto *context = result.getContext();
946 
947   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
948     auto resultName = parser.getResultName(i);
949     StringRef nameStr;
950     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
951       nameStr = resultName.first;
952 
953     names.push_back(nameStr);
954   }
955 
956   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
957   result.attributes.push_back({Identifier::get("names", context), namesAttr});
958   return success();
959 }
960 
961 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
962   // Note that we only need to print the "name" attribute if the asmprinter
963   // result name disagrees with it.  This can happen in strange cases, e.g.
964   // when there are conflicts.
965   bool namesDisagree = op.names().size() != op.getNumResults();
966 
967   SmallString<32> resultNameStr;
968   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
969     resultNameStr.clear();
970     llvm::raw_svector_ostream tmpStream(resultNameStr);
971     p.printOperand(op.getResult(i), tmpStream);
972 
973     auto expectedName = op.names()[i].dyn_cast<StringAttr>();
974     if (!expectedName ||
975         tmpStream.str().drop_front() != expectedName.getValue()) {
976       namesDisagree = true;
977     }
978   }
979 
980   if (namesDisagree)
981     p.printOptionalAttrDictWithKeyword(op->getAttrs());
982   else
983     p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"});
984 }
985 
986 // We set the SSA name in the asm syntax to the contents of the name
987 // attribute.
988 void StringAttrPrettyNameOp::getAsmResultNames(
989     function_ref<void(Value, StringRef)> setNameFn) {
990 
991   auto value = names();
992   for (size_t i = 0, e = value.size(); i != e; ++i)
993     if (auto str = value[i].dyn_cast<StringAttr>())
994       if (!str.getValue().empty())
995         setNameFn(getResult(i), str.getValue());
996 }
997 
998 //===----------------------------------------------------------------------===//
999 // RegionIfOp
1000 //===----------------------------------------------------------------------===//
1001 
1002 static void print(OpAsmPrinter &p, RegionIfOp op) {
1003   p << " ";
1004   p.printOperands(op.getOperands());
1005   p << ": " << op.getOperandTypes();
1006   p.printArrowTypeList(op.getResultTypes());
1007   p << " then";
1008   p.printRegion(op.thenRegion(),
1009                 /*printEntryBlockArgs=*/true,
1010                 /*printBlockTerminators=*/true);
1011   p << " else";
1012   p.printRegion(op.elseRegion(),
1013                 /*printEntryBlockArgs=*/true,
1014                 /*printBlockTerminators=*/true);
1015   p << " join";
1016   p.printRegion(op.joinRegion(),
1017                 /*printEntryBlockArgs=*/true,
1018                 /*printBlockTerminators=*/true);
1019 }
1020 
1021 static ParseResult parseRegionIfOp(OpAsmParser &parser,
1022                                    OperationState &result) {
1023   SmallVector<OpAsmParser::OperandType, 2> operandInfos;
1024   SmallVector<Type, 2> operandTypes;
1025 
1026   result.regions.reserve(3);
1027   Region *thenRegion = result.addRegion();
1028   Region *elseRegion = result.addRegion();
1029   Region *joinRegion = result.addRegion();
1030 
1031   // Parse operand, type and arrow type lists.
1032   if (parser.parseOperandList(operandInfos) ||
1033       parser.parseColonTypeList(operandTypes) ||
1034       parser.parseArrowTypeList(result.types))
1035     return failure();
1036 
1037   // Parse all attached regions.
1038   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
1039       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
1040       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
1041     return failure();
1042 
1043   return parser.resolveOperands(operandInfos, operandTypes,
1044                                 parser.getCurrentLocation(), result.operands);
1045 }
1046 
1047 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
1048   assert(index < 2 && "invalid region index");
1049   return getOperands();
1050 }
1051 
1052 void RegionIfOp::getSuccessorRegions(
1053     Optional<unsigned> index, ArrayRef<Attribute> operands,
1054     SmallVectorImpl<RegionSuccessor> &regions) {
1055   // We always branch to the join region.
1056   if (index.hasValue()) {
1057     if (index.getValue() < 2)
1058       regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs()));
1059     else
1060       regions.push_back(RegionSuccessor(getResults()));
1061     return;
1062   }
1063 
1064   // The then and else regions are the entry regions of this op.
1065   regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs()));
1066   regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs()));
1067 }
1068 
1069 //===----------------------------------------------------------------------===//
1070 // SingleNoTerminatorCustomAsmOp
1071 //===----------------------------------------------------------------------===//
1072 
1073 static ParseResult parseSingleNoTerminatorCustomAsmOp(OpAsmParser &parser,
1074                                                       OperationState &state) {
1075   Region *body = state.addRegion();
1076   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
1077     return failure();
1078   return success();
1079 }
1080 
1081 static void print(SingleNoTerminatorCustomAsmOp op, OpAsmPrinter &printer) {
1082   printer.printRegion(
1083       op.getRegion(), /*printEntryBlockArgs=*/false,
1084       // This op has a single block without terminators. But explicitly mark
1085       // as not printing block terminators for testing.
1086       /*printBlockTerminators=*/false);
1087 }
1088 
1089 #include "TestOpEnums.cpp.inc"
1090 #include "TestOpInterfaces.cpp.inc"
1091 #include "TestOpStructs.cpp.inc"
1092 #include "TestTypeInterfaces.cpp.inc"
1093 
1094 #define GET_OP_CLASSES
1095 #include "TestOps.cpp.inc"
1096