1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestAttributes.h"
11 #include "TestInterfaces.h"
12 #include "TestTypes.h"
13 #include "mlir/Dialect/DLTI/DLTI.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/DialectImplementation.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Reducer/ReductionPatternInterface.h"
21 #include "mlir/Transforms/FoldUtils.h"
22 #include "mlir/Transforms/InliningUtils.h"
23 #include "llvm/ADT/StringSwitch.h"
24 
25 using namespace mlir;
26 using namespace mlir::test;
27 
28 #include "TestOpsDialect.cpp.inc"
29 
30 void mlir::test::registerTestDialect(DialectRegistry &registry) {
31   registry.insert<TestDialect>();
32 }
33 
34 //===----------------------------------------------------------------------===//
35 // TestDialect Interfaces
36 //===----------------------------------------------------------------------===//
37 
38 namespace {
39 
40 /// Testing the correctness of some traits.
41 static_assert(
42     llvm::is_detected<OpTrait::has_implicit_terminator_t,
43                       SingleBlockImplicitTerminatorOp>::value,
44     "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
45 static_assert(OpTrait::hasSingleBlockImplicitTerminator<
46                   SingleBlockImplicitTerminatorOp>::value,
47               "hasSingleBlockImplicitTerminator does not match "
48               "SingleBlockImplicitTerminatorOp");
49 
50 // Test support for interacting with the AsmPrinter.
51 struct TestOpAsmInterface : public OpAsmDialectInterface {
52   using OpAsmDialectInterface::OpAsmDialectInterface;
53 
54   LogicalResult getAlias(Attribute attr, raw_ostream &os) const final {
55     StringAttr strAttr = attr.dyn_cast<StringAttr>();
56     if (!strAttr)
57       return failure();
58 
59     // Check the contents of the string attribute to see what the test alias
60     // should be named.
61     Optional<StringRef> aliasName =
62         StringSwitch<Optional<StringRef>>(strAttr.getValue())
63             .Case("alias_test:dot_in_name", StringRef("test.alias"))
64             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
65             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
66             .Case("alias_test:sanitize_conflict_a",
67                   StringRef("test_alias_conflict0"))
68             .Case("alias_test:sanitize_conflict_b",
69                   StringRef("test_alias_conflict0_"))
70             .Default(llvm::None);
71     if (!aliasName)
72       return failure();
73 
74     os << *aliasName;
75     return success();
76   }
77 
78   void getAsmResultNames(Operation *op,
79                          OpAsmSetValueNameFn setNameFn) const final {
80     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
81       setNameFn(asmOp, "result");
82   }
83 
84   void getAsmBlockArgumentNames(Block *block,
85                                 OpAsmSetValueNameFn setNameFn) const final {
86     auto op = block->getParentOp();
87     auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
88     if (!arrayAttr)
89       return;
90     auto args = block->getArguments();
91     auto e = std::min(arrayAttr.size(), args.size());
92     for (unsigned i = 0; i < e; ++i) {
93       if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
94         setNameFn(args[i], strAttr.getValue());
95     }
96   }
97 };
98 
99 struct TestDialectFoldInterface : public DialectFoldInterface {
100   using DialectFoldInterface::DialectFoldInterface;
101 
102   /// Registered hook to check if the given region, which is attached to an
103   /// operation that is *not* isolated from above, should be used when
104   /// materializing constants.
105   bool shouldMaterializeInto(Region *region) const final {
106     // If this is a one region operation, then insert into it.
107     return isa<OneRegionOp>(region->getParentOp());
108   }
109 };
110 
111 /// This class defines the interface for handling inlining with standard
112 /// operations.
113 struct TestInlinerInterface : public DialectInlinerInterface {
114   using DialectInlinerInterface::DialectInlinerInterface;
115 
116   //===--------------------------------------------------------------------===//
117   // Analysis Hooks
118   //===--------------------------------------------------------------------===//
119 
120   bool isLegalToInline(Operation *call, Operation *callable,
121                        bool wouldBeCloned) const final {
122     // Don't allow inlining calls that are marked `noinline`.
123     return !call->hasAttr("noinline");
124   }
125   bool isLegalToInline(Region *, Region *, bool,
126                        BlockAndValueMapping &) const final {
127     // Inlining into test dialect regions is legal.
128     return true;
129   }
130   bool isLegalToInline(Operation *, Region *, bool,
131                        BlockAndValueMapping &) const final {
132     return true;
133   }
134 
135   bool shouldAnalyzeRecursively(Operation *op) const final {
136     // Analyze recursively if this is not a functional region operation, it
137     // froms a separate functional scope.
138     return !isa<FunctionalRegionOp>(op);
139   }
140 
141   //===--------------------------------------------------------------------===//
142   // Transformation Hooks
143   //===--------------------------------------------------------------------===//
144 
145   /// Handle the given inlined terminator by replacing it with a new operation
146   /// as necessary.
147   void handleTerminator(Operation *op,
148                         ArrayRef<Value> valuesToRepl) const final {
149     // Only handle "test.return" here.
150     auto returnOp = dyn_cast<TestReturnOp>(op);
151     if (!returnOp)
152       return;
153 
154     // Replace the values directly with the return operands.
155     assert(returnOp.getNumOperands() == valuesToRepl.size());
156     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
157       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
158   }
159 
160   /// Attempt to materialize a conversion for a type mismatch between a call
161   /// from this dialect, and a callable region. This method should generate an
162   /// operation that takes 'input' as the only operand, and produces a single
163   /// result of 'resultType'. If a conversion can not be generated, nullptr
164   /// should be returned.
165   Operation *materializeCallConversion(OpBuilder &builder, Value input,
166                                        Type resultType,
167                                        Location conversionLoc) const final {
168     // Only allow conversion for i16/i32 types.
169     if (!(resultType.isSignlessInteger(16) ||
170           resultType.isSignlessInteger(32)) ||
171         !(input.getType().isSignlessInteger(16) ||
172           input.getType().isSignlessInteger(32)))
173       return nullptr;
174     return builder.create<TestCastOp>(conversionLoc, resultType, input);
175   }
176 
177   void processInlinedCallBlocks(
178       Operation *call,
179       iterator_range<Region::iterator> inlinedBlocks) const final {
180     if (!isa<ConversionCallOp>(call))
181       return;
182 
183     // Set attributed on all ops in the inlined blocks.
184     for (Block &block : inlinedBlocks) {
185       block.walk([&](Operation *op) {
186         op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
187       });
188     }
189   }
190 };
191 
192 struct TestReductionPatternInterface : public DialectReductionPatternInterface {
193 public:
194   TestReductionPatternInterface(Dialect *dialect)
195       : DialectReductionPatternInterface(dialect) {}
196 
197   virtual void
198   populateReductionPatterns(RewritePatternSet &patterns) const final {
199     populateTestReductionPatterns(patterns);
200   }
201 };
202 
203 } // end anonymous namespace
204 
205 //===----------------------------------------------------------------------===//
206 // TestDialect
207 //===----------------------------------------------------------------------===//
208 
209 static void testSideEffectOpGetEffect(
210     Operation *op,
211     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects);
212 
213 // This is the implementation of a dialect fallback for `TestEffectOpInterface`.
214 struct TestOpEffectInterfaceFallback
215     : public TestEffectOpInterface::FallbackModel<
216           TestOpEffectInterfaceFallback> {
217   static bool classof(Operation *op) {
218     bool isSupportedOp =
219         op->getName().getStringRef() == "test.unregistered_side_effect_op";
220     assert(isSupportedOp && "Unexpected dispatch");
221     return isSupportedOp;
222   }
223 
224   void
225   getEffects(Operation *op,
226              SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
227                  &effects) const {
228     testSideEffectOpGetEffect(op, effects);
229   }
230 };
231 
232 void TestDialect::initialize() {
233   registerAttributes();
234   registerTypes();
235   addOperations<
236 #define GET_OP_LIST
237 #include "TestOps.cpp.inc"
238       >();
239   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
240                 TestInlinerInterface, TestReductionPatternInterface>();
241   allowUnknownOperations();
242 
243   // Instantiate our fallback op interface that we'll use on specific
244   // unregistered op.
245   fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
246 }
247 TestDialect::~TestDialect() {
248   delete static_cast<TestOpEffectInterfaceFallback *>(
249       fallbackEffectOpInterfaces);
250 }
251 
252 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
253                                             Type type, Location loc) {
254   return builder.create<TestOpConstant>(loc, type, value);
255 }
256 
257 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
258                                                OperationName opName) {
259   if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
260       typeID == TypeID::get<TestEffectOpInterface>())
261     return fallbackEffectOpInterfaces;
262   return nullptr;
263 }
264 
265 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
266                                                     NamedAttribute namedAttr) {
267   if (namedAttr.first == "test.invalid_attr")
268     return op->emitError() << "invalid to use 'test.invalid_attr'";
269   return success();
270 }
271 
272 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
273                                                     unsigned regionIndex,
274                                                     unsigned argIndex,
275                                                     NamedAttribute namedAttr) {
276   if (namedAttr.first == "test.invalid_attr")
277     return op->emitError() << "invalid to use 'test.invalid_attr'";
278   return success();
279 }
280 
281 LogicalResult
282 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
283                                          unsigned resultIndex,
284                                          NamedAttribute namedAttr) {
285   if (namedAttr.first == "test.invalid_attr")
286     return op->emitError() << "invalid to use 'test.invalid_attr'";
287   return success();
288 }
289 
290 Optional<Dialect::ParseOpHook>
291 TestDialect::getParseOperationHook(StringRef opName) const {
292   if (opName == "test.dialect_custom_printer") {
293     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
294       return parser.parseKeyword("custom_format");
295     }};
296   }
297   return None;
298 }
299 
300 LogicalResult TestDialect::printOperation(Operation *op,
301                                           OpAsmPrinter &printer) const {
302   StringRef opName = op->getName().getStringRef();
303   if (opName == "test.dialect_custom_printer") {
304     printer.getStream() << opName << " custom_format";
305     return success();
306   }
307   return failure();
308 }
309 
310 //===----------------------------------------------------------------------===//
311 // TestBranchOp
312 //===----------------------------------------------------------------------===//
313 
314 Optional<MutableOperandRange>
315 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
316   assert(index == 0 && "invalid successor index");
317   return targetOperandsMutable();
318 }
319 
320 //===----------------------------------------------------------------------===//
321 // TestDialectCanonicalizerOp
322 //===----------------------------------------------------------------------===//
323 
324 static LogicalResult
325 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
326                                PatternRewriter &rewriter) {
327   rewriter.replaceOpWithNewOp<ConstantOp>(op, rewriter.getI32Type(),
328                                           rewriter.getI32IntegerAttr(42));
329   return success();
330 }
331 
332 void TestDialect::getCanonicalizationPatterns(
333     RewritePatternSet &results) const {
334   results.add(&dialectCanonicalizationPattern);
335 }
336 
337 //===----------------------------------------------------------------------===//
338 // TestFoldToCallOp
339 //===----------------------------------------------------------------------===//
340 
341 namespace {
342 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
343   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
344 
345   LogicalResult matchAndRewrite(FoldToCallOp op,
346                                 PatternRewriter &rewriter) const override {
347     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(),
348                                         ValueRange());
349     return success();
350   }
351 };
352 } // end anonymous namespace
353 
354 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
355                                                MLIRContext *context) {
356   results.add<FoldToCallOpPattern>(context);
357 }
358 
359 //===----------------------------------------------------------------------===//
360 // Test Format* operations
361 //===----------------------------------------------------------------------===//
362 
363 //===----------------------------------------------------------------------===//
364 // Parsing
365 
366 static ParseResult parseCustomDirectiveOperands(
367     OpAsmParser &parser, OpAsmParser::OperandType &operand,
368     Optional<OpAsmParser::OperandType> &optOperand,
369     SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
370   if (parser.parseOperand(operand))
371     return failure();
372   if (succeeded(parser.parseOptionalComma())) {
373     optOperand.emplace();
374     if (parser.parseOperand(*optOperand))
375       return failure();
376   }
377   if (parser.parseArrow() || parser.parseLParen() ||
378       parser.parseOperandList(varOperands) || parser.parseRParen())
379     return failure();
380   return success();
381 }
382 static ParseResult
383 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
384                             Type &optOperandType,
385                             SmallVectorImpl<Type> &varOperandTypes) {
386   if (parser.parseColon())
387     return failure();
388 
389   if (parser.parseType(operandType))
390     return failure();
391   if (succeeded(parser.parseOptionalComma())) {
392     if (parser.parseType(optOperandType))
393       return failure();
394   }
395   if (parser.parseArrow() || parser.parseLParen() ||
396       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
397     return failure();
398   return success();
399 }
400 static ParseResult
401 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
402                                  Type optOperandType,
403                                  const SmallVectorImpl<Type> &varOperandTypes) {
404   if (parser.parseKeyword("type_refs_capture"))
405     return failure();
406 
407   Type operandType2, optOperandType2;
408   SmallVector<Type, 1> varOperandTypes2;
409   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
410                                   varOperandTypes2))
411     return failure();
412 
413   if (operandType != operandType2 || optOperandType != optOperandType2 ||
414       varOperandTypes != varOperandTypes2)
415     return failure();
416 
417   return success();
418 }
419 static ParseResult parseCustomDirectiveOperandsAndTypes(
420     OpAsmParser &parser, OpAsmParser::OperandType &operand,
421     Optional<OpAsmParser::OperandType> &optOperand,
422     SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
423     Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
424   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
425       parseCustomDirectiveResults(parser, operandType, optOperandType,
426                                   varOperandTypes))
427     return failure();
428   return success();
429 }
430 static ParseResult parseCustomDirectiveRegions(
431     OpAsmParser &parser, Region &region,
432     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
433   if (parser.parseRegion(region))
434     return failure();
435   if (failed(parser.parseOptionalComma()))
436     return success();
437   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
438   if (parser.parseRegion(*varRegion))
439     return failure();
440   varRegions.emplace_back(std::move(varRegion));
441   return success();
442 }
443 static ParseResult
444 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
445                                SmallVectorImpl<Block *> &varSuccessors) {
446   if (parser.parseSuccessor(successor))
447     return failure();
448   if (failed(parser.parseOptionalComma()))
449     return success();
450   Block *varSuccessor;
451   if (parser.parseSuccessor(varSuccessor))
452     return failure();
453   varSuccessors.append(2, varSuccessor);
454   return success();
455 }
456 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
457                                                   IntegerAttr &attr,
458                                                   IntegerAttr &optAttr) {
459   if (parser.parseAttribute(attr))
460     return failure();
461   if (succeeded(parser.parseOptionalComma())) {
462     if (parser.parseAttribute(optAttr))
463       return failure();
464   }
465   return success();
466 }
467 
468 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
469                                                 NamedAttrList &attrs) {
470   return parser.parseOptionalAttrDict(attrs);
471 }
472 static ParseResult parseCustomDirectiveOptionalOperandRef(
473     OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) {
474   int64_t operandCount = 0;
475   if (parser.parseInteger(operandCount))
476     return failure();
477   bool expectedOptionalOperand = operandCount == 0;
478   return success(expectedOptionalOperand != optOperand.hasValue());
479 }
480 
481 //===----------------------------------------------------------------------===//
482 // Printing
483 
484 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
485                                          Value operand, Value optOperand,
486                                          OperandRange varOperands) {
487   printer << operand;
488   if (optOperand)
489     printer << ", " << optOperand;
490   printer << " -> (" << varOperands << ")";
491 }
492 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
493                                         Type operandType, Type optOperandType,
494                                         TypeRange varOperandTypes) {
495   printer << " : " << operandType;
496   if (optOperandType)
497     printer << ", " << optOperandType;
498   printer << " -> (" << varOperandTypes << ")";
499 }
500 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
501                                              Operation *op, Type operandType,
502                                              Type optOperandType,
503                                              TypeRange varOperandTypes) {
504   printer << " type_refs_capture ";
505   printCustomDirectiveResults(printer, op, operandType, optOperandType,
506                               varOperandTypes);
507 }
508 static void printCustomDirectiveOperandsAndTypes(
509     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
510     OperandRange varOperands, Type operandType, Type optOperandType,
511     TypeRange varOperandTypes) {
512   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
513   printCustomDirectiveResults(printer, op, operandType, optOperandType,
514                               varOperandTypes);
515 }
516 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
517                                         Region &region,
518                                         MutableArrayRef<Region> varRegions) {
519   printer.printRegion(region);
520   if (!varRegions.empty()) {
521     printer << ", ";
522     for (Region &region : varRegions)
523       printer.printRegion(region);
524   }
525 }
526 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
527                                            Block *successor,
528                                            SuccessorRange varSuccessors) {
529   printer << successor;
530   if (!varSuccessors.empty())
531     printer << ", " << varSuccessors.front();
532 }
533 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
534                                            Attribute attribute,
535                                            Attribute optAttribute) {
536   printer << attribute;
537   if (optAttribute)
538     printer << ", " << optAttribute;
539 }
540 
541 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
542                                          DictionaryAttr attrs) {
543   printer.printOptionalAttrDict(attrs.getValue());
544 }
545 
546 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
547                                                    Operation *op,
548                                                    Value optOperand) {
549   printer << (optOperand ? "1" : "0");
550 }
551 
552 //===----------------------------------------------------------------------===//
553 // Test IsolatedRegionOp - parse passthrough region arguments.
554 //===----------------------------------------------------------------------===//
555 
556 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
557                                          OperationState &result) {
558   OpAsmParser::OperandType argInfo;
559   Type argType = parser.getBuilder().getIndexType();
560 
561   // Parse the input operand.
562   if (parser.parseOperand(argInfo) ||
563       parser.resolveOperand(argInfo, argType, result.operands))
564     return failure();
565 
566   // Parse the body region, and reuse the operand info as the argument info.
567   Region *body = result.addRegion();
568   return parser.parseRegion(*body, argInfo, argType,
569                             /*enableNameShadowing=*/true);
570 }
571 
572 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
573   p << "test.isolated_region ";
574   p.printOperand(op.getOperand());
575   p.shadowRegionArgs(op.region(), op.getOperand());
576   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
577 }
578 
579 //===----------------------------------------------------------------------===//
580 // Test SSACFGRegionOp
581 //===----------------------------------------------------------------------===//
582 
583 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
584   return RegionKind::SSACFG;
585 }
586 
587 //===----------------------------------------------------------------------===//
588 // Test GraphRegionOp
589 //===----------------------------------------------------------------------===//
590 
591 static ParseResult parseGraphRegionOp(OpAsmParser &parser,
592                                       OperationState &result) {
593   // Parse the body region, and reuse the operand info as the argument info.
594   Region *body = result.addRegion();
595   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
596 }
597 
598 static void print(OpAsmPrinter &p, GraphRegionOp op) {
599   p << "test.graph_region ";
600   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
601 }
602 
603 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
604   return RegionKind::Graph;
605 }
606 
607 //===----------------------------------------------------------------------===//
608 // Test AffineScopeOp
609 //===----------------------------------------------------------------------===//
610 
611 static ParseResult parseAffineScopeOp(OpAsmParser &parser,
612                                       OperationState &result) {
613   // Parse the body region, and reuse the operand info as the argument info.
614   Region *body = result.addRegion();
615   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
616 }
617 
618 static void print(OpAsmPrinter &p, AffineScopeOp op) {
619   p << "test.affine_scope ";
620   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
621 }
622 
623 //===----------------------------------------------------------------------===//
624 // Test parser.
625 //===----------------------------------------------------------------------===//
626 
627 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser,
628                                               OperationState &result) {
629   if (parser.parseOptionalColon())
630     return success();
631   uint64_t numResults;
632   if (parser.parseInteger(numResults))
633     return failure();
634 
635   IndexType type = parser.getBuilder().getIndexType();
636   for (unsigned i = 0; i < numResults; ++i)
637     result.addTypes(type);
638   return success();
639 }
640 
641 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) {
642   p << ParseIntegerLiteralOp::getOperationName();
643   if (unsigned numResults = op->getNumResults())
644     p << " : " << numResults;
645 }
646 
647 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser,
648                                               OperationState &result) {
649   StringRef keyword;
650   if (parser.parseKeyword(&keyword))
651     return failure();
652   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
653   return success();
654 }
655 
656 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) {
657   p << ParseWrappedKeywordOp::getOperationName() << " " << op.keyword();
658 }
659 
660 //===----------------------------------------------------------------------===//
661 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
662 
663 static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
664                                          OperationState &result) {
665   if (parser.parseKeyword("wraps"))
666     return failure();
667 
668   // Parse the wrapped op in a region
669   Region &body = *result.addRegion();
670   body.push_back(new Block);
671   Block &block = body.back();
672   Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
673   if (!wrapped_op)
674     return failure();
675 
676   // Create a return terminator in the inner region, pass as operand to the
677   // terminator the returned values from the wrapped operation.
678   SmallVector<Value, 8> return_operands(wrapped_op->getResults());
679   OpBuilder builder(parser.getBuilder().getContext());
680   builder.setInsertionPointToEnd(&block);
681   builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
682 
683   // Get the results type for the wrapping op from the terminator operands.
684   Operation &return_op = body.back().back();
685   result.types.append(return_op.operand_type_begin(),
686                       return_op.operand_type_end());
687 
688   // Use the location of the wrapped op for the "test.wrapping_region" op.
689   result.location = wrapped_op->getLoc();
690 
691   return success();
692 }
693 
694 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
695   p << op.getOperationName() << " wraps ";
696   p.printGenericOp(&op.region().front().front());
697 }
698 
699 //===----------------------------------------------------------------------===//
700 // Test PolyForOp - parse list of region arguments.
701 //===----------------------------------------------------------------------===//
702 
703 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
704   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
705   // Parse list of region arguments without a delimiter.
706   if (parser.parseRegionArgumentList(ivsInfo))
707     return failure();
708 
709   // Parse the body region.
710   Region *body = result.addRegion();
711   auto &builder = parser.getBuilder();
712   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
713   return parser.parseRegion(*body, ivsInfo, argTypes);
714 }
715 
716 //===----------------------------------------------------------------------===//
717 // Test removing op with inner ops.
718 //===----------------------------------------------------------------------===//
719 
720 namespace {
721 struct TestRemoveOpWithInnerOps
722     : public OpRewritePattern<TestOpWithRegionPattern> {
723   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
724 
725   void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
726 
727   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
728                                 PatternRewriter &rewriter) const override {
729     rewriter.eraseOp(op);
730     return success();
731   }
732 };
733 } // end anonymous namespace
734 
735 void TestOpWithRegionPattern::getCanonicalizationPatterns(
736     RewritePatternSet &results, MLIRContext *context) {
737   results.add<TestRemoveOpWithInnerOps>(context);
738 }
739 
740 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
741   return operand();
742 }
743 
744 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
745   return getValue();
746 }
747 
748 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
749     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
750   for (Value input : this->operands()) {
751     results.push_back(input);
752   }
753   return success();
754 }
755 
756 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
757   assert(operands.size() == 1);
758   if (operands.front()) {
759     (*this)->setAttr("attr", operands.front());
760     return getResult();
761   }
762   return {};
763 }
764 
765 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
766   return getOperand();
767 }
768 
769 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
770     MLIRContext *, Optional<Location> location, ValueRange operands,
771     DictionaryAttr attributes, RegionRange regions,
772     SmallVectorImpl<Type> &inferredReturnTypes) {
773   if (operands[0].getType() != operands[1].getType()) {
774     return emitOptionalError(location, "operand type mismatch ",
775                              operands[0].getType(), " vs ",
776                              operands[1].getType());
777   }
778   inferredReturnTypes.assign({operands[0].getType()});
779   return success();
780 }
781 
782 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
783     MLIRContext *context, Optional<Location> location, ValueRange operands,
784     DictionaryAttr attributes, RegionRange regions,
785     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
786   // Create return type consisting of the last element of the first operand.
787   auto operandType = *operands.getTypes().begin();
788   auto sval = operandType.dyn_cast<ShapedType>();
789   if (!sval) {
790     return emitOptionalError(location, "only shaped type operands allowed");
791   }
792   int64_t dim =
793       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
794   auto type = IntegerType::get(context, 17);
795   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
796   return success();
797 }
798 
799 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
800     OpBuilder &builder, ValueRange operands,
801     llvm::SmallVectorImpl<Value> &shapes) {
802   shapes = SmallVector<Value, 1>{
803       builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
804   return success();
805 }
806 
807 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
808     OpBuilder &builder, ValueRange operands,
809     llvm::SmallVectorImpl<Value> &shapes) {
810   Location loc = getLoc();
811   shapes.reserve(operands.size());
812   for (Value operand : llvm::reverse(operands)) {
813     auto currShape = llvm::to_vector<4>(llvm::map_range(
814         llvm::seq<int64_t>(
815             0, operand.getType().cast<RankedTensorType>().getRank()),
816         [&](int64_t dim) -> Value {
817           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
818         }));
819     shapes.push_back(builder.create<tensor::FromElementsOp>(
820         getLoc(), builder.getIndexType(), currShape));
821   }
822   return success();
823 }
824 
825 LogicalResult
826 OpWithResultShapePerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim(
827     OpBuilder &builder,
828     llvm::SmallVectorImpl<llvm::SmallVector<Value>> &shapes) {
829   Location loc = getLoc();
830   shapes.reserve(getNumOperands());
831   for (Value operand : llvm::reverse(getOperands())) {
832     auto currShape = llvm::to_vector<4>(llvm::map_range(
833         llvm::seq<int64_t>(
834             0, operand.getType().cast<RankedTensorType>().getRank()),
835         [&](int64_t dim) -> Value {
836           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
837         }));
838     shapes.emplace_back(std::move(currShape));
839   }
840   return success();
841 }
842 
843 LogicalResult OpWithResultShapeAndPerDimInterfaceOp::reifyReturnTypeShapes(
844     OpBuilder &builder, ValueRange operands,
845     llvm::SmallVectorImpl<Value> &shapes) {
846   Location loc = getLoc();
847   shapes.reserve(operands.size());
848   for (Value operand : llvm::reverse(operands)) {
849     auto currShape = llvm::to_vector<4>(llvm::map_range(
850         llvm::seq<int64_t>(
851             0, operand.getType().cast<RankedTensorType>().getRank()),
852         [&](int64_t dim) -> Value {
853           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
854         }));
855     shapes.push_back(builder.create<tensor::FromElementsOp>(
856         getLoc(), builder.getIndexType(), currShape));
857   }
858   return success();
859 }
860 
861 LogicalResult
862 OpWithResultShapeAndPerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim(
863     OpBuilder &builder,
864     llvm::SmallVectorImpl<llvm::SmallVector<Value>> &shapes) {
865   Location loc = getLoc();
866   shapes.reserve(getNumOperands());
867   for (Value operand : llvm::reverse(getOperands())) {
868     auto currShape = llvm::to_vector<4>(llvm::map_range(
869         llvm::seq<int64_t>(
870             0, operand.getType().cast<RankedTensorType>().getRank()),
871         [&](int64_t dim) -> Value {
872           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
873         }));
874     shapes.emplace_back(std::move(currShape));
875   }
876   return success();
877 }
878 
879 //===----------------------------------------------------------------------===//
880 // Test SideEffect interfaces
881 //===----------------------------------------------------------------------===//
882 
883 namespace {
884 /// A test resource for side effects.
885 struct TestResource : public SideEffects::Resource::Base<TestResource> {
886   StringRef getName() final { return "<Test>"; }
887 };
888 } // end anonymous namespace
889 
890 static void testSideEffectOpGetEffect(
891     Operation *op,
892     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
893         &effects) {
894   auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
895   if (!effectsAttr)
896     return;
897 
898   effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
899 }
900 
901 void SideEffectOp::getEffects(
902     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
903   // Check for an effects attribute on the op instance.
904   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
905   if (!effectsAttr)
906     return;
907 
908   // If there is one, it is an array of dictionary attributes that hold
909   // information on the effects of this operation.
910   for (Attribute element : effectsAttr) {
911     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
912 
913     // Get the specific memory effect.
914     MemoryEffects::Effect *effect =
915         StringSwitch<MemoryEffects::Effect *>(
916             effectElement.get("effect").cast<StringAttr>().getValue())
917             .Case("allocate", MemoryEffects::Allocate::get())
918             .Case("free", MemoryEffects::Free::get())
919             .Case("read", MemoryEffects::Read::get())
920             .Case("write", MemoryEffects::Write::get());
921 
922     // Check for a non-default resource to use.
923     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
924     if (effectElement.get("test_resource"))
925       resource = TestResource::get();
926 
927     // Check for a result to affect.
928     if (effectElement.get("on_result"))
929       effects.emplace_back(effect, getResult(), resource);
930     else if (Attribute ref = effectElement.get("on_reference"))
931       effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
932     else
933       effects.emplace_back(effect, resource);
934   }
935 }
936 
937 void SideEffectOp::getEffects(
938     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
939   testSideEffectOpGetEffect(getOperation(), effects);
940 }
941 
942 //===----------------------------------------------------------------------===//
943 // StringAttrPrettyNameOp
944 //===----------------------------------------------------------------------===//
945 
946 // This op has fancy handling of its SSA result name.
947 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
948                                                OperationState &result) {
949   // Add the result types.
950   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
951     result.addTypes(parser.getBuilder().getIntegerType(32));
952 
953   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
954     return failure();
955 
956   // If the attribute dictionary contains no 'names' attribute, infer it from
957   // the SSA name (if specified).
958   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
959     return attr.first == "names";
960   });
961 
962   // If there was no name specified, check to see if there was a useful name
963   // specified in the asm file.
964   if (hadNames || parser.getNumResults() == 0)
965     return success();
966 
967   SmallVector<StringRef, 4> names;
968   auto *context = result.getContext();
969 
970   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
971     auto resultName = parser.getResultName(i);
972     StringRef nameStr;
973     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
974       nameStr = resultName.first;
975 
976     names.push_back(nameStr);
977   }
978 
979   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
980   result.attributes.push_back({Identifier::get("names", context), namesAttr});
981   return success();
982 }
983 
984 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
985   p << "test.string_attr_pretty_name";
986 
987   // Note that we only need to print the "name" attribute if the asmprinter
988   // result name disagrees with it.  This can happen in strange cases, e.g.
989   // when there are conflicts.
990   bool namesDisagree = op.names().size() != op.getNumResults();
991 
992   SmallString<32> resultNameStr;
993   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
994     resultNameStr.clear();
995     llvm::raw_svector_ostream tmpStream(resultNameStr);
996     p.printOperand(op.getResult(i), tmpStream);
997 
998     auto expectedName = op.names()[i].dyn_cast<StringAttr>();
999     if (!expectedName ||
1000         tmpStream.str().drop_front() != expectedName.getValue()) {
1001       namesDisagree = true;
1002     }
1003   }
1004 
1005   if (namesDisagree)
1006     p.printOptionalAttrDictWithKeyword(op->getAttrs());
1007   else
1008     p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"});
1009 }
1010 
1011 // We set the SSA name in the asm syntax to the contents of the name
1012 // attribute.
1013 void StringAttrPrettyNameOp::getAsmResultNames(
1014     function_ref<void(Value, StringRef)> setNameFn) {
1015 
1016   auto value = names();
1017   for (size_t i = 0, e = value.size(); i != e; ++i)
1018     if (auto str = value[i].dyn_cast<StringAttr>())
1019       if (!str.getValue().empty())
1020         setNameFn(getResult(i), str.getValue());
1021 }
1022 
1023 //===----------------------------------------------------------------------===//
1024 // RegionIfOp
1025 //===----------------------------------------------------------------------===//
1026 
1027 static void print(OpAsmPrinter &p, RegionIfOp op) {
1028   p << RegionIfOp::getOperationName() << " ";
1029   p.printOperands(op.getOperands());
1030   p << ": " << op.getOperandTypes();
1031   p.printArrowTypeList(op.getResultTypes());
1032   p << " then";
1033   p.printRegion(op.thenRegion(),
1034                 /*printEntryBlockArgs=*/true,
1035                 /*printBlockTerminators=*/true);
1036   p << " else";
1037   p.printRegion(op.elseRegion(),
1038                 /*printEntryBlockArgs=*/true,
1039                 /*printBlockTerminators=*/true);
1040   p << " join";
1041   p.printRegion(op.joinRegion(),
1042                 /*printEntryBlockArgs=*/true,
1043                 /*printBlockTerminators=*/true);
1044 }
1045 
1046 static ParseResult parseRegionIfOp(OpAsmParser &parser,
1047                                    OperationState &result) {
1048   SmallVector<OpAsmParser::OperandType, 2> operandInfos;
1049   SmallVector<Type, 2> operandTypes;
1050 
1051   result.regions.reserve(3);
1052   Region *thenRegion = result.addRegion();
1053   Region *elseRegion = result.addRegion();
1054   Region *joinRegion = result.addRegion();
1055 
1056   // Parse operand, type and arrow type lists.
1057   if (parser.parseOperandList(operandInfos) ||
1058       parser.parseColonTypeList(operandTypes) ||
1059       parser.parseArrowTypeList(result.types))
1060     return failure();
1061 
1062   // Parse all attached regions.
1063   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
1064       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
1065       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
1066     return failure();
1067 
1068   return parser.resolveOperands(operandInfos, operandTypes,
1069                                 parser.getCurrentLocation(), result.operands);
1070 }
1071 
1072 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
1073   assert(index < 2 && "invalid region index");
1074   return getOperands();
1075 }
1076 
1077 void RegionIfOp::getSuccessorRegions(
1078     Optional<unsigned> index, ArrayRef<Attribute> operands,
1079     SmallVectorImpl<RegionSuccessor> &regions) {
1080   // We always branch to the join region.
1081   if (index.hasValue()) {
1082     if (index.getValue() < 2)
1083       regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs()));
1084     else
1085       regions.push_back(RegionSuccessor(getResults()));
1086     return;
1087   }
1088 
1089   // The then and else regions are the entry regions of this op.
1090   regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs()));
1091   regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs()));
1092 }
1093 
1094 //===----------------------------------------------------------------------===//
1095 // SingleNoTerminatorCustomAsmOp
1096 //===----------------------------------------------------------------------===//
1097 
1098 static ParseResult parseSingleNoTerminatorCustomAsmOp(OpAsmParser &parser,
1099                                                       OperationState &state) {
1100   Region *body = state.addRegion();
1101   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
1102     return failure();
1103   return success();
1104 }
1105 
1106 static void print(SingleNoTerminatorCustomAsmOp op, OpAsmPrinter &printer) {
1107   printer << op.getOperationName();
1108   printer.printRegion(
1109       op.getRegion(), /*printEntryBlockArgs=*/false,
1110       // This op has a single block without terminators. But explicitly mark
1111       // as not printing block terminators for testing.
1112       /*printBlockTerminators=*/false);
1113 }
1114 
1115 #include "TestOpEnums.cpp.inc"
1116 #include "TestOpInterfaces.cpp.inc"
1117 #include "TestOpStructs.cpp.inc"
1118 #include "TestTypeInterfaces.cpp.inc"
1119 
1120 #define GET_OP_CLASSES
1121 #include "TestOps.cpp.inc"
1122