1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestAttributes.h"
11 #include "TestInterfaces.h"
12 #include "TestTypes.h"
13 #include "mlir/Dialect/DLTI/DLTI.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/DialectImplementation.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Reducer/ReductionPatternInterface.h"
21 #include "mlir/Transforms/FoldUtils.h"
22 #include "mlir/Transforms/InliningUtils.h"
23 #include "llvm/ADT/StringSwitch.h"
24 
25 using namespace mlir;
26 using namespace mlir::test;
27 
28 #include "TestOpsDialect.cpp.inc"
29 
30 void mlir::test::registerTestDialect(DialectRegistry &registry) {
31   registry.insert<TestDialect>();
32 }
33 
34 //===----------------------------------------------------------------------===//
35 // TestDialect Interfaces
36 //===----------------------------------------------------------------------===//
37 
38 namespace {
39 
40 /// Testing the correctness of some traits.
41 static_assert(
42     llvm::is_detected<OpTrait::has_implicit_terminator_t,
43                       SingleBlockImplicitTerminatorOp>::value,
44     "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
45 static_assert(OpTrait::hasSingleBlockImplicitTerminator<
46                   SingleBlockImplicitTerminatorOp>::value,
47               "hasSingleBlockImplicitTerminator does not match "
48               "SingleBlockImplicitTerminatorOp");
49 
50 // Test support for interacting with the AsmPrinter.
51 struct TestOpAsmInterface : public OpAsmDialectInterface {
52   using OpAsmDialectInterface::OpAsmDialectInterface;
53 
54   AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
55     StringAttr strAttr = attr.dyn_cast<StringAttr>();
56     if (!strAttr)
57       return AliasResult::NoAlias;
58 
59     // Check the contents of the string attribute to see what the test alias
60     // should be named.
61     Optional<StringRef> aliasName =
62         StringSwitch<Optional<StringRef>>(strAttr.getValue())
63             .Case("alias_test:dot_in_name", StringRef("test.alias"))
64             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
65             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
66             .Case("alias_test:sanitize_conflict_a",
67                   StringRef("test_alias_conflict0"))
68             .Case("alias_test:sanitize_conflict_b",
69                   StringRef("test_alias_conflict0_"))
70             .Case("alias_test:tensor_encoding", StringRef("test_encoding"))
71             .Default(llvm::None);
72     if (!aliasName)
73       return AliasResult::NoAlias;
74 
75     os << *aliasName;
76     return AliasResult::FinalAlias;
77   }
78 
79   AliasResult getAlias(Type type, raw_ostream &os) const final {
80     if (auto tupleType = type.dyn_cast<TupleType>()) {
81       if (tupleType.size() > 0 &&
82           llvm::all_of(tupleType.getTypes(), [](Type elemType) {
83             return elemType.isa<SimpleAType>();
84           })) {
85         os << "test_tuple";
86         return AliasResult::FinalAlias;
87       }
88     }
89     return AliasResult::NoAlias;
90   }
91 
92   void getAsmResultNames(Operation *op,
93                          OpAsmSetValueNameFn setNameFn) const final {
94     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
95       setNameFn(asmOp, "result");
96   }
97 
98   void getAsmBlockArgumentNames(Block *block,
99                                 OpAsmSetValueNameFn setNameFn) const final {
100     auto op = block->getParentOp();
101     auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
102     if (!arrayAttr)
103       return;
104     auto args = block->getArguments();
105     auto e = std::min(arrayAttr.size(), args.size());
106     for (unsigned i = 0; i < e; ++i) {
107       if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
108         setNameFn(args[i], strAttr.getValue());
109     }
110   }
111 };
112 
113 struct TestDialectFoldInterface : public DialectFoldInterface {
114   using DialectFoldInterface::DialectFoldInterface;
115 
116   /// Registered hook to check if the given region, which is attached to an
117   /// operation that is *not* isolated from above, should be used when
118   /// materializing constants.
119   bool shouldMaterializeInto(Region *region) const final {
120     // If this is a one region operation, then insert into it.
121     return isa<OneRegionOp>(region->getParentOp());
122   }
123 };
124 
125 /// This class defines the interface for handling inlining with standard
126 /// operations.
127 struct TestInlinerInterface : public DialectInlinerInterface {
128   using DialectInlinerInterface::DialectInlinerInterface;
129 
130   //===--------------------------------------------------------------------===//
131   // Analysis Hooks
132   //===--------------------------------------------------------------------===//
133 
134   bool isLegalToInline(Operation *call, Operation *callable,
135                        bool wouldBeCloned) const final {
136     // Don't allow inlining calls that are marked `noinline`.
137     return !call->hasAttr("noinline");
138   }
139   bool isLegalToInline(Region *, Region *, bool,
140                        BlockAndValueMapping &) const final {
141     // Inlining into test dialect regions is legal.
142     return true;
143   }
144   bool isLegalToInline(Operation *, Region *, bool,
145                        BlockAndValueMapping &) const final {
146     return true;
147   }
148 
149   bool shouldAnalyzeRecursively(Operation *op) const final {
150     // Analyze recursively if this is not a functional region operation, it
151     // froms a separate functional scope.
152     return !isa<FunctionalRegionOp>(op);
153   }
154 
155   //===--------------------------------------------------------------------===//
156   // Transformation Hooks
157   //===--------------------------------------------------------------------===//
158 
159   /// Handle the given inlined terminator by replacing it with a new operation
160   /// as necessary.
161   void handleTerminator(Operation *op,
162                         ArrayRef<Value> valuesToRepl) const final {
163     // Only handle "test.return" here.
164     auto returnOp = dyn_cast<TestReturnOp>(op);
165     if (!returnOp)
166       return;
167 
168     // Replace the values directly with the return operands.
169     assert(returnOp.getNumOperands() == valuesToRepl.size());
170     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
171       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
172   }
173 
174   /// Attempt to materialize a conversion for a type mismatch between a call
175   /// from this dialect, and a callable region. This method should generate an
176   /// operation that takes 'input' as the only operand, and produces a single
177   /// result of 'resultType'. If a conversion can not be generated, nullptr
178   /// should be returned.
179   Operation *materializeCallConversion(OpBuilder &builder, Value input,
180                                        Type resultType,
181                                        Location conversionLoc) const final {
182     // Only allow conversion for i16/i32 types.
183     if (!(resultType.isSignlessInteger(16) ||
184           resultType.isSignlessInteger(32)) ||
185         !(input.getType().isSignlessInteger(16) ||
186           input.getType().isSignlessInteger(32)))
187       return nullptr;
188     return builder.create<TestCastOp>(conversionLoc, resultType, input);
189   }
190 
191   void processInlinedCallBlocks(
192       Operation *call,
193       iterator_range<Region::iterator> inlinedBlocks) const final {
194     if (!isa<ConversionCallOp>(call))
195       return;
196 
197     // Set attributed on all ops in the inlined blocks.
198     for (Block &block : inlinedBlocks) {
199       block.walk([&](Operation *op) {
200         op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
201       });
202     }
203   }
204 };
205 
206 struct TestReductionPatternInterface : public DialectReductionPatternInterface {
207 public:
208   TestReductionPatternInterface(Dialect *dialect)
209       : DialectReductionPatternInterface(dialect) {}
210 
211   virtual void
212   populateReductionPatterns(RewritePatternSet &patterns) const final {
213     populateTestReductionPatterns(patterns);
214   }
215 };
216 
217 } // end anonymous namespace
218 
219 //===----------------------------------------------------------------------===//
220 // TestDialect
221 //===----------------------------------------------------------------------===//
222 
223 static void testSideEffectOpGetEffect(
224     Operation *op,
225     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects);
226 
227 // This is the implementation of a dialect fallback for `TestEffectOpInterface`.
228 struct TestOpEffectInterfaceFallback
229     : public TestEffectOpInterface::FallbackModel<
230           TestOpEffectInterfaceFallback> {
231   static bool classof(Operation *op) {
232     bool isSupportedOp =
233         op->getName().getStringRef() == "test.unregistered_side_effect_op";
234     assert(isSupportedOp && "Unexpected dispatch");
235     return isSupportedOp;
236   }
237 
238   void
239   getEffects(Operation *op,
240              SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
241                  &effects) const {
242     testSideEffectOpGetEffect(op, effects);
243   }
244 };
245 
246 void TestDialect::initialize() {
247   registerAttributes();
248   registerTypes();
249   addOperations<
250 #define GET_OP_LIST
251 #include "TestOps.cpp.inc"
252       >();
253   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
254                 TestInlinerInterface, TestReductionPatternInterface>();
255   allowUnknownOperations();
256 
257   // Instantiate our fallback op interface that we'll use on specific
258   // unregistered op.
259   fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
260 }
261 TestDialect::~TestDialect() {
262   delete static_cast<TestOpEffectInterfaceFallback *>(
263       fallbackEffectOpInterfaces);
264 }
265 
266 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
267                                             Type type, Location loc) {
268   return builder.create<TestOpConstant>(loc, type, value);
269 }
270 
271 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
272                                                OperationName opName) {
273   if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
274       typeID == TypeID::get<TestEffectOpInterface>())
275     return fallbackEffectOpInterfaces;
276   return nullptr;
277 }
278 
279 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
280                                                     NamedAttribute namedAttr) {
281   if (namedAttr.first == "test.invalid_attr")
282     return op->emitError() << "invalid to use 'test.invalid_attr'";
283   return success();
284 }
285 
286 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
287                                                     unsigned regionIndex,
288                                                     unsigned argIndex,
289                                                     NamedAttribute namedAttr) {
290   if (namedAttr.first == "test.invalid_attr")
291     return op->emitError() << "invalid to use 'test.invalid_attr'";
292   return success();
293 }
294 
295 LogicalResult
296 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
297                                          unsigned resultIndex,
298                                          NamedAttribute namedAttr) {
299   if (namedAttr.first == "test.invalid_attr")
300     return op->emitError() << "invalid to use 'test.invalid_attr'";
301   return success();
302 }
303 
304 Optional<Dialect::ParseOpHook>
305 TestDialect::getParseOperationHook(StringRef opName) const {
306   if (opName == "test.dialect_custom_printer") {
307     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
308       return parser.parseKeyword("custom_format");
309     }};
310   }
311   return None;
312 }
313 
314 LogicalResult TestDialect::printOperation(Operation *op,
315                                           OpAsmPrinter &printer) const {
316   StringRef opName = op->getName().getStringRef();
317   if (opName == "test.dialect_custom_printer") {
318     printer.getStream() << opName << " custom_format";
319     return success();
320   }
321   return failure();
322 }
323 
324 //===----------------------------------------------------------------------===//
325 // TestBranchOp
326 //===----------------------------------------------------------------------===//
327 
328 Optional<MutableOperandRange>
329 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
330   assert(index == 0 && "invalid successor index");
331   return targetOperandsMutable();
332 }
333 
334 //===----------------------------------------------------------------------===//
335 // TestDialectCanonicalizerOp
336 //===----------------------------------------------------------------------===//
337 
338 static LogicalResult
339 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
340                                PatternRewriter &rewriter) {
341   rewriter.replaceOpWithNewOp<ConstantOp>(op, rewriter.getI32Type(),
342                                           rewriter.getI32IntegerAttr(42));
343   return success();
344 }
345 
346 void TestDialect::getCanonicalizationPatterns(
347     RewritePatternSet &results) const {
348   results.add(&dialectCanonicalizationPattern);
349 }
350 
351 //===----------------------------------------------------------------------===//
352 // TestFoldToCallOp
353 //===----------------------------------------------------------------------===//
354 
355 namespace {
356 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
357   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
358 
359   LogicalResult matchAndRewrite(FoldToCallOp op,
360                                 PatternRewriter &rewriter) const override {
361     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(),
362                                         ValueRange());
363     return success();
364   }
365 };
366 } // end anonymous namespace
367 
368 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
369                                                MLIRContext *context) {
370   results.add<FoldToCallOpPattern>(context);
371 }
372 
373 //===----------------------------------------------------------------------===//
374 // Test Format* operations
375 //===----------------------------------------------------------------------===//
376 
377 //===----------------------------------------------------------------------===//
378 // Parsing
379 
380 static ParseResult parseCustomDirectiveOperands(
381     OpAsmParser &parser, OpAsmParser::OperandType &operand,
382     Optional<OpAsmParser::OperandType> &optOperand,
383     SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
384   if (parser.parseOperand(operand))
385     return failure();
386   if (succeeded(parser.parseOptionalComma())) {
387     optOperand.emplace();
388     if (parser.parseOperand(*optOperand))
389       return failure();
390   }
391   if (parser.parseArrow() || parser.parseLParen() ||
392       parser.parseOperandList(varOperands) || parser.parseRParen())
393     return failure();
394   return success();
395 }
396 static ParseResult
397 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
398                             Type &optOperandType,
399                             SmallVectorImpl<Type> &varOperandTypes) {
400   if (parser.parseColon())
401     return failure();
402 
403   if (parser.parseType(operandType))
404     return failure();
405   if (succeeded(parser.parseOptionalComma())) {
406     if (parser.parseType(optOperandType))
407       return failure();
408   }
409   if (parser.parseArrow() || parser.parseLParen() ||
410       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
411     return failure();
412   return success();
413 }
414 static ParseResult
415 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
416                                  Type optOperandType,
417                                  const SmallVectorImpl<Type> &varOperandTypes) {
418   if (parser.parseKeyword("type_refs_capture"))
419     return failure();
420 
421   Type operandType2, optOperandType2;
422   SmallVector<Type, 1> varOperandTypes2;
423   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
424                                   varOperandTypes2))
425     return failure();
426 
427   if (operandType != operandType2 || optOperandType != optOperandType2 ||
428       varOperandTypes != varOperandTypes2)
429     return failure();
430 
431   return success();
432 }
433 static ParseResult parseCustomDirectiveOperandsAndTypes(
434     OpAsmParser &parser, OpAsmParser::OperandType &operand,
435     Optional<OpAsmParser::OperandType> &optOperand,
436     SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
437     Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
438   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
439       parseCustomDirectiveResults(parser, operandType, optOperandType,
440                                   varOperandTypes))
441     return failure();
442   return success();
443 }
444 static ParseResult parseCustomDirectiveRegions(
445     OpAsmParser &parser, Region &region,
446     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
447   if (parser.parseRegion(region))
448     return failure();
449   if (failed(parser.parseOptionalComma()))
450     return success();
451   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
452   if (parser.parseRegion(*varRegion))
453     return failure();
454   varRegions.emplace_back(std::move(varRegion));
455   return success();
456 }
457 static ParseResult
458 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
459                                SmallVectorImpl<Block *> &varSuccessors) {
460   if (parser.parseSuccessor(successor))
461     return failure();
462   if (failed(parser.parseOptionalComma()))
463     return success();
464   Block *varSuccessor;
465   if (parser.parseSuccessor(varSuccessor))
466     return failure();
467   varSuccessors.append(2, varSuccessor);
468   return success();
469 }
470 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
471                                                   IntegerAttr &attr,
472                                                   IntegerAttr &optAttr) {
473   if (parser.parseAttribute(attr))
474     return failure();
475   if (succeeded(parser.parseOptionalComma())) {
476     if (parser.parseAttribute(optAttr))
477       return failure();
478   }
479   return success();
480 }
481 
482 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
483                                                 NamedAttrList &attrs) {
484   return parser.parseOptionalAttrDict(attrs);
485 }
486 static ParseResult parseCustomDirectiveOptionalOperandRef(
487     OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) {
488   int64_t operandCount = 0;
489   if (parser.parseInteger(operandCount))
490     return failure();
491   bool expectedOptionalOperand = operandCount == 0;
492   return success(expectedOptionalOperand != optOperand.hasValue());
493 }
494 
495 //===----------------------------------------------------------------------===//
496 // Printing
497 
498 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
499                                          Value operand, Value optOperand,
500                                          OperandRange varOperands) {
501   printer << operand;
502   if (optOperand)
503     printer << ", " << optOperand;
504   printer << " -> (" << varOperands << ")";
505 }
506 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
507                                         Type operandType, Type optOperandType,
508                                         TypeRange varOperandTypes) {
509   printer << " : " << operandType;
510   if (optOperandType)
511     printer << ", " << optOperandType;
512   printer << " -> (" << varOperandTypes << ")";
513 }
514 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
515                                              Operation *op, Type operandType,
516                                              Type optOperandType,
517                                              TypeRange varOperandTypes) {
518   printer << " type_refs_capture ";
519   printCustomDirectiveResults(printer, op, operandType, optOperandType,
520                               varOperandTypes);
521 }
522 static void printCustomDirectiveOperandsAndTypes(
523     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
524     OperandRange varOperands, Type operandType, Type optOperandType,
525     TypeRange varOperandTypes) {
526   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
527   printCustomDirectiveResults(printer, op, operandType, optOperandType,
528                               varOperandTypes);
529 }
530 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
531                                         Region &region,
532                                         MutableArrayRef<Region> varRegions) {
533   printer.printRegion(region);
534   if (!varRegions.empty()) {
535     printer << ", ";
536     for (Region &region : varRegions)
537       printer.printRegion(region);
538   }
539 }
540 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
541                                            Block *successor,
542                                            SuccessorRange varSuccessors) {
543   printer << successor;
544   if (!varSuccessors.empty())
545     printer << ", " << varSuccessors.front();
546 }
547 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
548                                            Attribute attribute,
549                                            Attribute optAttribute) {
550   printer << attribute;
551   if (optAttribute)
552     printer << ", " << optAttribute;
553 }
554 
555 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
556                                          DictionaryAttr attrs) {
557   printer.printOptionalAttrDict(attrs.getValue());
558 }
559 
560 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
561                                                    Operation *op,
562                                                    Value optOperand) {
563   printer << (optOperand ? "1" : "0");
564 }
565 
566 //===----------------------------------------------------------------------===//
567 // Test IsolatedRegionOp - parse passthrough region arguments.
568 //===----------------------------------------------------------------------===//
569 
570 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
571                                          OperationState &result) {
572   OpAsmParser::OperandType argInfo;
573   Type argType = parser.getBuilder().getIndexType();
574 
575   // Parse the input operand.
576   if (parser.parseOperand(argInfo) ||
577       parser.resolveOperand(argInfo, argType, result.operands))
578     return failure();
579 
580   // Parse the body region, and reuse the operand info as the argument info.
581   Region *body = result.addRegion();
582   return parser.parseRegion(*body, argInfo, argType,
583                             /*enableNameShadowing=*/true);
584 }
585 
586 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
587   p << "test.isolated_region ";
588   p.printOperand(op.getOperand());
589   p.shadowRegionArgs(op.region(), op.getOperand());
590   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
591 }
592 
593 //===----------------------------------------------------------------------===//
594 // Test SSACFGRegionOp
595 //===----------------------------------------------------------------------===//
596 
597 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
598   return RegionKind::SSACFG;
599 }
600 
601 //===----------------------------------------------------------------------===//
602 // Test GraphRegionOp
603 //===----------------------------------------------------------------------===//
604 
605 static ParseResult parseGraphRegionOp(OpAsmParser &parser,
606                                       OperationState &result) {
607   // Parse the body region, and reuse the operand info as the argument info.
608   Region *body = result.addRegion();
609   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
610 }
611 
612 static void print(OpAsmPrinter &p, GraphRegionOp op) {
613   p << "test.graph_region ";
614   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
615 }
616 
617 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
618   return RegionKind::Graph;
619 }
620 
621 //===----------------------------------------------------------------------===//
622 // Test AffineScopeOp
623 //===----------------------------------------------------------------------===//
624 
625 static ParseResult parseAffineScopeOp(OpAsmParser &parser,
626                                       OperationState &result) {
627   // Parse the body region, and reuse the operand info as the argument info.
628   Region *body = result.addRegion();
629   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
630 }
631 
632 static void print(OpAsmPrinter &p, AffineScopeOp op) {
633   p << "test.affine_scope ";
634   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
635 }
636 
637 //===----------------------------------------------------------------------===//
638 // Test parser.
639 //===----------------------------------------------------------------------===//
640 
641 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser,
642                                               OperationState &result) {
643   if (parser.parseOptionalColon())
644     return success();
645   uint64_t numResults;
646   if (parser.parseInteger(numResults))
647     return failure();
648 
649   IndexType type = parser.getBuilder().getIndexType();
650   for (unsigned i = 0; i < numResults; ++i)
651     result.addTypes(type);
652   return success();
653 }
654 
655 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) {
656   p << ParseIntegerLiteralOp::getOperationName();
657   if (unsigned numResults = op->getNumResults())
658     p << " : " << numResults;
659 }
660 
661 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser,
662                                               OperationState &result) {
663   StringRef keyword;
664   if (parser.parseKeyword(&keyword))
665     return failure();
666   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
667   return success();
668 }
669 
670 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) {
671   p << ParseWrappedKeywordOp::getOperationName() << " " << op.keyword();
672 }
673 
674 //===----------------------------------------------------------------------===//
675 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
676 
677 static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
678                                          OperationState &result) {
679   if (parser.parseKeyword("wraps"))
680     return failure();
681 
682   // Parse the wrapped op in a region
683   Region &body = *result.addRegion();
684   body.push_back(new Block);
685   Block &block = body.back();
686   Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
687   if (!wrapped_op)
688     return failure();
689 
690   // Create a return terminator in the inner region, pass as operand to the
691   // terminator the returned values from the wrapped operation.
692   SmallVector<Value, 8> return_operands(wrapped_op->getResults());
693   OpBuilder builder(parser.getBuilder().getContext());
694   builder.setInsertionPointToEnd(&block);
695   builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
696 
697   // Get the results type for the wrapping op from the terminator operands.
698   Operation &return_op = body.back().back();
699   result.types.append(return_op.operand_type_begin(),
700                       return_op.operand_type_end());
701 
702   // Use the location of the wrapped op for the "test.wrapping_region" op.
703   result.location = wrapped_op->getLoc();
704 
705   return success();
706 }
707 
708 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
709   p << op.getOperationName() << " wraps ";
710   p.printGenericOp(&op.region().front().front());
711 }
712 
713 //===----------------------------------------------------------------------===//
714 // Test PolyForOp - parse list of region arguments.
715 //===----------------------------------------------------------------------===//
716 
717 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
718   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
719   // Parse list of region arguments without a delimiter.
720   if (parser.parseRegionArgumentList(ivsInfo))
721     return failure();
722 
723   // Parse the body region.
724   Region *body = result.addRegion();
725   auto &builder = parser.getBuilder();
726   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
727   return parser.parseRegion(*body, ivsInfo, argTypes);
728 }
729 
730 //===----------------------------------------------------------------------===//
731 // Test removing op with inner ops.
732 //===----------------------------------------------------------------------===//
733 
734 namespace {
735 struct TestRemoveOpWithInnerOps
736     : public OpRewritePattern<TestOpWithRegionPattern> {
737   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
738 
739   void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
740 
741   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
742                                 PatternRewriter &rewriter) const override {
743     rewriter.eraseOp(op);
744     return success();
745   }
746 };
747 } // end anonymous namespace
748 
749 void TestOpWithRegionPattern::getCanonicalizationPatterns(
750     RewritePatternSet &results, MLIRContext *context) {
751   results.add<TestRemoveOpWithInnerOps>(context);
752 }
753 
754 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
755   return operand();
756 }
757 
758 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
759   return getValue();
760 }
761 
762 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
763     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
764   for (Value input : this->operands()) {
765     results.push_back(input);
766   }
767   return success();
768 }
769 
770 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
771   assert(operands.size() == 1);
772   if (operands.front()) {
773     (*this)->setAttr("attr", operands.front());
774     return getResult();
775   }
776   return {};
777 }
778 
779 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
780   return getOperand();
781 }
782 
783 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
784     MLIRContext *, Optional<Location> location, ValueRange operands,
785     DictionaryAttr attributes, RegionRange regions,
786     SmallVectorImpl<Type> &inferredReturnTypes) {
787   if (operands[0].getType() != operands[1].getType()) {
788     return emitOptionalError(location, "operand type mismatch ",
789                              operands[0].getType(), " vs ",
790                              operands[1].getType());
791   }
792   inferredReturnTypes.assign({operands[0].getType()});
793   return success();
794 }
795 
796 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
797     MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
798     DictionaryAttr attributes, RegionRange regions,
799     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
800   // Create return type consisting of the last element of the first operand.
801   auto operandType = operands.front().getType();
802   auto sval = operandType.dyn_cast<ShapedType>();
803   if (!sval) {
804     return emitOptionalError(location, "only shaped type operands allowed");
805   }
806   int64_t dim =
807       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
808   auto type = IntegerType::get(context, 17);
809   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
810   return success();
811 }
812 
813 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
814     OpBuilder &builder, ValueRange operands,
815     llvm::SmallVectorImpl<Value> &shapes) {
816   shapes = SmallVector<Value, 1>{
817       builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
818   return success();
819 }
820 
821 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
822     OpBuilder &builder, ValueRange operands,
823     llvm::SmallVectorImpl<Value> &shapes) {
824   Location loc = getLoc();
825   shapes.reserve(operands.size());
826   for (Value operand : llvm::reverse(operands)) {
827     auto currShape = llvm::to_vector<4>(llvm::map_range(
828         llvm::seq<int64_t>(
829             0, operand.getType().cast<RankedTensorType>().getRank()),
830         [&](int64_t dim) -> Value {
831           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
832         }));
833     shapes.push_back(builder.create<tensor::FromElementsOp>(
834         getLoc(), builder.getIndexType(), currShape));
835   }
836   return success();
837 }
838 
839 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
840     OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
841   Location loc = getLoc();
842   shapes.reserve(getNumOperands());
843   for (Value operand : llvm::reverse(getOperands())) {
844     auto currShape = llvm::to_vector<4>(llvm::map_range(
845         llvm::seq<int64_t>(
846             0, operand.getType().cast<RankedTensorType>().getRank()),
847         [&](int64_t dim) -> Value {
848           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
849         }));
850     shapes.emplace_back(std::move(currShape));
851   }
852   return success();
853 }
854 
855 //===----------------------------------------------------------------------===//
856 // Test SideEffect interfaces
857 //===----------------------------------------------------------------------===//
858 
859 namespace {
860 /// A test resource for side effects.
861 struct TestResource : public SideEffects::Resource::Base<TestResource> {
862   StringRef getName() final { return "<Test>"; }
863 };
864 } // end anonymous namespace
865 
866 static void testSideEffectOpGetEffect(
867     Operation *op,
868     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
869         &effects) {
870   auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
871   if (!effectsAttr)
872     return;
873 
874   effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
875 }
876 
877 void SideEffectOp::getEffects(
878     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
879   // Check for an effects attribute on the op instance.
880   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
881   if (!effectsAttr)
882     return;
883 
884   // If there is one, it is an array of dictionary attributes that hold
885   // information on the effects of this operation.
886   for (Attribute element : effectsAttr) {
887     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
888 
889     // Get the specific memory effect.
890     MemoryEffects::Effect *effect =
891         StringSwitch<MemoryEffects::Effect *>(
892             effectElement.get("effect").cast<StringAttr>().getValue())
893             .Case("allocate", MemoryEffects::Allocate::get())
894             .Case("free", MemoryEffects::Free::get())
895             .Case("read", MemoryEffects::Read::get())
896             .Case("write", MemoryEffects::Write::get());
897 
898     // Check for a non-default resource to use.
899     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
900     if (effectElement.get("test_resource"))
901       resource = TestResource::get();
902 
903     // Check for a result to affect.
904     if (effectElement.get("on_result"))
905       effects.emplace_back(effect, getResult(), resource);
906     else if (Attribute ref = effectElement.get("on_reference"))
907       effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
908     else
909       effects.emplace_back(effect, resource);
910   }
911 }
912 
913 void SideEffectOp::getEffects(
914     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
915   testSideEffectOpGetEffect(getOperation(), effects);
916 }
917 
918 //===----------------------------------------------------------------------===//
919 // StringAttrPrettyNameOp
920 //===----------------------------------------------------------------------===//
921 
922 // This op has fancy handling of its SSA result name.
923 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
924                                                OperationState &result) {
925   // Add the result types.
926   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
927     result.addTypes(parser.getBuilder().getIntegerType(32));
928 
929   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
930     return failure();
931 
932   // If the attribute dictionary contains no 'names' attribute, infer it from
933   // the SSA name (if specified).
934   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
935     return attr.first == "names";
936   });
937 
938   // If there was no name specified, check to see if there was a useful name
939   // specified in the asm file.
940   if (hadNames || parser.getNumResults() == 0)
941     return success();
942 
943   SmallVector<StringRef, 4> names;
944   auto *context = result.getContext();
945 
946   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
947     auto resultName = parser.getResultName(i);
948     StringRef nameStr;
949     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
950       nameStr = resultName.first;
951 
952     names.push_back(nameStr);
953   }
954 
955   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
956   result.attributes.push_back({Identifier::get("names", context), namesAttr});
957   return success();
958 }
959 
960 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
961   p << "test.string_attr_pretty_name";
962 
963   // Note that we only need to print the "name" attribute if the asmprinter
964   // result name disagrees with it.  This can happen in strange cases, e.g.
965   // when there are conflicts.
966   bool namesDisagree = op.names().size() != op.getNumResults();
967 
968   SmallString<32> resultNameStr;
969   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
970     resultNameStr.clear();
971     llvm::raw_svector_ostream tmpStream(resultNameStr);
972     p.printOperand(op.getResult(i), tmpStream);
973 
974     auto expectedName = op.names()[i].dyn_cast<StringAttr>();
975     if (!expectedName ||
976         tmpStream.str().drop_front() != expectedName.getValue()) {
977       namesDisagree = true;
978     }
979   }
980 
981   if (namesDisagree)
982     p.printOptionalAttrDictWithKeyword(op->getAttrs());
983   else
984     p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"});
985 }
986 
987 // We set the SSA name in the asm syntax to the contents of the name
988 // attribute.
989 void StringAttrPrettyNameOp::getAsmResultNames(
990     function_ref<void(Value, StringRef)> setNameFn) {
991 
992   auto value = names();
993   for (size_t i = 0, e = value.size(); i != e; ++i)
994     if (auto str = value[i].dyn_cast<StringAttr>())
995       if (!str.getValue().empty())
996         setNameFn(getResult(i), str.getValue());
997 }
998 
999 //===----------------------------------------------------------------------===//
1000 // RegionIfOp
1001 //===----------------------------------------------------------------------===//
1002 
1003 static void print(OpAsmPrinter &p, RegionIfOp op) {
1004   p << RegionIfOp::getOperationName() << " ";
1005   p.printOperands(op.getOperands());
1006   p << ": " << op.getOperandTypes();
1007   p.printArrowTypeList(op.getResultTypes());
1008   p << " then";
1009   p.printRegion(op.thenRegion(),
1010                 /*printEntryBlockArgs=*/true,
1011                 /*printBlockTerminators=*/true);
1012   p << " else";
1013   p.printRegion(op.elseRegion(),
1014                 /*printEntryBlockArgs=*/true,
1015                 /*printBlockTerminators=*/true);
1016   p << " join";
1017   p.printRegion(op.joinRegion(),
1018                 /*printEntryBlockArgs=*/true,
1019                 /*printBlockTerminators=*/true);
1020 }
1021 
1022 static ParseResult parseRegionIfOp(OpAsmParser &parser,
1023                                    OperationState &result) {
1024   SmallVector<OpAsmParser::OperandType, 2> operandInfos;
1025   SmallVector<Type, 2> operandTypes;
1026 
1027   result.regions.reserve(3);
1028   Region *thenRegion = result.addRegion();
1029   Region *elseRegion = result.addRegion();
1030   Region *joinRegion = result.addRegion();
1031 
1032   // Parse operand, type and arrow type lists.
1033   if (parser.parseOperandList(operandInfos) ||
1034       parser.parseColonTypeList(operandTypes) ||
1035       parser.parseArrowTypeList(result.types))
1036     return failure();
1037 
1038   // Parse all attached regions.
1039   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
1040       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
1041       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
1042     return failure();
1043 
1044   return parser.resolveOperands(operandInfos, operandTypes,
1045                                 parser.getCurrentLocation(), result.operands);
1046 }
1047 
1048 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
1049   assert(index < 2 && "invalid region index");
1050   return getOperands();
1051 }
1052 
1053 void RegionIfOp::getSuccessorRegions(
1054     Optional<unsigned> index, ArrayRef<Attribute> operands,
1055     SmallVectorImpl<RegionSuccessor> &regions) {
1056   // We always branch to the join region.
1057   if (index.hasValue()) {
1058     if (index.getValue() < 2)
1059       regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs()));
1060     else
1061       regions.push_back(RegionSuccessor(getResults()));
1062     return;
1063   }
1064 
1065   // The then and else regions are the entry regions of this op.
1066   regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs()));
1067   regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs()));
1068 }
1069 
1070 //===----------------------------------------------------------------------===//
1071 // SingleNoTerminatorCustomAsmOp
1072 //===----------------------------------------------------------------------===//
1073 
1074 static ParseResult parseSingleNoTerminatorCustomAsmOp(OpAsmParser &parser,
1075                                                       OperationState &state) {
1076   Region *body = state.addRegion();
1077   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
1078     return failure();
1079   return success();
1080 }
1081 
1082 static void print(SingleNoTerminatorCustomAsmOp op, OpAsmPrinter &printer) {
1083   printer << op.getOperationName();
1084   printer.printRegion(
1085       op.getRegion(), /*printEntryBlockArgs=*/false,
1086       // This op has a single block without terminators. But explicitly mark
1087       // as not printing block terminators for testing.
1088       /*printBlockTerminators=*/false);
1089 }
1090 
1091 #include "TestOpEnums.cpp.inc"
1092 #include "TestOpInterfaces.cpp.inc"
1093 #include "TestOpStructs.cpp.inc"
1094 #include "TestTypeInterfaces.cpp.inc"
1095 
1096 #define GET_OP_CLASSES
1097 #include "TestOps.cpp.inc"
1098