1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestAttributes.h"
11 #include "TestInterfaces.h"
12 #include "TestTypes.h"
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/DLTI/DLTI.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/DialectImplementation.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/Reducer/ReductionPatternInterface.h"
22 #include "mlir/Transforms/FoldUtils.h"
23 #include "mlir/Transforms/InliningUtils.h"
24 #include "llvm/ADT/StringSwitch.h"
25 
26 // Include this before the using namespace lines below to
27 // test that we don't have namespace dependencies.
28 #include "TestOpsDialect.cpp.inc"
29 
30 using namespace mlir;
31 using namespace test;
32 
33 void test::registerTestDialect(DialectRegistry &registry) {
34   registry.insert<TestDialect>();
35 }
36 
37 //===----------------------------------------------------------------------===//
38 // TestDialect Interfaces
39 //===----------------------------------------------------------------------===//
40 
41 namespace {
42 
43 /// Testing the correctness of some traits.
44 static_assert(
45     llvm::is_detected<OpTrait::has_implicit_terminator_t,
46                       SingleBlockImplicitTerminatorOp>::value,
47     "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
48 static_assert(OpTrait::hasSingleBlockImplicitTerminator<
49                   SingleBlockImplicitTerminatorOp>::value,
50               "hasSingleBlockImplicitTerminator does not match "
51               "SingleBlockImplicitTerminatorOp");
52 
53 // Test support for interacting with the AsmPrinter.
54 struct TestOpAsmInterface : public OpAsmDialectInterface {
55   using OpAsmDialectInterface::OpAsmDialectInterface;
56 
57   AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
58     StringAttr strAttr = attr.dyn_cast<StringAttr>();
59     if (!strAttr)
60       return AliasResult::NoAlias;
61 
62     // Check the contents of the string attribute to see what the test alias
63     // should be named.
64     Optional<StringRef> aliasName =
65         StringSwitch<Optional<StringRef>>(strAttr.getValue())
66             .Case("alias_test:dot_in_name", StringRef("test.alias"))
67             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
68             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
69             .Case("alias_test:sanitize_conflict_a",
70                   StringRef("test_alias_conflict0"))
71             .Case("alias_test:sanitize_conflict_b",
72                   StringRef("test_alias_conflict0_"))
73             .Case("alias_test:tensor_encoding", StringRef("test_encoding"))
74             .Default(llvm::None);
75     if (!aliasName)
76       return AliasResult::NoAlias;
77 
78     os << *aliasName;
79     return AliasResult::FinalAlias;
80   }
81 
82   AliasResult getAlias(Type type, raw_ostream &os) const final {
83     if (auto tupleType = type.dyn_cast<TupleType>()) {
84       if (tupleType.size() > 0 &&
85           llvm::all_of(tupleType.getTypes(), [](Type elemType) {
86             return elemType.isa<SimpleAType>();
87           })) {
88         os << "test_tuple";
89         return AliasResult::FinalAlias;
90       }
91     }
92     if (auto intType = type.dyn_cast<TestIntegerType>()) {
93       if (intType.getSignedness() ==
94               TestIntegerType::SignednessSemantics::Unsigned &&
95           intType.getWidth() == 8) {
96         os << "test_ui8";
97         return AliasResult::FinalAlias;
98       }
99     }
100     return AliasResult::NoAlias;
101   }
102 
103   void getAsmResultNames(Operation *op,
104                          OpAsmSetValueNameFn setNameFn) const final {
105     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
106       setNameFn(asmOp, "result");
107   }
108 
109   void getAsmBlockArgumentNames(Block *block,
110                                 OpAsmSetValueNameFn setNameFn) const final {
111     auto op = block->getParentOp();
112     auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
113     if (!arrayAttr)
114       return;
115     auto args = block->getArguments();
116     auto e = std::min(arrayAttr.size(), args.size());
117     for (unsigned i = 0; i < e; ++i) {
118       if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
119         setNameFn(args[i], strAttr.getValue());
120     }
121   }
122 };
123 
124 struct TestDialectFoldInterface : public DialectFoldInterface {
125   using DialectFoldInterface::DialectFoldInterface;
126 
127   /// Registered hook to check if the given region, which is attached to an
128   /// operation that is *not* isolated from above, should be used when
129   /// materializing constants.
130   bool shouldMaterializeInto(Region *region) const final {
131     // If this is a one region operation, then insert into it.
132     return isa<OneRegionOp>(region->getParentOp());
133   }
134 };
135 
136 /// This class defines the interface for handling inlining with standard
137 /// operations.
138 struct TestInlinerInterface : public DialectInlinerInterface {
139   using DialectInlinerInterface::DialectInlinerInterface;
140 
141   //===--------------------------------------------------------------------===//
142   // Analysis Hooks
143   //===--------------------------------------------------------------------===//
144 
145   bool isLegalToInline(Operation *call, Operation *callable,
146                        bool wouldBeCloned) const final {
147     // Don't allow inlining calls that are marked `noinline`.
148     return !call->hasAttr("noinline");
149   }
150   bool isLegalToInline(Region *, Region *, bool,
151                        BlockAndValueMapping &) const final {
152     // Inlining into test dialect regions is legal.
153     return true;
154   }
155   bool isLegalToInline(Operation *, Region *, bool,
156                        BlockAndValueMapping &) const final {
157     return true;
158   }
159 
160   bool shouldAnalyzeRecursively(Operation *op) const final {
161     // Analyze recursively if this is not a functional region operation, it
162     // froms a separate functional scope.
163     return !isa<FunctionalRegionOp>(op);
164   }
165 
166   //===--------------------------------------------------------------------===//
167   // Transformation Hooks
168   //===--------------------------------------------------------------------===//
169 
170   /// Handle the given inlined terminator by replacing it with a new operation
171   /// as necessary.
172   void handleTerminator(Operation *op,
173                         ArrayRef<Value> valuesToRepl) const final {
174     // Only handle "test.return" here.
175     auto returnOp = dyn_cast<TestReturnOp>(op);
176     if (!returnOp)
177       return;
178 
179     // Replace the values directly with the return operands.
180     assert(returnOp.getNumOperands() == valuesToRepl.size());
181     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
182       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
183   }
184 
185   /// Attempt to materialize a conversion for a type mismatch between a call
186   /// from this dialect, and a callable region. This method should generate an
187   /// operation that takes 'input' as the only operand, and produces a single
188   /// result of 'resultType'. If a conversion can not be generated, nullptr
189   /// should be returned.
190   Operation *materializeCallConversion(OpBuilder &builder, Value input,
191                                        Type resultType,
192                                        Location conversionLoc) const final {
193     // Only allow conversion for i16/i32 types.
194     if (!(resultType.isSignlessInteger(16) ||
195           resultType.isSignlessInteger(32)) ||
196         !(input.getType().isSignlessInteger(16) ||
197           input.getType().isSignlessInteger(32)))
198       return nullptr;
199     return builder.create<TestCastOp>(conversionLoc, resultType, input);
200   }
201 
202   void processInlinedCallBlocks(
203       Operation *call,
204       iterator_range<Region::iterator> inlinedBlocks) const final {
205     if (!isa<ConversionCallOp>(call))
206       return;
207 
208     // Set attributed on all ops in the inlined blocks.
209     for (Block &block : inlinedBlocks) {
210       block.walk([&](Operation *op) {
211         op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
212       });
213     }
214   }
215 };
216 
217 struct TestReductionPatternInterface : public DialectReductionPatternInterface {
218 public:
219   TestReductionPatternInterface(Dialect *dialect)
220       : DialectReductionPatternInterface(dialect) {}
221 
222   void populateReductionPatterns(RewritePatternSet &patterns) const final {
223     populateTestReductionPatterns(patterns);
224   }
225 };
226 
227 } // end anonymous namespace
228 
229 //===----------------------------------------------------------------------===//
230 // TestDialect
231 //===----------------------------------------------------------------------===//
232 
233 static void testSideEffectOpGetEffect(
234     Operation *op,
235     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects);
236 
237 // This is the implementation of a dialect fallback for `TestEffectOpInterface`.
238 struct TestOpEffectInterfaceFallback
239     : public TestEffectOpInterface::FallbackModel<
240           TestOpEffectInterfaceFallback> {
241   static bool classof(Operation *op) {
242     bool isSupportedOp =
243         op->getName().getStringRef() == "test.unregistered_side_effect_op";
244     assert(isSupportedOp && "Unexpected dispatch");
245     return isSupportedOp;
246   }
247 
248   void
249   getEffects(Operation *op,
250              SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
251                  &effects) const {
252     testSideEffectOpGetEffect(op, effects);
253   }
254 };
255 
256 void TestDialect::initialize() {
257   registerAttributes();
258   registerTypes();
259   addOperations<
260 #define GET_OP_LIST
261 #include "TestOps.cpp.inc"
262       >();
263   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
264                 TestInlinerInterface, TestReductionPatternInterface>();
265   allowUnknownOperations();
266 
267   // Instantiate our fallback op interface that we'll use on specific
268   // unregistered op.
269   fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
270 }
271 TestDialect::~TestDialect() {
272   delete static_cast<TestOpEffectInterfaceFallback *>(
273       fallbackEffectOpInterfaces);
274 }
275 
276 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
277                                             Type type, Location loc) {
278   return builder.create<TestOpConstant>(loc, type, value);
279 }
280 
281 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
282                                                OperationName opName) {
283   if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
284       typeID == TypeID::get<TestEffectOpInterface>())
285     return fallbackEffectOpInterfaces;
286   return nullptr;
287 }
288 
289 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
290                                                     NamedAttribute namedAttr) {
291   if (namedAttr.getName() == "test.invalid_attr")
292     return op->emitError() << "invalid to use 'test.invalid_attr'";
293   return success();
294 }
295 
296 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
297                                                     unsigned regionIndex,
298                                                     unsigned argIndex,
299                                                     NamedAttribute namedAttr) {
300   if (namedAttr.getName() == "test.invalid_attr")
301     return op->emitError() << "invalid to use 'test.invalid_attr'";
302   return success();
303 }
304 
305 LogicalResult
306 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
307                                          unsigned resultIndex,
308                                          NamedAttribute namedAttr) {
309   if (namedAttr.getName() == "test.invalid_attr")
310     return op->emitError() << "invalid to use 'test.invalid_attr'";
311   return success();
312 }
313 
314 Optional<Dialect::ParseOpHook>
315 TestDialect::getParseOperationHook(StringRef opName) const {
316   if (opName == "test.dialect_custom_printer") {
317     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
318       return parser.parseKeyword("custom_format");
319     }};
320   }
321   return None;
322 }
323 
324 llvm::unique_function<void(Operation *, OpAsmPrinter &)>
325 TestDialect::getOperationPrinter(Operation *op) const {
326   StringRef opName = op->getName().getStringRef();
327   if (opName == "test.dialect_custom_printer") {
328     return [](Operation *op, OpAsmPrinter &printer) {
329       printer.getStream() << " custom_format";
330     };
331   }
332   return {};
333 }
334 
335 //===----------------------------------------------------------------------===//
336 // TestBranchOp
337 //===----------------------------------------------------------------------===//
338 
339 Optional<MutableOperandRange>
340 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
341   assert(index == 0 && "invalid successor index");
342   return getTargetOperandsMutable();
343 }
344 
345 //===----------------------------------------------------------------------===//
346 // TestDialectCanonicalizerOp
347 //===----------------------------------------------------------------------===//
348 
349 static LogicalResult
350 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
351                                PatternRewriter &rewriter) {
352   rewriter.replaceOpWithNewOp<arith::ConstantOp>(
353       op, rewriter.getI32IntegerAttr(42));
354   return success();
355 }
356 
357 void TestDialect::getCanonicalizationPatterns(
358     RewritePatternSet &results) const {
359   results.add(&dialectCanonicalizationPattern);
360 }
361 
362 //===----------------------------------------------------------------------===//
363 // TestFoldToCallOp
364 //===----------------------------------------------------------------------===//
365 
366 namespace {
367 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
368   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
369 
370   LogicalResult matchAndRewrite(FoldToCallOp op,
371                                 PatternRewriter &rewriter) const override {
372     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.getCalleeAttr(),
373                                         ValueRange());
374     return success();
375   }
376 };
377 } // end anonymous namespace
378 
379 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
380                                                MLIRContext *context) {
381   results.add<FoldToCallOpPattern>(context);
382 }
383 
384 //===----------------------------------------------------------------------===//
385 // Test Format* operations
386 //===----------------------------------------------------------------------===//
387 
388 //===----------------------------------------------------------------------===//
389 // Parsing
390 
391 static ParseResult parseCustomDirectiveOperands(
392     OpAsmParser &parser, OpAsmParser::OperandType &operand,
393     Optional<OpAsmParser::OperandType> &optOperand,
394     SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
395   if (parser.parseOperand(operand))
396     return failure();
397   if (succeeded(parser.parseOptionalComma())) {
398     optOperand.emplace();
399     if (parser.parseOperand(*optOperand))
400       return failure();
401   }
402   if (parser.parseArrow() || parser.parseLParen() ||
403       parser.parseOperandList(varOperands) || parser.parseRParen())
404     return failure();
405   return success();
406 }
407 static ParseResult
408 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
409                             Type &optOperandType,
410                             SmallVectorImpl<Type> &varOperandTypes) {
411   if (parser.parseColon())
412     return failure();
413 
414   if (parser.parseType(operandType))
415     return failure();
416   if (succeeded(parser.parseOptionalComma())) {
417     if (parser.parseType(optOperandType))
418       return failure();
419   }
420   if (parser.parseArrow() || parser.parseLParen() ||
421       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
422     return failure();
423   return success();
424 }
425 static ParseResult
426 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
427                                  Type optOperandType,
428                                  const SmallVectorImpl<Type> &varOperandTypes) {
429   if (parser.parseKeyword("type_refs_capture"))
430     return failure();
431 
432   Type operandType2, optOperandType2;
433   SmallVector<Type, 1> varOperandTypes2;
434   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
435                                   varOperandTypes2))
436     return failure();
437 
438   if (operandType != operandType2 || optOperandType != optOperandType2 ||
439       varOperandTypes != varOperandTypes2)
440     return failure();
441 
442   return success();
443 }
444 static ParseResult parseCustomDirectiveOperandsAndTypes(
445     OpAsmParser &parser, OpAsmParser::OperandType &operand,
446     Optional<OpAsmParser::OperandType> &optOperand,
447     SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
448     Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
449   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
450       parseCustomDirectiveResults(parser, operandType, optOperandType,
451                                   varOperandTypes))
452     return failure();
453   return success();
454 }
455 static ParseResult parseCustomDirectiveRegions(
456     OpAsmParser &parser, Region &region,
457     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
458   if (parser.parseRegion(region))
459     return failure();
460   if (failed(parser.parseOptionalComma()))
461     return success();
462   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
463   if (parser.parseRegion(*varRegion))
464     return failure();
465   varRegions.emplace_back(std::move(varRegion));
466   return success();
467 }
468 static ParseResult
469 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
470                                SmallVectorImpl<Block *> &varSuccessors) {
471   if (parser.parseSuccessor(successor))
472     return failure();
473   if (failed(parser.parseOptionalComma()))
474     return success();
475   Block *varSuccessor;
476   if (parser.parseSuccessor(varSuccessor))
477     return failure();
478   varSuccessors.append(2, varSuccessor);
479   return success();
480 }
481 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
482                                                   IntegerAttr &attr,
483                                                   IntegerAttr &optAttr) {
484   if (parser.parseAttribute(attr))
485     return failure();
486   if (succeeded(parser.parseOptionalComma())) {
487     if (parser.parseAttribute(optAttr))
488       return failure();
489   }
490   return success();
491 }
492 
493 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
494                                                 NamedAttrList &attrs) {
495   return parser.parseOptionalAttrDict(attrs);
496 }
497 static ParseResult parseCustomDirectiveOptionalOperandRef(
498     OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) {
499   int64_t operandCount = 0;
500   if (parser.parseInteger(operandCount))
501     return failure();
502   bool expectedOptionalOperand = operandCount == 0;
503   return success(expectedOptionalOperand != optOperand.hasValue());
504 }
505 
506 //===----------------------------------------------------------------------===//
507 // Printing
508 
509 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
510                                          Value operand, Value optOperand,
511                                          OperandRange varOperands) {
512   printer << operand;
513   if (optOperand)
514     printer << ", " << optOperand;
515   printer << " -> (" << varOperands << ")";
516 }
517 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
518                                         Type operandType, Type optOperandType,
519                                         TypeRange varOperandTypes) {
520   printer << " : " << operandType;
521   if (optOperandType)
522     printer << ", " << optOperandType;
523   printer << " -> (" << varOperandTypes << ")";
524 }
525 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
526                                              Operation *op, Type operandType,
527                                              Type optOperandType,
528                                              TypeRange varOperandTypes) {
529   printer << " type_refs_capture ";
530   printCustomDirectiveResults(printer, op, operandType, optOperandType,
531                               varOperandTypes);
532 }
533 static void printCustomDirectiveOperandsAndTypes(
534     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
535     OperandRange varOperands, Type operandType, Type optOperandType,
536     TypeRange varOperandTypes) {
537   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
538   printCustomDirectiveResults(printer, op, operandType, optOperandType,
539                               varOperandTypes);
540 }
541 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
542                                         Region &region,
543                                         MutableArrayRef<Region> varRegions) {
544   printer.printRegion(region);
545   if (!varRegions.empty()) {
546     printer << ", ";
547     for (Region &region : varRegions)
548       printer.printRegion(region);
549   }
550 }
551 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
552                                            Block *successor,
553                                            SuccessorRange varSuccessors) {
554   printer << successor;
555   if (!varSuccessors.empty())
556     printer << ", " << varSuccessors.front();
557 }
558 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
559                                            Attribute attribute,
560                                            Attribute optAttribute) {
561   printer << attribute;
562   if (optAttribute)
563     printer << ", " << optAttribute;
564 }
565 
566 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
567                                          DictionaryAttr attrs) {
568   printer.printOptionalAttrDict(attrs.getValue());
569 }
570 
571 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
572                                                    Operation *op,
573                                                    Value optOperand) {
574   printer << (optOperand ? "1" : "0");
575 }
576 
577 //===----------------------------------------------------------------------===//
578 // Test IsolatedRegionOp - parse passthrough region arguments.
579 //===----------------------------------------------------------------------===//
580 
581 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
582                                          OperationState &result) {
583   OpAsmParser::OperandType argInfo;
584   Type argType = parser.getBuilder().getIndexType();
585 
586   // Parse the input operand.
587   if (parser.parseOperand(argInfo) ||
588       parser.resolveOperand(argInfo, argType, result.operands))
589     return failure();
590 
591   // Parse the body region, and reuse the operand info as the argument info.
592   Region *body = result.addRegion();
593   return parser.parseRegion(*body, argInfo, argType,
594                             /*enableNameShadowing=*/true);
595 }
596 
597 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
598   p << "test.isolated_region ";
599   p.printOperand(op.getOperand());
600   p.shadowRegionArgs(op.getRegion(), op.getOperand());
601   p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false);
602 }
603 
604 //===----------------------------------------------------------------------===//
605 // Test SSACFGRegionOp
606 //===----------------------------------------------------------------------===//
607 
608 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
609   return RegionKind::SSACFG;
610 }
611 
612 //===----------------------------------------------------------------------===//
613 // Test GraphRegionOp
614 //===----------------------------------------------------------------------===//
615 
616 static ParseResult parseGraphRegionOp(OpAsmParser &parser,
617                                       OperationState &result) {
618   // Parse the body region, and reuse the operand info as the argument info.
619   Region *body = result.addRegion();
620   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
621 }
622 
623 static void print(OpAsmPrinter &p, GraphRegionOp op) {
624   p << "test.graph_region ";
625   p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false);
626 }
627 
628 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
629   return RegionKind::Graph;
630 }
631 
632 //===----------------------------------------------------------------------===//
633 // Test AffineScopeOp
634 //===----------------------------------------------------------------------===//
635 
636 static ParseResult parseAffineScopeOp(OpAsmParser &parser,
637                                       OperationState &result) {
638   // Parse the body region, and reuse the operand info as the argument info.
639   Region *body = result.addRegion();
640   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
641 }
642 
643 static void print(OpAsmPrinter &p, AffineScopeOp op) {
644   p << "test.affine_scope ";
645   p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false);
646 }
647 
648 //===----------------------------------------------------------------------===//
649 // Test parser.
650 //===----------------------------------------------------------------------===//
651 
652 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser,
653                                               OperationState &result) {
654   if (parser.parseOptionalColon())
655     return success();
656   uint64_t numResults;
657   if (parser.parseInteger(numResults))
658     return failure();
659 
660   IndexType type = parser.getBuilder().getIndexType();
661   for (unsigned i = 0; i < numResults; ++i)
662     result.addTypes(type);
663   return success();
664 }
665 
666 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) {
667   if (unsigned numResults = op->getNumResults())
668     p << " : " << numResults;
669 }
670 
671 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser,
672                                               OperationState &result) {
673   StringRef keyword;
674   if (parser.parseKeyword(&keyword))
675     return failure();
676   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
677   return success();
678 }
679 
680 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) {
681   p << " " << op.getKeyword();
682 }
683 
684 //===----------------------------------------------------------------------===//
685 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
686 
687 static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
688                                          OperationState &result) {
689   if (parser.parseKeyword("wraps"))
690     return failure();
691 
692   // Parse the wrapped op in a region
693   Region &body = *result.addRegion();
694   body.push_back(new Block);
695   Block &block = body.back();
696   Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
697   if (!wrapped_op)
698     return failure();
699 
700   // Create a return terminator in the inner region, pass as operand to the
701   // terminator the returned values from the wrapped operation.
702   SmallVector<Value, 8> return_operands(wrapped_op->getResults());
703   OpBuilder builder(parser.getContext());
704   builder.setInsertionPointToEnd(&block);
705   builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
706 
707   // Get the results type for the wrapping op from the terminator operands.
708   Operation &return_op = body.back().back();
709   result.types.append(return_op.operand_type_begin(),
710                       return_op.operand_type_end());
711 
712   // Use the location of the wrapped op for the "test.wrapping_region" op.
713   result.location = wrapped_op->getLoc();
714 
715   return success();
716 }
717 
718 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
719   p << " wraps ";
720   p.printGenericOp(&op.getRegion().front().front());
721 }
722 
723 //===----------------------------------------------------------------------===//
724 // Test PrettyPrintedRegionOp -  exercising the following parser APIs
725 //   parseGenericOperationAfterOpName
726 //   parseCustomOperationName
727 //===----------------------------------------------------------------------===//
728 
729 static ParseResult parsePrettyPrintedRegionOp(OpAsmParser &parser,
730                                               OperationState &result) {
731 
732   llvm::SMLoc loc = parser.getCurrentLocation();
733   Location currLocation = parser.getEncodedSourceLoc(loc);
734 
735   // Parse the operands.
736   SmallVector<OpAsmParser::OperandType, 2> operands;
737   if (parser.parseOperandList(operands))
738     return failure();
739 
740   // Check if we are parsing the pretty-printed version
741   //  test.pretty_printed_region start <inner-op> end : <functional-type>
742   // Else fallback to parsing the "non pretty-printed" version.
743   if (!succeeded(parser.parseOptionalKeyword("start")))
744     return parser.parseGenericOperationAfterOpName(
745         result, llvm::makeArrayRef(operands));
746 
747   FailureOr<OperationName> parseOpNameInfo = parser.parseCustomOperationName();
748   if (failed(parseOpNameInfo))
749     return failure();
750 
751   StringRef innerOpName = parseOpNameInfo->getStringRef();
752 
753   FunctionType opFntype;
754   Optional<Location> explicitLoc;
755   if (parser.parseKeyword("end") || parser.parseColon() ||
756       parser.parseType(opFntype) ||
757       parser.parseOptionalLocationSpecifier(explicitLoc))
758     return failure();
759 
760   // If location of the op is explicitly provided, then use it; Else use
761   // the parser's current location.
762   Location opLoc = explicitLoc.getValueOr(currLocation);
763 
764   // Derive the SSA-values for op's operands.
765   if (parser.resolveOperands(operands, opFntype.getInputs(), loc,
766                              result.operands))
767     return failure();
768 
769   // Add a region for op.
770   Region &region = *result.addRegion();
771 
772   // Create a basic-block inside op's region.
773   Block &block = region.emplaceBlock();
774 
775   // Create and insert an "inner-op" operation in the block.
776   // Just for testing purposes, we can assume that inner op is a binary op with
777   // result and operand types all same as the test-op's first operand.
778   Type innerOpType = opFntype.getInput(0);
779   Value lhs = block.addArgument(innerOpType, opLoc);
780   Value rhs = block.addArgument(innerOpType, opLoc);
781 
782   OpBuilder builder(parser.getBuilder().getContext());
783   builder.setInsertionPointToStart(&block);
784 
785   OperationState innerOpState(opLoc, innerOpName);
786   innerOpState.operands.push_back(lhs);
787   innerOpState.operands.push_back(rhs);
788   innerOpState.addTypes(innerOpType);
789 
790   Operation *innerOp = builder.createOperation(innerOpState);
791 
792   // Insert a return statement in the block returning the inner-op's result.
793   builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults());
794 
795   // Populate the op operation-state with result-type and location.
796   result.addTypes(opFntype.getResults());
797   result.location = innerOp->getLoc();
798 
799   return success();
800 }
801 
802 static void print(OpAsmPrinter &p, PrettyPrintedRegionOp op) {
803   p << ' ';
804   p.printOperands(op.getOperands());
805 
806   Operation &innerOp = op.getRegion().front().front();
807   // Assuming that region has a single non-terminator inner-op, if the inner-op
808   // meets some criteria (which in this case is a simple one  based on the name
809   // of inner-op), then we can print the entire region in a succinct way.
810   // Here we assume that the prototype of "special.op" can be trivially derived
811   // while parsing it back.
812   if (innerOp.getName().getStringRef().equals("special.op")) {
813     p << " start special.op end";
814   } else {
815     p << " (";
816     p.printRegion(op.getRegion());
817     p << ")";
818   }
819 
820   p << " : ";
821   p.printFunctionalType(op);
822 }
823 
824 //===----------------------------------------------------------------------===//
825 // Test PolyForOp - parse list of region arguments.
826 //===----------------------------------------------------------------------===//
827 
828 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
829   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
830   // Parse list of region arguments without a delimiter.
831   if (parser.parseRegionArgumentList(ivsInfo))
832     return failure();
833 
834   // Parse the body region.
835   Region *body = result.addRegion();
836   auto &builder = parser.getBuilder();
837   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
838   return parser.parseRegion(*body, ivsInfo, argTypes);
839 }
840 
841 //===----------------------------------------------------------------------===//
842 // Test removing op with inner ops.
843 //===----------------------------------------------------------------------===//
844 
845 namespace {
846 struct TestRemoveOpWithInnerOps
847     : public OpRewritePattern<TestOpWithRegionPattern> {
848   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
849 
850   void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
851 
852   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
853                                 PatternRewriter &rewriter) const override {
854     rewriter.eraseOp(op);
855     return success();
856   }
857 };
858 } // end anonymous namespace
859 
860 void TestOpWithRegionPattern::getCanonicalizationPatterns(
861     RewritePatternSet &results, MLIRContext *context) {
862   results.add<TestRemoveOpWithInnerOps>(context);
863 }
864 
865 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
866   return getOperand();
867 }
868 
869 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
870   return getValue();
871 }
872 
873 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
874     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
875   for (Value input : this->getOperands()) {
876     results.push_back(input);
877   }
878   return success();
879 }
880 
881 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
882   assert(operands.size() == 1);
883   if (operands.front()) {
884     (*this)->setAttr("attr", operands.front());
885     return getResult();
886   }
887   return {};
888 }
889 
890 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
891   return getOperand();
892 }
893 
894 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
895     MLIRContext *, Optional<Location> location, ValueRange operands,
896     DictionaryAttr attributes, RegionRange regions,
897     SmallVectorImpl<Type> &inferredReturnTypes) {
898   if (operands[0].getType() != operands[1].getType()) {
899     return emitOptionalError(location, "operand type mismatch ",
900                              operands[0].getType(), " vs ",
901                              operands[1].getType());
902   }
903   inferredReturnTypes.assign({operands[0].getType()});
904   return success();
905 }
906 
907 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
908     MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
909     DictionaryAttr attributes, RegionRange regions,
910     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
911   // Create return type consisting of the last element of the first operand.
912   auto operandType = operands.front().getType();
913   auto sval = operandType.dyn_cast<ShapedType>();
914   if (!sval) {
915     return emitOptionalError(location, "only shaped type operands allowed");
916   }
917   int64_t dim =
918       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
919   auto type = IntegerType::get(context, 17);
920   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
921   return success();
922 }
923 
924 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
925     OpBuilder &builder, ValueRange operands,
926     llvm::SmallVectorImpl<Value> &shapes) {
927   shapes = SmallVector<Value, 1>{
928       builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
929   return success();
930 }
931 
932 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
933     OpBuilder &builder, ValueRange operands,
934     llvm::SmallVectorImpl<Value> &shapes) {
935   Location loc = getLoc();
936   shapes.reserve(operands.size());
937   for (Value operand : llvm::reverse(operands)) {
938     auto currShape = llvm::to_vector<4>(llvm::map_range(
939         llvm::seq<int64_t>(
940             0, operand.getType().cast<RankedTensorType>().getRank()),
941         [&](int64_t dim) -> Value {
942           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
943         }));
944     shapes.push_back(builder.create<tensor::FromElementsOp>(
945         getLoc(), builder.getIndexType(), currShape));
946   }
947   return success();
948 }
949 
950 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
951     OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
952   Location loc = getLoc();
953   shapes.reserve(getNumOperands());
954   for (Value operand : llvm::reverse(getOperands())) {
955     auto currShape = llvm::to_vector<4>(llvm::map_range(
956         llvm::seq<int64_t>(
957             0, operand.getType().cast<RankedTensorType>().getRank()),
958         [&](int64_t dim) -> Value {
959           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
960         }));
961     shapes.emplace_back(std::move(currShape));
962   }
963   return success();
964 }
965 
966 //===----------------------------------------------------------------------===//
967 // Test SideEffect interfaces
968 //===----------------------------------------------------------------------===//
969 
970 namespace {
971 /// A test resource for side effects.
972 struct TestResource : public SideEffects::Resource::Base<TestResource> {
973   StringRef getName() final { return "<Test>"; }
974 };
975 } // end anonymous namespace
976 
977 static void testSideEffectOpGetEffect(
978     Operation *op,
979     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
980         &effects) {
981   auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
982   if (!effectsAttr)
983     return;
984 
985   effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
986 }
987 
988 void SideEffectOp::getEffects(
989     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
990   // Check for an effects attribute on the op instance.
991   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
992   if (!effectsAttr)
993     return;
994 
995   // If there is one, it is an array of dictionary attributes that hold
996   // information on the effects of this operation.
997   for (Attribute element : effectsAttr) {
998     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
999 
1000     // Get the specific memory effect.
1001     MemoryEffects::Effect *effect =
1002         StringSwitch<MemoryEffects::Effect *>(
1003             effectElement.get("effect").cast<StringAttr>().getValue())
1004             .Case("allocate", MemoryEffects::Allocate::get())
1005             .Case("free", MemoryEffects::Free::get())
1006             .Case("read", MemoryEffects::Read::get())
1007             .Case("write", MemoryEffects::Write::get());
1008 
1009     // Check for a non-default resource to use.
1010     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
1011     if (effectElement.get("test_resource"))
1012       resource = TestResource::get();
1013 
1014     // Check for a result to affect.
1015     if (effectElement.get("on_result"))
1016       effects.emplace_back(effect, getResult(), resource);
1017     else if (Attribute ref = effectElement.get("on_reference"))
1018       effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
1019     else
1020       effects.emplace_back(effect, resource);
1021   }
1022 }
1023 
1024 void SideEffectOp::getEffects(
1025     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
1026   testSideEffectOpGetEffect(getOperation(), effects);
1027 }
1028 
1029 //===----------------------------------------------------------------------===//
1030 // StringAttrPrettyNameOp
1031 //===----------------------------------------------------------------------===//
1032 
1033 // This op has fancy handling of its SSA result name.
1034 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
1035                                                OperationState &result) {
1036   // Add the result types.
1037   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
1038     result.addTypes(parser.getBuilder().getIntegerType(32));
1039 
1040   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1041     return failure();
1042 
1043   // If the attribute dictionary contains no 'names' attribute, infer it from
1044   // the SSA name (if specified).
1045   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
1046     return attr.getName() == "names";
1047   });
1048 
1049   // If there was no name specified, check to see if there was a useful name
1050   // specified in the asm file.
1051   if (hadNames || parser.getNumResults() == 0)
1052     return success();
1053 
1054   SmallVector<StringRef, 4> names;
1055   auto *context = result.getContext();
1056 
1057   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
1058     auto resultName = parser.getResultName(i);
1059     StringRef nameStr;
1060     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
1061       nameStr = resultName.first;
1062 
1063     names.push_back(nameStr);
1064   }
1065 
1066   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
1067   result.attributes.push_back({StringAttr::get(context, "names"), namesAttr});
1068   return success();
1069 }
1070 
1071 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
1072   // Note that we only need to print the "name" attribute if the asmprinter
1073   // result name disagrees with it.  This can happen in strange cases, e.g.
1074   // when there are conflicts.
1075   bool namesDisagree = op.getNames().size() != op.getNumResults();
1076 
1077   SmallString<32> resultNameStr;
1078   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
1079     resultNameStr.clear();
1080     llvm::raw_svector_ostream tmpStream(resultNameStr);
1081     p.printOperand(op.getResult(i), tmpStream);
1082 
1083     auto expectedName = op.getNames()[i].dyn_cast<StringAttr>();
1084     if (!expectedName ||
1085         tmpStream.str().drop_front() != expectedName.getValue()) {
1086       namesDisagree = true;
1087     }
1088   }
1089 
1090   if (namesDisagree)
1091     p.printOptionalAttrDictWithKeyword(op->getAttrs());
1092   else
1093     p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"});
1094 }
1095 
1096 // We set the SSA name in the asm syntax to the contents of the name
1097 // attribute.
1098 void StringAttrPrettyNameOp::getAsmResultNames(
1099     function_ref<void(Value, StringRef)> setNameFn) {
1100 
1101   auto value = getNames();
1102   for (size_t i = 0, e = value.size(); i != e; ++i)
1103     if (auto str = value[i].dyn_cast<StringAttr>())
1104       if (!str.getValue().empty())
1105         setNameFn(getResult(i), str.getValue());
1106 }
1107 
1108 //===----------------------------------------------------------------------===//
1109 // RegionIfOp
1110 //===----------------------------------------------------------------------===//
1111 
1112 static void print(OpAsmPrinter &p, RegionIfOp op) {
1113   p << " ";
1114   p.printOperands(op.getOperands());
1115   p << ": " << op.getOperandTypes();
1116   p.printArrowTypeList(op.getResultTypes());
1117   p << " then";
1118   p.printRegion(op.getThenRegion(),
1119                 /*printEntryBlockArgs=*/true,
1120                 /*printBlockTerminators=*/true);
1121   p << " else";
1122   p.printRegion(op.getElseRegion(),
1123                 /*printEntryBlockArgs=*/true,
1124                 /*printBlockTerminators=*/true);
1125   p << " join";
1126   p.printRegion(op.getJoinRegion(),
1127                 /*printEntryBlockArgs=*/true,
1128                 /*printBlockTerminators=*/true);
1129 }
1130 
1131 static ParseResult parseRegionIfOp(OpAsmParser &parser,
1132                                    OperationState &result) {
1133   SmallVector<OpAsmParser::OperandType, 2> operandInfos;
1134   SmallVector<Type, 2> operandTypes;
1135 
1136   result.regions.reserve(3);
1137   Region *thenRegion = result.addRegion();
1138   Region *elseRegion = result.addRegion();
1139   Region *joinRegion = result.addRegion();
1140 
1141   // Parse operand, type and arrow type lists.
1142   if (parser.parseOperandList(operandInfos) ||
1143       parser.parseColonTypeList(operandTypes) ||
1144       parser.parseArrowTypeList(result.types))
1145     return failure();
1146 
1147   // Parse all attached regions.
1148   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
1149       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
1150       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
1151     return failure();
1152 
1153   return parser.resolveOperands(operandInfos, operandTypes,
1154                                 parser.getCurrentLocation(), result.operands);
1155 }
1156 
1157 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
1158   assert(index < 2 && "invalid region index");
1159   return getOperands();
1160 }
1161 
1162 void RegionIfOp::getSuccessorRegions(
1163     Optional<unsigned> index, ArrayRef<Attribute> operands,
1164     SmallVectorImpl<RegionSuccessor> &regions) {
1165   // We always branch to the join region.
1166   if (index.hasValue()) {
1167     if (index.getValue() < 2)
1168       regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
1169     else
1170       regions.push_back(RegionSuccessor(getResults()));
1171     return;
1172   }
1173 
1174   // The then and else regions are the entry regions of this op.
1175   regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs()));
1176   regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs()));
1177 }
1178 
1179 //===----------------------------------------------------------------------===//
1180 // SingleNoTerminatorCustomAsmOp
1181 //===----------------------------------------------------------------------===//
1182 
1183 static ParseResult parseSingleNoTerminatorCustomAsmOp(OpAsmParser &parser,
1184                                                       OperationState &state) {
1185   Region *body = state.addRegion();
1186   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
1187     return failure();
1188   return success();
1189 }
1190 
1191 static void print(SingleNoTerminatorCustomAsmOp op, OpAsmPrinter &printer) {
1192   printer.printRegion(
1193       op.getRegion(), /*printEntryBlockArgs=*/false,
1194       // This op has a single block without terminators. But explicitly mark
1195       // as not printing block terminators for testing.
1196       /*printBlockTerminators=*/false);
1197 }
1198 
1199 #include "TestOpEnums.cpp.inc"
1200 #include "TestOpInterfaces.cpp.inc"
1201 #include "TestOpStructs.cpp.inc"
1202 #include "TestTypeInterfaces.cpp.inc"
1203 
1204 #define GET_OP_CLASSES
1205 #include "TestOps.cpp.inc"
1206