1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestAttributes.h"
11 #include "TestInterfaces.h"
12 #include "TestTypes.h"
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/DLTI/DLTI.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/DialectImplementation.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/Reducer/ReductionPatternInterface.h"
22 #include "mlir/Transforms/FoldUtils.h"
23 #include "mlir/Transforms/InliningUtils.h"
24 #include "llvm/ADT/StringSwitch.h"
25 
26 // Include this before the using namespace lines below to
27 // test that we don't have namespace dependencies.
28 #include "TestOpsDialect.cpp.inc"
29 
30 using namespace mlir;
31 using namespace test;
32 
33 void test::registerTestDialect(DialectRegistry &registry) {
34   registry.insert<TestDialect>();
35 }
36 
37 //===----------------------------------------------------------------------===//
38 // TestDialect Interfaces
39 //===----------------------------------------------------------------------===//
40 
41 namespace {
42 
43 /// Testing the correctness of some traits.
44 static_assert(
45     llvm::is_detected<OpTrait::has_implicit_terminator_t,
46                       SingleBlockImplicitTerminatorOp>::value,
47     "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
48 static_assert(OpTrait::hasSingleBlockImplicitTerminator<
49                   SingleBlockImplicitTerminatorOp>::value,
50               "hasSingleBlockImplicitTerminator does not match "
51               "SingleBlockImplicitTerminatorOp");
52 
53 // Test support for interacting with the AsmPrinter.
54 struct TestOpAsmInterface : public OpAsmDialectInterface {
55   using OpAsmDialectInterface::OpAsmDialectInterface;
56 
57   AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
58     StringAttr strAttr = attr.dyn_cast<StringAttr>();
59     if (!strAttr)
60       return AliasResult::NoAlias;
61 
62     // Check the contents of the string attribute to see what the test alias
63     // should be named.
64     Optional<StringRef> aliasName =
65         StringSwitch<Optional<StringRef>>(strAttr.getValue())
66             .Case("alias_test:dot_in_name", StringRef("test.alias"))
67             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
68             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
69             .Case("alias_test:sanitize_conflict_a",
70                   StringRef("test_alias_conflict0"))
71             .Case("alias_test:sanitize_conflict_b",
72                   StringRef("test_alias_conflict0_"))
73             .Case("alias_test:tensor_encoding", StringRef("test_encoding"))
74             .Default(llvm::None);
75     if (!aliasName)
76       return AliasResult::NoAlias;
77 
78     os << *aliasName;
79     return AliasResult::FinalAlias;
80   }
81 
82   AliasResult getAlias(Type type, raw_ostream &os) const final {
83     if (auto tupleType = type.dyn_cast<TupleType>()) {
84       if (tupleType.size() > 0 &&
85           llvm::all_of(tupleType.getTypes(), [](Type elemType) {
86             return elemType.isa<SimpleAType>();
87           })) {
88         os << "test_tuple";
89         return AliasResult::FinalAlias;
90       }
91     }
92     if (auto intType = type.dyn_cast<TestIntegerType>()) {
93       if (intType.getSignedness() ==
94               TestIntegerType::SignednessSemantics::Unsigned &&
95           intType.getWidth() == 8) {
96         os << "test_ui8";
97         return AliasResult::FinalAlias;
98       }
99     }
100     return AliasResult::NoAlias;
101   }
102 
103   void getAsmResultNames(Operation *op,
104                          OpAsmSetValueNameFn setNameFn) const final {
105     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
106       setNameFn(asmOp, "result");
107   }
108 };
109 
110 struct TestDialectFoldInterface : public DialectFoldInterface {
111   using DialectFoldInterface::DialectFoldInterface;
112 
113   /// Registered hook to check if the given region, which is attached to an
114   /// operation that is *not* isolated from above, should be used when
115   /// materializing constants.
116   bool shouldMaterializeInto(Region *region) const final {
117     // If this is a one region operation, then insert into it.
118     return isa<OneRegionOp>(region->getParentOp());
119   }
120 };
121 
122 /// This class defines the interface for handling inlining with standard
123 /// operations.
124 struct TestInlinerInterface : public DialectInlinerInterface {
125   using DialectInlinerInterface::DialectInlinerInterface;
126 
127   //===--------------------------------------------------------------------===//
128   // Analysis Hooks
129   //===--------------------------------------------------------------------===//
130 
131   bool isLegalToInline(Operation *call, Operation *callable,
132                        bool wouldBeCloned) const final {
133     // Don't allow inlining calls that are marked `noinline`.
134     return !call->hasAttr("noinline");
135   }
136   bool isLegalToInline(Region *, Region *, bool,
137                        BlockAndValueMapping &) const final {
138     // Inlining into test dialect regions is legal.
139     return true;
140   }
141   bool isLegalToInline(Operation *, Region *, bool,
142                        BlockAndValueMapping &) const final {
143     return true;
144   }
145 
146   bool shouldAnalyzeRecursively(Operation *op) const final {
147     // Analyze recursively if this is not a functional region operation, it
148     // froms a separate functional scope.
149     return !isa<FunctionalRegionOp>(op);
150   }
151 
152   //===--------------------------------------------------------------------===//
153   // Transformation Hooks
154   //===--------------------------------------------------------------------===//
155 
156   /// Handle the given inlined terminator by replacing it with a new operation
157   /// as necessary.
158   void handleTerminator(Operation *op,
159                         ArrayRef<Value> valuesToRepl) const final {
160     // Only handle "test.return" here.
161     auto returnOp = dyn_cast<TestReturnOp>(op);
162     if (!returnOp)
163       return;
164 
165     // Replace the values directly with the return operands.
166     assert(returnOp.getNumOperands() == valuesToRepl.size());
167     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
168       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
169   }
170 
171   /// Attempt to materialize a conversion for a type mismatch between a call
172   /// from this dialect, and a callable region. This method should generate an
173   /// operation that takes 'input' as the only operand, and produces a single
174   /// result of 'resultType'. If a conversion can not be generated, nullptr
175   /// should be returned.
176   Operation *materializeCallConversion(OpBuilder &builder, Value input,
177                                        Type resultType,
178                                        Location conversionLoc) const final {
179     // Only allow conversion for i16/i32 types.
180     if (!(resultType.isSignlessInteger(16) ||
181           resultType.isSignlessInteger(32)) ||
182         !(input.getType().isSignlessInteger(16) ||
183           input.getType().isSignlessInteger(32)))
184       return nullptr;
185     return builder.create<TestCastOp>(conversionLoc, resultType, input);
186   }
187 
188   void processInlinedCallBlocks(
189       Operation *call,
190       iterator_range<Region::iterator> inlinedBlocks) const final {
191     if (!isa<ConversionCallOp>(call))
192       return;
193 
194     // Set attributed on all ops in the inlined blocks.
195     for (Block &block : inlinedBlocks) {
196       block.walk([&](Operation *op) {
197         op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
198       });
199     }
200   }
201 };
202 
203 struct TestReductionPatternInterface : public DialectReductionPatternInterface {
204 public:
205   TestReductionPatternInterface(Dialect *dialect)
206       : DialectReductionPatternInterface(dialect) {}
207 
208   void populateReductionPatterns(RewritePatternSet &patterns) const final {
209     populateTestReductionPatterns(patterns);
210   }
211 };
212 
213 } // namespace
214 
215 //===----------------------------------------------------------------------===//
216 // TestDialect
217 //===----------------------------------------------------------------------===//
218 
219 static void testSideEffectOpGetEffect(
220     Operation *op,
221     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects);
222 
223 // This is the implementation of a dialect fallback for `TestEffectOpInterface`.
224 struct TestOpEffectInterfaceFallback
225     : public TestEffectOpInterface::FallbackModel<
226           TestOpEffectInterfaceFallback> {
227   static bool classof(Operation *op) {
228     bool isSupportedOp =
229         op->getName().getStringRef() == "test.unregistered_side_effect_op";
230     assert(isSupportedOp && "Unexpected dispatch");
231     return isSupportedOp;
232   }
233 
234   void
235   getEffects(Operation *op,
236              SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
237                  &effects) const {
238     testSideEffectOpGetEffect(op, effects);
239   }
240 };
241 
242 void TestDialect::initialize() {
243   registerAttributes();
244   registerTypes();
245   addOperations<
246 #define GET_OP_LIST
247 #include "TestOps.cpp.inc"
248       >();
249   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
250                 TestInlinerInterface, TestReductionPatternInterface>();
251   allowUnknownOperations();
252 
253   // Instantiate our fallback op interface that we'll use on specific
254   // unregistered op.
255   fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
256 }
257 TestDialect::~TestDialect() {
258   delete static_cast<TestOpEffectInterfaceFallback *>(
259       fallbackEffectOpInterfaces);
260 }
261 
262 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
263                                             Type type, Location loc) {
264   return builder.create<TestOpConstant>(loc, type, value);
265 }
266 
267 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
268                                                OperationName opName) {
269   if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
270       typeID == TypeID::get<TestEffectOpInterface>())
271     return fallbackEffectOpInterfaces;
272   return nullptr;
273 }
274 
275 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
276                                                     NamedAttribute namedAttr) {
277   if (namedAttr.getName() == "test.invalid_attr")
278     return op->emitError() << "invalid to use 'test.invalid_attr'";
279   return success();
280 }
281 
282 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
283                                                     unsigned regionIndex,
284                                                     unsigned argIndex,
285                                                     NamedAttribute namedAttr) {
286   if (namedAttr.getName() == "test.invalid_attr")
287     return op->emitError() << "invalid to use 'test.invalid_attr'";
288   return success();
289 }
290 
291 LogicalResult
292 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
293                                          unsigned resultIndex,
294                                          NamedAttribute namedAttr) {
295   if (namedAttr.getName() == "test.invalid_attr")
296     return op->emitError() << "invalid to use 'test.invalid_attr'";
297   return success();
298 }
299 
300 Optional<Dialect::ParseOpHook>
301 TestDialect::getParseOperationHook(StringRef opName) const {
302   if (opName == "test.dialect_custom_printer") {
303     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
304       return parser.parseKeyword("custom_format");
305     }};
306   }
307   if (opName == "test.dialect_custom_format_fallback") {
308     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
309       return parser.parseKeyword("custom_format_fallback");
310     }};
311   }
312   return None;
313 }
314 
315 llvm::unique_function<void(Operation *, OpAsmPrinter &)>
316 TestDialect::getOperationPrinter(Operation *op) const {
317   StringRef opName = op->getName().getStringRef();
318   if (opName == "test.dialect_custom_printer") {
319     return [](Operation *op, OpAsmPrinter &printer) {
320       printer.getStream() << " custom_format";
321     };
322   }
323   if (opName == "test.dialect_custom_format_fallback") {
324     return [](Operation *op, OpAsmPrinter &printer) {
325       printer.getStream() << " custom_format_fallback";
326     };
327   }
328   return {};
329 }
330 
331 //===----------------------------------------------------------------------===//
332 // TestBranchOp
333 //===----------------------------------------------------------------------===//
334 
335 Optional<MutableOperandRange>
336 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
337   assert(index == 0 && "invalid successor index");
338   return getTargetOperandsMutable();
339 }
340 
341 //===----------------------------------------------------------------------===//
342 // TestDialectCanonicalizerOp
343 //===----------------------------------------------------------------------===//
344 
345 static LogicalResult
346 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
347                                PatternRewriter &rewriter) {
348   rewriter.replaceOpWithNewOp<arith::ConstantOp>(
349       op, rewriter.getI32IntegerAttr(42));
350   return success();
351 }
352 
353 void TestDialect::getCanonicalizationPatterns(
354     RewritePatternSet &results) const {
355   results.add(&dialectCanonicalizationPattern);
356 }
357 
358 //===----------------------------------------------------------------------===//
359 // TestFoldToCallOp
360 //===----------------------------------------------------------------------===//
361 
362 namespace {
363 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
364   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
365 
366   LogicalResult matchAndRewrite(FoldToCallOp op,
367                                 PatternRewriter &rewriter) const override {
368     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.getCalleeAttr(),
369                                         ValueRange());
370     return success();
371   }
372 };
373 } // namespace
374 
375 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
376                                                MLIRContext *context) {
377   results.add<FoldToCallOpPattern>(context);
378 }
379 
380 //===----------------------------------------------------------------------===//
381 // Test Format* operations
382 //===----------------------------------------------------------------------===//
383 
384 //===----------------------------------------------------------------------===//
385 // Parsing
386 
387 static ParseResult parseCustomDirectiveOperands(
388     OpAsmParser &parser, OpAsmParser::OperandType &operand,
389     Optional<OpAsmParser::OperandType> &optOperand,
390     SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
391   if (parser.parseOperand(operand))
392     return failure();
393   if (succeeded(parser.parseOptionalComma())) {
394     optOperand.emplace();
395     if (parser.parseOperand(*optOperand))
396       return failure();
397   }
398   if (parser.parseArrow() || parser.parseLParen() ||
399       parser.parseOperandList(varOperands) || parser.parseRParen())
400     return failure();
401   return success();
402 }
403 static ParseResult
404 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
405                             Type &optOperandType,
406                             SmallVectorImpl<Type> &varOperandTypes) {
407   if (parser.parseColon())
408     return failure();
409 
410   if (parser.parseType(operandType))
411     return failure();
412   if (succeeded(parser.parseOptionalComma())) {
413     if (parser.parseType(optOperandType))
414       return failure();
415   }
416   if (parser.parseArrow() || parser.parseLParen() ||
417       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
418     return failure();
419   return success();
420 }
421 static ParseResult
422 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
423                                  Type optOperandType,
424                                  const SmallVectorImpl<Type> &varOperandTypes) {
425   if (parser.parseKeyword("type_refs_capture"))
426     return failure();
427 
428   Type operandType2, optOperandType2;
429   SmallVector<Type, 1> varOperandTypes2;
430   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
431                                   varOperandTypes2))
432     return failure();
433 
434   if (operandType != operandType2 || optOperandType != optOperandType2 ||
435       varOperandTypes != varOperandTypes2)
436     return failure();
437 
438   return success();
439 }
440 static ParseResult parseCustomDirectiveOperandsAndTypes(
441     OpAsmParser &parser, OpAsmParser::OperandType &operand,
442     Optional<OpAsmParser::OperandType> &optOperand,
443     SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
444     Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
445   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
446       parseCustomDirectiveResults(parser, operandType, optOperandType,
447                                   varOperandTypes))
448     return failure();
449   return success();
450 }
451 static ParseResult parseCustomDirectiveRegions(
452     OpAsmParser &parser, Region &region,
453     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
454   if (parser.parseRegion(region))
455     return failure();
456   if (failed(parser.parseOptionalComma()))
457     return success();
458   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
459   if (parser.parseRegion(*varRegion))
460     return failure();
461   varRegions.emplace_back(std::move(varRegion));
462   return success();
463 }
464 static ParseResult
465 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
466                                SmallVectorImpl<Block *> &varSuccessors) {
467   if (parser.parseSuccessor(successor))
468     return failure();
469   if (failed(parser.parseOptionalComma()))
470     return success();
471   Block *varSuccessor;
472   if (parser.parseSuccessor(varSuccessor))
473     return failure();
474   varSuccessors.append(2, varSuccessor);
475   return success();
476 }
477 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
478                                                   IntegerAttr &attr,
479                                                   IntegerAttr &optAttr) {
480   if (parser.parseAttribute(attr))
481     return failure();
482   if (succeeded(parser.parseOptionalComma())) {
483     if (parser.parseAttribute(optAttr))
484       return failure();
485   }
486   return success();
487 }
488 
489 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
490                                                 NamedAttrList &attrs) {
491   return parser.parseOptionalAttrDict(attrs);
492 }
493 static ParseResult parseCustomDirectiveOptionalOperandRef(
494     OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) {
495   int64_t operandCount = 0;
496   if (parser.parseInteger(operandCount))
497     return failure();
498   bool expectedOptionalOperand = operandCount == 0;
499   return success(expectedOptionalOperand != optOperand.hasValue());
500 }
501 
502 //===----------------------------------------------------------------------===//
503 // Printing
504 
505 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
506                                          Value operand, Value optOperand,
507                                          OperandRange varOperands) {
508   printer << operand;
509   if (optOperand)
510     printer << ", " << optOperand;
511   printer << " -> (" << varOperands << ")";
512 }
513 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
514                                         Type operandType, Type optOperandType,
515                                         TypeRange varOperandTypes) {
516   printer << " : " << operandType;
517   if (optOperandType)
518     printer << ", " << optOperandType;
519   printer << " -> (" << varOperandTypes << ")";
520 }
521 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
522                                              Operation *op, Type operandType,
523                                              Type optOperandType,
524                                              TypeRange varOperandTypes) {
525   printer << " type_refs_capture ";
526   printCustomDirectiveResults(printer, op, operandType, optOperandType,
527                               varOperandTypes);
528 }
529 static void printCustomDirectiveOperandsAndTypes(
530     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
531     OperandRange varOperands, Type operandType, Type optOperandType,
532     TypeRange varOperandTypes) {
533   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
534   printCustomDirectiveResults(printer, op, operandType, optOperandType,
535                               varOperandTypes);
536 }
537 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
538                                         Region &region,
539                                         MutableArrayRef<Region> varRegions) {
540   printer.printRegion(region);
541   if (!varRegions.empty()) {
542     printer << ", ";
543     for (Region &region : varRegions)
544       printer.printRegion(region);
545   }
546 }
547 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
548                                            Block *successor,
549                                            SuccessorRange varSuccessors) {
550   printer << successor;
551   if (!varSuccessors.empty())
552     printer << ", " << varSuccessors.front();
553 }
554 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
555                                            Attribute attribute,
556                                            Attribute optAttribute) {
557   printer << attribute;
558   if (optAttribute)
559     printer << ", " << optAttribute;
560 }
561 
562 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
563                                          DictionaryAttr attrs) {
564   printer.printOptionalAttrDict(attrs.getValue());
565 }
566 
567 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
568                                                    Operation *op,
569                                                    Value optOperand) {
570   printer << (optOperand ? "1" : "0");
571 }
572 
573 //===----------------------------------------------------------------------===//
574 // Test IsolatedRegionOp - parse passthrough region arguments.
575 //===----------------------------------------------------------------------===//
576 
577 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
578                                          OperationState &result) {
579   OpAsmParser::OperandType argInfo;
580   Type argType = parser.getBuilder().getIndexType();
581 
582   // Parse the input operand.
583   if (parser.parseOperand(argInfo) ||
584       parser.resolveOperand(argInfo, argType, result.operands))
585     return failure();
586 
587   // Parse the body region, and reuse the operand info as the argument info.
588   Region *body = result.addRegion();
589   return parser.parseRegion(*body, argInfo, argType,
590                             /*enableNameShadowing=*/true);
591 }
592 
593 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
594   p << "test.isolated_region ";
595   p.printOperand(op.getOperand());
596   p.shadowRegionArgs(op.getRegion(), op.getOperand());
597   p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false);
598 }
599 
600 //===----------------------------------------------------------------------===//
601 // Test SSACFGRegionOp
602 //===----------------------------------------------------------------------===//
603 
604 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
605   return RegionKind::SSACFG;
606 }
607 
608 //===----------------------------------------------------------------------===//
609 // Test GraphRegionOp
610 //===----------------------------------------------------------------------===//
611 
612 static ParseResult parseGraphRegionOp(OpAsmParser &parser,
613                                       OperationState &result) {
614   // Parse the body region, and reuse the operand info as the argument info.
615   Region *body = result.addRegion();
616   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
617 }
618 
619 static void print(OpAsmPrinter &p, GraphRegionOp op) {
620   p << "test.graph_region ";
621   p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false);
622 }
623 
624 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
625   return RegionKind::Graph;
626 }
627 
628 //===----------------------------------------------------------------------===//
629 // Test AffineScopeOp
630 //===----------------------------------------------------------------------===//
631 
632 static ParseResult parseAffineScopeOp(OpAsmParser &parser,
633                                       OperationState &result) {
634   // Parse the body region, and reuse the operand info as the argument info.
635   Region *body = result.addRegion();
636   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
637 }
638 
639 static void print(OpAsmPrinter &p, AffineScopeOp op) {
640   p << "test.affine_scope ";
641   p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false);
642 }
643 
644 //===----------------------------------------------------------------------===//
645 // Test parser.
646 //===----------------------------------------------------------------------===//
647 
648 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser,
649                                               OperationState &result) {
650   if (parser.parseOptionalColon())
651     return success();
652   uint64_t numResults;
653   if (parser.parseInteger(numResults))
654     return failure();
655 
656   IndexType type = parser.getBuilder().getIndexType();
657   for (unsigned i = 0; i < numResults; ++i)
658     result.addTypes(type);
659   return success();
660 }
661 
662 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) {
663   if (unsigned numResults = op->getNumResults())
664     p << " : " << numResults;
665 }
666 
667 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser,
668                                               OperationState &result) {
669   StringRef keyword;
670   if (parser.parseKeyword(&keyword))
671     return failure();
672   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
673   return success();
674 }
675 
676 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) {
677   p << " " << op.getKeyword();
678 }
679 
680 //===----------------------------------------------------------------------===//
681 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
682 
683 static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
684                                          OperationState &result) {
685   if (parser.parseKeyword("wraps"))
686     return failure();
687 
688   // Parse the wrapped op in a region
689   Region &body = *result.addRegion();
690   body.push_back(new Block);
691   Block &block = body.back();
692   Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
693   if (!wrappedOp)
694     return failure();
695 
696   // Create a return terminator in the inner region, pass as operand to the
697   // terminator the returned values from the wrapped operation.
698   SmallVector<Value, 8> returnOperands(wrappedOp->getResults());
699   OpBuilder builder(parser.getContext());
700   builder.setInsertionPointToEnd(&block);
701   builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands);
702 
703   // Get the results type for the wrapping op from the terminator operands.
704   Operation &returnOp = body.back().back();
705   result.types.append(returnOp.operand_type_begin(),
706                       returnOp.operand_type_end());
707 
708   // Use the location of the wrapped op for the "test.wrapping_region" op.
709   result.location = wrappedOp->getLoc();
710 
711   return success();
712 }
713 
714 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
715   p << " wraps ";
716   p.printGenericOp(&op.getRegion().front().front());
717 }
718 
719 //===----------------------------------------------------------------------===//
720 // Test PrettyPrintedRegionOp -  exercising the following parser APIs
721 //   parseGenericOperationAfterOpName
722 //   parseCustomOperationName
723 //===----------------------------------------------------------------------===//
724 
725 static ParseResult parsePrettyPrintedRegionOp(OpAsmParser &parser,
726                                               OperationState &result) {
727 
728   llvm::SMLoc loc = parser.getCurrentLocation();
729   Location currLocation = parser.getEncodedSourceLoc(loc);
730 
731   // Parse the operands.
732   SmallVector<OpAsmParser::OperandType, 2> operands;
733   if (parser.parseOperandList(operands))
734     return failure();
735 
736   // Check if we are parsing the pretty-printed version
737   //  test.pretty_printed_region start <inner-op> end : <functional-type>
738   // Else fallback to parsing the "non pretty-printed" version.
739   if (!succeeded(parser.parseOptionalKeyword("start")))
740     return parser.parseGenericOperationAfterOpName(
741         result, llvm::makeArrayRef(operands));
742 
743   FailureOr<OperationName> parseOpNameInfo = parser.parseCustomOperationName();
744   if (failed(parseOpNameInfo))
745     return failure();
746 
747   StringRef innerOpName = parseOpNameInfo->getStringRef();
748 
749   FunctionType opFntype;
750   Optional<Location> explicitLoc;
751   if (parser.parseKeyword("end") || parser.parseColon() ||
752       parser.parseType(opFntype) ||
753       parser.parseOptionalLocationSpecifier(explicitLoc))
754     return failure();
755 
756   // If location of the op is explicitly provided, then use it; Else use
757   // the parser's current location.
758   Location opLoc = explicitLoc.getValueOr(currLocation);
759 
760   // Derive the SSA-values for op's operands.
761   if (parser.resolveOperands(operands, opFntype.getInputs(), loc,
762                              result.operands))
763     return failure();
764 
765   // Add a region for op.
766   Region &region = *result.addRegion();
767 
768   // Create a basic-block inside op's region.
769   Block &block = region.emplaceBlock();
770 
771   // Create and insert an "inner-op" operation in the block.
772   // Just for testing purposes, we can assume that inner op is a binary op with
773   // result and operand types all same as the test-op's first operand.
774   Type innerOpType = opFntype.getInput(0);
775   Value lhs = block.addArgument(innerOpType, opLoc);
776   Value rhs = block.addArgument(innerOpType, opLoc);
777 
778   OpBuilder builder(parser.getBuilder().getContext());
779   builder.setInsertionPointToStart(&block);
780 
781   OperationState innerOpState(opLoc, innerOpName);
782   innerOpState.operands.push_back(lhs);
783   innerOpState.operands.push_back(rhs);
784   innerOpState.addTypes(innerOpType);
785 
786   Operation *innerOp = builder.createOperation(innerOpState);
787 
788   // Insert a return statement in the block returning the inner-op's result.
789   builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults());
790 
791   // Populate the op operation-state with result-type and location.
792   result.addTypes(opFntype.getResults());
793   result.location = innerOp->getLoc();
794 
795   return success();
796 }
797 
798 static void print(OpAsmPrinter &p, PrettyPrintedRegionOp op) {
799   p << ' ';
800   p.printOperands(op.getOperands());
801 
802   Operation &innerOp = op.getRegion().front().front();
803   // Assuming that region has a single non-terminator inner-op, if the inner-op
804   // meets some criteria (which in this case is a simple one  based on the name
805   // of inner-op), then we can print the entire region in a succinct way.
806   // Here we assume that the prototype of "special.op" can be trivially derived
807   // while parsing it back.
808   if (innerOp.getName().getStringRef().equals("special.op")) {
809     p << " start special.op end";
810   } else {
811     p << " (";
812     p.printRegion(op.getRegion());
813     p << ")";
814   }
815 
816   p << " : ";
817   p.printFunctionalType(op);
818 }
819 
820 //===----------------------------------------------------------------------===//
821 // Test PolyForOp - parse list of region arguments.
822 //===----------------------------------------------------------------------===//
823 
824 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
825   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
826   // Parse list of region arguments without a delimiter.
827   if (parser.parseRegionArgumentList(ivsInfo))
828     return failure();
829 
830   // Parse the body region.
831   Region *body = result.addRegion();
832   auto &builder = parser.getBuilder();
833   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
834   return parser.parseRegion(*body, ivsInfo, argTypes);
835 }
836 
837 void PolyForOp::getAsmBlockArgumentNames(Region &region,
838                                          OpAsmSetValueNameFn setNameFn) {
839   auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names");
840   if (!arrayAttr)
841     return;
842   auto args = getRegion().front().getArguments();
843   auto e = std::min(arrayAttr.size(), args.size());
844   for (unsigned i = 0; i < e; ++i) {
845     if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
846       setNameFn(args[i], strAttr.getValue());
847   }
848 }
849 
850 //===----------------------------------------------------------------------===//
851 // Test removing op with inner ops.
852 //===----------------------------------------------------------------------===//
853 
854 namespace {
855 struct TestRemoveOpWithInnerOps
856     : public OpRewritePattern<TestOpWithRegionPattern> {
857   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
858 
859   void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
860 
861   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
862                                 PatternRewriter &rewriter) const override {
863     rewriter.eraseOp(op);
864     return success();
865   }
866 };
867 } // namespace
868 
869 void TestOpWithRegionPattern::getCanonicalizationPatterns(
870     RewritePatternSet &results, MLIRContext *context) {
871   results.add<TestRemoveOpWithInnerOps>(context);
872 }
873 
874 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
875   return getOperand();
876 }
877 
878 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
879   return getValue();
880 }
881 
882 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
883     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
884   for (Value input : this->getOperands()) {
885     results.push_back(input);
886   }
887   return success();
888 }
889 
890 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
891   assert(operands.size() == 1);
892   if (operands.front()) {
893     (*this)->setAttr("attr", operands.front());
894     return getResult();
895   }
896   return {};
897 }
898 
899 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
900   return getOperand();
901 }
902 
903 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
904     MLIRContext *, Optional<Location> location, ValueRange operands,
905     DictionaryAttr attributes, RegionRange regions,
906     SmallVectorImpl<Type> &inferredReturnTypes) {
907   if (operands[0].getType() != operands[1].getType()) {
908     return emitOptionalError(location, "operand type mismatch ",
909                              operands[0].getType(), " vs ",
910                              operands[1].getType());
911   }
912   inferredReturnTypes.assign({operands[0].getType()});
913   return success();
914 }
915 
916 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
917     MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
918     DictionaryAttr attributes, RegionRange regions,
919     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
920   // Create return type consisting of the last element of the first operand.
921   auto operandType = operands.front().getType();
922   auto sval = operandType.dyn_cast<ShapedType>();
923   if (!sval) {
924     return emitOptionalError(location, "only shaped type operands allowed");
925   }
926   int64_t dim =
927       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
928   auto type = IntegerType::get(context, 17);
929   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
930   return success();
931 }
932 
933 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
934     OpBuilder &builder, ValueRange operands,
935     llvm::SmallVectorImpl<Value> &shapes) {
936   shapes = SmallVector<Value, 1>{
937       builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
938   return success();
939 }
940 
941 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
942     OpBuilder &builder, ValueRange operands,
943     llvm::SmallVectorImpl<Value> &shapes) {
944   Location loc = getLoc();
945   shapes.reserve(operands.size());
946   for (Value operand : llvm::reverse(operands)) {
947     auto rank = operand.getType().cast<RankedTensorType>().getRank();
948     auto currShape = llvm::to_vector<4>(
949         llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value {
950           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
951         }));
952     shapes.push_back(builder.create<tensor::FromElementsOp>(
953         getLoc(), RankedTensorType::get({rank}, builder.getIndexType()),
954         currShape));
955   }
956   return success();
957 }
958 
959 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
960     OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
961   Location loc = getLoc();
962   shapes.reserve(getNumOperands());
963   for (Value operand : llvm::reverse(getOperands())) {
964     auto currShape = llvm::to_vector<4>(llvm::map_range(
965         llvm::seq<int64_t>(
966             0, operand.getType().cast<RankedTensorType>().getRank()),
967         [&](int64_t dim) -> Value {
968           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
969         }));
970     shapes.emplace_back(std::move(currShape));
971   }
972   return success();
973 }
974 
975 //===----------------------------------------------------------------------===//
976 // Test SideEffect interfaces
977 //===----------------------------------------------------------------------===//
978 
979 namespace {
980 /// A test resource for side effects.
981 struct TestResource : public SideEffects::Resource::Base<TestResource> {
982   StringRef getName() final { return "<Test>"; }
983 };
984 } // namespace
985 
986 static void testSideEffectOpGetEffect(
987     Operation *op,
988     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
989         &effects) {
990   auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
991   if (!effectsAttr)
992     return;
993 
994   effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
995 }
996 
997 void SideEffectOp::getEffects(
998     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
999   // Check for an effects attribute on the op instance.
1000   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
1001   if (!effectsAttr)
1002     return;
1003 
1004   // If there is one, it is an array of dictionary attributes that hold
1005   // information on the effects of this operation.
1006   for (Attribute element : effectsAttr) {
1007     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
1008 
1009     // Get the specific memory effect.
1010     MemoryEffects::Effect *effect =
1011         StringSwitch<MemoryEffects::Effect *>(
1012             effectElement.get("effect").cast<StringAttr>().getValue())
1013             .Case("allocate", MemoryEffects::Allocate::get())
1014             .Case("free", MemoryEffects::Free::get())
1015             .Case("read", MemoryEffects::Read::get())
1016             .Case("write", MemoryEffects::Write::get());
1017 
1018     // Check for a non-default resource to use.
1019     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
1020     if (effectElement.get("test_resource"))
1021       resource = TestResource::get();
1022 
1023     // Check for a result to affect.
1024     if (effectElement.get("on_result"))
1025       effects.emplace_back(effect, getResult(), resource);
1026     else if (Attribute ref = effectElement.get("on_reference"))
1027       effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
1028     else
1029       effects.emplace_back(effect, resource);
1030   }
1031 }
1032 
1033 void SideEffectOp::getEffects(
1034     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
1035   testSideEffectOpGetEffect(getOperation(), effects);
1036 }
1037 
1038 //===----------------------------------------------------------------------===//
1039 // StringAttrPrettyNameOp
1040 //===----------------------------------------------------------------------===//
1041 
1042 // This op has fancy handling of its SSA result name.
1043 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
1044                                                OperationState &result) {
1045   // Add the result types.
1046   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
1047     result.addTypes(parser.getBuilder().getIntegerType(32));
1048 
1049   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1050     return failure();
1051 
1052   // If the attribute dictionary contains no 'names' attribute, infer it from
1053   // the SSA name (if specified).
1054   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
1055     return attr.getName() == "names";
1056   });
1057 
1058   // If there was no name specified, check to see if there was a useful name
1059   // specified in the asm file.
1060   if (hadNames || parser.getNumResults() == 0)
1061     return success();
1062 
1063   SmallVector<StringRef, 4> names;
1064   auto *context = result.getContext();
1065 
1066   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
1067     auto resultName = parser.getResultName(i);
1068     StringRef nameStr;
1069     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
1070       nameStr = resultName.first;
1071 
1072     names.push_back(nameStr);
1073   }
1074 
1075   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
1076   result.attributes.push_back({StringAttr::get(context, "names"), namesAttr});
1077   return success();
1078 }
1079 
1080 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
1081   // Note that we only need to print the "name" attribute if the asmprinter
1082   // result name disagrees with it.  This can happen in strange cases, e.g.
1083   // when there are conflicts.
1084   bool namesDisagree = op.getNames().size() != op.getNumResults();
1085 
1086   SmallString<32> resultNameStr;
1087   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
1088     resultNameStr.clear();
1089     llvm::raw_svector_ostream tmpStream(resultNameStr);
1090     p.printOperand(op.getResult(i), tmpStream);
1091 
1092     auto expectedName = op.getNames()[i].dyn_cast<StringAttr>();
1093     if (!expectedName ||
1094         tmpStream.str().drop_front() != expectedName.getValue()) {
1095       namesDisagree = true;
1096     }
1097   }
1098 
1099   if (namesDisagree)
1100     p.printOptionalAttrDictWithKeyword(op->getAttrs());
1101   else
1102     p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"});
1103 }
1104 
1105 // We set the SSA name in the asm syntax to the contents of the name
1106 // attribute.
1107 void StringAttrPrettyNameOp::getAsmResultNames(
1108     function_ref<void(Value, StringRef)> setNameFn) {
1109 
1110   auto value = getNames();
1111   for (size_t i = 0, e = value.size(); i != e; ++i)
1112     if (auto str = value[i].dyn_cast<StringAttr>())
1113       if (!str.getValue().empty())
1114         setNameFn(getResult(i), str.getValue());
1115 }
1116 
1117 //===----------------------------------------------------------------------===//
1118 // RegionIfOp
1119 //===----------------------------------------------------------------------===//
1120 
1121 static void print(OpAsmPrinter &p, RegionIfOp op) {
1122   p << " ";
1123   p.printOperands(op.getOperands());
1124   p << ": " << op.getOperandTypes();
1125   p.printArrowTypeList(op.getResultTypes());
1126   p << " then";
1127   p.printRegion(op.getThenRegion(),
1128                 /*printEntryBlockArgs=*/true,
1129                 /*printBlockTerminators=*/true);
1130   p << " else";
1131   p.printRegion(op.getElseRegion(),
1132                 /*printEntryBlockArgs=*/true,
1133                 /*printBlockTerminators=*/true);
1134   p << " join";
1135   p.printRegion(op.getJoinRegion(),
1136                 /*printEntryBlockArgs=*/true,
1137                 /*printBlockTerminators=*/true);
1138 }
1139 
1140 static ParseResult parseRegionIfOp(OpAsmParser &parser,
1141                                    OperationState &result) {
1142   SmallVector<OpAsmParser::OperandType, 2> operandInfos;
1143   SmallVector<Type, 2> operandTypes;
1144 
1145   result.regions.reserve(3);
1146   Region *thenRegion = result.addRegion();
1147   Region *elseRegion = result.addRegion();
1148   Region *joinRegion = result.addRegion();
1149 
1150   // Parse operand, type and arrow type lists.
1151   if (parser.parseOperandList(operandInfos) ||
1152       parser.parseColonTypeList(operandTypes) ||
1153       parser.parseArrowTypeList(result.types))
1154     return failure();
1155 
1156   // Parse all attached regions.
1157   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
1158       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
1159       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
1160     return failure();
1161 
1162   return parser.resolveOperands(operandInfos, operandTypes,
1163                                 parser.getCurrentLocation(), result.operands);
1164 }
1165 
1166 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
1167   assert(index < 2 && "invalid region index");
1168   return getOperands();
1169 }
1170 
1171 void RegionIfOp::getSuccessorRegions(
1172     Optional<unsigned> index, ArrayRef<Attribute> operands,
1173     SmallVectorImpl<RegionSuccessor> &regions) {
1174   // We always branch to the join region.
1175   if (index.hasValue()) {
1176     if (index.getValue() < 2)
1177       regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
1178     else
1179       regions.push_back(RegionSuccessor(getResults()));
1180     return;
1181   }
1182 
1183   // The then and else regions are the entry regions of this op.
1184   regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs()));
1185   regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs()));
1186 }
1187 
1188 //===----------------------------------------------------------------------===//
1189 // SingleNoTerminatorCustomAsmOp
1190 //===----------------------------------------------------------------------===//
1191 
1192 static ParseResult parseSingleNoTerminatorCustomAsmOp(OpAsmParser &parser,
1193                                                       OperationState &state) {
1194   Region *body = state.addRegion();
1195   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
1196     return failure();
1197   return success();
1198 }
1199 
1200 static void print(SingleNoTerminatorCustomAsmOp op, OpAsmPrinter &printer) {
1201   printer.printRegion(
1202       op.getRegion(), /*printEntryBlockArgs=*/false,
1203       // This op has a single block without terminators. But explicitly mark
1204       // as not printing block terminators for testing.
1205       /*printBlockTerminators=*/false);
1206 }
1207 
1208 #include "TestOpEnums.cpp.inc"
1209 #include "TestOpInterfaces.cpp.inc"
1210 #include "TestOpStructs.cpp.inc"
1211 #include "TestTypeInterfaces.cpp.inc"
1212 
1213 #define GET_OP_CLASSES
1214 #include "TestOps.cpp.inc"
1215