1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestAttributes.h"
11 #include "TestTypes.h"
12 #include "mlir/Dialect/DLTI/DLTI.h"
13 #include "mlir/Dialect/MemRef/IR/MemRef.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/DialectImplementation.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/Transforms/FoldUtils.h"
20 #include "mlir/Transforms/InliningUtils.h"
21 #include "llvm/ADT/StringSwitch.h"
22 
23 using namespace mlir;
24 using namespace mlir::test;
25 
26 void mlir::test::registerTestDialect(DialectRegistry &registry) {
27   registry.insert<TestDialect>();
28 }
29 
30 //===----------------------------------------------------------------------===//
31 // TestDialect Interfaces
32 //===----------------------------------------------------------------------===//
33 
34 namespace {
35 
36 // Test support for interacting with the AsmPrinter.
37 struct TestOpAsmInterface : public OpAsmDialectInterface {
38   using OpAsmDialectInterface::OpAsmDialectInterface;
39 
40   LogicalResult getAlias(Attribute attr, raw_ostream &os) const final {
41     StringAttr strAttr = attr.dyn_cast<StringAttr>();
42     if (!strAttr)
43       return failure();
44 
45     // Check the contents of the string attribute to see what the test alias
46     // should be named.
47     Optional<StringRef> aliasName =
48         StringSwitch<Optional<StringRef>>(strAttr.getValue())
49             .Case("alias_test:dot_in_name", StringRef("test.alias"))
50             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
51             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
52             .Case("alias_test:sanitize_conflict_a",
53                   StringRef("test_alias_conflict0"))
54             .Case("alias_test:sanitize_conflict_b",
55                   StringRef("test_alias_conflict0_"))
56             .Default(llvm::None);
57     if (!aliasName)
58       return failure();
59 
60     os << *aliasName;
61     return success();
62   }
63 
64   void getAsmResultNames(Operation *op,
65                          OpAsmSetValueNameFn setNameFn) const final {
66     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
67       setNameFn(asmOp, "result");
68   }
69 
70   void getAsmBlockArgumentNames(Block *block,
71                                 OpAsmSetValueNameFn setNameFn) const final {
72     auto op = block->getParentOp();
73     auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
74     if (!arrayAttr)
75       return;
76     auto args = block->getArguments();
77     auto e = std::min(arrayAttr.size(), args.size());
78     for (unsigned i = 0; i < e; ++i) {
79       if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
80         setNameFn(args[i], strAttr.getValue());
81     }
82   }
83 };
84 
85 struct TestDialectFoldInterface : public DialectFoldInterface {
86   using DialectFoldInterface::DialectFoldInterface;
87 
88   /// Registered hook to check if the given region, which is attached to an
89   /// operation that is *not* isolated from above, should be used when
90   /// materializing constants.
91   bool shouldMaterializeInto(Region *region) const final {
92     // If this is a one region operation, then insert into it.
93     return isa<OneRegionOp>(region->getParentOp());
94   }
95 };
96 
97 /// This class defines the interface for handling inlining with standard
98 /// operations.
99 struct TestInlinerInterface : public DialectInlinerInterface {
100   using DialectInlinerInterface::DialectInlinerInterface;
101 
102   //===--------------------------------------------------------------------===//
103   // Analysis Hooks
104   //===--------------------------------------------------------------------===//
105 
106   bool isLegalToInline(Operation *call, Operation *callable,
107                        bool wouldBeCloned) const final {
108     // Don't allow inlining calls that are marked `noinline`.
109     return !call->hasAttr("noinline");
110   }
111   bool isLegalToInline(Region *, Region *, bool,
112                        BlockAndValueMapping &) const final {
113     // Inlining into test dialect regions is legal.
114     return true;
115   }
116   bool isLegalToInline(Operation *, Region *, bool,
117                        BlockAndValueMapping &) const final {
118     return true;
119   }
120 
121   bool shouldAnalyzeRecursively(Operation *op) const final {
122     // Analyze recursively if this is not a functional region operation, it
123     // froms a separate functional scope.
124     return !isa<FunctionalRegionOp>(op);
125   }
126 
127   //===--------------------------------------------------------------------===//
128   // Transformation Hooks
129   //===--------------------------------------------------------------------===//
130 
131   /// Handle the given inlined terminator by replacing it with a new operation
132   /// as necessary.
133   void handleTerminator(Operation *op,
134                         ArrayRef<Value> valuesToRepl) const final {
135     // Only handle "test.return" here.
136     auto returnOp = dyn_cast<TestReturnOp>(op);
137     if (!returnOp)
138       return;
139 
140     // Replace the values directly with the return operands.
141     assert(returnOp.getNumOperands() == valuesToRepl.size());
142     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
143       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
144   }
145 
146   /// Attempt to materialize a conversion for a type mismatch between a call
147   /// from this dialect, and a callable region. This method should generate an
148   /// operation that takes 'input' as the only operand, and produces a single
149   /// result of 'resultType'. If a conversion can not be generated, nullptr
150   /// should be returned.
151   Operation *materializeCallConversion(OpBuilder &builder, Value input,
152                                        Type resultType,
153                                        Location conversionLoc) const final {
154     // Only allow conversion for i16/i32 types.
155     if (!(resultType.isSignlessInteger(16) ||
156           resultType.isSignlessInteger(32)) ||
157         !(input.getType().isSignlessInteger(16) ||
158           input.getType().isSignlessInteger(32)))
159       return nullptr;
160     return builder.create<TestCastOp>(conversionLoc, resultType, input);
161   }
162 };
163 } // end anonymous namespace
164 
165 //===----------------------------------------------------------------------===//
166 // TestDialect
167 //===----------------------------------------------------------------------===//
168 
169 static void testSideEffectOpGetEffect(
170     Operation *op,
171     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects);
172 
173 // This is the implementation of a dialect fallback for `TestEffectOpInterface`.
174 struct TestOpEffectInterfaceFallback
175     : public TestEffectOpInterface::FallbackModel<
176           TestOpEffectInterfaceFallback> {
177   static bool classof(Operation *op) {
178     bool isSupportedOp =
179         op->getName().getStringRef() == "test.unregistered_side_effect_op";
180     assert(isSupportedOp && "Unexpected dispatch");
181     return isSupportedOp;
182   }
183 
184   void
185   getEffects(Operation *op,
186              SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
187                  &effects) const {
188     testSideEffectOpGetEffect(op, effects);
189   }
190 };
191 
192 void TestDialect::initialize() {
193   registerAttributes();
194   registerTypes();
195   addOperations<
196 #define GET_OP_LIST
197 #include "TestOps.cpp.inc"
198       >();
199   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
200                 TestInlinerInterface>();
201   allowUnknownOperations();
202 
203   // Instantiate our fallback op interface that we'll use on specific
204   // unregistered op.
205   fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
206 }
207 TestDialect::~TestDialect() {
208   delete static_cast<TestOpEffectInterfaceFallback *>(
209       fallbackEffectOpInterfaces);
210 }
211 
212 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
213                                             Type type, Location loc) {
214   return builder.create<TestOpConstant>(loc, type, value);
215 }
216 
217 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
218                                                OperationName opName) {
219   if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
220       typeID == TypeID::get<TestEffectOpInterface>())
221     return fallbackEffectOpInterfaces;
222   return nullptr;
223 }
224 
225 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
226                                                     NamedAttribute namedAttr) {
227   if (namedAttr.first == "test.invalid_attr")
228     return op->emitError() << "invalid to use 'test.invalid_attr'";
229   return success();
230 }
231 
232 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
233                                                     unsigned regionIndex,
234                                                     unsigned argIndex,
235                                                     NamedAttribute namedAttr) {
236   if (namedAttr.first == "test.invalid_attr")
237     return op->emitError() << "invalid to use 'test.invalid_attr'";
238   return success();
239 }
240 
241 LogicalResult
242 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
243                                          unsigned resultIndex,
244                                          NamedAttribute namedAttr) {
245   if (namedAttr.first == "test.invalid_attr")
246     return op->emitError() << "invalid to use 'test.invalid_attr'";
247   return success();
248 }
249 
250 Optional<Dialect::ParseOpHook>
251 TestDialect::getParseOperationHook(StringRef opName) const {
252   if (opName == "test.dialect_custom_printer") {
253     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
254       return parser.parseKeyword("custom_format");
255     }};
256   }
257   return None;
258 }
259 
260 LogicalResult TestDialect::printOperation(Operation *op,
261                                           OpAsmPrinter &printer) const {
262   StringRef opName = op->getName().getStringRef();
263   if (opName == "test.dialect_custom_printer") {
264     printer.getStream() << opName << " custom_format";
265     return success();
266   }
267   return failure();
268 }
269 
270 //===----------------------------------------------------------------------===//
271 // TestBranchOp
272 //===----------------------------------------------------------------------===//
273 
274 Optional<MutableOperandRange>
275 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
276   assert(index == 0 && "invalid successor index");
277   return targetOperandsMutable();
278 }
279 
280 //===----------------------------------------------------------------------===//
281 // TestFoldToCallOp
282 //===----------------------------------------------------------------------===//
283 
284 namespace {
285 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
286   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
287 
288   LogicalResult matchAndRewrite(FoldToCallOp op,
289                                 PatternRewriter &rewriter) const override {
290     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(),
291                                         ValueRange());
292     return success();
293   }
294 };
295 } // end anonymous namespace
296 
297 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
298                                                MLIRContext *context) {
299   results.add<FoldToCallOpPattern>(context);
300 }
301 
302 //===----------------------------------------------------------------------===//
303 // Test Format* operations
304 //===----------------------------------------------------------------------===//
305 
306 //===----------------------------------------------------------------------===//
307 // Parsing
308 
309 static ParseResult parseCustomDirectiveOperands(
310     OpAsmParser &parser, OpAsmParser::OperandType &operand,
311     Optional<OpAsmParser::OperandType> &optOperand,
312     SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
313   if (parser.parseOperand(operand))
314     return failure();
315   if (succeeded(parser.parseOptionalComma())) {
316     optOperand.emplace();
317     if (parser.parseOperand(*optOperand))
318       return failure();
319   }
320   if (parser.parseArrow() || parser.parseLParen() ||
321       parser.parseOperandList(varOperands) || parser.parseRParen())
322     return failure();
323   return success();
324 }
325 static ParseResult
326 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
327                             Type &optOperandType,
328                             SmallVectorImpl<Type> &varOperandTypes) {
329   if (parser.parseColon())
330     return failure();
331 
332   if (parser.parseType(operandType))
333     return failure();
334   if (succeeded(parser.parseOptionalComma())) {
335     if (parser.parseType(optOperandType))
336       return failure();
337   }
338   if (parser.parseArrow() || parser.parseLParen() ||
339       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
340     return failure();
341   return success();
342 }
343 static ParseResult
344 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
345                                  Type optOperandType,
346                                  const SmallVectorImpl<Type> &varOperandTypes) {
347   if (parser.parseKeyword("type_refs_capture"))
348     return failure();
349 
350   Type operandType2, optOperandType2;
351   SmallVector<Type, 1> varOperandTypes2;
352   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
353                                   varOperandTypes2))
354     return failure();
355 
356   if (operandType != operandType2 || optOperandType != optOperandType2 ||
357       varOperandTypes != varOperandTypes2)
358     return failure();
359 
360   return success();
361 }
362 static ParseResult parseCustomDirectiveOperandsAndTypes(
363     OpAsmParser &parser, OpAsmParser::OperandType &operand,
364     Optional<OpAsmParser::OperandType> &optOperand,
365     SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
366     Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
367   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
368       parseCustomDirectiveResults(parser, operandType, optOperandType,
369                                   varOperandTypes))
370     return failure();
371   return success();
372 }
373 static ParseResult parseCustomDirectiveRegions(
374     OpAsmParser &parser, Region &region,
375     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
376   if (parser.parseRegion(region))
377     return failure();
378   if (failed(parser.parseOptionalComma()))
379     return success();
380   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
381   if (parser.parseRegion(*varRegion))
382     return failure();
383   varRegions.emplace_back(std::move(varRegion));
384   return success();
385 }
386 static ParseResult
387 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
388                                SmallVectorImpl<Block *> &varSuccessors) {
389   if (parser.parseSuccessor(successor))
390     return failure();
391   if (failed(parser.parseOptionalComma()))
392     return success();
393   Block *varSuccessor;
394   if (parser.parseSuccessor(varSuccessor))
395     return failure();
396   varSuccessors.append(2, varSuccessor);
397   return success();
398 }
399 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
400                                                   IntegerAttr &attr,
401                                                   IntegerAttr &optAttr) {
402   if (parser.parseAttribute(attr))
403     return failure();
404   if (succeeded(parser.parseOptionalComma())) {
405     if (parser.parseAttribute(optAttr))
406       return failure();
407   }
408   return success();
409 }
410 
411 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
412                                                 NamedAttrList &attrs) {
413   return parser.parseOptionalAttrDict(attrs);
414 }
415 static ParseResult parseCustomDirectiveOptionalOperandRef(
416     OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) {
417   int64_t operandCount = 0;
418   if (parser.parseInteger(operandCount))
419     return failure();
420   bool expectedOptionalOperand = operandCount == 0;
421   return success(expectedOptionalOperand != optOperand.hasValue());
422 }
423 
424 //===----------------------------------------------------------------------===//
425 // Printing
426 
427 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
428                                          Value operand, Value optOperand,
429                                          OperandRange varOperands) {
430   printer << operand;
431   if (optOperand)
432     printer << ", " << optOperand;
433   printer << " -> (" << varOperands << ")";
434 }
435 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
436                                         Type operandType, Type optOperandType,
437                                         TypeRange varOperandTypes) {
438   printer << " : " << operandType;
439   if (optOperandType)
440     printer << ", " << optOperandType;
441   printer << " -> (" << varOperandTypes << ")";
442 }
443 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
444                                              Operation *op, Type operandType,
445                                              Type optOperandType,
446                                              TypeRange varOperandTypes) {
447   printer << " type_refs_capture ";
448   printCustomDirectiveResults(printer, op, operandType, optOperandType,
449                               varOperandTypes);
450 }
451 static void printCustomDirectiveOperandsAndTypes(
452     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
453     OperandRange varOperands, Type operandType, Type optOperandType,
454     TypeRange varOperandTypes) {
455   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
456   printCustomDirectiveResults(printer, op, operandType, optOperandType,
457                               varOperandTypes);
458 }
459 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
460                                         Region &region,
461                                         MutableArrayRef<Region> varRegions) {
462   printer.printRegion(region);
463   if (!varRegions.empty()) {
464     printer << ", ";
465     for (Region &region : varRegions)
466       printer.printRegion(region);
467   }
468 }
469 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
470                                            Block *successor,
471                                            SuccessorRange varSuccessors) {
472   printer << successor;
473   if (!varSuccessors.empty())
474     printer << ", " << varSuccessors.front();
475 }
476 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
477                                            Attribute attribute,
478                                            Attribute optAttribute) {
479   printer << attribute;
480   if (optAttribute)
481     printer << ", " << optAttribute;
482 }
483 
484 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
485                                          DictionaryAttr attrs) {
486   printer.printOptionalAttrDict(attrs.getValue());
487 }
488 
489 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
490                                                    Operation *op,
491                                                    Value optOperand) {
492   printer << (optOperand ? "1" : "0");
493 }
494 
495 //===----------------------------------------------------------------------===//
496 // Test IsolatedRegionOp - parse passthrough region arguments.
497 //===----------------------------------------------------------------------===//
498 
499 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
500                                          OperationState &result) {
501   OpAsmParser::OperandType argInfo;
502   Type argType = parser.getBuilder().getIndexType();
503 
504   // Parse the input operand.
505   if (parser.parseOperand(argInfo) ||
506       parser.resolveOperand(argInfo, argType, result.operands))
507     return failure();
508 
509   // Parse the body region, and reuse the operand info as the argument info.
510   Region *body = result.addRegion();
511   return parser.parseRegion(*body, argInfo, argType,
512                             /*enableNameShadowing=*/true);
513 }
514 
515 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
516   p << "test.isolated_region ";
517   p.printOperand(op.getOperand());
518   p.shadowRegionArgs(op.region(), op.getOperand());
519   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
520 }
521 
522 //===----------------------------------------------------------------------===//
523 // Test SSACFGRegionOp
524 //===----------------------------------------------------------------------===//
525 
526 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
527   return RegionKind::SSACFG;
528 }
529 
530 //===----------------------------------------------------------------------===//
531 // Test GraphRegionOp
532 //===----------------------------------------------------------------------===//
533 
534 static ParseResult parseGraphRegionOp(OpAsmParser &parser,
535                                       OperationState &result) {
536   // Parse the body region, and reuse the operand info as the argument info.
537   Region *body = result.addRegion();
538   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
539 }
540 
541 static void print(OpAsmPrinter &p, GraphRegionOp op) {
542   p << "test.graph_region ";
543   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
544 }
545 
546 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
547   return RegionKind::Graph;
548 }
549 
550 //===----------------------------------------------------------------------===//
551 // Test AffineScopeOp
552 //===----------------------------------------------------------------------===//
553 
554 static ParseResult parseAffineScopeOp(OpAsmParser &parser,
555                                       OperationState &result) {
556   // Parse the body region, and reuse the operand info as the argument info.
557   Region *body = result.addRegion();
558   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
559 }
560 
561 static void print(OpAsmPrinter &p, AffineScopeOp op) {
562   p << "test.affine_scope ";
563   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
564 }
565 
566 //===----------------------------------------------------------------------===//
567 // Test parser.
568 //===----------------------------------------------------------------------===//
569 
570 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser,
571                                               OperationState &result) {
572   if (parser.parseOptionalColon())
573     return success();
574   uint64_t numResults;
575   if (parser.parseInteger(numResults))
576     return failure();
577 
578   IndexType type = parser.getBuilder().getIndexType();
579   for (unsigned i = 0; i < numResults; ++i)
580     result.addTypes(type);
581   return success();
582 }
583 
584 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) {
585   p << ParseIntegerLiteralOp::getOperationName();
586   if (unsigned numResults = op->getNumResults())
587     p << " : " << numResults;
588 }
589 
590 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser,
591                                               OperationState &result) {
592   StringRef keyword;
593   if (parser.parseKeyword(&keyword))
594     return failure();
595   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
596   return success();
597 }
598 
599 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) {
600   p << ParseWrappedKeywordOp::getOperationName() << " " << op.keyword();
601 }
602 
603 //===----------------------------------------------------------------------===//
604 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
605 
606 static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
607                                          OperationState &result) {
608   if (parser.parseKeyword("wraps"))
609     return failure();
610 
611   // Parse the wrapped op in a region
612   Region &body = *result.addRegion();
613   body.push_back(new Block);
614   Block &block = body.back();
615   Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
616   if (!wrapped_op)
617     return failure();
618 
619   // Create a return terminator in the inner region, pass as operand to the
620   // terminator the returned values from the wrapped operation.
621   SmallVector<Value, 8> return_operands(wrapped_op->getResults());
622   OpBuilder builder(parser.getBuilder().getContext());
623   builder.setInsertionPointToEnd(&block);
624   builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
625 
626   // Get the results type for the wrapping op from the terminator operands.
627   Operation &return_op = body.back().back();
628   result.types.append(return_op.operand_type_begin(),
629                       return_op.operand_type_end());
630 
631   // Use the location of the wrapped op for the "test.wrapping_region" op.
632   result.location = wrapped_op->getLoc();
633 
634   return success();
635 }
636 
637 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
638   p << op.getOperationName() << " wraps ";
639   p.printGenericOp(&op.region().front().front());
640 }
641 
642 //===----------------------------------------------------------------------===//
643 // Test PolyForOp - parse list of region arguments.
644 //===----------------------------------------------------------------------===//
645 
646 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
647   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
648   // Parse list of region arguments without a delimiter.
649   if (parser.parseRegionArgumentList(ivsInfo))
650     return failure();
651 
652   // Parse the body region.
653   Region *body = result.addRegion();
654   auto &builder = parser.getBuilder();
655   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
656   return parser.parseRegion(*body, ivsInfo, argTypes);
657 }
658 
659 //===----------------------------------------------------------------------===//
660 // Test removing op with inner ops.
661 //===----------------------------------------------------------------------===//
662 
663 namespace {
664 struct TestRemoveOpWithInnerOps
665     : public OpRewritePattern<TestOpWithRegionPattern> {
666   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
667 
668   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
669                                 PatternRewriter &rewriter) const override {
670     rewriter.eraseOp(op);
671     return success();
672   }
673 };
674 } // end anonymous namespace
675 
676 void TestOpWithRegionPattern::getCanonicalizationPatterns(
677     RewritePatternSet &results, MLIRContext *context) {
678   results.add<TestRemoveOpWithInnerOps>(context);
679 }
680 
681 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
682   return operand();
683 }
684 
685 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
686   return getValue();
687 }
688 
689 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
690     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
691   for (Value input : this->operands()) {
692     results.push_back(input);
693   }
694   return success();
695 }
696 
697 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
698   assert(operands.size() == 1);
699   if (operands.front()) {
700     (*this)->setAttr("attr", operands.front());
701     return getResult();
702   }
703   return {};
704 }
705 
706 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
707   return getOperand();
708 }
709 
710 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
711     MLIRContext *, Optional<Location> location, ValueRange operands,
712     DictionaryAttr attributes, RegionRange regions,
713     SmallVectorImpl<Type> &inferredReturnTypes) {
714   if (operands[0].getType() != operands[1].getType()) {
715     return emitOptionalError(location, "operand type mismatch ",
716                              operands[0].getType(), " vs ",
717                              operands[1].getType());
718   }
719   inferredReturnTypes.assign({operands[0].getType()});
720   return success();
721 }
722 
723 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
724     MLIRContext *context, Optional<Location> location, ValueRange operands,
725     DictionaryAttr attributes, RegionRange regions,
726     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
727   // Create return type consisting of the last element of the first operand.
728   auto operandType = *operands.getTypes().begin();
729   auto sval = operandType.dyn_cast<ShapedType>();
730   if (!sval) {
731     return emitOptionalError(location, "only shaped type operands allowed");
732   }
733   int64_t dim =
734       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
735   auto type = IntegerType::get(context, 17);
736   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
737   return success();
738 }
739 
740 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
741     OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) {
742   shapes = SmallVector<Value, 1>{
743       builder.createOrFold<memref::DimOp>(getLoc(), getOperand(0), 0)};
744   return success();
745 }
746 
747 //===----------------------------------------------------------------------===//
748 // Test SideEffect interfaces
749 //===----------------------------------------------------------------------===//
750 
751 namespace {
752 /// A test resource for side effects.
753 struct TestResource : public SideEffects::Resource::Base<TestResource> {
754   StringRef getName() final { return "<Test>"; }
755 };
756 } // end anonymous namespace
757 
758 static void testSideEffectOpGetEffect(
759     Operation *op,
760     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
761         &effects) {
762   auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
763   if (!effectsAttr)
764     return;
765 
766   effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
767 }
768 
769 void SideEffectOp::getEffects(
770     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
771   // Check for an effects attribute on the op instance.
772   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
773   if (!effectsAttr)
774     return;
775 
776   // If there is one, it is an array of dictionary attributes that hold
777   // information on the effects of this operation.
778   for (Attribute element : effectsAttr) {
779     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
780 
781     // Get the specific memory effect.
782     MemoryEffects::Effect *effect =
783         StringSwitch<MemoryEffects::Effect *>(
784             effectElement.get("effect").cast<StringAttr>().getValue())
785             .Case("allocate", MemoryEffects::Allocate::get())
786             .Case("free", MemoryEffects::Free::get())
787             .Case("read", MemoryEffects::Read::get())
788             .Case("write", MemoryEffects::Write::get());
789 
790     // Check for a non-default resource to use.
791     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
792     if (effectElement.get("test_resource"))
793       resource = TestResource::get();
794 
795     // Check for a result to affect.
796     if (effectElement.get("on_result"))
797       effects.emplace_back(effect, getResult(), resource);
798     else if (Attribute ref = effectElement.get("on_reference"))
799       effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
800     else
801       effects.emplace_back(effect, resource);
802   }
803 }
804 
805 void SideEffectOp::getEffects(
806     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
807   testSideEffectOpGetEffect(getOperation(), effects);
808 }
809 
810 //===----------------------------------------------------------------------===//
811 // StringAttrPrettyNameOp
812 //===----------------------------------------------------------------------===//
813 
814 // This op has fancy handling of its SSA result name.
815 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
816                                                OperationState &result) {
817   // Add the result types.
818   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
819     result.addTypes(parser.getBuilder().getIntegerType(32));
820 
821   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
822     return failure();
823 
824   // If the attribute dictionary contains no 'names' attribute, infer it from
825   // the SSA name (if specified).
826   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
827     return attr.first == "names";
828   });
829 
830   // If there was no name specified, check to see if there was a useful name
831   // specified in the asm file.
832   if (hadNames || parser.getNumResults() == 0)
833     return success();
834 
835   SmallVector<StringRef, 4> names;
836   auto *context = result.getContext();
837 
838   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
839     auto resultName = parser.getResultName(i);
840     StringRef nameStr;
841     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
842       nameStr = resultName.first;
843 
844     names.push_back(nameStr);
845   }
846 
847   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
848   result.attributes.push_back({Identifier::get("names", context), namesAttr});
849   return success();
850 }
851 
852 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
853   p << "test.string_attr_pretty_name";
854 
855   // Note that we only need to print the "name" attribute if the asmprinter
856   // result name disagrees with it.  This can happen in strange cases, e.g.
857   // when there are conflicts.
858   bool namesDisagree = op.names().size() != op.getNumResults();
859 
860   SmallString<32> resultNameStr;
861   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
862     resultNameStr.clear();
863     llvm::raw_svector_ostream tmpStream(resultNameStr);
864     p.printOperand(op.getResult(i), tmpStream);
865 
866     auto expectedName = op.names()[i].dyn_cast<StringAttr>();
867     if (!expectedName ||
868         tmpStream.str().drop_front() != expectedName.getValue()) {
869       namesDisagree = true;
870     }
871   }
872 
873   if (namesDisagree)
874     p.printOptionalAttrDictWithKeyword(op->getAttrs());
875   else
876     p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"});
877 }
878 
879 // We set the SSA name in the asm syntax to the contents of the name
880 // attribute.
881 void StringAttrPrettyNameOp::getAsmResultNames(
882     function_ref<void(Value, StringRef)> setNameFn) {
883 
884   auto value = names();
885   for (size_t i = 0, e = value.size(); i != e; ++i)
886     if (auto str = value[i].dyn_cast<StringAttr>())
887       if (!str.getValue().empty())
888         setNameFn(getResult(i), str.getValue());
889 }
890 
891 //===----------------------------------------------------------------------===//
892 // RegionIfOp
893 //===----------------------------------------------------------------------===//
894 
895 static void print(OpAsmPrinter &p, RegionIfOp op) {
896   p << RegionIfOp::getOperationName() << " ";
897   p.printOperands(op.getOperands());
898   p << ": " << op.getOperandTypes();
899   p.printArrowTypeList(op.getResultTypes());
900   p << " then";
901   p.printRegion(op.thenRegion(),
902                 /*printEntryBlockArgs=*/true,
903                 /*printBlockTerminators=*/true);
904   p << " else";
905   p.printRegion(op.elseRegion(),
906                 /*printEntryBlockArgs=*/true,
907                 /*printBlockTerminators=*/true);
908   p << " join";
909   p.printRegion(op.joinRegion(),
910                 /*printEntryBlockArgs=*/true,
911                 /*printBlockTerminators=*/true);
912 }
913 
914 static ParseResult parseRegionIfOp(OpAsmParser &parser,
915                                    OperationState &result) {
916   SmallVector<OpAsmParser::OperandType, 2> operandInfos;
917   SmallVector<Type, 2> operandTypes;
918 
919   result.regions.reserve(3);
920   Region *thenRegion = result.addRegion();
921   Region *elseRegion = result.addRegion();
922   Region *joinRegion = result.addRegion();
923 
924   // Parse operand, type and arrow type lists.
925   if (parser.parseOperandList(operandInfos) ||
926       parser.parseColonTypeList(operandTypes) ||
927       parser.parseArrowTypeList(result.types))
928     return failure();
929 
930   // Parse all attached regions.
931   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
932       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
933       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
934     return failure();
935 
936   return parser.resolveOperands(operandInfos, operandTypes,
937                                 parser.getCurrentLocation(), result.operands);
938 }
939 
940 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
941   assert(index < 2 && "invalid region index");
942   return getOperands();
943 }
944 
945 void RegionIfOp::getSuccessorRegions(
946     Optional<unsigned> index, ArrayRef<Attribute> operands,
947     SmallVectorImpl<RegionSuccessor> &regions) {
948   // We always branch to the join region.
949   if (index.hasValue()) {
950     if (index.getValue() < 2)
951       regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs()));
952     else
953       regions.push_back(RegionSuccessor(getResults()));
954     return;
955   }
956 
957   // The then and else regions are the entry regions of this op.
958   regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs()));
959   regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs()));
960 }
961 
962 #include "TestOpEnums.cpp.inc"
963 #include "TestOpInterfaces.cpp.inc"
964 #include "TestOpStructs.cpp.inc"
965 #include "TestTypeInterfaces.cpp.inc"
966 
967 #define GET_OP_CLASSES
968 #include "TestOps.cpp.inc"
969