1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestAttributes.h"
11 #include "TestInterfaces.h"
12 #include "TestTypes.h"
13 #include "mlir/Dialect/DLTI/DLTI.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/DialectImplementation.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Reducer/ReductionPatternInterface.h"
21 #include "mlir/Transforms/FoldUtils.h"
22 #include "mlir/Transforms/InliningUtils.h"
23 #include "llvm/ADT/StringSwitch.h"
24 
25 using namespace mlir;
26 using namespace mlir::test;
27 
28 void mlir::test::registerTestDialect(DialectRegistry &registry) {
29   registry.insert<TestDialect>();
30 }
31 
32 //===----------------------------------------------------------------------===//
33 // TestDialect Interfaces
34 //===----------------------------------------------------------------------===//
35 
36 namespace {
37 
38 /// Testing the correctness of some traits.
39 static_assert(
40     llvm::is_detected<OpTrait::has_implicit_terminator_t,
41                       SingleBlockImplicitTerminatorOp>::value,
42     "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
43 static_assert(OpTrait::hasSingleBlockImplicitTerminator<
44                   SingleBlockImplicitTerminatorOp>::value,
45               "hasSingleBlockImplicitTerminator does not match "
46               "SingleBlockImplicitTerminatorOp");
47 
48 // Test support for interacting with the AsmPrinter.
49 struct TestOpAsmInterface : public OpAsmDialectInterface {
50   using OpAsmDialectInterface::OpAsmDialectInterface;
51 
52   LogicalResult getAlias(Attribute attr, raw_ostream &os) const final {
53     StringAttr strAttr = attr.dyn_cast<StringAttr>();
54     if (!strAttr)
55       return failure();
56 
57     // Check the contents of the string attribute to see what the test alias
58     // should be named.
59     Optional<StringRef> aliasName =
60         StringSwitch<Optional<StringRef>>(strAttr.getValue())
61             .Case("alias_test:dot_in_name", StringRef("test.alias"))
62             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
63             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
64             .Case("alias_test:sanitize_conflict_a",
65                   StringRef("test_alias_conflict0"))
66             .Case("alias_test:sanitize_conflict_b",
67                   StringRef("test_alias_conflict0_"))
68             .Default(llvm::None);
69     if (!aliasName)
70       return failure();
71 
72     os << *aliasName;
73     return success();
74   }
75 
76   void getAsmResultNames(Operation *op,
77                          OpAsmSetValueNameFn setNameFn) const final {
78     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
79       setNameFn(asmOp, "result");
80   }
81 
82   void getAsmBlockArgumentNames(Block *block,
83                                 OpAsmSetValueNameFn setNameFn) const final {
84     auto op = block->getParentOp();
85     auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
86     if (!arrayAttr)
87       return;
88     auto args = block->getArguments();
89     auto e = std::min(arrayAttr.size(), args.size());
90     for (unsigned i = 0; i < e; ++i) {
91       if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
92         setNameFn(args[i], strAttr.getValue());
93     }
94   }
95 };
96 
97 struct TestDialectFoldInterface : public DialectFoldInterface {
98   using DialectFoldInterface::DialectFoldInterface;
99 
100   /// Registered hook to check if the given region, which is attached to an
101   /// operation that is *not* isolated from above, should be used when
102   /// materializing constants.
103   bool shouldMaterializeInto(Region *region) const final {
104     // If this is a one region operation, then insert into it.
105     return isa<OneRegionOp>(region->getParentOp());
106   }
107 };
108 
109 /// This class defines the interface for handling inlining with standard
110 /// operations.
111 struct TestInlinerInterface : public DialectInlinerInterface {
112   using DialectInlinerInterface::DialectInlinerInterface;
113 
114   //===--------------------------------------------------------------------===//
115   // Analysis Hooks
116   //===--------------------------------------------------------------------===//
117 
118   bool isLegalToInline(Operation *call, Operation *callable,
119                        bool wouldBeCloned) const final {
120     // Don't allow inlining calls that are marked `noinline`.
121     return !call->hasAttr("noinline");
122   }
123   bool isLegalToInline(Region *, Region *, bool,
124                        BlockAndValueMapping &) const final {
125     // Inlining into test dialect regions is legal.
126     return true;
127   }
128   bool isLegalToInline(Operation *, Region *, bool,
129                        BlockAndValueMapping &) const final {
130     return true;
131   }
132 
133   bool shouldAnalyzeRecursively(Operation *op) const final {
134     // Analyze recursively if this is not a functional region operation, it
135     // froms a separate functional scope.
136     return !isa<FunctionalRegionOp>(op);
137   }
138 
139   //===--------------------------------------------------------------------===//
140   // Transformation Hooks
141   //===--------------------------------------------------------------------===//
142 
143   /// Handle the given inlined terminator by replacing it with a new operation
144   /// as necessary.
145   void handleTerminator(Operation *op,
146                         ArrayRef<Value> valuesToRepl) const final {
147     // Only handle "test.return" here.
148     auto returnOp = dyn_cast<TestReturnOp>(op);
149     if (!returnOp)
150       return;
151 
152     // Replace the values directly with the return operands.
153     assert(returnOp.getNumOperands() == valuesToRepl.size());
154     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
155       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
156   }
157 
158   /// Attempt to materialize a conversion for a type mismatch between a call
159   /// from this dialect, and a callable region. This method should generate an
160   /// operation that takes 'input' as the only operand, and produces a single
161   /// result of 'resultType'. If a conversion can not be generated, nullptr
162   /// should be returned.
163   Operation *materializeCallConversion(OpBuilder &builder, Value input,
164                                        Type resultType,
165                                        Location conversionLoc) const final {
166     // Only allow conversion for i16/i32 types.
167     if (!(resultType.isSignlessInteger(16) ||
168           resultType.isSignlessInteger(32)) ||
169         !(input.getType().isSignlessInteger(16) ||
170           input.getType().isSignlessInteger(32)))
171       return nullptr;
172     return builder.create<TestCastOp>(conversionLoc, resultType, input);
173   }
174 
175   void processInlinedCallBlocks(
176       Operation *call,
177       iterator_range<Region::iterator> inlinedBlocks) const final {
178     if (!isa<ConversionCallOp>(call))
179       return;
180 
181     // Set attributed on all ops in the inlined blocks.
182     for (Block &block : inlinedBlocks) {
183       block.walk([&](Operation *op) {
184         op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
185       });
186     }
187   }
188 };
189 
190 struct TestReductionPatternInterface : public DialectReductionPatternInterface {
191 public:
192   TestReductionPatternInterface(Dialect *dialect)
193       : DialectReductionPatternInterface(dialect) {}
194 
195   virtual void
196   populateReductionPatterns(RewritePatternSet &patterns) const final {
197     populateTestReductionPatterns(patterns);
198   }
199 };
200 
201 } // end anonymous namespace
202 
203 //===----------------------------------------------------------------------===//
204 // TestDialect
205 //===----------------------------------------------------------------------===//
206 
207 static void testSideEffectOpGetEffect(
208     Operation *op,
209     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects);
210 
211 // This is the implementation of a dialect fallback for `TestEffectOpInterface`.
212 struct TestOpEffectInterfaceFallback
213     : public TestEffectOpInterface::FallbackModel<
214           TestOpEffectInterfaceFallback> {
215   static bool classof(Operation *op) {
216     bool isSupportedOp =
217         op->getName().getStringRef() == "test.unregistered_side_effect_op";
218     assert(isSupportedOp && "Unexpected dispatch");
219     return isSupportedOp;
220   }
221 
222   void
223   getEffects(Operation *op,
224              SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
225                  &effects) const {
226     testSideEffectOpGetEffect(op, effects);
227   }
228 };
229 
230 void TestDialect::initialize() {
231   registerAttributes();
232   registerTypes();
233   addOperations<
234 #define GET_OP_LIST
235 #include "TestOps.cpp.inc"
236       >();
237   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
238                 TestInlinerInterface, TestReductionPatternInterface>();
239   allowUnknownOperations();
240 
241   // Instantiate our fallback op interface that we'll use on specific
242   // unregistered op.
243   fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
244 }
245 TestDialect::~TestDialect() {
246   delete static_cast<TestOpEffectInterfaceFallback *>(
247       fallbackEffectOpInterfaces);
248 }
249 
250 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
251                                             Type type, Location loc) {
252   return builder.create<TestOpConstant>(loc, type, value);
253 }
254 
255 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
256                                                OperationName opName) {
257   if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
258       typeID == TypeID::get<TestEffectOpInterface>())
259     return fallbackEffectOpInterfaces;
260   return nullptr;
261 }
262 
263 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
264                                                     NamedAttribute namedAttr) {
265   if (namedAttr.first == "test.invalid_attr")
266     return op->emitError() << "invalid to use 'test.invalid_attr'";
267   return success();
268 }
269 
270 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
271                                                     unsigned regionIndex,
272                                                     unsigned argIndex,
273                                                     NamedAttribute namedAttr) {
274   if (namedAttr.first == "test.invalid_attr")
275     return op->emitError() << "invalid to use 'test.invalid_attr'";
276   return success();
277 }
278 
279 LogicalResult
280 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
281                                          unsigned resultIndex,
282                                          NamedAttribute namedAttr) {
283   if (namedAttr.first == "test.invalid_attr")
284     return op->emitError() << "invalid to use 'test.invalid_attr'";
285   return success();
286 }
287 
288 Optional<Dialect::ParseOpHook>
289 TestDialect::getParseOperationHook(StringRef opName) const {
290   if (opName == "test.dialect_custom_printer") {
291     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
292       return parser.parseKeyword("custom_format");
293     }};
294   }
295   return None;
296 }
297 
298 LogicalResult TestDialect::printOperation(Operation *op,
299                                           OpAsmPrinter &printer) const {
300   StringRef opName = op->getName().getStringRef();
301   if (opName == "test.dialect_custom_printer") {
302     printer.getStream() << opName << " custom_format";
303     return success();
304   }
305   return failure();
306 }
307 
308 //===----------------------------------------------------------------------===//
309 // TestBranchOp
310 //===----------------------------------------------------------------------===//
311 
312 Optional<MutableOperandRange>
313 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
314   assert(index == 0 && "invalid successor index");
315   return targetOperandsMutable();
316 }
317 
318 //===----------------------------------------------------------------------===//
319 // TestDialectCanonicalizerOp
320 //===----------------------------------------------------------------------===//
321 
322 static LogicalResult
323 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
324                                PatternRewriter &rewriter) {
325   rewriter.replaceOpWithNewOp<ConstantOp>(op, rewriter.getI32Type(),
326                                           rewriter.getI32IntegerAttr(42));
327   return success();
328 }
329 
330 void TestDialect::getCanonicalizationPatterns(
331     RewritePatternSet &results) const {
332   results.add(&dialectCanonicalizationPattern);
333 }
334 
335 //===----------------------------------------------------------------------===//
336 // TestFoldToCallOp
337 //===----------------------------------------------------------------------===//
338 
339 namespace {
340 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
341   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
342 
343   LogicalResult matchAndRewrite(FoldToCallOp op,
344                                 PatternRewriter &rewriter) const override {
345     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(),
346                                         ValueRange());
347     return success();
348   }
349 };
350 } // end anonymous namespace
351 
352 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
353                                                MLIRContext *context) {
354   results.add<FoldToCallOpPattern>(context);
355 }
356 
357 //===----------------------------------------------------------------------===//
358 // Test Format* operations
359 //===----------------------------------------------------------------------===//
360 
361 //===----------------------------------------------------------------------===//
362 // Parsing
363 
364 static ParseResult parseCustomDirectiveOperands(
365     OpAsmParser &parser, OpAsmParser::OperandType &operand,
366     Optional<OpAsmParser::OperandType> &optOperand,
367     SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
368   if (parser.parseOperand(operand))
369     return failure();
370   if (succeeded(parser.parseOptionalComma())) {
371     optOperand.emplace();
372     if (parser.parseOperand(*optOperand))
373       return failure();
374   }
375   if (parser.parseArrow() || parser.parseLParen() ||
376       parser.parseOperandList(varOperands) || parser.parseRParen())
377     return failure();
378   return success();
379 }
380 static ParseResult
381 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
382                             Type &optOperandType,
383                             SmallVectorImpl<Type> &varOperandTypes) {
384   if (parser.parseColon())
385     return failure();
386 
387   if (parser.parseType(operandType))
388     return failure();
389   if (succeeded(parser.parseOptionalComma())) {
390     if (parser.parseType(optOperandType))
391       return failure();
392   }
393   if (parser.parseArrow() || parser.parseLParen() ||
394       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
395     return failure();
396   return success();
397 }
398 static ParseResult
399 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
400                                  Type optOperandType,
401                                  const SmallVectorImpl<Type> &varOperandTypes) {
402   if (parser.parseKeyword("type_refs_capture"))
403     return failure();
404 
405   Type operandType2, optOperandType2;
406   SmallVector<Type, 1> varOperandTypes2;
407   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
408                                   varOperandTypes2))
409     return failure();
410 
411   if (operandType != operandType2 || optOperandType != optOperandType2 ||
412       varOperandTypes != varOperandTypes2)
413     return failure();
414 
415   return success();
416 }
417 static ParseResult parseCustomDirectiveOperandsAndTypes(
418     OpAsmParser &parser, OpAsmParser::OperandType &operand,
419     Optional<OpAsmParser::OperandType> &optOperand,
420     SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
421     Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
422   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
423       parseCustomDirectiveResults(parser, operandType, optOperandType,
424                                   varOperandTypes))
425     return failure();
426   return success();
427 }
428 static ParseResult parseCustomDirectiveRegions(
429     OpAsmParser &parser, Region &region,
430     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
431   if (parser.parseRegion(region))
432     return failure();
433   if (failed(parser.parseOptionalComma()))
434     return success();
435   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
436   if (parser.parseRegion(*varRegion))
437     return failure();
438   varRegions.emplace_back(std::move(varRegion));
439   return success();
440 }
441 static ParseResult
442 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
443                                SmallVectorImpl<Block *> &varSuccessors) {
444   if (parser.parseSuccessor(successor))
445     return failure();
446   if (failed(parser.parseOptionalComma()))
447     return success();
448   Block *varSuccessor;
449   if (parser.parseSuccessor(varSuccessor))
450     return failure();
451   varSuccessors.append(2, varSuccessor);
452   return success();
453 }
454 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
455                                                   IntegerAttr &attr,
456                                                   IntegerAttr &optAttr) {
457   if (parser.parseAttribute(attr))
458     return failure();
459   if (succeeded(parser.parseOptionalComma())) {
460     if (parser.parseAttribute(optAttr))
461       return failure();
462   }
463   return success();
464 }
465 
466 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
467                                                 NamedAttrList &attrs) {
468   return parser.parseOptionalAttrDict(attrs);
469 }
470 static ParseResult parseCustomDirectiveOptionalOperandRef(
471     OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) {
472   int64_t operandCount = 0;
473   if (parser.parseInteger(operandCount))
474     return failure();
475   bool expectedOptionalOperand = operandCount == 0;
476   return success(expectedOptionalOperand != optOperand.hasValue());
477 }
478 
479 //===----------------------------------------------------------------------===//
480 // Printing
481 
482 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
483                                          Value operand, Value optOperand,
484                                          OperandRange varOperands) {
485   printer << operand;
486   if (optOperand)
487     printer << ", " << optOperand;
488   printer << " -> (" << varOperands << ")";
489 }
490 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
491                                         Type operandType, Type optOperandType,
492                                         TypeRange varOperandTypes) {
493   printer << " : " << operandType;
494   if (optOperandType)
495     printer << ", " << optOperandType;
496   printer << " -> (" << varOperandTypes << ")";
497 }
498 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
499                                              Operation *op, Type operandType,
500                                              Type optOperandType,
501                                              TypeRange varOperandTypes) {
502   printer << " type_refs_capture ";
503   printCustomDirectiveResults(printer, op, operandType, optOperandType,
504                               varOperandTypes);
505 }
506 static void printCustomDirectiveOperandsAndTypes(
507     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
508     OperandRange varOperands, Type operandType, Type optOperandType,
509     TypeRange varOperandTypes) {
510   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
511   printCustomDirectiveResults(printer, op, operandType, optOperandType,
512                               varOperandTypes);
513 }
514 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
515                                         Region &region,
516                                         MutableArrayRef<Region> varRegions) {
517   printer.printRegion(region);
518   if (!varRegions.empty()) {
519     printer << ", ";
520     for (Region &region : varRegions)
521       printer.printRegion(region);
522   }
523 }
524 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
525                                            Block *successor,
526                                            SuccessorRange varSuccessors) {
527   printer << successor;
528   if (!varSuccessors.empty())
529     printer << ", " << varSuccessors.front();
530 }
531 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
532                                            Attribute attribute,
533                                            Attribute optAttribute) {
534   printer << attribute;
535   if (optAttribute)
536     printer << ", " << optAttribute;
537 }
538 
539 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
540                                          DictionaryAttr attrs) {
541   printer.printOptionalAttrDict(attrs.getValue());
542 }
543 
544 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
545                                                    Operation *op,
546                                                    Value optOperand) {
547   printer << (optOperand ? "1" : "0");
548 }
549 
550 //===----------------------------------------------------------------------===//
551 // Test IsolatedRegionOp - parse passthrough region arguments.
552 //===----------------------------------------------------------------------===//
553 
554 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
555                                          OperationState &result) {
556   OpAsmParser::OperandType argInfo;
557   Type argType = parser.getBuilder().getIndexType();
558 
559   // Parse the input operand.
560   if (parser.parseOperand(argInfo) ||
561       parser.resolveOperand(argInfo, argType, result.operands))
562     return failure();
563 
564   // Parse the body region, and reuse the operand info as the argument info.
565   Region *body = result.addRegion();
566   return parser.parseRegion(*body, argInfo, argType,
567                             /*enableNameShadowing=*/true);
568 }
569 
570 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
571   p << "test.isolated_region ";
572   p.printOperand(op.getOperand());
573   p.shadowRegionArgs(op.region(), op.getOperand());
574   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
575 }
576 
577 //===----------------------------------------------------------------------===//
578 // Test SSACFGRegionOp
579 //===----------------------------------------------------------------------===//
580 
581 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
582   return RegionKind::SSACFG;
583 }
584 
585 //===----------------------------------------------------------------------===//
586 // Test GraphRegionOp
587 //===----------------------------------------------------------------------===//
588 
589 static ParseResult parseGraphRegionOp(OpAsmParser &parser,
590                                       OperationState &result) {
591   // Parse the body region, and reuse the operand info as the argument info.
592   Region *body = result.addRegion();
593   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
594 }
595 
596 static void print(OpAsmPrinter &p, GraphRegionOp op) {
597   p << "test.graph_region ";
598   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
599 }
600 
601 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
602   return RegionKind::Graph;
603 }
604 
605 //===----------------------------------------------------------------------===//
606 // Test AffineScopeOp
607 //===----------------------------------------------------------------------===//
608 
609 static ParseResult parseAffineScopeOp(OpAsmParser &parser,
610                                       OperationState &result) {
611   // Parse the body region, and reuse the operand info as the argument info.
612   Region *body = result.addRegion();
613   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
614 }
615 
616 static void print(OpAsmPrinter &p, AffineScopeOp op) {
617   p << "test.affine_scope ";
618   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
619 }
620 
621 //===----------------------------------------------------------------------===//
622 // Test parser.
623 //===----------------------------------------------------------------------===//
624 
625 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser,
626                                               OperationState &result) {
627   if (parser.parseOptionalColon())
628     return success();
629   uint64_t numResults;
630   if (parser.parseInteger(numResults))
631     return failure();
632 
633   IndexType type = parser.getBuilder().getIndexType();
634   for (unsigned i = 0; i < numResults; ++i)
635     result.addTypes(type);
636   return success();
637 }
638 
639 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) {
640   p << ParseIntegerLiteralOp::getOperationName();
641   if (unsigned numResults = op->getNumResults())
642     p << " : " << numResults;
643 }
644 
645 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser,
646                                               OperationState &result) {
647   StringRef keyword;
648   if (parser.parseKeyword(&keyword))
649     return failure();
650   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
651   return success();
652 }
653 
654 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) {
655   p << ParseWrappedKeywordOp::getOperationName() << " " << op.keyword();
656 }
657 
658 //===----------------------------------------------------------------------===//
659 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
660 
661 static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
662                                          OperationState &result) {
663   if (parser.parseKeyword("wraps"))
664     return failure();
665 
666   // Parse the wrapped op in a region
667   Region &body = *result.addRegion();
668   body.push_back(new Block);
669   Block &block = body.back();
670   Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
671   if (!wrapped_op)
672     return failure();
673 
674   // Create a return terminator in the inner region, pass as operand to the
675   // terminator the returned values from the wrapped operation.
676   SmallVector<Value, 8> return_operands(wrapped_op->getResults());
677   OpBuilder builder(parser.getBuilder().getContext());
678   builder.setInsertionPointToEnd(&block);
679   builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
680 
681   // Get the results type for the wrapping op from the terminator operands.
682   Operation &return_op = body.back().back();
683   result.types.append(return_op.operand_type_begin(),
684                       return_op.operand_type_end());
685 
686   // Use the location of the wrapped op for the "test.wrapping_region" op.
687   result.location = wrapped_op->getLoc();
688 
689   return success();
690 }
691 
692 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
693   p << op.getOperationName() << " wraps ";
694   p.printGenericOp(&op.region().front().front());
695 }
696 
697 //===----------------------------------------------------------------------===//
698 // Test PolyForOp - parse list of region arguments.
699 //===----------------------------------------------------------------------===//
700 
701 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
702   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
703   // Parse list of region arguments without a delimiter.
704   if (parser.parseRegionArgumentList(ivsInfo))
705     return failure();
706 
707   // Parse the body region.
708   Region *body = result.addRegion();
709   auto &builder = parser.getBuilder();
710   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
711   return parser.parseRegion(*body, ivsInfo, argTypes);
712 }
713 
714 //===----------------------------------------------------------------------===//
715 // Test removing op with inner ops.
716 //===----------------------------------------------------------------------===//
717 
718 namespace {
719 struct TestRemoveOpWithInnerOps
720     : public OpRewritePattern<TestOpWithRegionPattern> {
721   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
722 
723   void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
724 
725   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
726                                 PatternRewriter &rewriter) const override {
727     rewriter.eraseOp(op);
728     return success();
729   }
730 };
731 } // end anonymous namespace
732 
733 void TestOpWithRegionPattern::getCanonicalizationPatterns(
734     RewritePatternSet &results, MLIRContext *context) {
735   results.add<TestRemoveOpWithInnerOps>(context);
736 }
737 
738 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
739   return operand();
740 }
741 
742 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
743   return getValue();
744 }
745 
746 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
747     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
748   for (Value input : this->operands()) {
749     results.push_back(input);
750   }
751   return success();
752 }
753 
754 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
755   assert(operands.size() == 1);
756   if (operands.front()) {
757     (*this)->setAttr("attr", operands.front());
758     return getResult();
759   }
760   return {};
761 }
762 
763 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
764   return getOperand();
765 }
766 
767 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
768     MLIRContext *, Optional<Location> location, ValueRange operands,
769     DictionaryAttr attributes, RegionRange regions,
770     SmallVectorImpl<Type> &inferredReturnTypes) {
771   if (operands[0].getType() != operands[1].getType()) {
772     return emitOptionalError(location, "operand type mismatch ",
773                              operands[0].getType(), " vs ",
774                              operands[1].getType());
775   }
776   inferredReturnTypes.assign({operands[0].getType()});
777   return success();
778 }
779 
780 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
781     MLIRContext *context, Optional<Location> location, ValueRange operands,
782     DictionaryAttr attributes, RegionRange regions,
783     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
784   // Create return type consisting of the last element of the first operand.
785   auto operandType = *operands.getTypes().begin();
786   auto sval = operandType.dyn_cast<ShapedType>();
787   if (!sval) {
788     return emitOptionalError(location, "only shaped type operands allowed");
789   }
790   int64_t dim =
791       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
792   auto type = IntegerType::get(context, 17);
793   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
794   return success();
795 }
796 
797 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
798     OpBuilder &builder, ValueRange operands,
799     llvm::SmallVectorImpl<Value> &shapes) {
800   shapes = SmallVector<Value, 1>{
801       builder.createOrFold<memref::DimOp>(getLoc(), operands.front(), 0)};
802   return success();
803 }
804 
805 LogicalResult
806 OpWithResultShapePerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim(
807     OpBuilder &builder,
808     llvm::SmallVectorImpl<llvm::SmallVector<Value>> &shapes) {
809   SmallVector<Value> operand1Shape, operand2Shape;
810   Location loc = getLoc();
811   for (auto i :
812        llvm::seq<int>(0, operand1().getType().cast<ShapedType>().getRank())) {
813     operand1Shape.push_back(builder.create<memref::DimOp>(loc, operand1(), i));
814   }
815   for (auto i :
816        llvm::seq<int>(0, operand2().getType().cast<ShapedType>().getRank())) {
817     operand2Shape.push_back(builder.create<memref::DimOp>(loc, operand2(), i));
818   }
819   shapes.emplace_back(std::move(operand2Shape));
820   shapes.emplace_back(std::move(operand1Shape));
821   return success();
822 }
823 
824 //===----------------------------------------------------------------------===//
825 // Test SideEffect interfaces
826 //===----------------------------------------------------------------------===//
827 
828 namespace {
829 /// A test resource for side effects.
830 struct TestResource : public SideEffects::Resource::Base<TestResource> {
831   StringRef getName() final { return "<Test>"; }
832 };
833 } // end anonymous namespace
834 
835 static void testSideEffectOpGetEffect(
836     Operation *op,
837     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
838         &effects) {
839   auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
840   if (!effectsAttr)
841     return;
842 
843   effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
844 }
845 
846 void SideEffectOp::getEffects(
847     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
848   // Check for an effects attribute on the op instance.
849   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
850   if (!effectsAttr)
851     return;
852 
853   // If there is one, it is an array of dictionary attributes that hold
854   // information on the effects of this operation.
855   for (Attribute element : effectsAttr) {
856     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
857 
858     // Get the specific memory effect.
859     MemoryEffects::Effect *effect =
860         StringSwitch<MemoryEffects::Effect *>(
861             effectElement.get("effect").cast<StringAttr>().getValue())
862             .Case("allocate", MemoryEffects::Allocate::get())
863             .Case("free", MemoryEffects::Free::get())
864             .Case("read", MemoryEffects::Read::get())
865             .Case("write", MemoryEffects::Write::get());
866 
867     // Check for a non-default resource to use.
868     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
869     if (effectElement.get("test_resource"))
870       resource = TestResource::get();
871 
872     // Check for a result to affect.
873     if (effectElement.get("on_result"))
874       effects.emplace_back(effect, getResult(), resource);
875     else if (Attribute ref = effectElement.get("on_reference"))
876       effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
877     else
878       effects.emplace_back(effect, resource);
879   }
880 }
881 
882 void SideEffectOp::getEffects(
883     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
884   testSideEffectOpGetEffect(getOperation(), effects);
885 }
886 
887 //===----------------------------------------------------------------------===//
888 // StringAttrPrettyNameOp
889 //===----------------------------------------------------------------------===//
890 
891 // This op has fancy handling of its SSA result name.
892 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
893                                                OperationState &result) {
894   // Add the result types.
895   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
896     result.addTypes(parser.getBuilder().getIntegerType(32));
897 
898   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
899     return failure();
900 
901   // If the attribute dictionary contains no 'names' attribute, infer it from
902   // the SSA name (if specified).
903   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
904     return attr.first == "names";
905   });
906 
907   // If there was no name specified, check to see if there was a useful name
908   // specified in the asm file.
909   if (hadNames || parser.getNumResults() == 0)
910     return success();
911 
912   SmallVector<StringRef, 4> names;
913   auto *context = result.getContext();
914 
915   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
916     auto resultName = parser.getResultName(i);
917     StringRef nameStr;
918     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
919       nameStr = resultName.first;
920 
921     names.push_back(nameStr);
922   }
923 
924   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
925   result.attributes.push_back({Identifier::get("names", context), namesAttr});
926   return success();
927 }
928 
929 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
930   p << "test.string_attr_pretty_name";
931 
932   // Note that we only need to print the "name" attribute if the asmprinter
933   // result name disagrees with it.  This can happen in strange cases, e.g.
934   // when there are conflicts.
935   bool namesDisagree = op.names().size() != op.getNumResults();
936 
937   SmallString<32> resultNameStr;
938   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
939     resultNameStr.clear();
940     llvm::raw_svector_ostream tmpStream(resultNameStr);
941     p.printOperand(op.getResult(i), tmpStream);
942 
943     auto expectedName = op.names()[i].dyn_cast<StringAttr>();
944     if (!expectedName ||
945         tmpStream.str().drop_front() != expectedName.getValue()) {
946       namesDisagree = true;
947     }
948   }
949 
950   if (namesDisagree)
951     p.printOptionalAttrDictWithKeyword(op->getAttrs());
952   else
953     p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"});
954 }
955 
956 // We set the SSA name in the asm syntax to the contents of the name
957 // attribute.
958 void StringAttrPrettyNameOp::getAsmResultNames(
959     function_ref<void(Value, StringRef)> setNameFn) {
960 
961   auto value = names();
962   for (size_t i = 0, e = value.size(); i != e; ++i)
963     if (auto str = value[i].dyn_cast<StringAttr>())
964       if (!str.getValue().empty())
965         setNameFn(getResult(i), str.getValue());
966 }
967 
968 //===----------------------------------------------------------------------===//
969 // RegionIfOp
970 //===----------------------------------------------------------------------===//
971 
972 static void print(OpAsmPrinter &p, RegionIfOp op) {
973   p << RegionIfOp::getOperationName() << " ";
974   p.printOperands(op.getOperands());
975   p << ": " << op.getOperandTypes();
976   p.printArrowTypeList(op.getResultTypes());
977   p << " then";
978   p.printRegion(op.thenRegion(),
979                 /*printEntryBlockArgs=*/true,
980                 /*printBlockTerminators=*/true);
981   p << " else";
982   p.printRegion(op.elseRegion(),
983                 /*printEntryBlockArgs=*/true,
984                 /*printBlockTerminators=*/true);
985   p << " join";
986   p.printRegion(op.joinRegion(),
987                 /*printEntryBlockArgs=*/true,
988                 /*printBlockTerminators=*/true);
989 }
990 
991 static ParseResult parseRegionIfOp(OpAsmParser &parser,
992                                    OperationState &result) {
993   SmallVector<OpAsmParser::OperandType, 2> operandInfos;
994   SmallVector<Type, 2> operandTypes;
995 
996   result.regions.reserve(3);
997   Region *thenRegion = result.addRegion();
998   Region *elseRegion = result.addRegion();
999   Region *joinRegion = result.addRegion();
1000 
1001   // Parse operand, type and arrow type lists.
1002   if (parser.parseOperandList(operandInfos) ||
1003       parser.parseColonTypeList(operandTypes) ||
1004       parser.parseArrowTypeList(result.types))
1005     return failure();
1006 
1007   // Parse all attached regions.
1008   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
1009       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
1010       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
1011     return failure();
1012 
1013   return parser.resolveOperands(operandInfos, operandTypes,
1014                                 parser.getCurrentLocation(), result.operands);
1015 }
1016 
1017 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
1018   assert(index < 2 && "invalid region index");
1019   return getOperands();
1020 }
1021 
1022 void RegionIfOp::getSuccessorRegions(
1023     Optional<unsigned> index, ArrayRef<Attribute> operands,
1024     SmallVectorImpl<RegionSuccessor> &regions) {
1025   // We always branch to the join region.
1026   if (index.hasValue()) {
1027     if (index.getValue() < 2)
1028       regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs()));
1029     else
1030       regions.push_back(RegionSuccessor(getResults()));
1031     return;
1032   }
1033 
1034   // The then and else regions are the entry regions of this op.
1035   regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs()));
1036   regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs()));
1037 }
1038 
1039 //===----------------------------------------------------------------------===//
1040 // SingleNoTerminatorCustomAsmOp
1041 //===----------------------------------------------------------------------===//
1042 
1043 static ParseResult parseSingleNoTerminatorCustomAsmOp(OpAsmParser &parser,
1044                                                       OperationState &state) {
1045   Region *body = state.addRegion();
1046   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
1047     return failure();
1048   return success();
1049 }
1050 
1051 static void print(SingleNoTerminatorCustomAsmOp op, OpAsmPrinter &printer) {
1052   printer << op.getOperationName();
1053   printer.printRegion(
1054       op.getRegion(), /*printEntryBlockArgs=*/false,
1055       // This op has a single block without terminators. But explicitly mark
1056       // as not printing block terminators for testing.
1057       /*printBlockTerminators=*/false);
1058 }
1059 
1060 #include "TestOpEnums.cpp.inc"
1061 #include "TestOpInterfaces.cpp.inc"
1062 #include "TestOpStructs.cpp.inc"
1063 #include "TestTypeInterfaces.cpp.inc"
1064 
1065 #define GET_OP_CLASSES
1066 #include "TestOps.cpp.inc"
1067