1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestAttributes.h"
11 #include "TestInterfaces.h"
12 #include "TestTypes.h"
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/DLTI/DLTI.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/DialectImplementation.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/Reducer/ReductionPatternInterface.h"
22 #include "mlir/Transforms/FoldUtils.h"
23 #include "mlir/Transforms/InliningUtils.h"
24 #include "llvm/ADT/StringSwitch.h"
25 
26 // Include this before the using namespace lines below to
27 // test that we don't have namespace dependencies.
28 #include "TestOpsDialect.cpp.inc"
29 
30 using namespace mlir;
31 using namespace test;
32 
33 void test::registerTestDialect(DialectRegistry &registry) {
34   registry.insert<TestDialect>();
35 }
36 
37 //===----------------------------------------------------------------------===//
38 // TestDialect Interfaces
39 //===----------------------------------------------------------------------===//
40 
41 namespace {
42 
43 /// Testing the correctness of some traits.
44 static_assert(
45     llvm::is_detected<OpTrait::has_implicit_terminator_t,
46                       SingleBlockImplicitTerminatorOp>::value,
47     "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
48 static_assert(OpTrait::hasSingleBlockImplicitTerminator<
49                   SingleBlockImplicitTerminatorOp>::value,
50               "hasSingleBlockImplicitTerminator does not match "
51               "SingleBlockImplicitTerminatorOp");
52 
53 // Test support for interacting with the AsmPrinter.
54 struct TestOpAsmInterface : public OpAsmDialectInterface {
55   using OpAsmDialectInterface::OpAsmDialectInterface;
56 
57   AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
58     StringAttr strAttr = attr.dyn_cast<StringAttr>();
59     if (!strAttr)
60       return AliasResult::NoAlias;
61 
62     // Check the contents of the string attribute to see what the test alias
63     // should be named.
64     Optional<StringRef> aliasName =
65         StringSwitch<Optional<StringRef>>(strAttr.getValue())
66             .Case("alias_test:dot_in_name", StringRef("test.alias"))
67             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
68             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
69             .Case("alias_test:sanitize_conflict_a",
70                   StringRef("test_alias_conflict0"))
71             .Case("alias_test:sanitize_conflict_b",
72                   StringRef("test_alias_conflict0_"))
73             .Case("alias_test:tensor_encoding", StringRef("test_encoding"))
74             .Default(llvm::None);
75     if (!aliasName)
76       return AliasResult::NoAlias;
77 
78     os << *aliasName;
79     return AliasResult::FinalAlias;
80   }
81 
82   AliasResult getAlias(Type type, raw_ostream &os) const final {
83     if (auto tupleType = type.dyn_cast<TupleType>()) {
84       if (tupleType.size() > 0 &&
85           llvm::all_of(tupleType.getTypes(), [](Type elemType) {
86             return elemType.isa<SimpleAType>();
87           })) {
88         os << "test_tuple";
89         return AliasResult::FinalAlias;
90       }
91     }
92     if (auto intType = type.dyn_cast<TestIntegerType>()) {
93       if (intType.getSignedness() ==
94               TestIntegerType::SignednessSemantics::Unsigned &&
95           intType.getWidth() == 8) {
96         os << "test_ui8";
97         return AliasResult::FinalAlias;
98       }
99     }
100     return AliasResult::NoAlias;
101   }
102 
103   void getAsmResultNames(Operation *op,
104                          OpAsmSetValueNameFn setNameFn) const final {
105     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
106       setNameFn(asmOp, "result");
107   }
108 
109   void getAsmBlockArgumentNames(Block *block,
110                                 OpAsmSetValueNameFn setNameFn) const final {
111     auto op = block->getParentOp();
112     auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
113     if (!arrayAttr)
114       return;
115     auto args = block->getArguments();
116     auto e = std::min(arrayAttr.size(), args.size());
117     for (unsigned i = 0; i < e; ++i) {
118       if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
119         setNameFn(args[i], strAttr.getValue());
120     }
121   }
122 };
123 
124 struct TestDialectFoldInterface : public DialectFoldInterface {
125   using DialectFoldInterface::DialectFoldInterface;
126 
127   /// Registered hook to check if the given region, which is attached to an
128   /// operation that is *not* isolated from above, should be used when
129   /// materializing constants.
130   bool shouldMaterializeInto(Region *region) const final {
131     // If this is a one region operation, then insert into it.
132     return isa<OneRegionOp>(region->getParentOp());
133   }
134 };
135 
136 /// This class defines the interface for handling inlining with standard
137 /// operations.
138 struct TestInlinerInterface : public DialectInlinerInterface {
139   using DialectInlinerInterface::DialectInlinerInterface;
140 
141   //===--------------------------------------------------------------------===//
142   // Analysis Hooks
143   //===--------------------------------------------------------------------===//
144 
145   bool isLegalToInline(Operation *call, Operation *callable,
146                        bool wouldBeCloned) const final {
147     // Don't allow inlining calls that are marked `noinline`.
148     return !call->hasAttr("noinline");
149   }
150   bool isLegalToInline(Region *, Region *, bool,
151                        BlockAndValueMapping &) const final {
152     // Inlining into test dialect regions is legal.
153     return true;
154   }
155   bool isLegalToInline(Operation *, Region *, bool,
156                        BlockAndValueMapping &) const final {
157     return true;
158   }
159 
160   bool shouldAnalyzeRecursively(Operation *op) const final {
161     // Analyze recursively if this is not a functional region operation, it
162     // froms a separate functional scope.
163     return !isa<FunctionalRegionOp>(op);
164   }
165 
166   //===--------------------------------------------------------------------===//
167   // Transformation Hooks
168   //===--------------------------------------------------------------------===//
169 
170   /// Handle the given inlined terminator by replacing it with a new operation
171   /// as necessary.
172   void handleTerminator(Operation *op,
173                         ArrayRef<Value> valuesToRepl) const final {
174     // Only handle "test.return" here.
175     auto returnOp = dyn_cast<TestReturnOp>(op);
176     if (!returnOp)
177       return;
178 
179     // Replace the values directly with the return operands.
180     assert(returnOp.getNumOperands() == valuesToRepl.size());
181     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
182       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
183   }
184 
185   /// Attempt to materialize a conversion for a type mismatch between a call
186   /// from this dialect, and a callable region. This method should generate an
187   /// operation that takes 'input' as the only operand, and produces a single
188   /// result of 'resultType'. If a conversion can not be generated, nullptr
189   /// should be returned.
190   Operation *materializeCallConversion(OpBuilder &builder, Value input,
191                                        Type resultType,
192                                        Location conversionLoc) const final {
193     // Only allow conversion for i16/i32 types.
194     if (!(resultType.isSignlessInteger(16) ||
195           resultType.isSignlessInteger(32)) ||
196         !(input.getType().isSignlessInteger(16) ||
197           input.getType().isSignlessInteger(32)))
198       return nullptr;
199     return builder.create<TestCastOp>(conversionLoc, resultType, input);
200   }
201 
202   void processInlinedCallBlocks(
203       Operation *call,
204       iterator_range<Region::iterator> inlinedBlocks) const final {
205     if (!isa<ConversionCallOp>(call))
206       return;
207 
208     // Set attributed on all ops in the inlined blocks.
209     for (Block &block : inlinedBlocks) {
210       block.walk([&](Operation *op) {
211         op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
212       });
213     }
214   }
215 };
216 
217 struct TestReductionPatternInterface : public DialectReductionPatternInterface {
218 public:
219   TestReductionPatternInterface(Dialect *dialect)
220       : DialectReductionPatternInterface(dialect) {}
221 
222   void populateReductionPatterns(RewritePatternSet &patterns) const final {
223     populateTestReductionPatterns(patterns);
224   }
225 };
226 
227 } // namespace
228 
229 //===----------------------------------------------------------------------===//
230 // TestDialect
231 //===----------------------------------------------------------------------===//
232 
233 static void testSideEffectOpGetEffect(
234     Operation *op,
235     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects);
236 
237 // This is the implementation of a dialect fallback for `TestEffectOpInterface`.
238 struct TestOpEffectInterfaceFallback
239     : public TestEffectOpInterface::FallbackModel<
240           TestOpEffectInterfaceFallback> {
241   static bool classof(Operation *op) {
242     bool isSupportedOp =
243         op->getName().getStringRef() == "test.unregistered_side_effect_op";
244     assert(isSupportedOp && "Unexpected dispatch");
245     return isSupportedOp;
246   }
247 
248   void
249   getEffects(Operation *op,
250              SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
251                  &effects) const {
252     testSideEffectOpGetEffect(op, effects);
253   }
254 };
255 
256 void TestDialect::initialize() {
257   registerAttributes();
258   registerTypes();
259   addOperations<
260 #define GET_OP_LIST
261 #include "TestOps.cpp.inc"
262       >();
263   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
264                 TestInlinerInterface, TestReductionPatternInterface>();
265   allowUnknownOperations();
266 
267   // Instantiate our fallback op interface that we'll use on specific
268   // unregistered op.
269   fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
270 }
271 TestDialect::~TestDialect() {
272   delete static_cast<TestOpEffectInterfaceFallback *>(
273       fallbackEffectOpInterfaces);
274 }
275 
276 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
277                                             Type type, Location loc) {
278   return builder.create<TestOpConstant>(loc, type, value);
279 }
280 
281 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
282                                                OperationName opName) {
283   if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
284       typeID == TypeID::get<TestEffectOpInterface>())
285     return fallbackEffectOpInterfaces;
286   return nullptr;
287 }
288 
289 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
290                                                     NamedAttribute namedAttr) {
291   if (namedAttr.getName() == "test.invalid_attr")
292     return op->emitError() << "invalid to use 'test.invalid_attr'";
293   return success();
294 }
295 
296 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
297                                                     unsigned regionIndex,
298                                                     unsigned argIndex,
299                                                     NamedAttribute namedAttr) {
300   if (namedAttr.getName() == "test.invalid_attr")
301     return op->emitError() << "invalid to use 'test.invalid_attr'";
302   return success();
303 }
304 
305 LogicalResult
306 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
307                                          unsigned resultIndex,
308                                          NamedAttribute namedAttr) {
309   if (namedAttr.getName() == "test.invalid_attr")
310     return op->emitError() << "invalid to use 'test.invalid_attr'";
311   return success();
312 }
313 
314 Optional<Dialect::ParseOpHook>
315 TestDialect::getParseOperationHook(StringRef opName) const {
316   if (opName == "test.dialect_custom_printer") {
317     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
318       return parser.parseKeyword("custom_format");
319     }};
320   }
321   if (opName == "test.dialect_custom_format_fallback") {
322     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
323       return parser.parseKeyword("custom_format_fallback");
324     }};
325   }
326   return None;
327 }
328 
329 llvm::unique_function<void(Operation *, OpAsmPrinter &)>
330 TestDialect::getOperationPrinter(Operation *op) const {
331   StringRef opName = op->getName().getStringRef();
332   if (opName == "test.dialect_custom_printer") {
333     return [](Operation *op, OpAsmPrinter &printer) {
334       printer.getStream() << " custom_format";
335     };
336   }
337   if (opName == "test.dialect_custom_format_fallback") {
338     return [](Operation *op, OpAsmPrinter &printer) {
339       printer.getStream() << " custom_format_fallback";
340     };
341   }
342   return {};
343 }
344 
345 //===----------------------------------------------------------------------===//
346 // TestBranchOp
347 //===----------------------------------------------------------------------===//
348 
349 Optional<MutableOperandRange>
350 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
351   assert(index == 0 && "invalid successor index");
352   return getTargetOperandsMutable();
353 }
354 
355 //===----------------------------------------------------------------------===//
356 // TestDialectCanonicalizerOp
357 //===----------------------------------------------------------------------===//
358 
359 static LogicalResult
360 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
361                                PatternRewriter &rewriter) {
362   rewriter.replaceOpWithNewOp<arith::ConstantOp>(
363       op, rewriter.getI32IntegerAttr(42));
364   return success();
365 }
366 
367 void TestDialect::getCanonicalizationPatterns(
368     RewritePatternSet &results) const {
369   results.add(&dialectCanonicalizationPattern);
370 }
371 
372 //===----------------------------------------------------------------------===//
373 // TestFoldToCallOp
374 //===----------------------------------------------------------------------===//
375 
376 namespace {
377 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
378   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
379 
380   LogicalResult matchAndRewrite(FoldToCallOp op,
381                                 PatternRewriter &rewriter) const override {
382     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.getCalleeAttr(),
383                                         ValueRange());
384     return success();
385   }
386 };
387 } // namespace
388 
389 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
390                                                MLIRContext *context) {
391   results.add<FoldToCallOpPattern>(context);
392 }
393 
394 //===----------------------------------------------------------------------===//
395 // Test Format* operations
396 //===----------------------------------------------------------------------===//
397 
398 //===----------------------------------------------------------------------===//
399 // Parsing
400 
401 static ParseResult parseCustomDirectiveOperands(
402     OpAsmParser &parser, OpAsmParser::OperandType &operand,
403     Optional<OpAsmParser::OperandType> &optOperand,
404     SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
405   if (parser.parseOperand(operand))
406     return failure();
407   if (succeeded(parser.parseOptionalComma())) {
408     optOperand.emplace();
409     if (parser.parseOperand(*optOperand))
410       return failure();
411   }
412   if (parser.parseArrow() || parser.parseLParen() ||
413       parser.parseOperandList(varOperands) || parser.parseRParen())
414     return failure();
415   return success();
416 }
417 static ParseResult
418 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
419                             Type &optOperandType,
420                             SmallVectorImpl<Type> &varOperandTypes) {
421   if (parser.parseColon())
422     return failure();
423 
424   if (parser.parseType(operandType))
425     return failure();
426   if (succeeded(parser.parseOptionalComma())) {
427     if (parser.parseType(optOperandType))
428       return failure();
429   }
430   if (parser.parseArrow() || parser.parseLParen() ||
431       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
432     return failure();
433   return success();
434 }
435 static ParseResult
436 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
437                                  Type optOperandType,
438                                  const SmallVectorImpl<Type> &varOperandTypes) {
439   if (parser.parseKeyword("type_refs_capture"))
440     return failure();
441 
442   Type operandType2, optOperandType2;
443   SmallVector<Type, 1> varOperandTypes2;
444   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
445                                   varOperandTypes2))
446     return failure();
447 
448   if (operandType != operandType2 || optOperandType != optOperandType2 ||
449       varOperandTypes != varOperandTypes2)
450     return failure();
451 
452   return success();
453 }
454 static ParseResult parseCustomDirectiveOperandsAndTypes(
455     OpAsmParser &parser, OpAsmParser::OperandType &operand,
456     Optional<OpAsmParser::OperandType> &optOperand,
457     SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
458     Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
459   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
460       parseCustomDirectiveResults(parser, operandType, optOperandType,
461                                   varOperandTypes))
462     return failure();
463   return success();
464 }
465 static ParseResult parseCustomDirectiveRegions(
466     OpAsmParser &parser, Region &region,
467     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
468   if (parser.parseRegion(region))
469     return failure();
470   if (failed(parser.parseOptionalComma()))
471     return success();
472   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
473   if (parser.parseRegion(*varRegion))
474     return failure();
475   varRegions.emplace_back(std::move(varRegion));
476   return success();
477 }
478 static ParseResult
479 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
480                                SmallVectorImpl<Block *> &varSuccessors) {
481   if (parser.parseSuccessor(successor))
482     return failure();
483   if (failed(parser.parseOptionalComma()))
484     return success();
485   Block *varSuccessor;
486   if (parser.parseSuccessor(varSuccessor))
487     return failure();
488   varSuccessors.append(2, varSuccessor);
489   return success();
490 }
491 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
492                                                   IntegerAttr &attr,
493                                                   IntegerAttr &optAttr) {
494   if (parser.parseAttribute(attr))
495     return failure();
496   if (succeeded(parser.parseOptionalComma())) {
497     if (parser.parseAttribute(optAttr))
498       return failure();
499   }
500   return success();
501 }
502 
503 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
504                                                 NamedAttrList &attrs) {
505   return parser.parseOptionalAttrDict(attrs);
506 }
507 static ParseResult parseCustomDirectiveOptionalOperandRef(
508     OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) {
509   int64_t operandCount = 0;
510   if (parser.parseInteger(operandCount))
511     return failure();
512   bool expectedOptionalOperand = operandCount == 0;
513   return success(expectedOptionalOperand != optOperand.hasValue());
514 }
515 
516 //===----------------------------------------------------------------------===//
517 // Printing
518 
519 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
520                                          Value operand, Value optOperand,
521                                          OperandRange varOperands) {
522   printer << operand;
523   if (optOperand)
524     printer << ", " << optOperand;
525   printer << " -> (" << varOperands << ")";
526 }
527 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
528                                         Type operandType, Type optOperandType,
529                                         TypeRange varOperandTypes) {
530   printer << " : " << operandType;
531   if (optOperandType)
532     printer << ", " << optOperandType;
533   printer << " -> (" << varOperandTypes << ")";
534 }
535 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
536                                              Operation *op, Type operandType,
537                                              Type optOperandType,
538                                              TypeRange varOperandTypes) {
539   printer << " type_refs_capture ";
540   printCustomDirectiveResults(printer, op, operandType, optOperandType,
541                               varOperandTypes);
542 }
543 static void printCustomDirectiveOperandsAndTypes(
544     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
545     OperandRange varOperands, Type operandType, Type optOperandType,
546     TypeRange varOperandTypes) {
547   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
548   printCustomDirectiveResults(printer, op, operandType, optOperandType,
549                               varOperandTypes);
550 }
551 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
552                                         Region &region,
553                                         MutableArrayRef<Region> varRegions) {
554   printer.printRegion(region);
555   if (!varRegions.empty()) {
556     printer << ", ";
557     for (Region &region : varRegions)
558       printer.printRegion(region);
559   }
560 }
561 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
562                                            Block *successor,
563                                            SuccessorRange varSuccessors) {
564   printer << successor;
565   if (!varSuccessors.empty())
566     printer << ", " << varSuccessors.front();
567 }
568 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
569                                            Attribute attribute,
570                                            Attribute optAttribute) {
571   printer << attribute;
572   if (optAttribute)
573     printer << ", " << optAttribute;
574 }
575 
576 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
577                                          DictionaryAttr attrs) {
578   printer.printOptionalAttrDict(attrs.getValue());
579 }
580 
581 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
582                                                    Operation *op,
583                                                    Value optOperand) {
584   printer << (optOperand ? "1" : "0");
585 }
586 
587 //===----------------------------------------------------------------------===//
588 // Test IsolatedRegionOp - parse passthrough region arguments.
589 //===----------------------------------------------------------------------===//
590 
591 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
592                                          OperationState &result) {
593   OpAsmParser::OperandType argInfo;
594   Type argType = parser.getBuilder().getIndexType();
595 
596   // Parse the input operand.
597   if (parser.parseOperand(argInfo) ||
598       parser.resolveOperand(argInfo, argType, result.operands))
599     return failure();
600 
601   // Parse the body region, and reuse the operand info as the argument info.
602   Region *body = result.addRegion();
603   return parser.parseRegion(*body, argInfo, argType,
604                             /*enableNameShadowing=*/true);
605 }
606 
607 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
608   p << "test.isolated_region ";
609   p.printOperand(op.getOperand());
610   p.shadowRegionArgs(op.getRegion(), op.getOperand());
611   p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false);
612 }
613 
614 //===----------------------------------------------------------------------===//
615 // Test SSACFGRegionOp
616 //===----------------------------------------------------------------------===//
617 
618 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
619   return RegionKind::SSACFG;
620 }
621 
622 //===----------------------------------------------------------------------===//
623 // Test GraphRegionOp
624 //===----------------------------------------------------------------------===//
625 
626 static ParseResult parseGraphRegionOp(OpAsmParser &parser,
627                                       OperationState &result) {
628   // Parse the body region, and reuse the operand info as the argument info.
629   Region *body = result.addRegion();
630   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
631 }
632 
633 static void print(OpAsmPrinter &p, GraphRegionOp op) {
634   p << "test.graph_region ";
635   p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false);
636 }
637 
638 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
639   return RegionKind::Graph;
640 }
641 
642 //===----------------------------------------------------------------------===//
643 // Test AffineScopeOp
644 //===----------------------------------------------------------------------===//
645 
646 static ParseResult parseAffineScopeOp(OpAsmParser &parser,
647                                       OperationState &result) {
648   // Parse the body region, and reuse the operand info as the argument info.
649   Region *body = result.addRegion();
650   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
651 }
652 
653 static void print(OpAsmPrinter &p, AffineScopeOp op) {
654   p << "test.affine_scope ";
655   p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false);
656 }
657 
658 //===----------------------------------------------------------------------===//
659 // Test parser.
660 //===----------------------------------------------------------------------===//
661 
662 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser,
663                                               OperationState &result) {
664   if (parser.parseOptionalColon())
665     return success();
666   uint64_t numResults;
667   if (parser.parseInteger(numResults))
668     return failure();
669 
670   IndexType type = parser.getBuilder().getIndexType();
671   for (unsigned i = 0; i < numResults; ++i)
672     result.addTypes(type);
673   return success();
674 }
675 
676 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) {
677   if (unsigned numResults = op->getNumResults())
678     p << " : " << numResults;
679 }
680 
681 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser,
682                                               OperationState &result) {
683   StringRef keyword;
684   if (parser.parseKeyword(&keyword))
685     return failure();
686   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
687   return success();
688 }
689 
690 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) {
691   p << " " << op.getKeyword();
692 }
693 
694 //===----------------------------------------------------------------------===//
695 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
696 
697 static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
698                                          OperationState &result) {
699   if (parser.parseKeyword("wraps"))
700     return failure();
701 
702   // Parse the wrapped op in a region
703   Region &body = *result.addRegion();
704   body.push_back(new Block);
705   Block &block = body.back();
706   Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
707   if (!wrapped_op)
708     return failure();
709 
710   // Create a return terminator in the inner region, pass as operand to the
711   // terminator the returned values from the wrapped operation.
712   SmallVector<Value, 8> return_operands(wrapped_op->getResults());
713   OpBuilder builder(parser.getContext());
714   builder.setInsertionPointToEnd(&block);
715   builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
716 
717   // Get the results type for the wrapping op from the terminator operands.
718   Operation &return_op = body.back().back();
719   result.types.append(return_op.operand_type_begin(),
720                       return_op.operand_type_end());
721 
722   // Use the location of the wrapped op for the "test.wrapping_region" op.
723   result.location = wrapped_op->getLoc();
724 
725   return success();
726 }
727 
728 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
729   p << " wraps ";
730   p.printGenericOp(&op.getRegion().front().front());
731 }
732 
733 //===----------------------------------------------------------------------===//
734 // Test PrettyPrintedRegionOp -  exercising the following parser APIs
735 //   parseGenericOperationAfterOpName
736 //   parseCustomOperationName
737 //===----------------------------------------------------------------------===//
738 
739 static ParseResult parsePrettyPrintedRegionOp(OpAsmParser &parser,
740                                               OperationState &result) {
741 
742   llvm::SMLoc loc = parser.getCurrentLocation();
743   Location currLocation = parser.getEncodedSourceLoc(loc);
744 
745   // Parse the operands.
746   SmallVector<OpAsmParser::OperandType, 2> operands;
747   if (parser.parseOperandList(operands))
748     return failure();
749 
750   // Check if we are parsing the pretty-printed version
751   //  test.pretty_printed_region start <inner-op> end : <functional-type>
752   // Else fallback to parsing the "non pretty-printed" version.
753   if (!succeeded(parser.parseOptionalKeyword("start")))
754     return parser.parseGenericOperationAfterOpName(
755         result, llvm::makeArrayRef(operands));
756 
757   FailureOr<OperationName> parseOpNameInfo = parser.parseCustomOperationName();
758   if (failed(parseOpNameInfo))
759     return failure();
760 
761   StringRef innerOpName = parseOpNameInfo->getStringRef();
762 
763   FunctionType opFntype;
764   Optional<Location> explicitLoc;
765   if (parser.parseKeyword("end") || parser.parseColon() ||
766       parser.parseType(opFntype) ||
767       parser.parseOptionalLocationSpecifier(explicitLoc))
768     return failure();
769 
770   // If location of the op is explicitly provided, then use it; Else use
771   // the parser's current location.
772   Location opLoc = explicitLoc.getValueOr(currLocation);
773 
774   // Derive the SSA-values for op's operands.
775   if (parser.resolveOperands(operands, opFntype.getInputs(), loc,
776                              result.operands))
777     return failure();
778 
779   // Add a region for op.
780   Region &region = *result.addRegion();
781 
782   // Create a basic-block inside op's region.
783   Block &block = region.emplaceBlock();
784 
785   // Create and insert an "inner-op" operation in the block.
786   // Just for testing purposes, we can assume that inner op is a binary op with
787   // result and operand types all same as the test-op's first operand.
788   Type innerOpType = opFntype.getInput(0);
789   Value lhs = block.addArgument(innerOpType, opLoc);
790   Value rhs = block.addArgument(innerOpType, opLoc);
791 
792   OpBuilder builder(parser.getBuilder().getContext());
793   builder.setInsertionPointToStart(&block);
794 
795   OperationState innerOpState(opLoc, innerOpName);
796   innerOpState.operands.push_back(lhs);
797   innerOpState.operands.push_back(rhs);
798   innerOpState.addTypes(innerOpType);
799 
800   Operation *innerOp = builder.createOperation(innerOpState);
801 
802   // Insert a return statement in the block returning the inner-op's result.
803   builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults());
804 
805   // Populate the op operation-state with result-type and location.
806   result.addTypes(opFntype.getResults());
807   result.location = innerOp->getLoc();
808 
809   return success();
810 }
811 
812 static void print(OpAsmPrinter &p, PrettyPrintedRegionOp op) {
813   p << ' ';
814   p.printOperands(op.getOperands());
815 
816   Operation &innerOp = op.getRegion().front().front();
817   // Assuming that region has a single non-terminator inner-op, if the inner-op
818   // meets some criteria (which in this case is a simple one  based on the name
819   // of inner-op), then we can print the entire region in a succinct way.
820   // Here we assume that the prototype of "special.op" can be trivially derived
821   // while parsing it back.
822   if (innerOp.getName().getStringRef().equals("special.op")) {
823     p << " start special.op end";
824   } else {
825     p << " (";
826     p.printRegion(op.getRegion());
827     p << ")";
828   }
829 
830   p << " : ";
831   p.printFunctionalType(op);
832 }
833 
834 //===----------------------------------------------------------------------===//
835 // Test PolyForOp - parse list of region arguments.
836 //===----------------------------------------------------------------------===//
837 
838 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
839   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
840   // Parse list of region arguments without a delimiter.
841   if (parser.parseRegionArgumentList(ivsInfo))
842     return failure();
843 
844   // Parse the body region.
845   Region *body = result.addRegion();
846   auto &builder = parser.getBuilder();
847   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
848   return parser.parseRegion(*body, ivsInfo, argTypes);
849 }
850 
851 //===----------------------------------------------------------------------===//
852 // Test removing op with inner ops.
853 //===----------------------------------------------------------------------===//
854 
855 namespace {
856 struct TestRemoveOpWithInnerOps
857     : public OpRewritePattern<TestOpWithRegionPattern> {
858   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
859 
860   void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
861 
862   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
863                                 PatternRewriter &rewriter) const override {
864     rewriter.eraseOp(op);
865     return success();
866   }
867 };
868 } // namespace
869 
870 void TestOpWithRegionPattern::getCanonicalizationPatterns(
871     RewritePatternSet &results, MLIRContext *context) {
872   results.add<TestRemoveOpWithInnerOps>(context);
873 }
874 
875 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
876   return getOperand();
877 }
878 
879 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
880   return getValue();
881 }
882 
883 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
884     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
885   for (Value input : this->getOperands()) {
886     results.push_back(input);
887   }
888   return success();
889 }
890 
891 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
892   assert(operands.size() == 1);
893   if (operands.front()) {
894     (*this)->setAttr("attr", operands.front());
895     return getResult();
896   }
897   return {};
898 }
899 
900 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
901   return getOperand();
902 }
903 
904 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
905     MLIRContext *, Optional<Location> location, ValueRange operands,
906     DictionaryAttr attributes, RegionRange regions,
907     SmallVectorImpl<Type> &inferredReturnTypes) {
908   if (operands[0].getType() != operands[1].getType()) {
909     return emitOptionalError(location, "operand type mismatch ",
910                              operands[0].getType(), " vs ",
911                              operands[1].getType());
912   }
913   inferredReturnTypes.assign({operands[0].getType()});
914   return success();
915 }
916 
917 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
918     MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
919     DictionaryAttr attributes, RegionRange regions,
920     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
921   // Create return type consisting of the last element of the first operand.
922   auto operandType = operands.front().getType();
923   auto sval = operandType.dyn_cast<ShapedType>();
924   if (!sval) {
925     return emitOptionalError(location, "only shaped type operands allowed");
926   }
927   int64_t dim =
928       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
929   auto type = IntegerType::get(context, 17);
930   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
931   return success();
932 }
933 
934 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
935     OpBuilder &builder, ValueRange operands,
936     llvm::SmallVectorImpl<Value> &shapes) {
937   shapes = SmallVector<Value, 1>{
938       builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
939   return success();
940 }
941 
942 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
943     OpBuilder &builder, ValueRange operands,
944     llvm::SmallVectorImpl<Value> &shapes) {
945   Location loc = getLoc();
946   shapes.reserve(operands.size());
947   for (Value operand : llvm::reverse(operands)) {
948     auto currShape = llvm::to_vector<4>(llvm::map_range(
949         llvm::seq<int64_t>(
950             0, operand.getType().cast<RankedTensorType>().getRank()),
951         [&](int64_t dim) -> Value {
952           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
953         }));
954     shapes.push_back(builder.create<tensor::FromElementsOp>(
955         getLoc(), builder.getIndexType(), currShape));
956   }
957   return success();
958 }
959 
960 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
961     OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
962   Location loc = getLoc();
963   shapes.reserve(getNumOperands());
964   for (Value operand : llvm::reverse(getOperands())) {
965     auto currShape = llvm::to_vector<4>(llvm::map_range(
966         llvm::seq<int64_t>(
967             0, operand.getType().cast<RankedTensorType>().getRank()),
968         [&](int64_t dim) -> Value {
969           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
970         }));
971     shapes.emplace_back(std::move(currShape));
972   }
973   return success();
974 }
975 
976 //===----------------------------------------------------------------------===//
977 // Test SideEffect interfaces
978 //===----------------------------------------------------------------------===//
979 
980 namespace {
981 /// A test resource for side effects.
982 struct TestResource : public SideEffects::Resource::Base<TestResource> {
983   StringRef getName() final { return "<Test>"; }
984 };
985 } // namespace
986 
987 static void testSideEffectOpGetEffect(
988     Operation *op,
989     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
990         &effects) {
991   auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
992   if (!effectsAttr)
993     return;
994 
995   effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
996 }
997 
998 void SideEffectOp::getEffects(
999     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1000   // Check for an effects attribute on the op instance.
1001   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
1002   if (!effectsAttr)
1003     return;
1004 
1005   // If there is one, it is an array of dictionary attributes that hold
1006   // information on the effects of this operation.
1007   for (Attribute element : effectsAttr) {
1008     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
1009 
1010     // Get the specific memory effect.
1011     MemoryEffects::Effect *effect =
1012         StringSwitch<MemoryEffects::Effect *>(
1013             effectElement.get("effect").cast<StringAttr>().getValue())
1014             .Case("allocate", MemoryEffects::Allocate::get())
1015             .Case("free", MemoryEffects::Free::get())
1016             .Case("read", MemoryEffects::Read::get())
1017             .Case("write", MemoryEffects::Write::get());
1018 
1019     // Check for a non-default resource to use.
1020     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
1021     if (effectElement.get("test_resource"))
1022       resource = TestResource::get();
1023 
1024     // Check for a result to affect.
1025     if (effectElement.get("on_result"))
1026       effects.emplace_back(effect, getResult(), resource);
1027     else if (Attribute ref = effectElement.get("on_reference"))
1028       effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
1029     else
1030       effects.emplace_back(effect, resource);
1031   }
1032 }
1033 
1034 void SideEffectOp::getEffects(
1035     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
1036   testSideEffectOpGetEffect(getOperation(), effects);
1037 }
1038 
1039 //===----------------------------------------------------------------------===//
1040 // StringAttrPrettyNameOp
1041 //===----------------------------------------------------------------------===//
1042 
1043 // This op has fancy handling of its SSA result name.
1044 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
1045                                                OperationState &result) {
1046   // Add the result types.
1047   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
1048     result.addTypes(parser.getBuilder().getIntegerType(32));
1049 
1050   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1051     return failure();
1052 
1053   // If the attribute dictionary contains no 'names' attribute, infer it from
1054   // the SSA name (if specified).
1055   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
1056     return attr.getName() == "names";
1057   });
1058 
1059   // If there was no name specified, check to see if there was a useful name
1060   // specified in the asm file.
1061   if (hadNames || parser.getNumResults() == 0)
1062     return success();
1063 
1064   SmallVector<StringRef, 4> names;
1065   auto *context = result.getContext();
1066 
1067   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
1068     auto resultName = parser.getResultName(i);
1069     StringRef nameStr;
1070     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
1071       nameStr = resultName.first;
1072 
1073     names.push_back(nameStr);
1074   }
1075 
1076   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
1077   result.attributes.push_back({StringAttr::get(context, "names"), namesAttr});
1078   return success();
1079 }
1080 
1081 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
1082   // Note that we only need to print the "name" attribute if the asmprinter
1083   // result name disagrees with it.  This can happen in strange cases, e.g.
1084   // when there are conflicts.
1085   bool namesDisagree = op.getNames().size() != op.getNumResults();
1086 
1087   SmallString<32> resultNameStr;
1088   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
1089     resultNameStr.clear();
1090     llvm::raw_svector_ostream tmpStream(resultNameStr);
1091     p.printOperand(op.getResult(i), tmpStream);
1092 
1093     auto expectedName = op.getNames()[i].dyn_cast<StringAttr>();
1094     if (!expectedName ||
1095         tmpStream.str().drop_front() != expectedName.getValue()) {
1096       namesDisagree = true;
1097     }
1098   }
1099 
1100   if (namesDisagree)
1101     p.printOptionalAttrDictWithKeyword(op->getAttrs());
1102   else
1103     p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"});
1104 }
1105 
1106 // We set the SSA name in the asm syntax to the contents of the name
1107 // attribute.
1108 void StringAttrPrettyNameOp::getAsmResultNames(
1109     function_ref<void(Value, StringRef)> setNameFn) {
1110 
1111   auto value = getNames();
1112   for (size_t i = 0, e = value.size(); i != e; ++i)
1113     if (auto str = value[i].dyn_cast<StringAttr>())
1114       if (!str.getValue().empty())
1115         setNameFn(getResult(i), str.getValue());
1116 }
1117 
1118 //===----------------------------------------------------------------------===//
1119 // RegionIfOp
1120 //===----------------------------------------------------------------------===//
1121 
1122 static void print(OpAsmPrinter &p, RegionIfOp op) {
1123   p << " ";
1124   p.printOperands(op.getOperands());
1125   p << ": " << op.getOperandTypes();
1126   p.printArrowTypeList(op.getResultTypes());
1127   p << " then";
1128   p.printRegion(op.getThenRegion(),
1129                 /*printEntryBlockArgs=*/true,
1130                 /*printBlockTerminators=*/true);
1131   p << " else";
1132   p.printRegion(op.getElseRegion(),
1133                 /*printEntryBlockArgs=*/true,
1134                 /*printBlockTerminators=*/true);
1135   p << " join";
1136   p.printRegion(op.getJoinRegion(),
1137                 /*printEntryBlockArgs=*/true,
1138                 /*printBlockTerminators=*/true);
1139 }
1140 
1141 static ParseResult parseRegionIfOp(OpAsmParser &parser,
1142                                    OperationState &result) {
1143   SmallVector<OpAsmParser::OperandType, 2> operandInfos;
1144   SmallVector<Type, 2> operandTypes;
1145 
1146   result.regions.reserve(3);
1147   Region *thenRegion = result.addRegion();
1148   Region *elseRegion = result.addRegion();
1149   Region *joinRegion = result.addRegion();
1150 
1151   // Parse operand, type and arrow type lists.
1152   if (parser.parseOperandList(operandInfos) ||
1153       parser.parseColonTypeList(operandTypes) ||
1154       parser.parseArrowTypeList(result.types))
1155     return failure();
1156 
1157   // Parse all attached regions.
1158   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
1159       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
1160       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
1161     return failure();
1162 
1163   return parser.resolveOperands(operandInfos, operandTypes,
1164                                 parser.getCurrentLocation(), result.operands);
1165 }
1166 
1167 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
1168   assert(index < 2 && "invalid region index");
1169   return getOperands();
1170 }
1171 
1172 void RegionIfOp::getSuccessorRegions(
1173     Optional<unsigned> index, ArrayRef<Attribute> operands,
1174     SmallVectorImpl<RegionSuccessor> &regions) {
1175   // We always branch to the join region.
1176   if (index.hasValue()) {
1177     if (index.getValue() < 2)
1178       regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
1179     else
1180       regions.push_back(RegionSuccessor(getResults()));
1181     return;
1182   }
1183 
1184   // The then and else regions are the entry regions of this op.
1185   regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs()));
1186   regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs()));
1187 }
1188 
1189 //===----------------------------------------------------------------------===//
1190 // SingleNoTerminatorCustomAsmOp
1191 //===----------------------------------------------------------------------===//
1192 
1193 static ParseResult parseSingleNoTerminatorCustomAsmOp(OpAsmParser &parser,
1194                                                       OperationState &state) {
1195   Region *body = state.addRegion();
1196   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
1197     return failure();
1198   return success();
1199 }
1200 
1201 static void print(SingleNoTerminatorCustomAsmOp op, OpAsmPrinter &printer) {
1202   printer.printRegion(
1203       op.getRegion(), /*printEntryBlockArgs=*/false,
1204       // This op has a single block without terminators. But explicitly mark
1205       // as not printing block terminators for testing.
1206       /*printBlockTerminators=*/false);
1207 }
1208 
1209 #include "TestOpEnums.cpp.inc"
1210 #include "TestOpInterfaces.cpp.inc"
1211 #include "TestOpStructs.cpp.inc"
1212 #include "TestTypeInterfaces.cpp.inc"
1213 
1214 #define GET_OP_CLASSES
1215 #include "TestOps.cpp.inc"
1216