1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestAttributes.h"
11 #include "TestInterfaces.h"
12 #include "TestTypes.h"
13 #include "mlir/Dialect/DLTI/DLTI.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/DialectImplementation.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/Reducer/ReductionPatternInterface.h"
22 #include "mlir/Transforms/FoldUtils.h"
23 #include "mlir/Transforms/InliningUtils.h"
24 #include "llvm/ADT/StringSwitch.h"
25 
26 using namespace mlir;
27 using namespace mlir::test;
28 
29 void mlir::test::registerTestDialect(DialectRegistry &registry) {
30   registry.insert<TestDialect>();
31 }
32 
33 //===----------------------------------------------------------------------===//
34 // TestDialect Interfaces
35 //===----------------------------------------------------------------------===//
36 
37 namespace {
38 
39 /// Testing the correctness of some traits.
40 static_assert(
41     llvm::is_detected<OpTrait::has_implicit_terminator_t,
42                       SingleBlockImplicitTerminatorOp>::value,
43     "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
44 static_assert(OpTrait::hasSingleBlockImplicitTerminator<
45                   SingleBlockImplicitTerminatorOp>::value,
46               "hasSingleBlockImplicitTerminator does not match "
47               "SingleBlockImplicitTerminatorOp");
48 
49 // Test support for interacting with the AsmPrinter.
50 struct TestOpAsmInterface : public OpAsmDialectInterface {
51   using OpAsmDialectInterface::OpAsmDialectInterface;
52 
53   LogicalResult getAlias(Attribute attr, raw_ostream &os) const final {
54     StringAttr strAttr = attr.dyn_cast<StringAttr>();
55     if (!strAttr)
56       return failure();
57 
58     // Check the contents of the string attribute to see what the test alias
59     // should be named.
60     Optional<StringRef> aliasName =
61         StringSwitch<Optional<StringRef>>(strAttr.getValue())
62             .Case("alias_test:dot_in_name", StringRef("test.alias"))
63             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
64             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
65             .Case("alias_test:sanitize_conflict_a",
66                   StringRef("test_alias_conflict0"))
67             .Case("alias_test:sanitize_conflict_b",
68                   StringRef("test_alias_conflict0_"))
69             .Default(llvm::None);
70     if (!aliasName)
71       return failure();
72 
73     os << *aliasName;
74     return success();
75   }
76 
77   void getAsmResultNames(Operation *op,
78                          OpAsmSetValueNameFn setNameFn) const final {
79     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
80       setNameFn(asmOp, "result");
81   }
82 
83   void getAsmBlockArgumentNames(Block *block,
84                                 OpAsmSetValueNameFn setNameFn) const final {
85     auto op = block->getParentOp();
86     auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
87     if (!arrayAttr)
88       return;
89     auto args = block->getArguments();
90     auto e = std::min(arrayAttr.size(), args.size());
91     for (unsigned i = 0; i < e; ++i) {
92       if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
93         setNameFn(args[i], strAttr.getValue());
94     }
95   }
96 };
97 
98 struct TestDialectFoldInterface : public DialectFoldInterface {
99   using DialectFoldInterface::DialectFoldInterface;
100 
101   /// Registered hook to check if the given region, which is attached to an
102   /// operation that is *not* isolated from above, should be used when
103   /// materializing constants.
104   bool shouldMaterializeInto(Region *region) const final {
105     // If this is a one region operation, then insert into it.
106     return isa<OneRegionOp>(region->getParentOp());
107   }
108 };
109 
110 /// This class defines the interface for handling inlining with standard
111 /// operations.
112 struct TestInlinerInterface : public DialectInlinerInterface {
113   using DialectInlinerInterface::DialectInlinerInterface;
114 
115   //===--------------------------------------------------------------------===//
116   // Analysis Hooks
117   //===--------------------------------------------------------------------===//
118 
119   bool isLegalToInline(Operation *call, Operation *callable,
120                        bool wouldBeCloned) const final {
121     // Don't allow inlining calls that are marked `noinline`.
122     return !call->hasAttr("noinline");
123   }
124   bool isLegalToInline(Region *, Region *, bool,
125                        BlockAndValueMapping &) const final {
126     // Inlining into test dialect regions is legal.
127     return true;
128   }
129   bool isLegalToInline(Operation *, Region *, bool,
130                        BlockAndValueMapping &) const final {
131     return true;
132   }
133 
134   bool shouldAnalyzeRecursively(Operation *op) const final {
135     // Analyze recursively if this is not a functional region operation, it
136     // froms a separate functional scope.
137     return !isa<FunctionalRegionOp>(op);
138   }
139 
140   //===--------------------------------------------------------------------===//
141   // Transformation Hooks
142   //===--------------------------------------------------------------------===//
143 
144   /// Handle the given inlined terminator by replacing it with a new operation
145   /// as necessary.
146   void handleTerminator(Operation *op,
147                         ArrayRef<Value> valuesToRepl) const final {
148     // Only handle "test.return" here.
149     auto returnOp = dyn_cast<TestReturnOp>(op);
150     if (!returnOp)
151       return;
152 
153     // Replace the values directly with the return operands.
154     assert(returnOp.getNumOperands() == valuesToRepl.size());
155     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
156       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
157   }
158 
159   /// Attempt to materialize a conversion for a type mismatch between a call
160   /// from this dialect, and a callable region. This method should generate an
161   /// operation that takes 'input' as the only operand, and produces a single
162   /// result of 'resultType'. If a conversion can not be generated, nullptr
163   /// should be returned.
164   Operation *materializeCallConversion(OpBuilder &builder, Value input,
165                                        Type resultType,
166                                        Location conversionLoc) const final {
167     // Only allow conversion for i16/i32 types.
168     if (!(resultType.isSignlessInteger(16) ||
169           resultType.isSignlessInteger(32)) ||
170         !(input.getType().isSignlessInteger(16) ||
171           input.getType().isSignlessInteger(32)))
172       return nullptr;
173     return builder.create<TestCastOp>(conversionLoc, resultType, input);
174   }
175 
176   void processInlinedCallBlocks(
177       Operation *call,
178       iterator_range<Region::iterator> inlinedBlocks) const final {
179     if (!isa<ConversionCallOp>(call))
180       return;
181 
182     // Set attributed on all ops in the inlined blocks.
183     for (Block &block : inlinedBlocks) {
184       block.walk([&](Operation *op) {
185         op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
186       });
187     }
188   }
189 };
190 
191 struct TestReductionPatternInterface : public DialectReductionPatternInterface {
192 public:
193   TestReductionPatternInterface(Dialect *dialect)
194       : DialectReductionPatternInterface(dialect) {}
195 
196   virtual void
197   populateReductionPatterns(RewritePatternSet &patterns) const final {
198     populateTestReductionPatterns(patterns);
199   }
200 };
201 
202 } // end anonymous namespace
203 
204 //===----------------------------------------------------------------------===//
205 // TestDialect
206 //===----------------------------------------------------------------------===//
207 
208 static void testSideEffectOpGetEffect(
209     Operation *op,
210     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects);
211 
212 // This is the implementation of a dialect fallback for `TestEffectOpInterface`.
213 struct TestOpEffectInterfaceFallback
214     : public TestEffectOpInterface::FallbackModel<
215           TestOpEffectInterfaceFallback> {
216   static bool classof(Operation *op) {
217     bool isSupportedOp =
218         op->getName().getStringRef() == "test.unregistered_side_effect_op";
219     assert(isSupportedOp && "Unexpected dispatch");
220     return isSupportedOp;
221   }
222 
223   void
224   getEffects(Operation *op,
225              SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
226                  &effects) const {
227     testSideEffectOpGetEffect(op, effects);
228   }
229 };
230 
231 void TestDialect::initialize() {
232   registerAttributes();
233   registerTypes();
234   addOperations<
235 #define GET_OP_LIST
236 #include "TestOps.cpp.inc"
237       >();
238   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
239                 TestInlinerInterface, TestReductionPatternInterface>();
240   allowUnknownOperations();
241 
242   // Instantiate our fallback op interface that we'll use on specific
243   // unregistered op.
244   fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
245 }
246 TestDialect::~TestDialect() {
247   delete static_cast<TestOpEffectInterfaceFallback *>(
248       fallbackEffectOpInterfaces);
249 }
250 
251 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
252                                             Type type, Location loc) {
253   return builder.create<TestOpConstant>(loc, type, value);
254 }
255 
256 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
257                                                OperationName opName) {
258   if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
259       typeID == TypeID::get<TestEffectOpInterface>())
260     return fallbackEffectOpInterfaces;
261   return nullptr;
262 }
263 
264 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
265                                                     NamedAttribute namedAttr) {
266   if (namedAttr.first == "test.invalid_attr")
267     return op->emitError() << "invalid to use 'test.invalid_attr'";
268   return success();
269 }
270 
271 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
272                                                     unsigned regionIndex,
273                                                     unsigned argIndex,
274                                                     NamedAttribute namedAttr) {
275   if (namedAttr.first == "test.invalid_attr")
276     return op->emitError() << "invalid to use 'test.invalid_attr'";
277   return success();
278 }
279 
280 LogicalResult
281 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
282                                          unsigned resultIndex,
283                                          NamedAttribute namedAttr) {
284   if (namedAttr.first == "test.invalid_attr")
285     return op->emitError() << "invalid to use 'test.invalid_attr'";
286   return success();
287 }
288 
289 Optional<Dialect::ParseOpHook>
290 TestDialect::getParseOperationHook(StringRef opName) const {
291   if (opName == "test.dialect_custom_printer") {
292     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
293       return parser.parseKeyword("custom_format");
294     }};
295   }
296   return None;
297 }
298 
299 LogicalResult TestDialect::printOperation(Operation *op,
300                                           OpAsmPrinter &printer) const {
301   StringRef opName = op->getName().getStringRef();
302   if (opName == "test.dialect_custom_printer") {
303     printer.getStream() << opName << " custom_format";
304     return success();
305   }
306   return failure();
307 }
308 
309 //===----------------------------------------------------------------------===//
310 // TestBranchOp
311 //===----------------------------------------------------------------------===//
312 
313 Optional<MutableOperandRange>
314 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
315   assert(index == 0 && "invalid successor index");
316   return targetOperandsMutable();
317 }
318 
319 //===----------------------------------------------------------------------===//
320 // TestDialectCanonicalizerOp
321 //===----------------------------------------------------------------------===//
322 
323 static LogicalResult
324 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
325                                PatternRewriter &rewriter) {
326   rewriter.replaceOpWithNewOp<ConstantOp>(op, rewriter.getI32Type(),
327                                           rewriter.getI32IntegerAttr(42));
328   return success();
329 }
330 
331 void TestDialect::getCanonicalizationPatterns(
332     RewritePatternSet &results) const {
333   results.add(&dialectCanonicalizationPattern);
334 }
335 
336 //===----------------------------------------------------------------------===//
337 // TestFoldToCallOp
338 //===----------------------------------------------------------------------===//
339 
340 namespace {
341 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
342   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
343 
344   LogicalResult matchAndRewrite(FoldToCallOp op,
345                                 PatternRewriter &rewriter) const override {
346     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(),
347                                         ValueRange());
348     return success();
349   }
350 };
351 } // end anonymous namespace
352 
353 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
354                                                MLIRContext *context) {
355   results.add<FoldToCallOpPattern>(context);
356 }
357 
358 //===----------------------------------------------------------------------===//
359 // Test Format* operations
360 //===----------------------------------------------------------------------===//
361 
362 //===----------------------------------------------------------------------===//
363 // Parsing
364 
365 static ParseResult parseCustomDirectiveOperands(
366     OpAsmParser &parser, OpAsmParser::OperandType &operand,
367     Optional<OpAsmParser::OperandType> &optOperand,
368     SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
369   if (parser.parseOperand(operand))
370     return failure();
371   if (succeeded(parser.parseOptionalComma())) {
372     optOperand.emplace();
373     if (parser.parseOperand(*optOperand))
374       return failure();
375   }
376   if (parser.parseArrow() || parser.parseLParen() ||
377       parser.parseOperandList(varOperands) || parser.parseRParen())
378     return failure();
379   return success();
380 }
381 static ParseResult
382 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
383                             Type &optOperandType,
384                             SmallVectorImpl<Type> &varOperandTypes) {
385   if (parser.parseColon())
386     return failure();
387 
388   if (parser.parseType(operandType))
389     return failure();
390   if (succeeded(parser.parseOptionalComma())) {
391     if (parser.parseType(optOperandType))
392       return failure();
393   }
394   if (parser.parseArrow() || parser.parseLParen() ||
395       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
396     return failure();
397   return success();
398 }
399 static ParseResult
400 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
401                                  Type optOperandType,
402                                  const SmallVectorImpl<Type> &varOperandTypes) {
403   if (parser.parseKeyword("type_refs_capture"))
404     return failure();
405 
406   Type operandType2, optOperandType2;
407   SmallVector<Type, 1> varOperandTypes2;
408   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
409                                   varOperandTypes2))
410     return failure();
411 
412   if (operandType != operandType2 || optOperandType != optOperandType2 ||
413       varOperandTypes != varOperandTypes2)
414     return failure();
415 
416   return success();
417 }
418 static ParseResult parseCustomDirectiveOperandsAndTypes(
419     OpAsmParser &parser, OpAsmParser::OperandType &operand,
420     Optional<OpAsmParser::OperandType> &optOperand,
421     SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
422     Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
423   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
424       parseCustomDirectiveResults(parser, operandType, optOperandType,
425                                   varOperandTypes))
426     return failure();
427   return success();
428 }
429 static ParseResult parseCustomDirectiveRegions(
430     OpAsmParser &parser, Region &region,
431     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
432   if (parser.parseRegion(region))
433     return failure();
434   if (failed(parser.parseOptionalComma()))
435     return success();
436   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
437   if (parser.parseRegion(*varRegion))
438     return failure();
439   varRegions.emplace_back(std::move(varRegion));
440   return success();
441 }
442 static ParseResult
443 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
444                                SmallVectorImpl<Block *> &varSuccessors) {
445   if (parser.parseSuccessor(successor))
446     return failure();
447   if (failed(parser.parseOptionalComma()))
448     return success();
449   Block *varSuccessor;
450   if (parser.parseSuccessor(varSuccessor))
451     return failure();
452   varSuccessors.append(2, varSuccessor);
453   return success();
454 }
455 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
456                                                   IntegerAttr &attr,
457                                                   IntegerAttr &optAttr) {
458   if (parser.parseAttribute(attr))
459     return failure();
460   if (succeeded(parser.parseOptionalComma())) {
461     if (parser.parseAttribute(optAttr))
462       return failure();
463   }
464   return success();
465 }
466 
467 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
468                                                 NamedAttrList &attrs) {
469   return parser.parseOptionalAttrDict(attrs);
470 }
471 static ParseResult parseCustomDirectiveOptionalOperandRef(
472     OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) {
473   int64_t operandCount = 0;
474   if (parser.parseInteger(operandCount))
475     return failure();
476   bool expectedOptionalOperand = operandCount == 0;
477   return success(expectedOptionalOperand != optOperand.hasValue());
478 }
479 
480 //===----------------------------------------------------------------------===//
481 // Printing
482 
483 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
484                                          Value operand, Value optOperand,
485                                          OperandRange varOperands) {
486   printer << operand;
487   if (optOperand)
488     printer << ", " << optOperand;
489   printer << " -> (" << varOperands << ")";
490 }
491 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
492                                         Type operandType, Type optOperandType,
493                                         TypeRange varOperandTypes) {
494   printer << " : " << operandType;
495   if (optOperandType)
496     printer << ", " << optOperandType;
497   printer << " -> (" << varOperandTypes << ")";
498 }
499 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
500                                              Operation *op, Type operandType,
501                                              Type optOperandType,
502                                              TypeRange varOperandTypes) {
503   printer << " type_refs_capture ";
504   printCustomDirectiveResults(printer, op, operandType, optOperandType,
505                               varOperandTypes);
506 }
507 static void printCustomDirectiveOperandsAndTypes(
508     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
509     OperandRange varOperands, Type operandType, Type optOperandType,
510     TypeRange varOperandTypes) {
511   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
512   printCustomDirectiveResults(printer, op, operandType, optOperandType,
513                               varOperandTypes);
514 }
515 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
516                                         Region &region,
517                                         MutableArrayRef<Region> varRegions) {
518   printer.printRegion(region);
519   if (!varRegions.empty()) {
520     printer << ", ";
521     for (Region &region : varRegions)
522       printer.printRegion(region);
523   }
524 }
525 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
526                                            Block *successor,
527                                            SuccessorRange varSuccessors) {
528   printer << successor;
529   if (!varSuccessors.empty())
530     printer << ", " << varSuccessors.front();
531 }
532 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
533                                            Attribute attribute,
534                                            Attribute optAttribute) {
535   printer << attribute;
536   if (optAttribute)
537     printer << ", " << optAttribute;
538 }
539 
540 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
541                                          DictionaryAttr attrs) {
542   printer.printOptionalAttrDict(attrs.getValue());
543 }
544 
545 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
546                                                    Operation *op,
547                                                    Value optOperand) {
548   printer << (optOperand ? "1" : "0");
549 }
550 
551 //===----------------------------------------------------------------------===//
552 // Test IsolatedRegionOp - parse passthrough region arguments.
553 //===----------------------------------------------------------------------===//
554 
555 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
556                                          OperationState &result) {
557   OpAsmParser::OperandType argInfo;
558   Type argType = parser.getBuilder().getIndexType();
559 
560   // Parse the input operand.
561   if (parser.parseOperand(argInfo) ||
562       parser.resolveOperand(argInfo, argType, result.operands))
563     return failure();
564 
565   // Parse the body region, and reuse the operand info as the argument info.
566   Region *body = result.addRegion();
567   return parser.parseRegion(*body, argInfo, argType,
568                             /*enableNameShadowing=*/true);
569 }
570 
571 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
572   p << "test.isolated_region ";
573   p.printOperand(op.getOperand());
574   p.shadowRegionArgs(op.region(), op.getOperand());
575   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
576 }
577 
578 //===----------------------------------------------------------------------===//
579 // Test SSACFGRegionOp
580 //===----------------------------------------------------------------------===//
581 
582 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
583   return RegionKind::SSACFG;
584 }
585 
586 //===----------------------------------------------------------------------===//
587 // Test GraphRegionOp
588 //===----------------------------------------------------------------------===//
589 
590 static ParseResult parseGraphRegionOp(OpAsmParser &parser,
591                                       OperationState &result) {
592   // Parse the body region, and reuse the operand info as the argument info.
593   Region *body = result.addRegion();
594   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
595 }
596 
597 static void print(OpAsmPrinter &p, GraphRegionOp op) {
598   p << "test.graph_region ";
599   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
600 }
601 
602 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
603   return RegionKind::Graph;
604 }
605 
606 //===----------------------------------------------------------------------===//
607 // Test AffineScopeOp
608 //===----------------------------------------------------------------------===//
609 
610 static ParseResult parseAffineScopeOp(OpAsmParser &parser,
611                                       OperationState &result) {
612   // Parse the body region, and reuse the operand info as the argument info.
613   Region *body = result.addRegion();
614   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
615 }
616 
617 static void print(OpAsmPrinter &p, AffineScopeOp op) {
618   p << "test.affine_scope ";
619   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
620 }
621 
622 //===----------------------------------------------------------------------===//
623 // Test parser.
624 //===----------------------------------------------------------------------===//
625 
626 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser,
627                                               OperationState &result) {
628   if (parser.parseOptionalColon())
629     return success();
630   uint64_t numResults;
631   if (parser.parseInteger(numResults))
632     return failure();
633 
634   IndexType type = parser.getBuilder().getIndexType();
635   for (unsigned i = 0; i < numResults; ++i)
636     result.addTypes(type);
637   return success();
638 }
639 
640 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) {
641   p << ParseIntegerLiteralOp::getOperationName();
642   if (unsigned numResults = op->getNumResults())
643     p << " : " << numResults;
644 }
645 
646 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser,
647                                               OperationState &result) {
648   StringRef keyword;
649   if (parser.parseKeyword(&keyword))
650     return failure();
651   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
652   return success();
653 }
654 
655 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) {
656   p << ParseWrappedKeywordOp::getOperationName() << " " << op.keyword();
657 }
658 
659 //===----------------------------------------------------------------------===//
660 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
661 
662 static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
663                                          OperationState &result) {
664   if (parser.parseKeyword("wraps"))
665     return failure();
666 
667   // Parse the wrapped op in a region
668   Region &body = *result.addRegion();
669   body.push_back(new Block);
670   Block &block = body.back();
671   Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
672   if (!wrapped_op)
673     return failure();
674 
675   // Create a return terminator in the inner region, pass as operand to the
676   // terminator the returned values from the wrapped operation.
677   SmallVector<Value, 8> return_operands(wrapped_op->getResults());
678   OpBuilder builder(parser.getBuilder().getContext());
679   builder.setInsertionPointToEnd(&block);
680   builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
681 
682   // Get the results type for the wrapping op from the terminator operands.
683   Operation &return_op = body.back().back();
684   result.types.append(return_op.operand_type_begin(),
685                       return_op.operand_type_end());
686 
687   // Use the location of the wrapped op for the "test.wrapping_region" op.
688   result.location = wrapped_op->getLoc();
689 
690   return success();
691 }
692 
693 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
694   p << op.getOperationName() << " wraps ";
695   p.printGenericOp(&op.region().front().front());
696 }
697 
698 //===----------------------------------------------------------------------===//
699 // Test PolyForOp - parse list of region arguments.
700 //===----------------------------------------------------------------------===//
701 
702 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
703   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
704   // Parse list of region arguments without a delimiter.
705   if (parser.parseRegionArgumentList(ivsInfo))
706     return failure();
707 
708   // Parse the body region.
709   Region *body = result.addRegion();
710   auto &builder = parser.getBuilder();
711   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
712   return parser.parseRegion(*body, ivsInfo, argTypes);
713 }
714 
715 //===----------------------------------------------------------------------===//
716 // Test removing op with inner ops.
717 //===----------------------------------------------------------------------===//
718 
719 namespace {
720 struct TestRemoveOpWithInnerOps
721     : public OpRewritePattern<TestOpWithRegionPattern> {
722   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
723 
724   void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
725 
726   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
727                                 PatternRewriter &rewriter) const override {
728     rewriter.eraseOp(op);
729     return success();
730   }
731 };
732 } // end anonymous namespace
733 
734 void TestOpWithRegionPattern::getCanonicalizationPatterns(
735     RewritePatternSet &results, MLIRContext *context) {
736   results.add<TestRemoveOpWithInnerOps>(context);
737 }
738 
739 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
740   return operand();
741 }
742 
743 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
744   return getValue();
745 }
746 
747 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
748     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
749   for (Value input : this->operands()) {
750     results.push_back(input);
751   }
752   return success();
753 }
754 
755 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
756   assert(operands.size() == 1);
757   if (operands.front()) {
758     (*this)->setAttr("attr", operands.front());
759     return getResult();
760   }
761   return {};
762 }
763 
764 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
765   return getOperand();
766 }
767 
768 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
769     MLIRContext *, Optional<Location> location, ValueRange operands,
770     DictionaryAttr attributes, RegionRange regions,
771     SmallVectorImpl<Type> &inferredReturnTypes) {
772   if (operands[0].getType() != operands[1].getType()) {
773     return emitOptionalError(location, "operand type mismatch ",
774                              operands[0].getType(), " vs ",
775                              operands[1].getType());
776   }
777   inferredReturnTypes.assign({operands[0].getType()});
778   return success();
779 }
780 
781 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
782     MLIRContext *context, Optional<Location> location, ValueRange operands,
783     DictionaryAttr attributes, RegionRange regions,
784     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
785   // Create return type consisting of the last element of the first operand.
786   auto operandType = *operands.getTypes().begin();
787   auto sval = operandType.dyn_cast<ShapedType>();
788   if (!sval) {
789     return emitOptionalError(location, "only shaped type operands allowed");
790   }
791   int64_t dim =
792       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
793   auto type = IntegerType::get(context, 17);
794   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
795   return success();
796 }
797 
798 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
799     OpBuilder &builder, ValueRange operands,
800     llvm::SmallVectorImpl<Value> &shapes) {
801   shapes = SmallVector<Value, 1>{
802       builder.createOrFold<memref::DimOp>(getLoc(), operands.front(), 0)};
803   return success();
804 }
805 
806 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
807     OpBuilder &builder, ValueRange operands,
808     llvm::SmallVectorImpl<Value> &shapes) {
809   Location loc = getLoc();
810   shapes.reserve(operands.size());
811   for (Value operand : llvm::reverse(operands)) {
812     auto currShape = llvm::to_vector<4>(llvm::map_range(
813         llvm::seq<int64_t>(
814             0, operand.getType().cast<RankedTensorType>().getRank()),
815         [&](int64_t dim) -> Value {
816           return builder.createOrFold<memref::DimOp>(loc, operand, dim);
817         }));
818     shapes.push_back(builder.create<tensor::FromElementsOp>(
819         getLoc(), builder.getIndexType(), currShape));
820   }
821   return success();
822 }
823 
824 LogicalResult
825 OpWithResultShapePerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim(
826     OpBuilder &builder,
827     llvm::SmallVectorImpl<llvm::SmallVector<Value>> &shapes) {
828   Location loc = getLoc();
829   shapes.reserve(getNumOperands());
830   for (Value operand : llvm::reverse(getOperands())) {
831     auto currShape = llvm::to_vector<4>(llvm::map_range(
832         llvm::seq<int64_t>(
833             0, operand.getType().cast<RankedTensorType>().getRank()),
834         [&](int64_t dim) -> Value {
835           return builder.createOrFold<memref::DimOp>(loc, operand, dim);
836         }));
837     shapes.emplace_back(std::move(currShape));
838   }
839   return success();
840 }
841 
842 LogicalResult OpWithResultShapeAndPerDimInterfaceOp::reifyReturnTypeShapes(
843     OpBuilder &builder, ValueRange operands,
844     llvm::SmallVectorImpl<Value> &shapes) {
845   Location loc = getLoc();
846   shapes.reserve(operands.size());
847   for (Value operand : llvm::reverse(operands)) {
848     auto currShape = llvm::to_vector<4>(llvm::map_range(
849         llvm::seq<int64_t>(
850             0, operand.getType().cast<RankedTensorType>().getRank()),
851         [&](int64_t dim) -> Value {
852           return builder.createOrFold<memref::DimOp>(loc, operand, dim);
853         }));
854     shapes.push_back(builder.create<tensor::FromElementsOp>(
855         getLoc(), builder.getIndexType(), currShape));
856   }
857   return success();
858 }
859 
860 LogicalResult
861 OpWithResultShapeAndPerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim(
862     OpBuilder &builder,
863     llvm::SmallVectorImpl<llvm::SmallVector<Value>> &shapes) {
864   Location loc = getLoc();
865   shapes.reserve(getNumOperands());
866   for (Value operand : llvm::reverse(getOperands())) {
867     auto currShape = llvm::to_vector<4>(llvm::map_range(
868         llvm::seq<int64_t>(
869             0, operand.getType().cast<RankedTensorType>().getRank()),
870         [&](int64_t dim) -> Value {
871           return builder.createOrFold<memref::DimOp>(loc, operand, dim);
872         }));
873     shapes.emplace_back(std::move(currShape));
874   }
875   return success();
876 }
877 
878 //===----------------------------------------------------------------------===//
879 // Test SideEffect interfaces
880 //===----------------------------------------------------------------------===//
881 
882 namespace {
883 /// A test resource for side effects.
884 struct TestResource : public SideEffects::Resource::Base<TestResource> {
885   StringRef getName() final { return "<Test>"; }
886 };
887 } // end anonymous namespace
888 
889 static void testSideEffectOpGetEffect(
890     Operation *op,
891     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
892         &effects) {
893   auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
894   if (!effectsAttr)
895     return;
896 
897   effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
898 }
899 
900 void SideEffectOp::getEffects(
901     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
902   // Check for an effects attribute on the op instance.
903   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
904   if (!effectsAttr)
905     return;
906 
907   // If there is one, it is an array of dictionary attributes that hold
908   // information on the effects of this operation.
909   for (Attribute element : effectsAttr) {
910     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
911 
912     // Get the specific memory effect.
913     MemoryEffects::Effect *effect =
914         StringSwitch<MemoryEffects::Effect *>(
915             effectElement.get("effect").cast<StringAttr>().getValue())
916             .Case("allocate", MemoryEffects::Allocate::get())
917             .Case("free", MemoryEffects::Free::get())
918             .Case("read", MemoryEffects::Read::get())
919             .Case("write", MemoryEffects::Write::get());
920 
921     // Check for a non-default resource to use.
922     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
923     if (effectElement.get("test_resource"))
924       resource = TestResource::get();
925 
926     // Check for a result to affect.
927     if (effectElement.get("on_result"))
928       effects.emplace_back(effect, getResult(), resource);
929     else if (Attribute ref = effectElement.get("on_reference"))
930       effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
931     else
932       effects.emplace_back(effect, resource);
933   }
934 }
935 
936 void SideEffectOp::getEffects(
937     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
938   testSideEffectOpGetEffect(getOperation(), effects);
939 }
940 
941 //===----------------------------------------------------------------------===//
942 // StringAttrPrettyNameOp
943 //===----------------------------------------------------------------------===//
944 
945 // This op has fancy handling of its SSA result name.
946 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
947                                                OperationState &result) {
948   // Add the result types.
949   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
950     result.addTypes(parser.getBuilder().getIntegerType(32));
951 
952   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
953     return failure();
954 
955   // If the attribute dictionary contains no 'names' attribute, infer it from
956   // the SSA name (if specified).
957   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
958     return attr.first == "names";
959   });
960 
961   // If there was no name specified, check to see if there was a useful name
962   // specified in the asm file.
963   if (hadNames || parser.getNumResults() == 0)
964     return success();
965 
966   SmallVector<StringRef, 4> names;
967   auto *context = result.getContext();
968 
969   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
970     auto resultName = parser.getResultName(i);
971     StringRef nameStr;
972     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
973       nameStr = resultName.first;
974 
975     names.push_back(nameStr);
976   }
977 
978   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
979   result.attributes.push_back({Identifier::get("names", context), namesAttr});
980   return success();
981 }
982 
983 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
984   p << "test.string_attr_pretty_name";
985 
986   // Note that we only need to print the "name" attribute if the asmprinter
987   // result name disagrees with it.  This can happen in strange cases, e.g.
988   // when there are conflicts.
989   bool namesDisagree = op.names().size() != op.getNumResults();
990 
991   SmallString<32> resultNameStr;
992   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
993     resultNameStr.clear();
994     llvm::raw_svector_ostream tmpStream(resultNameStr);
995     p.printOperand(op.getResult(i), tmpStream);
996 
997     auto expectedName = op.names()[i].dyn_cast<StringAttr>();
998     if (!expectedName ||
999         tmpStream.str().drop_front() != expectedName.getValue()) {
1000       namesDisagree = true;
1001     }
1002   }
1003 
1004   if (namesDisagree)
1005     p.printOptionalAttrDictWithKeyword(op->getAttrs());
1006   else
1007     p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"});
1008 }
1009 
1010 // We set the SSA name in the asm syntax to the contents of the name
1011 // attribute.
1012 void StringAttrPrettyNameOp::getAsmResultNames(
1013     function_ref<void(Value, StringRef)> setNameFn) {
1014 
1015   auto value = names();
1016   for (size_t i = 0, e = value.size(); i != e; ++i)
1017     if (auto str = value[i].dyn_cast<StringAttr>())
1018       if (!str.getValue().empty())
1019         setNameFn(getResult(i), str.getValue());
1020 }
1021 
1022 //===----------------------------------------------------------------------===//
1023 // RegionIfOp
1024 //===----------------------------------------------------------------------===//
1025 
1026 static void print(OpAsmPrinter &p, RegionIfOp op) {
1027   p << RegionIfOp::getOperationName() << " ";
1028   p.printOperands(op.getOperands());
1029   p << ": " << op.getOperandTypes();
1030   p.printArrowTypeList(op.getResultTypes());
1031   p << " then";
1032   p.printRegion(op.thenRegion(),
1033                 /*printEntryBlockArgs=*/true,
1034                 /*printBlockTerminators=*/true);
1035   p << " else";
1036   p.printRegion(op.elseRegion(),
1037                 /*printEntryBlockArgs=*/true,
1038                 /*printBlockTerminators=*/true);
1039   p << " join";
1040   p.printRegion(op.joinRegion(),
1041                 /*printEntryBlockArgs=*/true,
1042                 /*printBlockTerminators=*/true);
1043 }
1044 
1045 static ParseResult parseRegionIfOp(OpAsmParser &parser,
1046                                    OperationState &result) {
1047   SmallVector<OpAsmParser::OperandType, 2> operandInfos;
1048   SmallVector<Type, 2> operandTypes;
1049 
1050   result.regions.reserve(3);
1051   Region *thenRegion = result.addRegion();
1052   Region *elseRegion = result.addRegion();
1053   Region *joinRegion = result.addRegion();
1054 
1055   // Parse operand, type and arrow type lists.
1056   if (parser.parseOperandList(operandInfos) ||
1057       parser.parseColonTypeList(operandTypes) ||
1058       parser.parseArrowTypeList(result.types))
1059     return failure();
1060 
1061   // Parse all attached regions.
1062   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
1063       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
1064       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
1065     return failure();
1066 
1067   return parser.resolveOperands(operandInfos, operandTypes,
1068                                 parser.getCurrentLocation(), result.operands);
1069 }
1070 
1071 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
1072   assert(index < 2 && "invalid region index");
1073   return getOperands();
1074 }
1075 
1076 void RegionIfOp::getSuccessorRegions(
1077     Optional<unsigned> index, ArrayRef<Attribute> operands,
1078     SmallVectorImpl<RegionSuccessor> &regions) {
1079   // We always branch to the join region.
1080   if (index.hasValue()) {
1081     if (index.getValue() < 2)
1082       regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs()));
1083     else
1084       regions.push_back(RegionSuccessor(getResults()));
1085     return;
1086   }
1087 
1088   // The then and else regions are the entry regions of this op.
1089   regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs()));
1090   regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs()));
1091 }
1092 
1093 //===----------------------------------------------------------------------===//
1094 // SingleNoTerminatorCustomAsmOp
1095 //===----------------------------------------------------------------------===//
1096 
1097 static ParseResult parseSingleNoTerminatorCustomAsmOp(OpAsmParser &parser,
1098                                                       OperationState &state) {
1099   Region *body = state.addRegion();
1100   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
1101     return failure();
1102   return success();
1103 }
1104 
1105 static void print(SingleNoTerminatorCustomAsmOp op, OpAsmPrinter &printer) {
1106   printer << op.getOperationName();
1107   printer.printRegion(
1108       op.getRegion(), /*printEntryBlockArgs=*/false,
1109       // This op has a single block without terminators. But explicitly mark
1110       // as not printing block terminators for testing.
1111       /*printBlockTerminators=*/false);
1112 }
1113 
1114 #include "TestOpEnums.cpp.inc"
1115 #include "TestOpInterfaces.cpp.inc"
1116 #include "TestOpStructs.cpp.inc"
1117 #include "TestTypeInterfaces.cpp.inc"
1118 
1119 #define GET_OP_CLASSES
1120 #include "TestOps.cpp.inc"
1121