1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestAttributes.h"
11 #include "TestInterfaces.h"
12 #include "TestTypes.h"
13 #include "mlir/Dialect/DLTI/DLTI.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/DialectImplementation.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Reducer/ReductionPatternInterface.h"
21 #include "mlir/Transforms/FoldUtils.h"
22 #include "mlir/Transforms/InliningUtils.h"
23 #include "llvm/ADT/StringSwitch.h"
24 
25 // Include this before the using namespace lines below to
26 // test that we don't have namespace dependencies.
27 #include "TestOpsDialect.cpp.inc"
28 
29 using namespace mlir;
30 using namespace test;
31 
32 void test::registerTestDialect(DialectRegistry &registry) {
33   registry.insert<TestDialect>();
34 }
35 
36 //===----------------------------------------------------------------------===//
37 // TestDialect Interfaces
38 //===----------------------------------------------------------------------===//
39 
40 namespace {
41 
42 /// Testing the correctness of some traits.
43 static_assert(
44     llvm::is_detected<OpTrait::has_implicit_terminator_t,
45                       SingleBlockImplicitTerminatorOp>::value,
46     "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
47 static_assert(OpTrait::hasSingleBlockImplicitTerminator<
48                   SingleBlockImplicitTerminatorOp>::value,
49               "hasSingleBlockImplicitTerminator does not match "
50               "SingleBlockImplicitTerminatorOp");
51 
52 // Test support for interacting with the AsmPrinter.
53 struct TestOpAsmInterface : public OpAsmDialectInterface {
54   using OpAsmDialectInterface::OpAsmDialectInterface;
55 
56   AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
57     StringAttr strAttr = attr.dyn_cast<StringAttr>();
58     if (!strAttr)
59       return AliasResult::NoAlias;
60 
61     // Check the contents of the string attribute to see what the test alias
62     // should be named.
63     Optional<StringRef> aliasName =
64         StringSwitch<Optional<StringRef>>(strAttr.getValue())
65             .Case("alias_test:dot_in_name", StringRef("test.alias"))
66             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
67             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
68             .Case("alias_test:sanitize_conflict_a",
69                   StringRef("test_alias_conflict0"))
70             .Case("alias_test:sanitize_conflict_b",
71                   StringRef("test_alias_conflict0_"))
72             .Case("alias_test:tensor_encoding", StringRef("test_encoding"))
73             .Default(llvm::None);
74     if (!aliasName)
75       return AliasResult::NoAlias;
76 
77     os << *aliasName;
78     return AliasResult::FinalAlias;
79   }
80 
81   AliasResult getAlias(Type type, raw_ostream &os) const final {
82     if (auto tupleType = type.dyn_cast<TupleType>()) {
83       if (tupleType.size() > 0 &&
84           llvm::all_of(tupleType.getTypes(), [](Type elemType) {
85             return elemType.isa<SimpleAType>();
86           })) {
87         os << "test_tuple";
88         return AliasResult::FinalAlias;
89       }
90     }
91     return AliasResult::NoAlias;
92   }
93 
94   void getAsmResultNames(Operation *op,
95                          OpAsmSetValueNameFn setNameFn) const final {
96     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
97       setNameFn(asmOp, "result");
98   }
99 
100   void getAsmBlockArgumentNames(Block *block,
101                                 OpAsmSetValueNameFn setNameFn) const final {
102     auto op = block->getParentOp();
103     auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
104     if (!arrayAttr)
105       return;
106     auto args = block->getArguments();
107     auto e = std::min(arrayAttr.size(), args.size());
108     for (unsigned i = 0; i < e; ++i) {
109       if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
110         setNameFn(args[i], strAttr.getValue());
111     }
112   }
113 };
114 
115 struct TestDialectFoldInterface : public DialectFoldInterface {
116   using DialectFoldInterface::DialectFoldInterface;
117 
118   /// Registered hook to check if the given region, which is attached to an
119   /// operation that is *not* isolated from above, should be used when
120   /// materializing constants.
121   bool shouldMaterializeInto(Region *region) const final {
122     // If this is a one region operation, then insert into it.
123     return isa<OneRegionOp>(region->getParentOp());
124   }
125 };
126 
127 /// This class defines the interface for handling inlining with standard
128 /// operations.
129 struct TestInlinerInterface : public DialectInlinerInterface {
130   using DialectInlinerInterface::DialectInlinerInterface;
131 
132   //===--------------------------------------------------------------------===//
133   // Analysis Hooks
134   //===--------------------------------------------------------------------===//
135 
136   bool isLegalToInline(Operation *call, Operation *callable,
137                        bool wouldBeCloned) const final {
138     // Don't allow inlining calls that are marked `noinline`.
139     return !call->hasAttr("noinline");
140   }
141   bool isLegalToInline(Region *, Region *, bool,
142                        BlockAndValueMapping &) const final {
143     // Inlining into test dialect regions is legal.
144     return true;
145   }
146   bool isLegalToInline(Operation *, Region *, bool,
147                        BlockAndValueMapping &) const final {
148     return true;
149   }
150 
151   bool shouldAnalyzeRecursively(Operation *op) const final {
152     // Analyze recursively if this is not a functional region operation, it
153     // froms a separate functional scope.
154     return !isa<FunctionalRegionOp>(op);
155   }
156 
157   //===--------------------------------------------------------------------===//
158   // Transformation Hooks
159   //===--------------------------------------------------------------------===//
160 
161   /// Handle the given inlined terminator by replacing it with a new operation
162   /// as necessary.
163   void handleTerminator(Operation *op,
164                         ArrayRef<Value> valuesToRepl) const final {
165     // Only handle "test.return" here.
166     auto returnOp = dyn_cast<TestReturnOp>(op);
167     if (!returnOp)
168       return;
169 
170     // Replace the values directly with the return operands.
171     assert(returnOp.getNumOperands() == valuesToRepl.size());
172     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
173       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
174   }
175 
176   /// Attempt to materialize a conversion for a type mismatch between a call
177   /// from this dialect, and a callable region. This method should generate an
178   /// operation that takes 'input' as the only operand, and produces a single
179   /// result of 'resultType'. If a conversion can not be generated, nullptr
180   /// should be returned.
181   Operation *materializeCallConversion(OpBuilder &builder, Value input,
182                                        Type resultType,
183                                        Location conversionLoc) const final {
184     // Only allow conversion for i16/i32 types.
185     if (!(resultType.isSignlessInteger(16) ||
186           resultType.isSignlessInteger(32)) ||
187         !(input.getType().isSignlessInteger(16) ||
188           input.getType().isSignlessInteger(32)))
189       return nullptr;
190     return builder.create<TestCastOp>(conversionLoc, resultType, input);
191   }
192 
193   void processInlinedCallBlocks(
194       Operation *call,
195       iterator_range<Region::iterator> inlinedBlocks) const final {
196     if (!isa<ConversionCallOp>(call))
197       return;
198 
199     // Set attributed on all ops in the inlined blocks.
200     for (Block &block : inlinedBlocks) {
201       block.walk([&](Operation *op) {
202         op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
203       });
204     }
205   }
206 };
207 
208 struct TestReductionPatternInterface : public DialectReductionPatternInterface {
209 public:
210   TestReductionPatternInterface(Dialect *dialect)
211       : DialectReductionPatternInterface(dialect) {}
212 
213   virtual void
214   populateReductionPatterns(RewritePatternSet &patterns) const final {
215     populateTestReductionPatterns(patterns);
216   }
217 };
218 
219 } // end anonymous namespace
220 
221 //===----------------------------------------------------------------------===//
222 // TestDialect
223 //===----------------------------------------------------------------------===//
224 
225 static void testSideEffectOpGetEffect(
226     Operation *op,
227     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects);
228 
229 // This is the implementation of a dialect fallback for `TestEffectOpInterface`.
230 struct TestOpEffectInterfaceFallback
231     : public TestEffectOpInterface::FallbackModel<
232           TestOpEffectInterfaceFallback> {
233   static bool classof(Operation *op) {
234     bool isSupportedOp =
235         op->getName().getStringRef() == "test.unregistered_side_effect_op";
236     assert(isSupportedOp && "Unexpected dispatch");
237     return isSupportedOp;
238   }
239 
240   void
241   getEffects(Operation *op,
242              SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
243                  &effects) const {
244     testSideEffectOpGetEffect(op, effects);
245   }
246 };
247 
248 void TestDialect::initialize() {
249   registerAttributes();
250   registerTypes();
251   addOperations<
252 #define GET_OP_LIST
253 #include "TestOps.cpp.inc"
254       >();
255   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
256                 TestInlinerInterface, TestReductionPatternInterface>();
257   allowUnknownOperations();
258 
259   // Instantiate our fallback op interface that we'll use on specific
260   // unregistered op.
261   fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
262 }
263 TestDialect::~TestDialect() {
264   delete static_cast<TestOpEffectInterfaceFallback *>(
265       fallbackEffectOpInterfaces);
266 }
267 
268 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
269                                             Type type, Location loc) {
270   return builder.create<TestOpConstant>(loc, type, value);
271 }
272 
273 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
274                                                OperationName opName) {
275   if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
276       typeID == TypeID::get<TestEffectOpInterface>())
277     return fallbackEffectOpInterfaces;
278   return nullptr;
279 }
280 
281 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
282                                                     NamedAttribute namedAttr) {
283   if (namedAttr.first == "test.invalid_attr")
284     return op->emitError() << "invalid to use 'test.invalid_attr'";
285   return success();
286 }
287 
288 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
289                                                     unsigned regionIndex,
290                                                     unsigned argIndex,
291                                                     NamedAttribute namedAttr) {
292   if (namedAttr.first == "test.invalid_attr")
293     return op->emitError() << "invalid to use 'test.invalid_attr'";
294   return success();
295 }
296 
297 LogicalResult
298 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
299                                          unsigned resultIndex,
300                                          NamedAttribute namedAttr) {
301   if (namedAttr.first == "test.invalid_attr")
302     return op->emitError() << "invalid to use 'test.invalid_attr'";
303   return success();
304 }
305 
306 Optional<Dialect::ParseOpHook>
307 TestDialect::getParseOperationHook(StringRef opName) const {
308   if (opName == "test.dialect_custom_printer") {
309     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
310       return parser.parseKeyword("custom_format");
311     }};
312   }
313   return None;
314 }
315 
316 LogicalResult TestDialect::printOperation(Operation *op,
317                                           OpAsmPrinter &printer) const {
318   StringRef opName = op->getName().getStringRef();
319   if (opName == "test.dialect_custom_printer") {
320     printer.getStream() << opName << " custom_format";
321     return success();
322   }
323   return failure();
324 }
325 
326 //===----------------------------------------------------------------------===//
327 // TestBranchOp
328 //===----------------------------------------------------------------------===//
329 
330 Optional<MutableOperandRange>
331 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
332   assert(index == 0 && "invalid successor index");
333   return targetOperandsMutable();
334 }
335 
336 //===----------------------------------------------------------------------===//
337 // TestDialectCanonicalizerOp
338 //===----------------------------------------------------------------------===//
339 
340 static LogicalResult
341 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
342                                PatternRewriter &rewriter) {
343   rewriter.replaceOpWithNewOp<ConstantOp>(op, rewriter.getI32Type(),
344                                           rewriter.getI32IntegerAttr(42));
345   return success();
346 }
347 
348 void TestDialect::getCanonicalizationPatterns(
349     RewritePatternSet &results) const {
350   results.add(&dialectCanonicalizationPattern);
351 }
352 
353 //===----------------------------------------------------------------------===//
354 // TestFoldToCallOp
355 //===----------------------------------------------------------------------===//
356 
357 namespace {
358 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
359   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
360 
361   LogicalResult matchAndRewrite(FoldToCallOp op,
362                                 PatternRewriter &rewriter) const override {
363     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(),
364                                         ValueRange());
365     return success();
366   }
367 };
368 } // end anonymous namespace
369 
370 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
371                                                MLIRContext *context) {
372   results.add<FoldToCallOpPattern>(context);
373 }
374 
375 //===----------------------------------------------------------------------===//
376 // Test Format* operations
377 //===----------------------------------------------------------------------===//
378 
379 //===----------------------------------------------------------------------===//
380 // Parsing
381 
382 static ParseResult parseCustomDirectiveOperands(
383     OpAsmParser &parser, OpAsmParser::OperandType &operand,
384     Optional<OpAsmParser::OperandType> &optOperand,
385     SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
386   if (parser.parseOperand(operand))
387     return failure();
388   if (succeeded(parser.parseOptionalComma())) {
389     optOperand.emplace();
390     if (parser.parseOperand(*optOperand))
391       return failure();
392   }
393   if (parser.parseArrow() || parser.parseLParen() ||
394       parser.parseOperandList(varOperands) || parser.parseRParen())
395     return failure();
396   return success();
397 }
398 static ParseResult
399 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
400                             Type &optOperandType,
401                             SmallVectorImpl<Type> &varOperandTypes) {
402   if (parser.parseColon())
403     return failure();
404 
405   if (parser.parseType(operandType))
406     return failure();
407   if (succeeded(parser.parseOptionalComma())) {
408     if (parser.parseType(optOperandType))
409       return failure();
410   }
411   if (parser.parseArrow() || parser.parseLParen() ||
412       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
413     return failure();
414   return success();
415 }
416 static ParseResult
417 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
418                                  Type optOperandType,
419                                  const SmallVectorImpl<Type> &varOperandTypes) {
420   if (parser.parseKeyword("type_refs_capture"))
421     return failure();
422 
423   Type operandType2, optOperandType2;
424   SmallVector<Type, 1> varOperandTypes2;
425   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
426                                   varOperandTypes2))
427     return failure();
428 
429   if (operandType != operandType2 || optOperandType != optOperandType2 ||
430       varOperandTypes != varOperandTypes2)
431     return failure();
432 
433   return success();
434 }
435 static ParseResult parseCustomDirectiveOperandsAndTypes(
436     OpAsmParser &parser, OpAsmParser::OperandType &operand,
437     Optional<OpAsmParser::OperandType> &optOperand,
438     SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
439     Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
440   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
441       parseCustomDirectiveResults(parser, operandType, optOperandType,
442                                   varOperandTypes))
443     return failure();
444   return success();
445 }
446 static ParseResult parseCustomDirectiveRegions(
447     OpAsmParser &parser, Region &region,
448     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
449   if (parser.parseRegion(region))
450     return failure();
451   if (failed(parser.parseOptionalComma()))
452     return success();
453   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
454   if (parser.parseRegion(*varRegion))
455     return failure();
456   varRegions.emplace_back(std::move(varRegion));
457   return success();
458 }
459 static ParseResult
460 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
461                                SmallVectorImpl<Block *> &varSuccessors) {
462   if (parser.parseSuccessor(successor))
463     return failure();
464   if (failed(parser.parseOptionalComma()))
465     return success();
466   Block *varSuccessor;
467   if (parser.parseSuccessor(varSuccessor))
468     return failure();
469   varSuccessors.append(2, varSuccessor);
470   return success();
471 }
472 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
473                                                   IntegerAttr &attr,
474                                                   IntegerAttr &optAttr) {
475   if (parser.parseAttribute(attr))
476     return failure();
477   if (succeeded(parser.parseOptionalComma())) {
478     if (parser.parseAttribute(optAttr))
479       return failure();
480   }
481   return success();
482 }
483 
484 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
485                                                 NamedAttrList &attrs) {
486   return parser.parseOptionalAttrDict(attrs);
487 }
488 static ParseResult parseCustomDirectiveOptionalOperandRef(
489     OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) {
490   int64_t operandCount = 0;
491   if (parser.parseInteger(operandCount))
492     return failure();
493   bool expectedOptionalOperand = operandCount == 0;
494   return success(expectedOptionalOperand != optOperand.hasValue());
495 }
496 
497 //===----------------------------------------------------------------------===//
498 // Printing
499 
500 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
501                                          Value operand, Value optOperand,
502                                          OperandRange varOperands) {
503   printer << operand;
504   if (optOperand)
505     printer << ", " << optOperand;
506   printer << " -> (" << varOperands << ")";
507 }
508 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
509                                         Type operandType, Type optOperandType,
510                                         TypeRange varOperandTypes) {
511   printer << " : " << operandType;
512   if (optOperandType)
513     printer << ", " << optOperandType;
514   printer << " -> (" << varOperandTypes << ")";
515 }
516 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
517                                              Operation *op, Type operandType,
518                                              Type optOperandType,
519                                              TypeRange varOperandTypes) {
520   printer << " type_refs_capture ";
521   printCustomDirectiveResults(printer, op, operandType, optOperandType,
522                               varOperandTypes);
523 }
524 static void printCustomDirectiveOperandsAndTypes(
525     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
526     OperandRange varOperands, Type operandType, Type optOperandType,
527     TypeRange varOperandTypes) {
528   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
529   printCustomDirectiveResults(printer, op, operandType, optOperandType,
530                               varOperandTypes);
531 }
532 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
533                                         Region &region,
534                                         MutableArrayRef<Region> varRegions) {
535   printer.printRegion(region);
536   if (!varRegions.empty()) {
537     printer << ", ";
538     for (Region &region : varRegions)
539       printer.printRegion(region);
540   }
541 }
542 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
543                                            Block *successor,
544                                            SuccessorRange varSuccessors) {
545   printer << successor;
546   if (!varSuccessors.empty())
547     printer << ", " << varSuccessors.front();
548 }
549 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
550                                            Attribute attribute,
551                                            Attribute optAttribute) {
552   printer << attribute;
553   if (optAttribute)
554     printer << ", " << optAttribute;
555 }
556 
557 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
558                                          DictionaryAttr attrs) {
559   printer.printOptionalAttrDict(attrs.getValue());
560 }
561 
562 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
563                                                    Operation *op,
564                                                    Value optOperand) {
565   printer << (optOperand ? "1" : "0");
566 }
567 
568 //===----------------------------------------------------------------------===//
569 // Test IsolatedRegionOp - parse passthrough region arguments.
570 //===----------------------------------------------------------------------===//
571 
572 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
573                                          OperationState &result) {
574   OpAsmParser::OperandType argInfo;
575   Type argType = parser.getBuilder().getIndexType();
576 
577   // Parse the input operand.
578   if (parser.parseOperand(argInfo) ||
579       parser.resolveOperand(argInfo, argType, result.operands))
580     return failure();
581 
582   // Parse the body region, and reuse the operand info as the argument info.
583   Region *body = result.addRegion();
584   return parser.parseRegion(*body, argInfo, argType,
585                             /*enableNameShadowing=*/true);
586 }
587 
588 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
589   p << "test.isolated_region ";
590   p.printOperand(op.getOperand());
591   p.shadowRegionArgs(op.region(), op.getOperand());
592   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
593 }
594 
595 //===----------------------------------------------------------------------===//
596 // Test SSACFGRegionOp
597 //===----------------------------------------------------------------------===//
598 
599 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
600   return RegionKind::SSACFG;
601 }
602 
603 //===----------------------------------------------------------------------===//
604 // Test GraphRegionOp
605 //===----------------------------------------------------------------------===//
606 
607 static ParseResult parseGraphRegionOp(OpAsmParser &parser,
608                                       OperationState &result) {
609   // Parse the body region, and reuse the operand info as the argument info.
610   Region *body = result.addRegion();
611   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
612 }
613 
614 static void print(OpAsmPrinter &p, GraphRegionOp op) {
615   p << "test.graph_region ";
616   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
617 }
618 
619 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
620   return RegionKind::Graph;
621 }
622 
623 //===----------------------------------------------------------------------===//
624 // Test AffineScopeOp
625 //===----------------------------------------------------------------------===//
626 
627 static ParseResult parseAffineScopeOp(OpAsmParser &parser,
628                                       OperationState &result) {
629   // Parse the body region, and reuse the operand info as the argument info.
630   Region *body = result.addRegion();
631   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
632 }
633 
634 static void print(OpAsmPrinter &p, AffineScopeOp op) {
635   p << "test.affine_scope ";
636   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
637 }
638 
639 //===----------------------------------------------------------------------===//
640 // Test parser.
641 //===----------------------------------------------------------------------===//
642 
643 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser,
644                                               OperationState &result) {
645   if (parser.parseOptionalColon())
646     return success();
647   uint64_t numResults;
648   if (parser.parseInteger(numResults))
649     return failure();
650 
651   IndexType type = parser.getBuilder().getIndexType();
652   for (unsigned i = 0; i < numResults; ++i)
653     result.addTypes(type);
654   return success();
655 }
656 
657 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) {
658   p << ParseIntegerLiteralOp::getOperationName();
659   if (unsigned numResults = op->getNumResults())
660     p << " : " << numResults;
661 }
662 
663 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser,
664                                               OperationState &result) {
665   StringRef keyword;
666   if (parser.parseKeyword(&keyword))
667     return failure();
668   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
669   return success();
670 }
671 
672 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) {
673   p << ParseWrappedKeywordOp::getOperationName() << " " << op.keyword();
674 }
675 
676 //===----------------------------------------------------------------------===//
677 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
678 
679 static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
680                                          OperationState &result) {
681   if (parser.parseKeyword("wraps"))
682     return failure();
683 
684   // Parse the wrapped op in a region
685   Region &body = *result.addRegion();
686   body.push_back(new Block);
687   Block &block = body.back();
688   Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
689   if (!wrapped_op)
690     return failure();
691 
692   // Create a return terminator in the inner region, pass as operand to the
693   // terminator the returned values from the wrapped operation.
694   SmallVector<Value, 8> return_operands(wrapped_op->getResults());
695   OpBuilder builder(parser.getBuilder().getContext());
696   builder.setInsertionPointToEnd(&block);
697   builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
698 
699   // Get the results type for the wrapping op from the terminator operands.
700   Operation &return_op = body.back().back();
701   result.types.append(return_op.operand_type_begin(),
702                       return_op.operand_type_end());
703 
704   // Use the location of the wrapped op for the "test.wrapping_region" op.
705   result.location = wrapped_op->getLoc();
706 
707   return success();
708 }
709 
710 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
711   p << op.getOperationName() << " wraps ";
712   p.printGenericOp(&op.region().front().front());
713 }
714 
715 //===----------------------------------------------------------------------===//
716 // Test PolyForOp - parse list of region arguments.
717 //===----------------------------------------------------------------------===//
718 
719 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
720   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
721   // Parse list of region arguments without a delimiter.
722   if (parser.parseRegionArgumentList(ivsInfo))
723     return failure();
724 
725   // Parse the body region.
726   Region *body = result.addRegion();
727   auto &builder = parser.getBuilder();
728   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
729   return parser.parseRegion(*body, ivsInfo, argTypes);
730 }
731 
732 //===----------------------------------------------------------------------===//
733 // Test removing op with inner ops.
734 //===----------------------------------------------------------------------===//
735 
736 namespace {
737 struct TestRemoveOpWithInnerOps
738     : public OpRewritePattern<TestOpWithRegionPattern> {
739   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
740 
741   void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
742 
743   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
744                                 PatternRewriter &rewriter) const override {
745     rewriter.eraseOp(op);
746     return success();
747   }
748 };
749 } // end anonymous namespace
750 
751 void TestOpWithRegionPattern::getCanonicalizationPatterns(
752     RewritePatternSet &results, MLIRContext *context) {
753   results.add<TestRemoveOpWithInnerOps>(context);
754 }
755 
756 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
757   return operand();
758 }
759 
760 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
761   return getValue();
762 }
763 
764 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
765     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
766   for (Value input : this->operands()) {
767     results.push_back(input);
768   }
769   return success();
770 }
771 
772 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
773   assert(operands.size() == 1);
774   if (operands.front()) {
775     (*this)->setAttr("attr", operands.front());
776     return getResult();
777   }
778   return {};
779 }
780 
781 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
782   return getOperand();
783 }
784 
785 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
786     MLIRContext *, Optional<Location> location, ValueRange operands,
787     DictionaryAttr attributes, RegionRange regions,
788     SmallVectorImpl<Type> &inferredReturnTypes) {
789   if (operands[0].getType() != operands[1].getType()) {
790     return emitOptionalError(location, "operand type mismatch ",
791                              operands[0].getType(), " vs ",
792                              operands[1].getType());
793   }
794   inferredReturnTypes.assign({operands[0].getType()});
795   return success();
796 }
797 
798 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
799     MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
800     DictionaryAttr attributes, RegionRange regions,
801     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
802   // Create return type consisting of the last element of the first operand.
803   auto operandType = operands.front().getType();
804   auto sval = operandType.dyn_cast<ShapedType>();
805   if (!sval) {
806     return emitOptionalError(location, "only shaped type operands allowed");
807   }
808   int64_t dim =
809       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
810   auto type = IntegerType::get(context, 17);
811   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
812   return success();
813 }
814 
815 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
816     OpBuilder &builder, ValueRange operands,
817     llvm::SmallVectorImpl<Value> &shapes) {
818   shapes = SmallVector<Value, 1>{
819       builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
820   return success();
821 }
822 
823 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
824     OpBuilder &builder, ValueRange operands,
825     llvm::SmallVectorImpl<Value> &shapes) {
826   Location loc = getLoc();
827   shapes.reserve(operands.size());
828   for (Value operand : llvm::reverse(operands)) {
829     auto currShape = llvm::to_vector<4>(llvm::map_range(
830         llvm::seq<int64_t>(
831             0, operand.getType().cast<RankedTensorType>().getRank()),
832         [&](int64_t dim) -> Value {
833           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
834         }));
835     shapes.push_back(builder.create<tensor::FromElementsOp>(
836         getLoc(), builder.getIndexType(), currShape));
837   }
838   return success();
839 }
840 
841 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
842     OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
843   Location loc = getLoc();
844   shapes.reserve(getNumOperands());
845   for (Value operand : llvm::reverse(getOperands())) {
846     auto currShape = llvm::to_vector<4>(llvm::map_range(
847         llvm::seq<int64_t>(
848             0, operand.getType().cast<RankedTensorType>().getRank()),
849         [&](int64_t dim) -> Value {
850           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
851         }));
852     shapes.emplace_back(std::move(currShape));
853   }
854   return success();
855 }
856 
857 //===----------------------------------------------------------------------===//
858 // Test SideEffect interfaces
859 //===----------------------------------------------------------------------===//
860 
861 namespace {
862 /// A test resource for side effects.
863 struct TestResource : public SideEffects::Resource::Base<TestResource> {
864   StringRef getName() final { return "<Test>"; }
865 };
866 } // end anonymous namespace
867 
868 static void testSideEffectOpGetEffect(
869     Operation *op,
870     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
871         &effects) {
872   auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
873   if (!effectsAttr)
874     return;
875 
876   effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
877 }
878 
879 void SideEffectOp::getEffects(
880     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
881   // Check for an effects attribute on the op instance.
882   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
883   if (!effectsAttr)
884     return;
885 
886   // If there is one, it is an array of dictionary attributes that hold
887   // information on the effects of this operation.
888   for (Attribute element : effectsAttr) {
889     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
890 
891     // Get the specific memory effect.
892     MemoryEffects::Effect *effect =
893         StringSwitch<MemoryEffects::Effect *>(
894             effectElement.get("effect").cast<StringAttr>().getValue())
895             .Case("allocate", MemoryEffects::Allocate::get())
896             .Case("free", MemoryEffects::Free::get())
897             .Case("read", MemoryEffects::Read::get())
898             .Case("write", MemoryEffects::Write::get());
899 
900     // Check for a non-default resource to use.
901     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
902     if (effectElement.get("test_resource"))
903       resource = TestResource::get();
904 
905     // Check for a result to affect.
906     if (effectElement.get("on_result"))
907       effects.emplace_back(effect, getResult(), resource);
908     else if (Attribute ref = effectElement.get("on_reference"))
909       effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
910     else
911       effects.emplace_back(effect, resource);
912   }
913 }
914 
915 void SideEffectOp::getEffects(
916     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
917   testSideEffectOpGetEffect(getOperation(), effects);
918 }
919 
920 //===----------------------------------------------------------------------===//
921 // StringAttrPrettyNameOp
922 //===----------------------------------------------------------------------===//
923 
924 // This op has fancy handling of its SSA result name.
925 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
926                                                OperationState &result) {
927   // Add the result types.
928   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
929     result.addTypes(parser.getBuilder().getIntegerType(32));
930 
931   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
932     return failure();
933 
934   // If the attribute dictionary contains no 'names' attribute, infer it from
935   // the SSA name (if specified).
936   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
937     return attr.first == "names";
938   });
939 
940   // If there was no name specified, check to see if there was a useful name
941   // specified in the asm file.
942   if (hadNames || parser.getNumResults() == 0)
943     return success();
944 
945   SmallVector<StringRef, 4> names;
946   auto *context = result.getContext();
947 
948   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
949     auto resultName = parser.getResultName(i);
950     StringRef nameStr;
951     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
952       nameStr = resultName.first;
953 
954     names.push_back(nameStr);
955   }
956 
957   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
958   result.attributes.push_back({Identifier::get("names", context), namesAttr});
959   return success();
960 }
961 
962 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
963   p << "test.string_attr_pretty_name";
964 
965   // Note that we only need to print the "name" attribute if the asmprinter
966   // result name disagrees with it.  This can happen in strange cases, e.g.
967   // when there are conflicts.
968   bool namesDisagree = op.names().size() != op.getNumResults();
969 
970   SmallString<32> resultNameStr;
971   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
972     resultNameStr.clear();
973     llvm::raw_svector_ostream tmpStream(resultNameStr);
974     p.printOperand(op.getResult(i), tmpStream);
975 
976     auto expectedName = op.names()[i].dyn_cast<StringAttr>();
977     if (!expectedName ||
978         tmpStream.str().drop_front() != expectedName.getValue()) {
979       namesDisagree = true;
980     }
981   }
982 
983   if (namesDisagree)
984     p.printOptionalAttrDictWithKeyword(op->getAttrs());
985   else
986     p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"});
987 }
988 
989 // We set the SSA name in the asm syntax to the contents of the name
990 // attribute.
991 void StringAttrPrettyNameOp::getAsmResultNames(
992     function_ref<void(Value, StringRef)> setNameFn) {
993 
994   auto value = names();
995   for (size_t i = 0, e = value.size(); i != e; ++i)
996     if (auto str = value[i].dyn_cast<StringAttr>())
997       if (!str.getValue().empty())
998         setNameFn(getResult(i), str.getValue());
999 }
1000 
1001 //===----------------------------------------------------------------------===//
1002 // RegionIfOp
1003 //===----------------------------------------------------------------------===//
1004 
1005 static void print(OpAsmPrinter &p, RegionIfOp op) {
1006   p << RegionIfOp::getOperationName() << " ";
1007   p.printOperands(op.getOperands());
1008   p << ": " << op.getOperandTypes();
1009   p.printArrowTypeList(op.getResultTypes());
1010   p << " then";
1011   p.printRegion(op.thenRegion(),
1012                 /*printEntryBlockArgs=*/true,
1013                 /*printBlockTerminators=*/true);
1014   p << " else";
1015   p.printRegion(op.elseRegion(),
1016                 /*printEntryBlockArgs=*/true,
1017                 /*printBlockTerminators=*/true);
1018   p << " join";
1019   p.printRegion(op.joinRegion(),
1020                 /*printEntryBlockArgs=*/true,
1021                 /*printBlockTerminators=*/true);
1022 }
1023 
1024 static ParseResult parseRegionIfOp(OpAsmParser &parser,
1025                                    OperationState &result) {
1026   SmallVector<OpAsmParser::OperandType, 2> operandInfos;
1027   SmallVector<Type, 2> operandTypes;
1028 
1029   result.regions.reserve(3);
1030   Region *thenRegion = result.addRegion();
1031   Region *elseRegion = result.addRegion();
1032   Region *joinRegion = result.addRegion();
1033 
1034   // Parse operand, type and arrow type lists.
1035   if (parser.parseOperandList(operandInfos) ||
1036       parser.parseColonTypeList(operandTypes) ||
1037       parser.parseArrowTypeList(result.types))
1038     return failure();
1039 
1040   // Parse all attached regions.
1041   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
1042       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
1043       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
1044     return failure();
1045 
1046   return parser.resolveOperands(operandInfos, operandTypes,
1047                                 parser.getCurrentLocation(), result.operands);
1048 }
1049 
1050 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
1051   assert(index < 2 && "invalid region index");
1052   return getOperands();
1053 }
1054 
1055 void RegionIfOp::getSuccessorRegions(
1056     Optional<unsigned> index, ArrayRef<Attribute> operands,
1057     SmallVectorImpl<RegionSuccessor> &regions) {
1058   // We always branch to the join region.
1059   if (index.hasValue()) {
1060     if (index.getValue() < 2)
1061       regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs()));
1062     else
1063       regions.push_back(RegionSuccessor(getResults()));
1064     return;
1065   }
1066 
1067   // The then and else regions are the entry regions of this op.
1068   regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs()));
1069   regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs()));
1070 }
1071 
1072 //===----------------------------------------------------------------------===//
1073 // SingleNoTerminatorCustomAsmOp
1074 //===----------------------------------------------------------------------===//
1075 
1076 static ParseResult parseSingleNoTerminatorCustomAsmOp(OpAsmParser &parser,
1077                                                       OperationState &state) {
1078   Region *body = state.addRegion();
1079   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
1080     return failure();
1081   return success();
1082 }
1083 
1084 static void print(SingleNoTerminatorCustomAsmOp op, OpAsmPrinter &printer) {
1085   printer << op.getOperationName();
1086   printer.printRegion(
1087       op.getRegion(), /*printEntryBlockArgs=*/false,
1088       // This op has a single block without terminators. But explicitly mark
1089       // as not printing block terminators for testing.
1090       /*printBlockTerminators=*/false);
1091 }
1092 
1093 #include "TestOpEnums.cpp.inc"
1094 #include "TestOpInterfaces.cpp.inc"
1095 #include "TestOpStructs.cpp.inc"
1096 #include "TestTypeInterfaces.cpp.inc"
1097 
1098 #define GET_OP_CLASSES
1099 #include "TestOps.cpp.inc"
1100