1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestAttributes.h"
11 #include "TestInterfaces.h"
12 #include "TestTypes.h"
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/DLTI/DLTI.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/DialectImplementation.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/Reducer/ReductionPatternInterface.h"
22 #include "mlir/Transforms/FoldUtils.h"
23 #include "mlir/Transforms/InliningUtils.h"
24 #include "llvm/ADT/StringSwitch.h"
25 
26 // Include this before the using namespace lines below to
27 // test that we don't have namespace dependencies.
28 #include "TestOpsDialect.cpp.inc"
29 
30 using namespace mlir;
31 using namespace test;
32 
33 void test::registerTestDialect(DialectRegistry &registry) {
34   registry.insert<TestDialect>();
35 }
36 
37 //===----------------------------------------------------------------------===//
38 // TestDialect Interfaces
39 //===----------------------------------------------------------------------===//
40 
41 namespace {
42 
43 /// Testing the correctness of some traits.
44 static_assert(
45     llvm::is_detected<OpTrait::has_implicit_terminator_t,
46                       SingleBlockImplicitTerminatorOp>::value,
47     "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
48 static_assert(OpTrait::hasSingleBlockImplicitTerminator<
49                   SingleBlockImplicitTerminatorOp>::value,
50               "hasSingleBlockImplicitTerminator does not match "
51               "SingleBlockImplicitTerminatorOp");
52 
53 // Test support for interacting with the AsmPrinter.
54 struct TestOpAsmInterface : public OpAsmDialectInterface {
55   using OpAsmDialectInterface::OpAsmDialectInterface;
56 
57   AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
58     StringAttr strAttr = attr.dyn_cast<StringAttr>();
59     if (!strAttr)
60       return AliasResult::NoAlias;
61 
62     // Check the contents of the string attribute to see what the test alias
63     // should be named.
64     Optional<StringRef> aliasName =
65         StringSwitch<Optional<StringRef>>(strAttr.getValue())
66             .Case("alias_test:dot_in_name", StringRef("test.alias"))
67             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
68             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
69             .Case("alias_test:sanitize_conflict_a",
70                   StringRef("test_alias_conflict0"))
71             .Case("alias_test:sanitize_conflict_b",
72                   StringRef("test_alias_conflict0_"))
73             .Case("alias_test:tensor_encoding", StringRef("test_encoding"))
74             .Default(llvm::None);
75     if (!aliasName)
76       return AliasResult::NoAlias;
77 
78     os << *aliasName;
79     return AliasResult::FinalAlias;
80   }
81 
82   AliasResult getAlias(Type type, raw_ostream &os) const final {
83     if (auto tupleType = type.dyn_cast<TupleType>()) {
84       if (tupleType.size() > 0 &&
85           llvm::all_of(tupleType.getTypes(), [](Type elemType) {
86             return elemType.isa<SimpleAType>();
87           })) {
88         os << "test_tuple";
89         return AliasResult::FinalAlias;
90       }
91     }
92     if (auto intType = type.dyn_cast<TestIntegerType>()) {
93       if (intType.getSignedness() ==
94               TestIntegerType::SignednessSemantics::Unsigned &&
95           intType.getWidth() == 8) {
96         os << "test_ui8";
97         return AliasResult::FinalAlias;
98       }
99     }
100     return AliasResult::NoAlias;
101   }
102 
103   void getAsmResultNames(Operation *op,
104                          OpAsmSetValueNameFn setNameFn) const final {
105     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
106       setNameFn(asmOp, "result");
107   }
108 
109   void getAsmBlockArgumentNames(Block *block,
110                                 OpAsmSetValueNameFn setNameFn) const final {
111     auto op = block->getParentOp();
112     auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
113     if (!arrayAttr)
114       return;
115     auto args = block->getArguments();
116     auto e = std::min(arrayAttr.size(), args.size());
117     for (unsigned i = 0; i < e; ++i) {
118       if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
119         setNameFn(args[i], strAttr.getValue());
120     }
121   }
122 };
123 
124 struct TestDialectFoldInterface : public DialectFoldInterface {
125   using DialectFoldInterface::DialectFoldInterface;
126 
127   /// Registered hook to check if the given region, which is attached to an
128   /// operation that is *not* isolated from above, should be used when
129   /// materializing constants.
130   bool shouldMaterializeInto(Region *region) const final {
131     // If this is a one region operation, then insert into it.
132     return isa<OneRegionOp>(region->getParentOp());
133   }
134 };
135 
136 /// This class defines the interface for handling inlining with standard
137 /// operations.
138 struct TestInlinerInterface : public DialectInlinerInterface {
139   using DialectInlinerInterface::DialectInlinerInterface;
140 
141   //===--------------------------------------------------------------------===//
142   // Analysis Hooks
143   //===--------------------------------------------------------------------===//
144 
145   bool isLegalToInline(Operation *call, Operation *callable,
146                        bool wouldBeCloned) const final {
147     // Don't allow inlining calls that are marked `noinline`.
148     return !call->hasAttr("noinline");
149   }
150   bool isLegalToInline(Region *, Region *, bool,
151                        BlockAndValueMapping &) const final {
152     // Inlining into test dialect regions is legal.
153     return true;
154   }
155   bool isLegalToInline(Operation *, Region *, bool,
156                        BlockAndValueMapping &) const final {
157     return true;
158   }
159 
160   bool shouldAnalyzeRecursively(Operation *op) const final {
161     // Analyze recursively if this is not a functional region operation, it
162     // froms a separate functional scope.
163     return !isa<FunctionalRegionOp>(op);
164   }
165 
166   //===--------------------------------------------------------------------===//
167   // Transformation Hooks
168   //===--------------------------------------------------------------------===//
169 
170   /// Handle the given inlined terminator by replacing it with a new operation
171   /// as necessary.
172   void handleTerminator(Operation *op,
173                         ArrayRef<Value> valuesToRepl) const final {
174     // Only handle "test.return" here.
175     auto returnOp = dyn_cast<TestReturnOp>(op);
176     if (!returnOp)
177       return;
178 
179     // Replace the values directly with the return operands.
180     assert(returnOp.getNumOperands() == valuesToRepl.size());
181     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
182       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
183   }
184 
185   /// Attempt to materialize a conversion for a type mismatch between a call
186   /// from this dialect, and a callable region. This method should generate an
187   /// operation that takes 'input' as the only operand, and produces a single
188   /// result of 'resultType'. If a conversion can not be generated, nullptr
189   /// should be returned.
190   Operation *materializeCallConversion(OpBuilder &builder, Value input,
191                                        Type resultType,
192                                        Location conversionLoc) const final {
193     // Only allow conversion for i16/i32 types.
194     if (!(resultType.isSignlessInteger(16) ||
195           resultType.isSignlessInteger(32)) ||
196         !(input.getType().isSignlessInteger(16) ||
197           input.getType().isSignlessInteger(32)))
198       return nullptr;
199     return builder.create<TestCastOp>(conversionLoc, resultType, input);
200   }
201 
202   void processInlinedCallBlocks(
203       Operation *call,
204       iterator_range<Region::iterator> inlinedBlocks) const final {
205     if (!isa<ConversionCallOp>(call))
206       return;
207 
208     // Set attributed on all ops in the inlined blocks.
209     for (Block &block : inlinedBlocks) {
210       block.walk([&](Operation *op) {
211         op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
212       });
213     }
214   }
215 };
216 
217 struct TestReductionPatternInterface : public DialectReductionPatternInterface {
218 public:
219   TestReductionPatternInterface(Dialect *dialect)
220       : DialectReductionPatternInterface(dialect) {}
221 
222   void populateReductionPatterns(RewritePatternSet &patterns) const final {
223     populateTestReductionPatterns(patterns);
224   }
225 };
226 
227 } // end anonymous namespace
228 
229 //===----------------------------------------------------------------------===//
230 // TestDialect
231 //===----------------------------------------------------------------------===//
232 
233 static void testSideEffectOpGetEffect(
234     Operation *op,
235     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects);
236 
237 // This is the implementation of a dialect fallback for `TestEffectOpInterface`.
238 struct TestOpEffectInterfaceFallback
239     : public TestEffectOpInterface::FallbackModel<
240           TestOpEffectInterfaceFallback> {
241   static bool classof(Operation *op) {
242     bool isSupportedOp =
243         op->getName().getStringRef() == "test.unregistered_side_effect_op";
244     assert(isSupportedOp && "Unexpected dispatch");
245     return isSupportedOp;
246   }
247 
248   void
249   getEffects(Operation *op,
250              SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
251                  &effects) const {
252     testSideEffectOpGetEffect(op, effects);
253   }
254 };
255 
256 void TestDialect::initialize() {
257   registerAttributes();
258   registerTypes();
259   addOperations<
260 #define GET_OP_LIST
261 #include "TestOps.cpp.inc"
262       >();
263   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
264                 TestInlinerInterface, TestReductionPatternInterface>();
265   allowUnknownOperations();
266 
267   // Instantiate our fallback op interface that we'll use on specific
268   // unregistered op.
269   fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
270 }
271 TestDialect::~TestDialect() {
272   delete static_cast<TestOpEffectInterfaceFallback *>(
273       fallbackEffectOpInterfaces);
274 }
275 
276 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
277                                             Type type, Location loc) {
278   return builder.create<TestOpConstant>(loc, type, value);
279 }
280 
281 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
282                                                OperationName opName) {
283   if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
284       typeID == TypeID::get<TestEffectOpInterface>())
285     return fallbackEffectOpInterfaces;
286   return nullptr;
287 }
288 
289 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
290                                                     NamedAttribute namedAttr) {
291   if (namedAttr.first == "test.invalid_attr")
292     return op->emitError() << "invalid to use 'test.invalid_attr'";
293   return success();
294 }
295 
296 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
297                                                     unsigned regionIndex,
298                                                     unsigned argIndex,
299                                                     NamedAttribute namedAttr) {
300   if (namedAttr.first == "test.invalid_attr")
301     return op->emitError() << "invalid to use 'test.invalid_attr'";
302   return success();
303 }
304 
305 LogicalResult
306 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
307                                          unsigned resultIndex,
308                                          NamedAttribute namedAttr) {
309   if (namedAttr.first == "test.invalid_attr")
310     return op->emitError() << "invalid to use 'test.invalid_attr'";
311   return success();
312 }
313 
314 Optional<Dialect::ParseOpHook>
315 TestDialect::getParseOperationHook(StringRef opName) const {
316   if (opName == "test.dialect_custom_printer") {
317     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
318       return parser.parseKeyword("custom_format");
319     }};
320   }
321   return None;
322 }
323 
324 llvm::unique_function<void(Operation *, OpAsmPrinter &)>
325 TestDialect::getOperationPrinter(Operation *op) const {
326   StringRef opName = op->getName().getStringRef();
327   if (opName == "test.dialect_custom_printer") {
328     return [](Operation *op, OpAsmPrinter &printer) {
329       printer.getStream() << " custom_format";
330     };
331   }
332   return {};
333 }
334 
335 //===----------------------------------------------------------------------===//
336 // TestBranchOp
337 //===----------------------------------------------------------------------===//
338 
339 Optional<MutableOperandRange>
340 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
341   assert(index == 0 && "invalid successor index");
342   return getTargetOperandsMutable();
343 }
344 
345 //===----------------------------------------------------------------------===//
346 // TestDialectCanonicalizerOp
347 //===----------------------------------------------------------------------===//
348 
349 static LogicalResult
350 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
351                                PatternRewriter &rewriter) {
352   rewriter.replaceOpWithNewOp<arith::ConstantOp>(
353       op, rewriter.getI32IntegerAttr(42));
354   return success();
355 }
356 
357 void TestDialect::getCanonicalizationPatterns(
358     RewritePatternSet &results) const {
359   results.add(&dialectCanonicalizationPattern);
360 }
361 
362 //===----------------------------------------------------------------------===//
363 // TestFoldToCallOp
364 //===----------------------------------------------------------------------===//
365 
366 namespace {
367 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
368   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
369 
370   LogicalResult matchAndRewrite(FoldToCallOp op,
371                                 PatternRewriter &rewriter) const override {
372     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.getCalleeAttr(),
373                                         ValueRange());
374     return success();
375   }
376 };
377 } // end anonymous namespace
378 
379 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
380                                                MLIRContext *context) {
381   results.add<FoldToCallOpPattern>(context);
382 }
383 
384 //===----------------------------------------------------------------------===//
385 // Test Format* operations
386 //===----------------------------------------------------------------------===//
387 
388 //===----------------------------------------------------------------------===//
389 // Parsing
390 
391 static ParseResult parseCustomDirectiveOperands(
392     OpAsmParser &parser, OpAsmParser::OperandType &operand,
393     Optional<OpAsmParser::OperandType> &optOperand,
394     SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
395   if (parser.parseOperand(operand))
396     return failure();
397   if (succeeded(parser.parseOptionalComma())) {
398     optOperand.emplace();
399     if (parser.parseOperand(*optOperand))
400       return failure();
401   }
402   if (parser.parseArrow() || parser.parseLParen() ||
403       parser.parseOperandList(varOperands) || parser.parseRParen())
404     return failure();
405   return success();
406 }
407 static ParseResult
408 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
409                             Type &optOperandType,
410                             SmallVectorImpl<Type> &varOperandTypes) {
411   if (parser.parseColon())
412     return failure();
413 
414   if (parser.parseType(operandType))
415     return failure();
416   if (succeeded(parser.parseOptionalComma())) {
417     if (parser.parseType(optOperandType))
418       return failure();
419   }
420   if (parser.parseArrow() || parser.parseLParen() ||
421       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
422     return failure();
423   return success();
424 }
425 static ParseResult
426 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
427                                  Type optOperandType,
428                                  const SmallVectorImpl<Type> &varOperandTypes) {
429   if (parser.parseKeyword("type_refs_capture"))
430     return failure();
431 
432   Type operandType2, optOperandType2;
433   SmallVector<Type, 1> varOperandTypes2;
434   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
435                                   varOperandTypes2))
436     return failure();
437 
438   if (operandType != operandType2 || optOperandType != optOperandType2 ||
439       varOperandTypes != varOperandTypes2)
440     return failure();
441 
442   return success();
443 }
444 static ParseResult parseCustomDirectiveOperandsAndTypes(
445     OpAsmParser &parser, OpAsmParser::OperandType &operand,
446     Optional<OpAsmParser::OperandType> &optOperand,
447     SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
448     Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
449   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
450       parseCustomDirectiveResults(parser, operandType, optOperandType,
451                                   varOperandTypes))
452     return failure();
453   return success();
454 }
455 static ParseResult parseCustomDirectiveRegions(
456     OpAsmParser &parser, Region &region,
457     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
458   if (parser.parseRegion(region))
459     return failure();
460   if (failed(parser.parseOptionalComma()))
461     return success();
462   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
463   if (parser.parseRegion(*varRegion))
464     return failure();
465   varRegions.emplace_back(std::move(varRegion));
466   return success();
467 }
468 static ParseResult
469 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
470                                SmallVectorImpl<Block *> &varSuccessors) {
471   if (parser.parseSuccessor(successor))
472     return failure();
473   if (failed(parser.parseOptionalComma()))
474     return success();
475   Block *varSuccessor;
476   if (parser.parseSuccessor(varSuccessor))
477     return failure();
478   varSuccessors.append(2, varSuccessor);
479   return success();
480 }
481 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
482                                                   IntegerAttr &attr,
483                                                   IntegerAttr &optAttr) {
484   if (parser.parseAttribute(attr))
485     return failure();
486   if (succeeded(parser.parseOptionalComma())) {
487     if (parser.parseAttribute(optAttr))
488       return failure();
489   }
490   return success();
491 }
492 
493 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
494                                                 NamedAttrList &attrs) {
495   return parser.parseOptionalAttrDict(attrs);
496 }
497 static ParseResult parseCustomDirectiveOptionalOperandRef(
498     OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) {
499   int64_t operandCount = 0;
500   if (parser.parseInteger(operandCount))
501     return failure();
502   bool expectedOptionalOperand = operandCount == 0;
503   return success(expectedOptionalOperand != optOperand.hasValue());
504 }
505 
506 //===----------------------------------------------------------------------===//
507 // Printing
508 
509 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
510                                          Value operand, Value optOperand,
511                                          OperandRange varOperands) {
512   printer << operand;
513   if (optOperand)
514     printer << ", " << optOperand;
515   printer << " -> (" << varOperands << ")";
516 }
517 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
518                                         Type operandType, Type optOperandType,
519                                         TypeRange varOperandTypes) {
520   printer << " : " << operandType;
521   if (optOperandType)
522     printer << ", " << optOperandType;
523   printer << " -> (" << varOperandTypes << ")";
524 }
525 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
526                                              Operation *op, Type operandType,
527                                              Type optOperandType,
528                                              TypeRange varOperandTypes) {
529   printer << " type_refs_capture ";
530   printCustomDirectiveResults(printer, op, operandType, optOperandType,
531                               varOperandTypes);
532 }
533 static void printCustomDirectiveOperandsAndTypes(
534     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
535     OperandRange varOperands, Type operandType, Type optOperandType,
536     TypeRange varOperandTypes) {
537   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
538   printCustomDirectiveResults(printer, op, operandType, optOperandType,
539                               varOperandTypes);
540 }
541 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
542                                         Region &region,
543                                         MutableArrayRef<Region> varRegions) {
544   printer.printRegion(region);
545   if (!varRegions.empty()) {
546     printer << ", ";
547     for (Region &region : varRegions)
548       printer.printRegion(region);
549   }
550 }
551 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
552                                            Block *successor,
553                                            SuccessorRange varSuccessors) {
554   printer << successor;
555   if (!varSuccessors.empty())
556     printer << ", " << varSuccessors.front();
557 }
558 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
559                                            Attribute attribute,
560                                            Attribute optAttribute) {
561   printer << attribute;
562   if (optAttribute)
563     printer << ", " << optAttribute;
564 }
565 
566 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
567                                          DictionaryAttr attrs) {
568   printer.printOptionalAttrDict(attrs.getValue());
569 }
570 
571 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
572                                                    Operation *op,
573                                                    Value optOperand) {
574   printer << (optOperand ? "1" : "0");
575 }
576 
577 //===----------------------------------------------------------------------===//
578 // Test IsolatedRegionOp - parse passthrough region arguments.
579 //===----------------------------------------------------------------------===//
580 
581 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
582                                          OperationState &result) {
583   OpAsmParser::OperandType argInfo;
584   Type argType = parser.getBuilder().getIndexType();
585 
586   // Parse the input operand.
587   if (parser.parseOperand(argInfo) ||
588       parser.resolveOperand(argInfo, argType, result.operands))
589     return failure();
590 
591   // Parse the body region, and reuse the operand info as the argument info.
592   Region *body = result.addRegion();
593   return parser.parseRegion(*body, argInfo, argType,
594                             /*enableNameShadowing=*/true);
595 }
596 
597 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
598   p << "test.isolated_region ";
599   p.printOperand(op.getOperand());
600   p.shadowRegionArgs(op.getRegion(), op.getOperand());
601   p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false);
602 }
603 
604 //===----------------------------------------------------------------------===//
605 // Test SSACFGRegionOp
606 //===----------------------------------------------------------------------===//
607 
608 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
609   return RegionKind::SSACFG;
610 }
611 
612 //===----------------------------------------------------------------------===//
613 // Test GraphRegionOp
614 //===----------------------------------------------------------------------===//
615 
616 static ParseResult parseGraphRegionOp(OpAsmParser &parser,
617                                       OperationState &result) {
618   // Parse the body region, and reuse the operand info as the argument info.
619   Region *body = result.addRegion();
620   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
621 }
622 
623 static void print(OpAsmPrinter &p, GraphRegionOp op) {
624   p << "test.graph_region ";
625   p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false);
626 }
627 
628 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
629   return RegionKind::Graph;
630 }
631 
632 //===----------------------------------------------------------------------===//
633 // Test AffineScopeOp
634 //===----------------------------------------------------------------------===//
635 
636 static ParseResult parseAffineScopeOp(OpAsmParser &parser,
637                                       OperationState &result) {
638   // Parse the body region, and reuse the operand info as the argument info.
639   Region *body = result.addRegion();
640   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
641 }
642 
643 static void print(OpAsmPrinter &p, AffineScopeOp op) {
644   p << "test.affine_scope ";
645   p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false);
646 }
647 
648 //===----------------------------------------------------------------------===//
649 // Test parser.
650 //===----------------------------------------------------------------------===//
651 
652 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser,
653                                               OperationState &result) {
654   if (parser.parseOptionalColon())
655     return success();
656   uint64_t numResults;
657   if (parser.parseInteger(numResults))
658     return failure();
659 
660   IndexType type = parser.getBuilder().getIndexType();
661   for (unsigned i = 0; i < numResults; ++i)
662     result.addTypes(type);
663   return success();
664 }
665 
666 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) {
667   if (unsigned numResults = op->getNumResults())
668     p << " : " << numResults;
669 }
670 
671 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser,
672                                               OperationState &result) {
673   StringRef keyword;
674   if (parser.parseKeyword(&keyword))
675     return failure();
676   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
677   return success();
678 }
679 
680 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) {
681   p << " " << op.getKeyword();
682 }
683 
684 //===----------------------------------------------------------------------===//
685 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
686 
687 static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
688                                          OperationState &result) {
689   if (parser.parseKeyword("wraps"))
690     return failure();
691 
692   // Parse the wrapped op in a region
693   Region &body = *result.addRegion();
694   body.push_back(new Block);
695   Block &block = body.back();
696   Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
697   if (!wrapped_op)
698     return failure();
699 
700   // Create a return terminator in the inner region, pass as operand to the
701   // terminator the returned values from the wrapped operation.
702   SmallVector<Value, 8> return_operands(wrapped_op->getResults());
703   OpBuilder builder(parser.getContext());
704   builder.setInsertionPointToEnd(&block);
705   builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
706 
707   // Get the results type for the wrapping op from the terminator operands.
708   Operation &return_op = body.back().back();
709   result.types.append(return_op.operand_type_begin(),
710                       return_op.operand_type_end());
711 
712   // Use the location of the wrapped op for the "test.wrapping_region" op.
713   result.location = wrapped_op->getLoc();
714 
715   return success();
716 }
717 
718 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
719   p << " wraps ";
720   p.printGenericOp(&op.getRegion().front().front());
721 }
722 
723 //===----------------------------------------------------------------------===//
724 // Test PolyForOp - parse list of region arguments.
725 //===----------------------------------------------------------------------===//
726 
727 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
728   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
729   // Parse list of region arguments without a delimiter.
730   if (parser.parseRegionArgumentList(ivsInfo))
731     return failure();
732 
733   // Parse the body region.
734   Region *body = result.addRegion();
735   auto &builder = parser.getBuilder();
736   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
737   return parser.parseRegion(*body, ivsInfo, argTypes);
738 }
739 
740 //===----------------------------------------------------------------------===//
741 // Test removing op with inner ops.
742 //===----------------------------------------------------------------------===//
743 
744 namespace {
745 struct TestRemoveOpWithInnerOps
746     : public OpRewritePattern<TestOpWithRegionPattern> {
747   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
748 
749   void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
750 
751   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
752                                 PatternRewriter &rewriter) const override {
753     rewriter.eraseOp(op);
754     return success();
755   }
756 };
757 } // end anonymous namespace
758 
759 void TestOpWithRegionPattern::getCanonicalizationPatterns(
760     RewritePatternSet &results, MLIRContext *context) {
761   results.add<TestRemoveOpWithInnerOps>(context);
762 }
763 
764 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
765   return getOperand();
766 }
767 
768 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
769   return getValue();
770 }
771 
772 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
773     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
774   for (Value input : this->getOperands()) {
775     results.push_back(input);
776   }
777   return success();
778 }
779 
780 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
781   assert(operands.size() == 1);
782   if (operands.front()) {
783     (*this)->setAttr("attr", operands.front());
784     return getResult();
785   }
786   return {};
787 }
788 
789 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
790   return getOperand();
791 }
792 
793 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
794     MLIRContext *, Optional<Location> location, ValueRange operands,
795     DictionaryAttr attributes, RegionRange regions,
796     SmallVectorImpl<Type> &inferredReturnTypes) {
797   if (operands[0].getType() != operands[1].getType()) {
798     return emitOptionalError(location, "operand type mismatch ",
799                              operands[0].getType(), " vs ",
800                              operands[1].getType());
801   }
802   inferredReturnTypes.assign({operands[0].getType()});
803   return success();
804 }
805 
806 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
807     MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
808     DictionaryAttr attributes, RegionRange regions,
809     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
810   // Create return type consisting of the last element of the first operand.
811   auto operandType = operands.front().getType();
812   auto sval = operandType.dyn_cast<ShapedType>();
813   if (!sval) {
814     return emitOptionalError(location, "only shaped type operands allowed");
815   }
816   int64_t dim =
817       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
818   auto type = IntegerType::get(context, 17);
819   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
820   return success();
821 }
822 
823 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
824     OpBuilder &builder, ValueRange operands,
825     llvm::SmallVectorImpl<Value> &shapes) {
826   shapes = SmallVector<Value, 1>{
827       builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
828   return success();
829 }
830 
831 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
832     OpBuilder &builder, ValueRange operands,
833     llvm::SmallVectorImpl<Value> &shapes) {
834   Location loc = getLoc();
835   shapes.reserve(operands.size());
836   for (Value operand : llvm::reverse(operands)) {
837     auto currShape = llvm::to_vector<4>(llvm::map_range(
838         llvm::seq<int64_t>(
839             0, operand.getType().cast<RankedTensorType>().getRank()),
840         [&](int64_t dim) -> Value {
841           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
842         }));
843     shapes.push_back(builder.create<tensor::FromElementsOp>(
844         getLoc(), builder.getIndexType(), currShape));
845   }
846   return success();
847 }
848 
849 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
850     OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
851   Location loc = getLoc();
852   shapes.reserve(getNumOperands());
853   for (Value operand : llvm::reverse(getOperands())) {
854     auto currShape = llvm::to_vector<4>(llvm::map_range(
855         llvm::seq<int64_t>(
856             0, operand.getType().cast<RankedTensorType>().getRank()),
857         [&](int64_t dim) -> Value {
858           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
859         }));
860     shapes.emplace_back(std::move(currShape));
861   }
862   return success();
863 }
864 
865 //===----------------------------------------------------------------------===//
866 // Test SideEffect interfaces
867 //===----------------------------------------------------------------------===//
868 
869 namespace {
870 /// A test resource for side effects.
871 struct TestResource : public SideEffects::Resource::Base<TestResource> {
872   StringRef getName() final { return "<Test>"; }
873 };
874 } // end anonymous namespace
875 
876 static void testSideEffectOpGetEffect(
877     Operation *op,
878     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
879         &effects) {
880   auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
881   if (!effectsAttr)
882     return;
883 
884   effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
885 }
886 
887 void SideEffectOp::getEffects(
888     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
889   // Check for an effects attribute on the op instance.
890   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
891   if (!effectsAttr)
892     return;
893 
894   // If there is one, it is an array of dictionary attributes that hold
895   // information on the effects of this operation.
896   for (Attribute element : effectsAttr) {
897     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
898 
899     // Get the specific memory effect.
900     MemoryEffects::Effect *effect =
901         StringSwitch<MemoryEffects::Effect *>(
902             effectElement.get("effect").cast<StringAttr>().getValue())
903             .Case("allocate", MemoryEffects::Allocate::get())
904             .Case("free", MemoryEffects::Free::get())
905             .Case("read", MemoryEffects::Read::get())
906             .Case("write", MemoryEffects::Write::get());
907 
908     // Check for a non-default resource to use.
909     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
910     if (effectElement.get("test_resource"))
911       resource = TestResource::get();
912 
913     // Check for a result to affect.
914     if (effectElement.get("on_result"))
915       effects.emplace_back(effect, getResult(), resource);
916     else if (Attribute ref = effectElement.get("on_reference"))
917       effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
918     else
919       effects.emplace_back(effect, resource);
920   }
921 }
922 
923 void SideEffectOp::getEffects(
924     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
925   testSideEffectOpGetEffect(getOperation(), effects);
926 }
927 
928 //===----------------------------------------------------------------------===//
929 // StringAttrPrettyNameOp
930 //===----------------------------------------------------------------------===//
931 
932 // This op has fancy handling of its SSA result name.
933 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
934                                                OperationState &result) {
935   // Add the result types.
936   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
937     result.addTypes(parser.getBuilder().getIntegerType(32));
938 
939   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
940     return failure();
941 
942   // If the attribute dictionary contains no 'names' attribute, infer it from
943   // the SSA name (if specified).
944   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
945     return attr.first == "names";
946   });
947 
948   // If there was no name specified, check to see if there was a useful name
949   // specified in the asm file.
950   if (hadNames || parser.getNumResults() == 0)
951     return success();
952 
953   SmallVector<StringRef, 4> names;
954   auto *context = result.getContext();
955 
956   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
957     auto resultName = parser.getResultName(i);
958     StringRef nameStr;
959     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
960       nameStr = resultName.first;
961 
962     names.push_back(nameStr);
963   }
964 
965   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
966   result.attributes.push_back({Identifier::get("names", context), namesAttr});
967   return success();
968 }
969 
970 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
971   // Note that we only need to print the "name" attribute if the asmprinter
972   // result name disagrees with it.  This can happen in strange cases, e.g.
973   // when there are conflicts.
974   bool namesDisagree = op.getNames().size() != op.getNumResults();
975 
976   SmallString<32> resultNameStr;
977   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
978     resultNameStr.clear();
979     llvm::raw_svector_ostream tmpStream(resultNameStr);
980     p.printOperand(op.getResult(i), tmpStream);
981 
982     auto expectedName = op.getNames()[i].dyn_cast<StringAttr>();
983     if (!expectedName ||
984         tmpStream.str().drop_front() != expectedName.getValue()) {
985       namesDisagree = true;
986     }
987   }
988 
989   if (namesDisagree)
990     p.printOptionalAttrDictWithKeyword(op->getAttrs());
991   else
992     p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"});
993 }
994 
995 // We set the SSA name in the asm syntax to the contents of the name
996 // attribute.
997 void StringAttrPrettyNameOp::getAsmResultNames(
998     function_ref<void(Value, StringRef)> setNameFn) {
999 
1000   auto value = getNames();
1001   for (size_t i = 0, e = value.size(); i != e; ++i)
1002     if (auto str = value[i].dyn_cast<StringAttr>())
1003       if (!str.getValue().empty())
1004         setNameFn(getResult(i), str.getValue());
1005 }
1006 
1007 //===----------------------------------------------------------------------===//
1008 // RegionIfOp
1009 //===----------------------------------------------------------------------===//
1010 
1011 static void print(OpAsmPrinter &p, RegionIfOp op) {
1012   p << " ";
1013   p.printOperands(op.getOperands());
1014   p << ": " << op.getOperandTypes();
1015   p.printArrowTypeList(op.getResultTypes());
1016   p << " then";
1017   p.printRegion(op.getThenRegion(),
1018                 /*printEntryBlockArgs=*/true,
1019                 /*printBlockTerminators=*/true);
1020   p << " else";
1021   p.printRegion(op.getElseRegion(),
1022                 /*printEntryBlockArgs=*/true,
1023                 /*printBlockTerminators=*/true);
1024   p << " join";
1025   p.printRegion(op.getJoinRegion(),
1026                 /*printEntryBlockArgs=*/true,
1027                 /*printBlockTerminators=*/true);
1028 }
1029 
1030 static ParseResult parseRegionIfOp(OpAsmParser &parser,
1031                                    OperationState &result) {
1032   SmallVector<OpAsmParser::OperandType, 2> operandInfos;
1033   SmallVector<Type, 2> operandTypes;
1034 
1035   result.regions.reserve(3);
1036   Region *thenRegion = result.addRegion();
1037   Region *elseRegion = result.addRegion();
1038   Region *joinRegion = result.addRegion();
1039 
1040   // Parse operand, type and arrow type lists.
1041   if (parser.parseOperandList(operandInfos) ||
1042       parser.parseColonTypeList(operandTypes) ||
1043       parser.parseArrowTypeList(result.types))
1044     return failure();
1045 
1046   // Parse all attached regions.
1047   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
1048       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
1049       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
1050     return failure();
1051 
1052   return parser.resolveOperands(operandInfos, operandTypes,
1053                                 parser.getCurrentLocation(), result.operands);
1054 }
1055 
1056 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
1057   assert(index < 2 && "invalid region index");
1058   return getOperands();
1059 }
1060 
1061 void RegionIfOp::getSuccessorRegions(
1062     Optional<unsigned> index, ArrayRef<Attribute> operands,
1063     SmallVectorImpl<RegionSuccessor> &regions) {
1064   // We always branch to the join region.
1065   if (index.hasValue()) {
1066     if (index.getValue() < 2)
1067       regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
1068     else
1069       regions.push_back(RegionSuccessor(getResults()));
1070     return;
1071   }
1072 
1073   // The then and else regions are the entry regions of this op.
1074   regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs()));
1075   regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs()));
1076 }
1077 
1078 //===----------------------------------------------------------------------===//
1079 // SingleNoTerminatorCustomAsmOp
1080 //===----------------------------------------------------------------------===//
1081 
1082 static ParseResult parseSingleNoTerminatorCustomAsmOp(OpAsmParser &parser,
1083                                                       OperationState &state) {
1084   Region *body = state.addRegion();
1085   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
1086     return failure();
1087   return success();
1088 }
1089 
1090 static void print(SingleNoTerminatorCustomAsmOp op, OpAsmPrinter &printer) {
1091   printer.printRegion(
1092       op.getRegion(), /*printEntryBlockArgs=*/false,
1093       // This op has a single block without terminators. But explicitly mark
1094       // as not printing block terminators for testing.
1095       /*printBlockTerminators=*/false);
1096 }
1097 
1098 #include "TestOpEnums.cpp.inc"
1099 #include "TestOpInterfaces.cpp.inc"
1100 #include "TestOpStructs.cpp.inc"
1101 #include "TestTypeInterfaces.cpp.inc"
1102 
1103 #define GET_OP_CLASSES
1104 #include "TestOps.cpp.inc"
1105