1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestAttributes.h"
11 #include "TestInterfaces.h"
12 #include "TestTypes.h"
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/DLTI/DLTI.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/DialectImplementation.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/Reducer/ReductionPatternInterface.h"
22 #include "mlir/Transforms/FoldUtils.h"
23 #include "mlir/Transforms/InliningUtils.h"
24 #include "llvm/ADT/StringSwitch.h"
25 
26 // Include this before the using namespace lines below to
27 // test that we don't have namespace dependencies.
28 #include "TestOpsDialect.cpp.inc"
29 
30 using namespace mlir;
31 using namespace test;
32 
33 void test::registerTestDialect(DialectRegistry &registry) {
34   registry.insert<TestDialect>();
35 }
36 
37 //===----------------------------------------------------------------------===//
38 // TestDialect Interfaces
39 //===----------------------------------------------------------------------===//
40 
41 namespace {
42 
43 /// Testing the correctness of some traits.
44 static_assert(
45     llvm::is_detected<OpTrait::has_implicit_terminator_t,
46                       SingleBlockImplicitTerminatorOp>::value,
47     "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
48 static_assert(OpTrait::hasSingleBlockImplicitTerminator<
49                   SingleBlockImplicitTerminatorOp>::value,
50               "hasSingleBlockImplicitTerminator does not match "
51               "SingleBlockImplicitTerminatorOp");
52 
53 // Test support for interacting with the AsmPrinter.
54 struct TestOpAsmInterface : public OpAsmDialectInterface {
55   using OpAsmDialectInterface::OpAsmDialectInterface;
56 
57   AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
58     StringAttr strAttr = attr.dyn_cast<StringAttr>();
59     if (!strAttr)
60       return AliasResult::NoAlias;
61 
62     // Check the contents of the string attribute to see what the test alias
63     // should be named.
64     Optional<StringRef> aliasName =
65         StringSwitch<Optional<StringRef>>(strAttr.getValue())
66             .Case("alias_test:dot_in_name", StringRef("test.alias"))
67             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
68             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
69             .Case("alias_test:sanitize_conflict_a",
70                   StringRef("test_alias_conflict0"))
71             .Case("alias_test:sanitize_conflict_b",
72                   StringRef("test_alias_conflict0_"))
73             .Case("alias_test:tensor_encoding", StringRef("test_encoding"))
74             .Default(llvm::None);
75     if (!aliasName)
76       return AliasResult::NoAlias;
77 
78     os << *aliasName;
79     return AliasResult::FinalAlias;
80   }
81 
82   AliasResult getAlias(Type type, raw_ostream &os) const final {
83     if (auto tupleType = type.dyn_cast<TupleType>()) {
84       if (tupleType.size() > 0 &&
85           llvm::all_of(tupleType.getTypes(), [](Type elemType) {
86             return elemType.isa<SimpleAType>();
87           })) {
88         os << "test_tuple";
89         return AliasResult::FinalAlias;
90       }
91     }
92     if (auto intType = type.dyn_cast<TestIntegerType>()) {
93       if (intType.getSignedness() ==
94               TestIntegerType::SignednessSemantics::Unsigned &&
95           intType.getWidth() == 8) {
96         os << "test_ui8";
97         return AliasResult::FinalAlias;
98       }
99     }
100     return AliasResult::NoAlias;
101   }
102 
103   void getAsmResultNames(Operation *op,
104                          OpAsmSetValueNameFn setNameFn) const final {
105     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
106       setNameFn(asmOp, "result");
107   }
108 };
109 
110 struct TestDialectFoldInterface : public DialectFoldInterface {
111   using DialectFoldInterface::DialectFoldInterface;
112 
113   /// Registered hook to check if the given region, which is attached to an
114   /// operation that is *not* isolated from above, should be used when
115   /// materializing constants.
116   bool shouldMaterializeInto(Region *region) const final {
117     // If this is a one region operation, then insert into it.
118     return isa<OneRegionOp>(region->getParentOp());
119   }
120 };
121 
122 /// This class defines the interface for handling inlining with standard
123 /// operations.
124 struct TestInlinerInterface : public DialectInlinerInterface {
125   using DialectInlinerInterface::DialectInlinerInterface;
126 
127   //===--------------------------------------------------------------------===//
128   // Analysis Hooks
129   //===--------------------------------------------------------------------===//
130 
131   bool isLegalToInline(Operation *call, Operation *callable,
132                        bool wouldBeCloned) const final {
133     // Don't allow inlining calls that are marked `noinline`.
134     return !call->hasAttr("noinline");
135   }
136   bool isLegalToInline(Region *, Region *, bool,
137                        BlockAndValueMapping &) const final {
138     // Inlining into test dialect regions is legal.
139     return true;
140   }
141   bool isLegalToInline(Operation *, Region *, bool,
142                        BlockAndValueMapping &) const final {
143     return true;
144   }
145 
146   bool shouldAnalyzeRecursively(Operation *op) const final {
147     // Analyze recursively if this is not a functional region operation, it
148     // froms a separate functional scope.
149     return !isa<FunctionalRegionOp>(op);
150   }
151 
152   //===--------------------------------------------------------------------===//
153   // Transformation Hooks
154   //===--------------------------------------------------------------------===//
155 
156   /// Handle the given inlined terminator by replacing it with a new operation
157   /// as necessary.
158   void handleTerminator(Operation *op,
159                         ArrayRef<Value> valuesToRepl) const final {
160     // Only handle "test.return" here.
161     auto returnOp = dyn_cast<TestReturnOp>(op);
162     if (!returnOp)
163       return;
164 
165     // Replace the values directly with the return operands.
166     assert(returnOp.getNumOperands() == valuesToRepl.size());
167     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
168       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
169   }
170 
171   /// Attempt to materialize a conversion for a type mismatch between a call
172   /// from this dialect, and a callable region. This method should generate an
173   /// operation that takes 'input' as the only operand, and produces a single
174   /// result of 'resultType'. If a conversion can not be generated, nullptr
175   /// should be returned.
176   Operation *materializeCallConversion(OpBuilder &builder, Value input,
177                                        Type resultType,
178                                        Location conversionLoc) const final {
179     // Only allow conversion for i16/i32 types.
180     if (!(resultType.isSignlessInteger(16) ||
181           resultType.isSignlessInteger(32)) ||
182         !(input.getType().isSignlessInteger(16) ||
183           input.getType().isSignlessInteger(32)))
184       return nullptr;
185     return builder.create<TestCastOp>(conversionLoc, resultType, input);
186   }
187 
188   void processInlinedCallBlocks(
189       Operation *call,
190       iterator_range<Region::iterator> inlinedBlocks) const final {
191     if (!isa<ConversionCallOp>(call))
192       return;
193 
194     // Set attributed on all ops in the inlined blocks.
195     for (Block &block : inlinedBlocks) {
196       block.walk([&](Operation *op) {
197         op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
198       });
199     }
200   }
201 };
202 
203 struct TestReductionPatternInterface : public DialectReductionPatternInterface {
204 public:
205   TestReductionPatternInterface(Dialect *dialect)
206       : DialectReductionPatternInterface(dialect) {}
207 
208   void populateReductionPatterns(RewritePatternSet &patterns) const final {
209     populateTestReductionPatterns(patterns);
210   }
211 };
212 
213 } // namespace
214 
215 //===----------------------------------------------------------------------===//
216 // TestDialect
217 //===----------------------------------------------------------------------===//
218 
219 static void testSideEffectOpGetEffect(
220     Operation *op,
221     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects);
222 
223 // This is the implementation of a dialect fallback for `TestEffectOpInterface`.
224 struct TestOpEffectInterfaceFallback
225     : public TestEffectOpInterface::FallbackModel<
226           TestOpEffectInterfaceFallback> {
227   static bool classof(Operation *op) {
228     bool isSupportedOp =
229         op->getName().getStringRef() == "test.unregistered_side_effect_op";
230     assert(isSupportedOp && "Unexpected dispatch");
231     return isSupportedOp;
232   }
233 
234   void
235   getEffects(Operation *op,
236              SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
237                  &effects) const {
238     testSideEffectOpGetEffect(op, effects);
239   }
240 };
241 
242 void TestDialect::initialize() {
243   registerAttributes();
244   registerTypes();
245   addOperations<
246 #define GET_OP_LIST
247 #include "TestOps.cpp.inc"
248       >();
249   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
250                 TestInlinerInterface, TestReductionPatternInterface>();
251   allowUnknownOperations();
252 
253   // Instantiate our fallback op interface that we'll use on specific
254   // unregistered op.
255   fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
256 }
257 TestDialect::~TestDialect() {
258   delete static_cast<TestOpEffectInterfaceFallback *>(
259       fallbackEffectOpInterfaces);
260 }
261 
262 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
263                                             Type type, Location loc) {
264   return builder.create<TestOpConstant>(loc, type, value);
265 }
266 
267 ::mlir::LogicalResult FormatInferType2Op::inferReturnTypes(
268     ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location,
269     ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
270     ::mlir::RegionRange regions,
271     ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
272   inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)});
273   return ::mlir::success();
274 }
275 
276 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
277                                                OperationName opName) {
278   if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
279       typeID == TypeID::get<TestEffectOpInterface>())
280     return fallbackEffectOpInterfaces;
281   return nullptr;
282 }
283 
284 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
285                                                     NamedAttribute namedAttr) {
286   if (namedAttr.getName() == "test.invalid_attr")
287     return op->emitError() << "invalid to use 'test.invalid_attr'";
288   return success();
289 }
290 
291 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
292                                                     unsigned regionIndex,
293                                                     unsigned argIndex,
294                                                     NamedAttribute namedAttr) {
295   if (namedAttr.getName() == "test.invalid_attr")
296     return op->emitError() << "invalid to use 'test.invalid_attr'";
297   return success();
298 }
299 
300 LogicalResult
301 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
302                                          unsigned resultIndex,
303                                          NamedAttribute namedAttr) {
304   if (namedAttr.getName() == "test.invalid_attr")
305     return op->emitError() << "invalid to use 'test.invalid_attr'";
306   return success();
307 }
308 
309 Optional<Dialect::ParseOpHook>
310 TestDialect::getParseOperationHook(StringRef opName) const {
311   if (opName == "test.dialect_custom_printer") {
312     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
313       return parser.parseKeyword("custom_format");
314     }};
315   }
316   if (opName == "test.dialect_custom_format_fallback") {
317     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
318       return parser.parseKeyword("custom_format_fallback");
319     }};
320   }
321   return None;
322 }
323 
324 llvm::unique_function<void(Operation *, OpAsmPrinter &)>
325 TestDialect::getOperationPrinter(Operation *op) const {
326   StringRef opName = op->getName().getStringRef();
327   if (opName == "test.dialect_custom_printer") {
328     return [](Operation *op, OpAsmPrinter &printer) {
329       printer.getStream() << " custom_format";
330     };
331   }
332   if (opName == "test.dialect_custom_format_fallback") {
333     return [](Operation *op, OpAsmPrinter &printer) {
334       printer.getStream() << " custom_format_fallback";
335     };
336   }
337   return {};
338 }
339 
340 //===----------------------------------------------------------------------===//
341 // TestBranchOp
342 //===----------------------------------------------------------------------===//
343 
344 Optional<MutableOperandRange>
345 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
346   assert(index == 0 && "invalid successor index");
347   return getTargetOperandsMutable();
348 }
349 
350 //===----------------------------------------------------------------------===//
351 // TestDialectCanonicalizerOp
352 //===----------------------------------------------------------------------===//
353 
354 static LogicalResult
355 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
356                                PatternRewriter &rewriter) {
357   rewriter.replaceOpWithNewOp<arith::ConstantOp>(
358       op, rewriter.getI32IntegerAttr(42));
359   return success();
360 }
361 
362 void TestDialect::getCanonicalizationPatterns(
363     RewritePatternSet &results) const {
364   results.add(&dialectCanonicalizationPattern);
365 }
366 
367 //===----------------------------------------------------------------------===//
368 // TestFoldToCallOp
369 //===----------------------------------------------------------------------===//
370 
371 namespace {
372 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
373   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
374 
375   LogicalResult matchAndRewrite(FoldToCallOp op,
376                                 PatternRewriter &rewriter) const override {
377     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.getCalleeAttr(),
378                                         ValueRange());
379     return success();
380   }
381 };
382 } // namespace
383 
384 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
385                                                MLIRContext *context) {
386   results.add<FoldToCallOpPattern>(context);
387 }
388 
389 //===----------------------------------------------------------------------===//
390 // Test Format* operations
391 //===----------------------------------------------------------------------===//
392 
393 //===----------------------------------------------------------------------===//
394 // Parsing
395 
396 static ParseResult parseCustomDirectiveOperands(
397     OpAsmParser &parser, OpAsmParser::OperandType &operand,
398     Optional<OpAsmParser::OperandType> &optOperand,
399     SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
400   if (parser.parseOperand(operand))
401     return failure();
402   if (succeeded(parser.parseOptionalComma())) {
403     optOperand.emplace();
404     if (parser.parseOperand(*optOperand))
405       return failure();
406   }
407   if (parser.parseArrow() || parser.parseLParen() ||
408       parser.parseOperandList(varOperands) || parser.parseRParen())
409     return failure();
410   return success();
411 }
412 static ParseResult
413 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
414                             Type &optOperandType,
415                             SmallVectorImpl<Type> &varOperandTypes) {
416   if (parser.parseColon())
417     return failure();
418 
419   if (parser.parseType(operandType))
420     return failure();
421   if (succeeded(parser.parseOptionalComma())) {
422     if (parser.parseType(optOperandType))
423       return failure();
424   }
425   if (parser.parseArrow() || parser.parseLParen() ||
426       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
427     return failure();
428   return success();
429 }
430 static ParseResult
431 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
432                                  Type optOperandType,
433                                  const SmallVectorImpl<Type> &varOperandTypes) {
434   if (parser.parseKeyword("type_refs_capture"))
435     return failure();
436 
437   Type operandType2, optOperandType2;
438   SmallVector<Type, 1> varOperandTypes2;
439   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
440                                   varOperandTypes2))
441     return failure();
442 
443   if (operandType != operandType2 || optOperandType != optOperandType2 ||
444       varOperandTypes != varOperandTypes2)
445     return failure();
446 
447   return success();
448 }
449 static ParseResult parseCustomDirectiveOperandsAndTypes(
450     OpAsmParser &parser, OpAsmParser::OperandType &operand,
451     Optional<OpAsmParser::OperandType> &optOperand,
452     SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
453     Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
454   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
455       parseCustomDirectiveResults(parser, operandType, optOperandType,
456                                   varOperandTypes))
457     return failure();
458   return success();
459 }
460 static ParseResult parseCustomDirectiveRegions(
461     OpAsmParser &parser, Region &region,
462     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
463   if (parser.parseRegion(region))
464     return failure();
465   if (failed(parser.parseOptionalComma()))
466     return success();
467   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
468   if (parser.parseRegion(*varRegion))
469     return failure();
470   varRegions.emplace_back(std::move(varRegion));
471   return success();
472 }
473 static ParseResult
474 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
475                                SmallVectorImpl<Block *> &varSuccessors) {
476   if (parser.parseSuccessor(successor))
477     return failure();
478   if (failed(parser.parseOptionalComma()))
479     return success();
480   Block *varSuccessor;
481   if (parser.parseSuccessor(varSuccessor))
482     return failure();
483   varSuccessors.append(2, varSuccessor);
484   return success();
485 }
486 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
487                                                   IntegerAttr &attr,
488                                                   IntegerAttr &optAttr) {
489   if (parser.parseAttribute(attr))
490     return failure();
491   if (succeeded(parser.parseOptionalComma())) {
492     if (parser.parseAttribute(optAttr))
493       return failure();
494   }
495   return success();
496 }
497 
498 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
499                                                 NamedAttrList &attrs) {
500   return parser.parseOptionalAttrDict(attrs);
501 }
502 static ParseResult parseCustomDirectiveOptionalOperandRef(
503     OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) {
504   int64_t operandCount = 0;
505   if (parser.parseInteger(operandCount))
506     return failure();
507   bool expectedOptionalOperand = operandCount == 0;
508   return success(expectedOptionalOperand != optOperand.hasValue());
509 }
510 
511 //===----------------------------------------------------------------------===//
512 // Printing
513 
514 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
515                                          Value operand, Value optOperand,
516                                          OperandRange varOperands) {
517   printer << operand;
518   if (optOperand)
519     printer << ", " << optOperand;
520   printer << " -> (" << varOperands << ")";
521 }
522 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
523                                         Type operandType, Type optOperandType,
524                                         TypeRange varOperandTypes) {
525   printer << " : " << operandType;
526   if (optOperandType)
527     printer << ", " << optOperandType;
528   printer << " -> (" << varOperandTypes << ")";
529 }
530 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
531                                              Operation *op, Type operandType,
532                                              Type optOperandType,
533                                              TypeRange varOperandTypes) {
534   printer << " type_refs_capture ";
535   printCustomDirectiveResults(printer, op, operandType, optOperandType,
536                               varOperandTypes);
537 }
538 static void printCustomDirectiveOperandsAndTypes(
539     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
540     OperandRange varOperands, Type operandType, Type optOperandType,
541     TypeRange varOperandTypes) {
542   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
543   printCustomDirectiveResults(printer, op, operandType, optOperandType,
544                               varOperandTypes);
545 }
546 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
547                                         Region &region,
548                                         MutableArrayRef<Region> varRegions) {
549   printer.printRegion(region);
550   if (!varRegions.empty()) {
551     printer << ", ";
552     for (Region &region : varRegions)
553       printer.printRegion(region);
554   }
555 }
556 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
557                                            Block *successor,
558                                            SuccessorRange varSuccessors) {
559   printer << successor;
560   if (!varSuccessors.empty())
561     printer << ", " << varSuccessors.front();
562 }
563 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
564                                            Attribute attribute,
565                                            Attribute optAttribute) {
566   printer << attribute;
567   if (optAttribute)
568     printer << ", " << optAttribute;
569 }
570 
571 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
572                                          DictionaryAttr attrs) {
573   printer.printOptionalAttrDict(attrs.getValue());
574 }
575 
576 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
577                                                    Operation *op,
578                                                    Value optOperand) {
579   printer << (optOperand ? "1" : "0");
580 }
581 
582 //===----------------------------------------------------------------------===//
583 // Test IsolatedRegionOp - parse passthrough region arguments.
584 //===----------------------------------------------------------------------===//
585 
586 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
587                                          OperationState &result) {
588   OpAsmParser::OperandType argInfo;
589   Type argType = parser.getBuilder().getIndexType();
590 
591   // Parse the input operand.
592   if (parser.parseOperand(argInfo) ||
593       parser.resolveOperand(argInfo, argType, result.operands))
594     return failure();
595 
596   // Parse the body region, and reuse the operand info as the argument info.
597   Region *body = result.addRegion();
598   return parser.parseRegion(*body, argInfo, argType,
599                             /*enableNameShadowing=*/true);
600 }
601 
602 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
603   p << "test.isolated_region ";
604   p.printOperand(op.getOperand());
605   p.shadowRegionArgs(op.getRegion(), op.getOperand());
606   p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false);
607 }
608 
609 //===----------------------------------------------------------------------===//
610 // Test SSACFGRegionOp
611 //===----------------------------------------------------------------------===//
612 
613 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
614   return RegionKind::SSACFG;
615 }
616 
617 //===----------------------------------------------------------------------===//
618 // Test GraphRegionOp
619 //===----------------------------------------------------------------------===//
620 
621 static ParseResult parseGraphRegionOp(OpAsmParser &parser,
622                                       OperationState &result) {
623   // Parse the body region, and reuse the operand info as the argument info.
624   Region *body = result.addRegion();
625   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
626 }
627 
628 static void print(OpAsmPrinter &p, GraphRegionOp op) {
629   p << "test.graph_region ";
630   p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false);
631 }
632 
633 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
634   return RegionKind::Graph;
635 }
636 
637 //===----------------------------------------------------------------------===//
638 // Test AffineScopeOp
639 //===----------------------------------------------------------------------===//
640 
641 static ParseResult parseAffineScopeOp(OpAsmParser &parser,
642                                       OperationState &result) {
643   // Parse the body region, and reuse the operand info as the argument info.
644   Region *body = result.addRegion();
645   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
646 }
647 
648 static void print(OpAsmPrinter &p, AffineScopeOp op) {
649   p << "test.affine_scope ";
650   p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false);
651 }
652 
653 //===----------------------------------------------------------------------===//
654 // Test parser.
655 //===----------------------------------------------------------------------===//
656 
657 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser,
658                                               OperationState &result) {
659   if (parser.parseOptionalColon())
660     return success();
661   uint64_t numResults;
662   if (parser.parseInteger(numResults))
663     return failure();
664 
665   IndexType type = parser.getBuilder().getIndexType();
666   for (unsigned i = 0; i < numResults; ++i)
667     result.addTypes(type);
668   return success();
669 }
670 
671 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) {
672   if (unsigned numResults = op->getNumResults())
673     p << " : " << numResults;
674 }
675 
676 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser,
677                                               OperationState &result) {
678   StringRef keyword;
679   if (parser.parseKeyword(&keyword))
680     return failure();
681   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
682   return success();
683 }
684 
685 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) {
686   p << " " << op.getKeyword();
687 }
688 
689 //===----------------------------------------------------------------------===//
690 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
691 
692 static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
693                                          OperationState &result) {
694   if (parser.parseKeyword("wraps"))
695     return failure();
696 
697   // Parse the wrapped op in a region
698   Region &body = *result.addRegion();
699   body.push_back(new Block);
700   Block &block = body.back();
701   Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
702   if (!wrappedOp)
703     return failure();
704 
705   // Create a return terminator in the inner region, pass as operand to the
706   // terminator the returned values from the wrapped operation.
707   SmallVector<Value, 8> returnOperands(wrappedOp->getResults());
708   OpBuilder builder(parser.getContext());
709   builder.setInsertionPointToEnd(&block);
710   builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands);
711 
712   // Get the results type for the wrapping op from the terminator operands.
713   Operation &returnOp = body.back().back();
714   result.types.append(returnOp.operand_type_begin(),
715                       returnOp.operand_type_end());
716 
717   // Use the location of the wrapped op for the "test.wrapping_region" op.
718   result.location = wrappedOp->getLoc();
719 
720   return success();
721 }
722 
723 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
724   p << " wraps ";
725   p.printGenericOp(&op.getRegion().front().front());
726 }
727 
728 //===----------------------------------------------------------------------===//
729 // Test PrettyPrintedRegionOp -  exercising the following parser APIs
730 //   parseGenericOperationAfterOpName
731 //   parseCustomOperationName
732 //===----------------------------------------------------------------------===//
733 
734 static ParseResult parsePrettyPrintedRegionOp(OpAsmParser &parser,
735                                               OperationState &result) {
736 
737   llvm::SMLoc loc = parser.getCurrentLocation();
738   Location currLocation = parser.getEncodedSourceLoc(loc);
739 
740   // Parse the operands.
741   SmallVector<OpAsmParser::OperandType, 2> operands;
742   if (parser.parseOperandList(operands))
743     return failure();
744 
745   // Check if we are parsing the pretty-printed version
746   //  test.pretty_printed_region start <inner-op> end : <functional-type>
747   // Else fallback to parsing the "non pretty-printed" version.
748   if (!succeeded(parser.parseOptionalKeyword("start")))
749     return parser.parseGenericOperationAfterOpName(
750         result, llvm::makeArrayRef(operands));
751 
752   FailureOr<OperationName> parseOpNameInfo = parser.parseCustomOperationName();
753   if (failed(parseOpNameInfo))
754     return failure();
755 
756   StringRef innerOpName = parseOpNameInfo->getStringRef();
757 
758   FunctionType opFntype;
759   Optional<Location> explicitLoc;
760   if (parser.parseKeyword("end") || parser.parseColon() ||
761       parser.parseType(opFntype) ||
762       parser.parseOptionalLocationSpecifier(explicitLoc))
763     return failure();
764 
765   // If location of the op is explicitly provided, then use it; Else use
766   // the parser's current location.
767   Location opLoc = explicitLoc.getValueOr(currLocation);
768 
769   // Derive the SSA-values for op's operands.
770   if (parser.resolveOperands(operands, opFntype.getInputs(), loc,
771                              result.operands))
772     return failure();
773 
774   // Add a region for op.
775   Region &region = *result.addRegion();
776 
777   // Create a basic-block inside op's region.
778   Block &block = region.emplaceBlock();
779 
780   // Create and insert an "inner-op" operation in the block.
781   // Just for testing purposes, we can assume that inner op is a binary op with
782   // result and operand types all same as the test-op's first operand.
783   Type innerOpType = opFntype.getInput(0);
784   Value lhs = block.addArgument(innerOpType, opLoc);
785   Value rhs = block.addArgument(innerOpType, opLoc);
786 
787   OpBuilder builder(parser.getBuilder().getContext());
788   builder.setInsertionPointToStart(&block);
789 
790   OperationState innerOpState(opLoc, innerOpName);
791   innerOpState.operands.push_back(lhs);
792   innerOpState.operands.push_back(rhs);
793   innerOpState.addTypes(innerOpType);
794 
795   Operation *innerOp = builder.createOperation(innerOpState);
796 
797   // Insert a return statement in the block returning the inner-op's result.
798   builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults());
799 
800   // Populate the op operation-state with result-type and location.
801   result.addTypes(opFntype.getResults());
802   result.location = innerOp->getLoc();
803 
804   return success();
805 }
806 
807 static void print(OpAsmPrinter &p, PrettyPrintedRegionOp op) {
808   p << ' ';
809   p.printOperands(op.getOperands());
810 
811   Operation &innerOp = op.getRegion().front().front();
812   // Assuming that region has a single non-terminator inner-op, if the inner-op
813   // meets some criteria (which in this case is a simple one  based on the name
814   // of inner-op), then we can print the entire region in a succinct way.
815   // Here we assume that the prototype of "special.op" can be trivially derived
816   // while parsing it back.
817   if (innerOp.getName().getStringRef().equals("special.op")) {
818     p << " start special.op end";
819   } else {
820     p << " (";
821     p.printRegion(op.getRegion());
822     p << ")";
823   }
824 
825   p << " : ";
826   p.printFunctionalType(op);
827 }
828 
829 //===----------------------------------------------------------------------===//
830 // Test PolyForOp - parse list of region arguments.
831 //===----------------------------------------------------------------------===//
832 
833 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
834   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
835   // Parse list of region arguments without a delimiter.
836   if (parser.parseRegionArgumentList(ivsInfo))
837     return failure();
838 
839   // Parse the body region.
840   Region *body = result.addRegion();
841   auto &builder = parser.getBuilder();
842   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
843   return parser.parseRegion(*body, ivsInfo, argTypes);
844 }
845 
846 void PolyForOp::getAsmBlockArgumentNames(Region &region,
847                                          OpAsmSetValueNameFn setNameFn) {
848   auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names");
849   if (!arrayAttr)
850     return;
851   auto args = getRegion().front().getArguments();
852   auto e = std::min(arrayAttr.size(), args.size());
853   for (unsigned i = 0; i < e; ++i) {
854     if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
855       setNameFn(args[i], strAttr.getValue());
856   }
857 }
858 
859 //===----------------------------------------------------------------------===//
860 // Test removing op with inner ops.
861 //===----------------------------------------------------------------------===//
862 
863 namespace {
864 struct TestRemoveOpWithInnerOps
865     : public OpRewritePattern<TestOpWithRegionPattern> {
866   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
867 
868   void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
869 
870   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
871                                 PatternRewriter &rewriter) const override {
872     rewriter.eraseOp(op);
873     return success();
874   }
875 };
876 } // namespace
877 
878 void TestOpWithRegionPattern::getCanonicalizationPatterns(
879     RewritePatternSet &results, MLIRContext *context) {
880   results.add<TestRemoveOpWithInnerOps>(context);
881 }
882 
883 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
884   return getOperand();
885 }
886 
887 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
888   return getValue();
889 }
890 
891 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
892     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
893   for (Value input : this->getOperands()) {
894     results.push_back(input);
895   }
896   return success();
897 }
898 
899 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
900   assert(operands.size() == 1);
901   if (operands.front()) {
902     (*this)->setAttr("attr", operands.front());
903     return getResult();
904   }
905   return {};
906 }
907 
908 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
909   return getOperand();
910 }
911 
912 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
913     MLIRContext *, Optional<Location> location, ValueRange operands,
914     DictionaryAttr attributes, RegionRange regions,
915     SmallVectorImpl<Type> &inferredReturnTypes) {
916   if (operands[0].getType() != operands[1].getType()) {
917     return emitOptionalError(location, "operand type mismatch ",
918                              operands[0].getType(), " vs ",
919                              operands[1].getType());
920   }
921   inferredReturnTypes.assign({operands[0].getType()});
922   return success();
923 }
924 
925 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
926     MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
927     DictionaryAttr attributes, RegionRange regions,
928     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
929   // Create return type consisting of the last element of the first operand.
930   auto operandType = operands.front().getType();
931   auto sval = operandType.dyn_cast<ShapedType>();
932   if (!sval) {
933     return emitOptionalError(location, "only shaped type operands allowed");
934   }
935   int64_t dim =
936       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
937   auto type = IntegerType::get(context, 17);
938   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
939   return success();
940 }
941 
942 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
943     OpBuilder &builder, ValueRange operands,
944     llvm::SmallVectorImpl<Value> &shapes) {
945   shapes = SmallVector<Value, 1>{
946       builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
947   return success();
948 }
949 
950 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
951     OpBuilder &builder, ValueRange operands,
952     llvm::SmallVectorImpl<Value> &shapes) {
953   Location loc = getLoc();
954   shapes.reserve(operands.size());
955   for (Value operand : llvm::reverse(operands)) {
956     auto rank = operand.getType().cast<RankedTensorType>().getRank();
957     auto currShape = llvm::to_vector<4>(
958         llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value {
959           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
960         }));
961     shapes.push_back(builder.create<tensor::FromElementsOp>(
962         getLoc(), RankedTensorType::get({rank}, builder.getIndexType()),
963         currShape));
964   }
965   return success();
966 }
967 
968 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
969     OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
970   Location loc = getLoc();
971   shapes.reserve(getNumOperands());
972   for (Value operand : llvm::reverse(getOperands())) {
973     auto currShape = llvm::to_vector<4>(llvm::map_range(
974         llvm::seq<int64_t>(
975             0, operand.getType().cast<RankedTensorType>().getRank()),
976         [&](int64_t dim) -> Value {
977           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
978         }));
979     shapes.emplace_back(std::move(currShape));
980   }
981   return success();
982 }
983 
984 //===----------------------------------------------------------------------===//
985 // Test SideEffect interfaces
986 //===----------------------------------------------------------------------===//
987 
988 namespace {
989 /// A test resource for side effects.
990 struct TestResource : public SideEffects::Resource::Base<TestResource> {
991   StringRef getName() final { return "<Test>"; }
992 };
993 } // namespace
994 
995 static void testSideEffectOpGetEffect(
996     Operation *op,
997     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
998         &effects) {
999   auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
1000   if (!effectsAttr)
1001     return;
1002 
1003   effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
1004 }
1005 
1006 void SideEffectOp::getEffects(
1007     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1008   // Check for an effects attribute on the op instance.
1009   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
1010   if (!effectsAttr)
1011     return;
1012 
1013   // If there is one, it is an array of dictionary attributes that hold
1014   // information on the effects of this operation.
1015   for (Attribute element : effectsAttr) {
1016     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
1017 
1018     // Get the specific memory effect.
1019     MemoryEffects::Effect *effect =
1020         StringSwitch<MemoryEffects::Effect *>(
1021             effectElement.get("effect").cast<StringAttr>().getValue())
1022             .Case("allocate", MemoryEffects::Allocate::get())
1023             .Case("free", MemoryEffects::Free::get())
1024             .Case("read", MemoryEffects::Read::get())
1025             .Case("write", MemoryEffects::Write::get());
1026 
1027     // Check for a non-default resource to use.
1028     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
1029     if (effectElement.get("test_resource"))
1030       resource = TestResource::get();
1031 
1032     // Check for a result to affect.
1033     if (effectElement.get("on_result"))
1034       effects.emplace_back(effect, getResult(), resource);
1035     else if (Attribute ref = effectElement.get("on_reference"))
1036       effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
1037     else
1038       effects.emplace_back(effect, resource);
1039   }
1040 }
1041 
1042 void SideEffectOp::getEffects(
1043     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
1044   testSideEffectOpGetEffect(getOperation(), effects);
1045 }
1046 
1047 //===----------------------------------------------------------------------===//
1048 // StringAttrPrettyNameOp
1049 //===----------------------------------------------------------------------===//
1050 
1051 // This op has fancy handling of its SSA result name.
1052 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
1053                                                OperationState &result) {
1054   // Add the result types.
1055   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
1056     result.addTypes(parser.getBuilder().getIntegerType(32));
1057 
1058   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1059     return failure();
1060 
1061   // If the attribute dictionary contains no 'names' attribute, infer it from
1062   // the SSA name (if specified).
1063   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
1064     return attr.getName() == "names";
1065   });
1066 
1067   // If there was no name specified, check to see if there was a useful name
1068   // specified in the asm file.
1069   if (hadNames || parser.getNumResults() == 0)
1070     return success();
1071 
1072   SmallVector<StringRef, 4> names;
1073   auto *context = result.getContext();
1074 
1075   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
1076     auto resultName = parser.getResultName(i);
1077     StringRef nameStr;
1078     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
1079       nameStr = resultName.first;
1080 
1081     names.push_back(nameStr);
1082   }
1083 
1084   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
1085   result.attributes.push_back({StringAttr::get(context, "names"), namesAttr});
1086   return success();
1087 }
1088 
1089 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
1090   // Note that we only need to print the "name" attribute if the asmprinter
1091   // result name disagrees with it.  This can happen in strange cases, e.g.
1092   // when there are conflicts.
1093   bool namesDisagree = op.getNames().size() != op.getNumResults();
1094 
1095   SmallString<32> resultNameStr;
1096   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
1097     resultNameStr.clear();
1098     llvm::raw_svector_ostream tmpStream(resultNameStr);
1099     p.printOperand(op.getResult(i), tmpStream);
1100 
1101     auto expectedName = op.getNames()[i].dyn_cast<StringAttr>();
1102     if (!expectedName ||
1103         tmpStream.str().drop_front() != expectedName.getValue()) {
1104       namesDisagree = true;
1105     }
1106   }
1107 
1108   if (namesDisagree)
1109     p.printOptionalAttrDictWithKeyword(op->getAttrs());
1110   else
1111     p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"});
1112 }
1113 
1114 // We set the SSA name in the asm syntax to the contents of the name
1115 // attribute.
1116 void StringAttrPrettyNameOp::getAsmResultNames(
1117     function_ref<void(Value, StringRef)> setNameFn) {
1118 
1119   auto value = getNames();
1120   for (size_t i = 0, e = value.size(); i != e; ++i)
1121     if (auto str = value[i].dyn_cast<StringAttr>())
1122       if (!str.getValue().empty())
1123         setNameFn(getResult(i), str.getValue());
1124 }
1125 
1126 //===----------------------------------------------------------------------===//
1127 // RegionIfOp
1128 //===----------------------------------------------------------------------===//
1129 
1130 static void print(OpAsmPrinter &p, RegionIfOp op) {
1131   p << " ";
1132   p.printOperands(op.getOperands());
1133   p << ": " << op.getOperandTypes();
1134   p.printArrowTypeList(op.getResultTypes());
1135   p << " then";
1136   p.printRegion(op.getThenRegion(),
1137                 /*printEntryBlockArgs=*/true,
1138                 /*printBlockTerminators=*/true);
1139   p << " else";
1140   p.printRegion(op.getElseRegion(),
1141                 /*printEntryBlockArgs=*/true,
1142                 /*printBlockTerminators=*/true);
1143   p << " join";
1144   p.printRegion(op.getJoinRegion(),
1145                 /*printEntryBlockArgs=*/true,
1146                 /*printBlockTerminators=*/true);
1147 }
1148 
1149 static ParseResult parseRegionIfOp(OpAsmParser &parser,
1150                                    OperationState &result) {
1151   SmallVector<OpAsmParser::OperandType, 2> operandInfos;
1152   SmallVector<Type, 2> operandTypes;
1153 
1154   result.regions.reserve(3);
1155   Region *thenRegion = result.addRegion();
1156   Region *elseRegion = result.addRegion();
1157   Region *joinRegion = result.addRegion();
1158 
1159   // Parse operand, type and arrow type lists.
1160   if (parser.parseOperandList(operandInfos) ||
1161       parser.parseColonTypeList(operandTypes) ||
1162       parser.parseArrowTypeList(result.types))
1163     return failure();
1164 
1165   // Parse all attached regions.
1166   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
1167       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
1168       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
1169     return failure();
1170 
1171   return parser.resolveOperands(operandInfos, operandTypes,
1172                                 parser.getCurrentLocation(), result.operands);
1173 }
1174 
1175 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
1176   assert(index < 2 && "invalid region index");
1177   return getOperands();
1178 }
1179 
1180 void RegionIfOp::getSuccessorRegions(
1181     Optional<unsigned> index, ArrayRef<Attribute> operands,
1182     SmallVectorImpl<RegionSuccessor> &regions) {
1183   // We always branch to the join region.
1184   if (index.hasValue()) {
1185     if (index.getValue() < 2)
1186       regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
1187     else
1188       regions.push_back(RegionSuccessor(getResults()));
1189     return;
1190   }
1191 
1192   // The then and else regions are the entry regions of this op.
1193   regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs()));
1194   regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs()));
1195 }
1196 
1197 //===----------------------------------------------------------------------===//
1198 // SingleNoTerminatorCustomAsmOp
1199 //===----------------------------------------------------------------------===//
1200 
1201 static ParseResult parseSingleNoTerminatorCustomAsmOp(OpAsmParser &parser,
1202                                                       OperationState &state) {
1203   Region *body = state.addRegion();
1204   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
1205     return failure();
1206   return success();
1207 }
1208 
1209 static void print(SingleNoTerminatorCustomAsmOp op, OpAsmPrinter &printer) {
1210   printer.printRegion(
1211       op.getRegion(), /*printEntryBlockArgs=*/false,
1212       // This op has a single block without terminators. But explicitly mark
1213       // as not printing block terminators for testing.
1214       /*printBlockTerminators=*/false);
1215 }
1216 
1217 #include "TestOpEnums.cpp.inc"
1218 #include "TestOpInterfaces.cpp.inc"
1219 #include "TestOpStructs.cpp.inc"
1220 #include "TestTypeInterfaces.cpp.inc"
1221 
1222 #define GET_OP_CLASSES
1223 #include "TestOps.cpp.inc"
1224