1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestAttributes.h"
11 #include "TestInterfaces.h"
12 #include "TestTypes.h"
13 #include "mlir/Dialect/DLTI/DLTI.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/DialectImplementation.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Reducer/ReductionPatternInterface.h"
21 #include "mlir/Transforms/FoldUtils.h"
22 #include "mlir/Transforms/InliningUtils.h"
23 #include "llvm/ADT/StringSwitch.h"
24 
25 using namespace mlir;
26 using namespace mlir::test;
27 
28 void mlir::test::registerTestDialect(DialectRegistry &registry) {
29   registry.insert<TestDialect>();
30 }
31 
32 //===----------------------------------------------------------------------===//
33 // TestDialect Interfaces
34 //===----------------------------------------------------------------------===//
35 
36 namespace {
37 
38 /// Testing the correctness of some traits.
39 static_assert(
40     llvm::is_detected<OpTrait::has_implicit_terminator_t,
41                       SingleBlockImplicitTerminatorOp>::value,
42     "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
43 static_assert(OpTrait::hasSingleBlockImplicitTerminator<
44                   SingleBlockImplicitTerminatorOp>::value,
45               "hasSingleBlockImplicitTerminator does not match "
46               "SingleBlockImplicitTerminatorOp");
47 
48 // Test support for interacting with the AsmPrinter.
49 struct TestOpAsmInterface : public OpAsmDialectInterface {
50   using OpAsmDialectInterface::OpAsmDialectInterface;
51 
52   LogicalResult getAlias(Attribute attr, raw_ostream &os) const final {
53     StringAttr strAttr = attr.dyn_cast<StringAttr>();
54     if (!strAttr)
55       return failure();
56 
57     // Check the contents of the string attribute to see what the test alias
58     // should be named.
59     Optional<StringRef> aliasName =
60         StringSwitch<Optional<StringRef>>(strAttr.getValue())
61             .Case("alias_test:dot_in_name", StringRef("test.alias"))
62             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
63             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
64             .Case("alias_test:sanitize_conflict_a",
65                   StringRef("test_alias_conflict0"))
66             .Case("alias_test:sanitize_conflict_b",
67                   StringRef("test_alias_conflict0_"))
68             .Default(llvm::None);
69     if (!aliasName)
70       return failure();
71 
72     os << *aliasName;
73     return success();
74   }
75 
76   void getAsmResultNames(Operation *op,
77                          OpAsmSetValueNameFn setNameFn) const final {
78     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
79       setNameFn(asmOp, "result");
80   }
81 
82   void getAsmBlockArgumentNames(Block *block,
83                                 OpAsmSetValueNameFn setNameFn) const final {
84     auto op = block->getParentOp();
85     auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
86     if (!arrayAttr)
87       return;
88     auto args = block->getArguments();
89     auto e = std::min(arrayAttr.size(), args.size());
90     for (unsigned i = 0; i < e; ++i) {
91       if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
92         setNameFn(args[i], strAttr.getValue());
93     }
94   }
95 };
96 
97 struct TestDialectFoldInterface : public DialectFoldInterface {
98   using DialectFoldInterface::DialectFoldInterface;
99 
100   /// Registered hook to check if the given region, which is attached to an
101   /// operation that is *not* isolated from above, should be used when
102   /// materializing constants.
103   bool shouldMaterializeInto(Region *region) const final {
104     // If this is a one region operation, then insert into it.
105     return isa<OneRegionOp>(region->getParentOp());
106   }
107 };
108 
109 /// This class defines the interface for handling inlining with standard
110 /// operations.
111 struct TestInlinerInterface : public DialectInlinerInterface {
112   using DialectInlinerInterface::DialectInlinerInterface;
113 
114   //===--------------------------------------------------------------------===//
115   // Analysis Hooks
116   //===--------------------------------------------------------------------===//
117 
118   bool isLegalToInline(Operation *call, Operation *callable,
119                        bool wouldBeCloned) const final {
120     // Don't allow inlining calls that are marked `noinline`.
121     return !call->hasAttr("noinline");
122   }
123   bool isLegalToInline(Region *, Region *, bool,
124                        BlockAndValueMapping &) const final {
125     // Inlining into test dialect regions is legal.
126     return true;
127   }
128   bool isLegalToInline(Operation *, Region *, bool,
129                        BlockAndValueMapping &) const final {
130     return true;
131   }
132 
133   bool shouldAnalyzeRecursively(Operation *op) const final {
134     // Analyze recursively if this is not a functional region operation, it
135     // froms a separate functional scope.
136     return !isa<FunctionalRegionOp>(op);
137   }
138 
139   //===--------------------------------------------------------------------===//
140   // Transformation Hooks
141   //===--------------------------------------------------------------------===//
142 
143   /// Handle the given inlined terminator by replacing it with a new operation
144   /// as necessary.
145   void handleTerminator(Operation *op,
146                         ArrayRef<Value> valuesToRepl) const final {
147     // Only handle "test.return" here.
148     auto returnOp = dyn_cast<TestReturnOp>(op);
149     if (!returnOp)
150       return;
151 
152     // Replace the values directly with the return operands.
153     assert(returnOp.getNumOperands() == valuesToRepl.size());
154     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
155       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
156   }
157 
158   /// Attempt to materialize a conversion for a type mismatch between a call
159   /// from this dialect, and a callable region. This method should generate an
160   /// operation that takes 'input' as the only operand, and produces a single
161   /// result of 'resultType'. If a conversion can not be generated, nullptr
162   /// should be returned.
163   Operation *materializeCallConversion(OpBuilder &builder, Value input,
164                                        Type resultType,
165                                        Location conversionLoc) const final {
166     // Only allow conversion for i16/i32 types.
167     if (!(resultType.isSignlessInteger(16) ||
168           resultType.isSignlessInteger(32)) ||
169         !(input.getType().isSignlessInteger(16) ||
170           input.getType().isSignlessInteger(32)))
171       return nullptr;
172     return builder.create<TestCastOp>(conversionLoc, resultType, input);
173   }
174 };
175 
176 struct TestReductionPatternInterface : public DialectReductionPatternInterface {
177 public:
178   TestReductionPatternInterface(Dialect *dialect)
179       : DialectReductionPatternInterface(dialect) {}
180 
181   virtual void
182   populateReductionPatterns(RewritePatternSet &patterns) const final {
183     populateTestReductionPatterns(patterns);
184   }
185 };
186 
187 } // end anonymous namespace
188 
189 //===----------------------------------------------------------------------===//
190 // TestDialect
191 //===----------------------------------------------------------------------===//
192 
193 static void testSideEffectOpGetEffect(
194     Operation *op,
195     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects);
196 
197 // This is the implementation of a dialect fallback for `TestEffectOpInterface`.
198 struct TestOpEffectInterfaceFallback
199     : public TestEffectOpInterface::FallbackModel<
200           TestOpEffectInterfaceFallback> {
201   static bool classof(Operation *op) {
202     bool isSupportedOp =
203         op->getName().getStringRef() == "test.unregistered_side_effect_op";
204     assert(isSupportedOp && "Unexpected dispatch");
205     return isSupportedOp;
206   }
207 
208   void
209   getEffects(Operation *op,
210              SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
211                  &effects) const {
212     testSideEffectOpGetEffect(op, effects);
213   }
214 };
215 
216 void TestDialect::initialize() {
217   registerAttributes();
218   registerTypes();
219   addOperations<
220 #define GET_OP_LIST
221 #include "TestOps.cpp.inc"
222       >();
223   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
224                 TestInlinerInterface, TestReductionPatternInterface>();
225   allowUnknownOperations();
226 
227   // Instantiate our fallback op interface that we'll use on specific
228   // unregistered op.
229   fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
230 }
231 TestDialect::~TestDialect() {
232   delete static_cast<TestOpEffectInterfaceFallback *>(
233       fallbackEffectOpInterfaces);
234 }
235 
236 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
237                                             Type type, Location loc) {
238   return builder.create<TestOpConstant>(loc, type, value);
239 }
240 
241 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
242                                                OperationName opName) {
243   if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
244       typeID == TypeID::get<TestEffectOpInterface>())
245     return fallbackEffectOpInterfaces;
246   return nullptr;
247 }
248 
249 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
250                                                     NamedAttribute namedAttr) {
251   if (namedAttr.first == "test.invalid_attr")
252     return op->emitError() << "invalid to use 'test.invalid_attr'";
253   return success();
254 }
255 
256 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
257                                                     unsigned regionIndex,
258                                                     unsigned argIndex,
259                                                     NamedAttribute namedAttr) {
260   if (namedAttr.first == "test.invalid_attr")
261     return op->emitError() << "invalid to use 'test.invalid_attr'";
262   return success();
263 }
264 
265 LogicalResult
266 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
267                                          unsigned resultIndex,
268                                          NamedAttribute namedAttr) {
269   if (namedAttr.first == "test.invalid_attr")
270     return op->emitError() << "invalid to use 'test.invalid_attr'";
271   return success();
272 }
273 
274 Optional<Dialect::ParseOpHook>
275 TestDialect::getParseOperationHook(StringRef opName) const {
276   if (opName == "test.dialect_custom_printer") {
277     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
278       return parser.parseKeyword("custom_format");
279     }};
280   }
281   return None;
282 }
283 
284 LogicalResult TestDialect::printOperation(Operation *op,
285                                           OpAsmPrinter &printer) const {
286   StringRef opName = op->getName().getStringRef();
287   if (opName == "test.dialect_custom_printer") {
288     printer.getStream() << opName << " custom_format";
289     return success();
290   }
291   return failure();
292 }
293 
294 //===----------------------------------------------------------------------===//
295 // TestBranchOp
296 //===----------------------------------------------------------------------===//
297 
298 Optional<MutableOperandRange>
299 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
300   assert(index == 0 && "invalid successor index");
301   return targetOperandsMutable();
302 }
303 
304 //===----------------------------------------------------------------------===//
305 // TestDialectCanonicalizerOp
306 //===----------------------------------------------------------------------===//
307 
308 static LogicalResult
309 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
310                                PatternRewriter &rewriter) {
311   rewriter.replaceOpWithNewOp<ConstantOp>(op, rewriter.getI32Type(),
312                                           rewriter.getI32IntegerAttr(42));
313   return success();
314 }
315 
316 void TestDialect::getCanonicalizationPatterns(
317     RewritePatternSet &results) const {
318   results.add(&dialectCanonicalizationPattern);
319 }
320 
321 //===----------------------------------------------------------------------===//
322 // TestFoldToCallOp
323 //===----------------------------------------------------------------------===//
324 
325 namespace {
326 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
327   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
328 
329   LogicalResult matchAndRewrite(FoldToCallOp op,
330                                 PatternRewriter &rewriter) const override {
331     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(),
332                                         ValueRange());
333     return success();
334   }
335 };
336 } // end anonymous namespace
337 
338 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
339                                                MLIRContext *context) {
340   results.add<FoldToCallOpPattern>(context);
341 }
342 
343 //===----------------------------------------------------------------------===//
344 // Test Format* operations
345 //===----------------------------------------------------------------------===//
346 
347 //===----------------------------------------------------------------------===//
348 // Parsing
349 
350 static ParseResult parseCustomDirectiveOperands(
351     OpAsmParser &parser, OpAsmParser::OperandType &operand,
352     Optional<OpAsmParser::OperandType> &optOperand,
353     SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
354   if (parser.parseOperand(operand))
355     return failure();
356   if (succeeded(parser.parseOptionalComma())) {
357     optOperand.emplace();
358     if (parser.parseOperand(*optOperand))
359       return failure();
360   }
361   if (parser.parseArrow() || parser.parseLParen() ||
362       parser.parseOperandList(varOperands) || parser.parseRParen())
363     return failure();
364   return success();
365 }
366 static ParseResult
367 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
368                             Type &optOperandType,
369                             SmallVectorImpl<Type> &varOperandTypes) {
370   if (parser.parseColon())
371     return failure();
372 
373   if (parser.parseType(operandType))
374     return failure();
375   if (succeeded(parser.parseOptionalComma())) {
376     if (parser.parseType(optOperandType))
377       return failure();
378   }
379   if (parser.parseArrow() || parser.parseLParen() ||
380       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
381     return failure();
382   return success();
383 }
384 static ParseResult
385 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
386                                  Type optOperandType,
387                                  const SmallVectorImpl<Type> &varOperandTypes) {
388   if (parser.parseKeyword("type_refs_capture"))
389     return failure();
390 
391   Type operandType2, optOperandType2;
392   SmallVector<Type, 1> varOperandTypes2;
393   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
394                                   varOperandTypes2))
395     return failure();
396 
397   if (operandType != operandType2 || optOperandType != optOperandType2 ||
398       varOperandTypes != varOperandTypes2)
399     return failure();
400 
401   return success();
402 }
403 static ParseResult parseCustomDirectiveOperandsAndTypes(
404     OpAsmParser &parser, OpAsmParser::OperandType &operand,
405     Optional<OpAsmParser::OperandType> &optOperand,
406     SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
407     Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
408   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
409       parseCustomDirectiveResults(parser, operandType, optOperandType,
410                                   varOperandTypes))
411     return failure();
412   return success();
413 }
414 static ParseResult parseCustomDirectiveRegions(
415     OpAsmParser &parser, Region &region,
416     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
417   if (parser.parseRegion(region))
418     return failure();
419   if (failed(parser.parseOptionalComma()))
420     return success();
421   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
422   if (parser.parseRegion(*varRegion))
423     return failure();
424   varRegions.emplace_back(std::move(varRegion));
425   return success();
426 }
427 static ParseResult
428 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
429                                SmallVectorImpl<Block *> &varSuccessors) {
430   if (parser.parseSuccessor(successor))
431     return failure();
432   if (failed(parser.parseOptionalComma()))
433     return success();
434   Block *varSuccessor;
435   if (parser.parseSuccessor(varSuccessor))
436     return failure();
437   varSuccessors.append(2, varSuccessor);
438   return success();
439 }
440 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
441                                                   IntegerAttr &attr,
442                                                   IntegerAttr &optAttr) {
443   if (parser.parseAttribute(attr))
444     return failure();
445   if (succeeded(parser.parseOptionalComma())) {
446     if (parser.parseAttribute(optAttr))
447       return failure();
448   }
449   return success();
450 }
451 
452 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
453                                                 NamedAttrList &attrs) {
454   return parser.parseOptionalAttrDict(attrs);
455 }
456 static ParseResult parseCustomDirectiveOptionalOperandRef(
457     OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) {
458   int64_t operandCount = 0;
459   if (parser.parseInteger(operandCount))
460     return failure();
461   bool expectedOptionalOperand = operandCount == 0;
462   return success(expectedOptionalOperand != optOperand.hasValue());
463 }
464 
465 //===----------------------------------------------------------------------===//
466 // Printing
467 
468 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
469                                          Value operand, Value optOperand,
470                                          OperandRange varOperands) {
471   printer << operand;
472   if (optOperand)
473     printer << ", " << optOperand;
474   printer << " -> (" << varOperands << ")";
475 }
476 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
477                                         Type operandType, Type optOperandType,
478                                         TypeRange varOperandTypes) {
479   printer << " : " << operandType;
480   if (optOperandType)
481     printer << ", " << optOperandType;
482   printer << " -> (" << varOperandTypes << ")";
483 }
484 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
485                                              Operation *op, Type operandType,
486                                              Type optOperandType,
487                                              TypeRange varOperandTypes) {
488   printer << " type_refs_capture ";
489   printCustomDirectiveResults(printer, op, operandType, optOperandType,
490                               varOperandTypes);
491 }
492 static void printCustomDirectiveOperandsAndTypes(
493     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
494     OperandRange varOperands, Type operandType, Type optOperandType,
495     TypeRange varOperandTypes) {
496   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
497   printCustomDirectiveResults(printer, op, operandType, optOperandType,
498                               varOperandTypes);
499 }
500 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
501                                         Region &region,
502                                         MutableArrayRef<Region> varRegions) {
503   printer.printRegion(region);
504   if (!varRegions.empty()) {
505     printer << ", ";
506     for (Region &region : varRegions)
507       printer.printRegion(region);
508   }
509 }
510 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
511                                            Block *successor,
512                                            SuccessorRange varSuccessors) {
513   printer << successor;
514   if (!varSuccessors.empty())
515     printer << ", " << varSuccessors.front();
516 }
517 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
518                                            Attribute attribute,
519                                            Attribute optAttribute) {
520   printer << attribute;
521   if (optAttribute)
522     printer << ", " << optAttribute;
523 }
524 
525 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
526                                          DictionaryAttr attrs) {
527   printer.printOptionalAttrDict(attrs.getValue());
528 }
529 
530 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
531                                                    Operation *op,
532                                                    Value optOperand) {
533   printer << (optOperand ? "1" : "0");
534 }
535 
536 //===----------------------------------------------------------------------===//
537 // Test IsolatedRegionOp - parse passthrough region arguments.
538 //===----------------------------------------------------------------------===//
539 
540 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
541                                          OperationState &result) {
542   OpAsmParser::OperandType argInfo;
543   Type argType = parser.getBuilder().getIndexType();
544 
545   // Parse the input operand.
546   if (parser.parseOperand(argInfo) ||
547       parser.resolveOperand(argInfo, argType, result.operands))
548     return failure();
549 
550   // Parse the body region, and reuse the operand info as the argument info.
551   Region *body = result.addRegion();
552   return parser.parseRegion(*body, argInfo, argType,
553                             /*enableNameShadowing=*/true);
554 }
555 
556 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
557   p << "test.isolated_region ";
558   p.printOperand(op.getOperand());
559   p.shadowRegionArgs(op.region(), op.getOperand());
560   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
561 }
562 
563 //===----------------------------------------------------------------------===//
564 // Test SSACFGRegionOp
565 //===----------------------------------------------------------------------===//
566 
567 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
568   return RegionKind::SSACFG;
569 }
570 
571 //===----------------------------------------------------------------------===//
572 // Test GraphRegionOp
573 //===----------------------------------------------------------------------===//
574 
575 static ParseResult parseGraphRegionOp(OpAsmParser &parser,
576                                       OperationState &result) {
577   // Parse the body region, and reuse the operand info as the argument info.
578   Region *body = result.addRegion();
579   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
580 }
581 
582 static void print(OpAsmPrinter &p, GraphRegionOp op) {
583   p << "test.graph_region ";
584   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
585 }
586 
587 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
588   return RegionKind::Graph;
589 }
590 
591 //===----------------------------------------------------------------------===//
592 // Test AffineScopeOp
593 //===----------------------------------------------------------------------===//
594 
595 static ParseResult parseAffineScopeOp(OpAsmParser &parser,
596                                       OperationState &result) {
597   // Parse the body region, and reuse the operand info as the argument info.
598   Region *body = result.addRegion();
599   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
600 }
601 
602 static void print(OpAsmPrinter &p, AffineScopeOp op) {
603   p << "test.affine_scope ";
604   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
605 }
606 
607 //===----------------------------------------------------------------------===//
608 // Test parser.
609 //===----------------------------------------------------------------------===//
610 
611 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser,
612                                               OperationState &result) {
613   if (parser.parseOptionalColon())
614     return success();
615   uint64_t numResults;
616   if (parser.parseInteger(numResults))
617     return failure();
618 
619   IndexType type = parser.getBuilder().getIndexType();
620   for (unsigned i = 0; i < numResults; ++i)
621     result.addTypes(type);
622   return success();
623 }
624 
625 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) {
626   p << ParseIntegerLiteralOp::getOperationName();
627   if (unsigned numResults = op->getNumResults())
628     p << " : " << numResults;
629 }
630 
631 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser,
632                                               OperationState &result) {
633   StringRef keyword;
634   if (parser.parseKeyword(&keyword))
635     return failure();
636   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
637   return success();
638 }
639 
640 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) {
641   p << ParseWrappedKeywordOp::getOperationName() << " " << op.keyword();
642 }
643 
644 //===----------------------------------------------------------------------===//
645 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
646 
647 static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
648                                          OperationState &result) {
649   if (parser.parseKeyword("wraps"))
650     return failure();
651 
652   // Parse the wrapped op in a region
653   Region &body = *result.addRegion();
654   body.push_back(new Block);
655   Block &block = body.back();
656   Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
657   if (!wrapped_op)
658     return failure();
659 
660   // Create a return terminator in the inner region, pass as operand to the
661   // terminator the returned values from the wrapped operation.
662   SmallVector<Value, 8> return_operands(wrapped_op->getResults());
663   OpBuilder builder(parser.getBuilder().getContext());
664   builder.setInsertionPointToEnd(&block);
665   builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
666 
667   // Get the results type for the wrapping op from the terminator operands.
668   Operation &return_op = body.back().back();
669   result.types.append(return_op.operand_type_begin(),
670                       return_op.operand_type_end());
671 
672   // Use the location of the wrapped op for the "test.wrapping_region" op.
673   result.location = wrapped_op->getLoc();
674 
675   return success();
676 }
677 
678 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
679   p << op.getOperationName() << " wraps ";
680   p.printGenericOp(&op.region().front().front());
681 }
682 
683 //===----------------------------------------------------------------------===//
684 // Test PolyForOp - parse list of region arguments.
685 //===----------------------------------------------------------------------===//
686 
687 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
688   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
689   // Parse list of region arguments without a delimiter.
690   if (parser.parseRegionArgumentList(ivsInfo))
691     return failure();
692 
693   // Parse the body region.
694   Region *body = result.addRegion();
695   auto &builder = parser.getBuilder();
696   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
697   return parser.parseRegion(*body, ivsInfo, argTypes);
698 }
699 
700 //===----------------------------------------------------------------------===//
701 // Test removing op with inner ops.
702 //===----------------------------------------------------------------------===//
703 
704 namespace {
705 struct TestRemoveOpWithInnerOps
706     : public OpRewritePattern<TestOpWithRegionPattern> {
707   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
708 
709   void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
710 
711   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
712                                 PatternRewriter &rewriter) const override {
713     rewriter.eraseOp(op);
714     return success();
715   }
716 };
717 } // end anonymous namespace
718 
719 void TestOpWithRegionPattern::getCanonicalizationPatterns(
720     RewritePatternSet &results, MLIRContext *context) {
721   results.add<TestRemoveOpWithInnerOps>(context);
722 }
723 
724 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
725   return operand();
726 }
727 
728 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
729   return getValue();
730 }
731 
732 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
733     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
734   for (Value input : this->operands()) {
735     results.push_back(input);
736   }
737   return success();
738 }
739 
740 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
741   assert(operands.size() == 1);
742   if (operands.front()) {
743     (*this)->setAttr("attr", operands.front());
744     return getResult();
745   }
746   return {};
747 }
748 
749 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
750   return getOperand();
751 }
752 
753 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
754     MLIRContext *, Optional<Location> location, ValueRange operands,
755     DictionaryAttr attributes, RegionRange regions,
756     SmallVectorImpl<Type> &inferredReturnTypes) {
757   if (operands[0].getType() != operands[1].getType()) {
758     return emitOptionalError(location, "operand type mismatch ",
759                              operands[0].getType(), " vs ",
760                              operands[1].getType());
761   }
762   inferredReturnTypes.assign({operands[0].getType()});
763   return success();
764 }
765 
766 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
767     MLIRContext *context, Optional<Location> location, ValueRange operands,
768     DictionaryAttr attributes, RegionRange regions,
769     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
770   // Create return type consisting of the last element of the first operand.
771   auto operandType = *operands.getTypes().begin();
772   auto sval = operandType.dyn_cast<ShapedType>();
773   if (!sval) {
774     return emitOptionalError(location, "only shaped type operands allowed");
775   }
776   int64_t dim =
777       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
778   auto type = IntegerType::get(context, 17);
779   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
780   return success();
781 }
782 
783 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
784     OpBuilder &builder, ValueRange operands,
785     llvm::SmallVectorImpl<Value> &shapes) {
786   shapes = SmallVector<Value, 1>{
787       builder.createOrFold<memref::DimOp>(getLoc(), operands.front(), 0)};
788   return success();
789 }
790 
791 LogicalResult
792 OpWithResultShapePerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim(
793     OpBuilder &builder,
794     llvm::SmallVectorImpl<llvm::SmallVector<Value>> &shapes) {
795   SmallVector<Value> operand1Shape, operand2Shape;
796   Location loc = getLoc();
797   for (auto i :
798        llvm::seq<int>(0, operand1().getType().cast<ShapedType>().getRank())) {
799     operand1Shape.push_back(builder.create<memref::DimOp>(loc, operand1(), i));
800   }
801   for (auto i :
802        llvm::seq<int>(0, operand2().getType().cast<ShapedType>().getRank())) {
803     operand2Shape.push_back(builder.create<memref::DimOp>(loc, operand2(), i));
804   }
805   shapes.emplace_back(std::move(operand2Shape));
806   shapes.emplace_back(std::move(operand1Shape));
807   return success();
808 }
809 
810 //===----------------------------------------------------------------------===//
811 // Test SideEffect interfaces
812 //===----------------------------------------------------------------------===//
813 
814 namespace {
815 /// A test resource for side effects.
816 struct TestResource : public SideEffects::Resource::Base<TestResource> {
817   StringRef getName() final { return "<Test>"; }
818 };
819 } // end anonymous namespace
820 
821 static void testSideEffectOpGetEffect(
822     Operation *op,
823     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
824         &effects) {
825   auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
826   if (!effectsAttr)
827     return;
828 
829   effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
830 }
831 
832 void SideEffectOp::getEffects(
833     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
834   // Check for an effects attribute on the op instance.
835   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
836   if (!effectsAttr)
837     return;
838 
839   // If there is one, it is an array of dictionary attributes that hold
840   // information on the effects of this operation.
841   for (Attribute element : effectsAttr) {
842     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
843 
844     // Get the specific memory effect.
845     MemoryEffects::Effect *effect =
846         StringSwitch<MemoryEffects::Effect *>(
847             effectElement.get("effect").cast<StringAttr>().getValue())
848             .Case("allocate", MemoryEffects::Allocate::get())
849             .Case("free", MemoryEffects::Free::get())
850             .Case("read", MemoryEffects::Read::get())
851             .Case("write", MemoryEffects::Write::get());
852 
853     // Check for a non-default resource to use.
854     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
855     if (effectElement.get("test_resource"))
856       resource = TestResource::get();
857 
858     // Check for a result to affect.
859     if (effectElement.get("on_result"))
860       effects.emplace_back(effect, getResult(), resource);
861     else if (Attribute ref = effectElement.get("on_reference"))
862       effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
863     else
864       effects.emplace_back(effect, resource);
865   }
866 }
867 
868 void SideEffectOp::getEffects(
869     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
870   testSideEffectOpGetEffect(getOperation(), effects);
871 }
872 
873 //===----------------------------------------------------------------------===//
874 // StringAttrPrettyNameOp
875 //===----------------------------------------------------------------------===//
876 
877 // This op has fancy handling of its SSA result name.
878 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
879                                                OperationState &result) {
880   // Add the result types.
881   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
882     result.addTypes(parser.getBuilder().getIntegerType(32));
883 
884   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
885     return failure();
886 
887   // If the attribute dictionary contains no 'names' attribute, infer it from
888   // the SSA name (if specified).
889   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
890     return attr.first == "names";
891   });
892 
893   // If there was no name specified, check to see if there was a useful name
894   // specified in the asm file.
895   if (hadNames || parser.getNumResults() == 0)
896     return success();
897 
898   SmallVector<StringRef, 4> names;
899   auto *context = result.getContext();
900 
901   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
902     auto resultName = parser.getResultName(i);
903     StringRef nameStr;
904     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
905       nameStr = resultName.first;
906 
907     names.push_back(nameStr);
908   }
909 
910   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
911   result.attributes.push_back({Identifier::get("names", context), namesAttr});
912   return success();
913 }
914 
915 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
916   p << "test.string_attr_pretty_name";
917 
918   // Note that we only need to print the "name" attribute if the asmprinter
919   // result name disagrees with it.  This can happen in strange cases, e.g.
920   // when there are conflicts.
921   bool namesDisagree = op.names().size() != op.getNumResults();
922 
923   SmallString<32> resultNameStr;
924   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
925     resultNameStr.clear();
926     llvm::raw_svector_ostream tmpStream(resultNameStr);
927     p.printOperand(op.getResult(i), tmpStream);
928 
929     auto expectedName = op.names()[i].dyn_cast<StringAttr>();
930     if (!expectedName ||
931         tmpStream.str().drop_front() != expectedName.getValue()) {
932       namesDisagree = true;
933     }
934   }
935 
936   if (namesDisagree)
937     p.printOptionalAttrDictWithKeyword(op->getAttrs());
938   else
939     p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"});
940 }
941 
942 // We set the SSA name in the asm syntax to the contents of the name
943 // attribute.
944 void StringAttrPrettyNameOp::getAsmResultNames(
945     function_ref<void(Value, StringRef)> setNameFn) {
946 
947   auto value = names();
948   for (size_t i = 0, e = value.size(); i != e; ++i)
949     if (auto str = value[i].dyn_cast<StringAttr>())
950       if (!str.getValue().empty())
951         setNameFn(getResult(i), str.getValue());
952 }
953 
954 //===----------------------------------------------------------------------===//
955 // RegionIfOp
956 //===----------------------------------------------------------------------===//
957 
958 static void print(OpAsmPrinter &p, RegionIfOp op) {
959   p << RegionIfOp::getOperationName() << " ";
960   p.printOperands(op.getOperands());
961   p << ": " << op.getOperandTypes();
962   p.printArrowTypeList(op.getResultTypes());
963   p << " then";
964   p.printRegion(op.thenRegion(),
965                 /*printEntryBlockArgs=*/true,
966                 /*printBlockTerminators=*/true);
967   p << " else";
968   p.printRegion(op.elseRegion(),
969                 /*printEntryBlockArgs=*/true,
970                 /*printBlockTerminators=*/true);
971   p << " join";
972   p.printRegion(op.joinRegion(),
973                 /*printEntryBlockArgs=*/true,
974                 /*printBlockTerminators=*/true);
975 }
976 
977 static ParseResult parseRegionIfOp(OpAsmParser &parser,
978                                    OperationState &result) {
979   SmallVector<OpAsmParser::OperandType, 2> operandInfos;
980   SmallVector<Type, 2> operandTypes;
981 
982   result.regions.reserve(3);
983   Region *thenRegion = result.addRegion();
984   Region *elseRegion = result.addRegion();
985   Region *joinRegion = result.addRegion();
986 
987   // Parse operand, type and arrow type lists.
988   if (parser.parseOperandList(operandInfos) ||
989       parser.parseColonTypeList(operandTypes) ||
990       parser.parseArrowTypeList(result.types))
991     return failure();
992 
993   // Parse all attached regions.
994   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
995       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
996       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
997     return failure();
998 
999   return parser.resolveOperands(operandInfos, operandTypes,
1000                                 parser.getCurrentLocation(), result.operands);
1001 }
1002 
1003 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
1004   assert(index < 2 && "invalid region index");
1005   return getOperands();
1006 }
1007 
1008 void RegionIfOp::getSuccessorRegions(
1009     Optional<unsigned> index, ArrayRef<Attribute> operands,
1010     SmallVectorImpl<RegionSuccessor> &regions) {
1011   // We always branch to the join region.
1012   if (index.hasValue()) {
1013     if (index.getValue() < 2)
1014       regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs()));
1015     else
1016       regions.push_back(RegionSuccessor(getResults()));
1017     return;
1018   }
1019 
1020   // The then and else regions are the entry regions of this op.
1021   regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs()));
1022   regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs()));
1023 }
1024 
1025 //===----------------------------------------------------------------------===//
1026 // SingleNoTerminatorCustomAsmOp
1027 //===----------------------------------------------------------------------===//
1028 
1029 static ParseResult parseSingleNoTerminatorCustomAsmOp(OpAsmParser &parser,
1030                                                       OperationState &state) {
1031   Region *body = state.addRegion();
1032   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
1033     return failure();
1034   return success();
1035 }
1036 
1037 static void print(SingleNoTerminatorCustomAsmOp op, OpAsmPrinter &printer) {
1038   printer << op.getOperationName();
1039   printer.printRegion(
1040       op.getRegion(), /*printEntryBlockArgs=*/false,
1041       // This op has a single block without terminators. But explicitly mark
1042       // as not printing block terminators for testing.
1043       /*printBlockTerminators=*/false);
1044 }
1045 
1046 #include "TestOpEnums.cpp.inc"
1047 #include "TestOpInterfaces.cpp.inc"
1048 #include "TestOpStructs.cpp.inc"
1049 #include "TestTypeInterfaces.cpp.inc"
1050 
1051 #define GET_OP_CLASSES
1052 #include "TestOps.cpp.inc"
1053