1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestAttributes.h"
11 #include "TestTypes.h"
12 #include "mlir/Dialect/DLTI/DLTI.h"
13 #include "mlir/Dialect/MemRef/IR/MemRef.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/DialectImplementation.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/Transforms/FoldUtils.h"
20 #include "mlir/Transforms/InliningUtils.h"
21 #include "llvm/ADT/StringSwitch.h"
22 
23 using namespace mlir;
24 using namespace mlir::test;
25 
26 void mlir::test::registerTestDialect(DialectRegistry &registry) {
27   registry.insert<TestDialect>();
28 }
29 
30 //===----------------------------------------------------------------------===//
31 // TestDialect Interfaces
32 //===----------------------------------------------------------------------===//
33 
34 namespace {
35 
36 /// Testing the correctness of some traits.
37 static_assert(
38     llvm::is_detected<OpTrait::has_implicit_terminator_t,
39                       SingleBlockImplicitTerminatorOp>::value,
40     "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
41 static_assert(OpTrait::hasSingleBlockImplicitTerminator<
42                   SingleBlockImplicitTerminatorOp>::value,
43               "hasSingleBlockImplicitTerminator does not match "
44               "SingleBlockImplicitTerminatorOp");
45 
46 // Test support for interacting with the AsmPrinter.
47 struct TestOpAsmInterface : public OpAsmDialectInterface {
48   using OpAsmDialectInterface::OpAsmDialectInterface;
49 
50   LogicalResult getAlias(Attribute attr, raw_ostream &os) const final {
51     StringAttr strAttr = attr.dyn_cast<StringAttr>();
52     if (!strAttr)
53       return failure();
54 
55     // Check the contents of the string attribute to see what the test alias
56     // should be named.
57     Optional<StringRef> aliasName =
58         StringSwitch<Optional<StringRef>>(strAttr.getValue())
59             .Case("alias_test:dot_in_name", StringRef("test.alias"))
60             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
61             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
62             .Case("alias_test:sanitize_conflict_a",
63                   StringRef("test_alias_conflict0"))
64             .Case("alias_test:sanitize_conflict_b",
65                   StringRef("test_alias_conflict0_"))
66             .Default(llvm::None);
67     if (!aliasName)
68       return failure();
69 
70     os << *aliasName;
71     return success();
72   }
73 
74   void getAsmResultNames(Operation *op,
75                          OpAsmSetValueNameFn setNameFn) const final {
76     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
77       setNameFn(asmOp, "result");
78   }
79 
80   void getAsmBlockArgumentNames(Block *block,
81                                 OpAsmSetValueNameFn setNameFn) const final {
82     auto op = block->getParentOp();
83     auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
84     if (!arrayAttr)
85       return;
86     auto args = block->getArguments();
87     auto e = std::min(arrayAttr.size(), args.size());
88     for (unsigned i = 0; i < e; ++i) {
89       if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
90         setNameFn(args[i], strAttr.getValue());
91     }
92   }
93 };
94 
95 struct TestDialectFoldInterface : public DialectFoldInterface {
96   using DialectFoldInterface::DialectFoldInterface;
97 
98   /// Registered hook to check if the given region, which is attached to an
99   /// operation that is *not* isolated from above, should be used when
100   /// materializing constants.
101   bool shouldMaterializeInto(Region *region) const final {
102     // If this is a one region operation, then insert into it.
103     return isa<OneRegionOp>(region->getParentOp());
104   }
105 };
106 
107 /// This class defines the interface for handling inlining with standard
108 /// operations.
109 struct TestInlinerInterface : public DialectInlinerInterface {
110   using DialectInlinerInterface::DialectInlinerInterface;
111 
112   //===--------------------------------------------------------------------===//
113   // Analysis Hooks
114   //===--------------------------------------------------------------------===//
115 
116   bool isLegalToInline(Operation *call, Operation *callable,
117                        bool wouldBeCloned) const final {
118     // Don't allow inlining calls that are marked `noinline`.
119     return !call->hasAttr("noinline");
120   }
121   bool isLegalToInline(Region *, Region *, bool,
122                        BlockAndValueMapping &) const final {
123     // Inlining into test dialect regions is legal.
124     return true;
125   }
126   bool isLegalToInline(Operation *, Region *, bool,
127                        BlockAndValueMapping &) const final {
128     return true;
129   }
130 
131   bool shouldAnalyzeRecursively(Operation *op) const final {
132     // Analyze recursively if this is not a functional region operation, it
133     // froms a separate functional scope.
134     return !isa<FunctionalRegionOp>(op);
135   }
136 
137   //===--------------------------------------------------------------------===//
138   // Transformation Hooks
139   //===--------------------------------------------------------------------===//
140 
141   /// Handle the given inlined terminator by replacing it with a new operation
142   /// as necessary.
143   void handleTerminator(Operation *op,
144                         ArrayRef<Value> valuesToRepl) const final {
145     // Only handle "test.return" here.
146     auto returnOp = dyn_cast<TestReturnOp>(op);
147     if (!returnOp)
148       return;
149 
150     // Replace the values directly with the return operands.
151     assert(returnOp.getNumOperands() == valuesToRepl.size());
152     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
153       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
154   }
155 
156   /// Attempt to materialize a conversion for a type mismatch between a call
157   /// from this dialect, and a callable region. This method should generate an
158   /// operation that takes 'input' as the only operand, and produces a single
159   /// result of 'resultType'. If a conversion can not be generated, nullptr
160   /// should be returned.
161   Operation *materializeCallConversion(OpBuilder &builder, Value input,
162                                        Type resultType,
163                                        Location conversionLoc) const final {
164     // Only allow conversion for i16/i32 types.
165     if (!(resultType.isSignlessInteger(16) ||
166           resultType.isSignlessInteger(32)) ||
167         !(input.getType().isSignlessInteger(16) ||
168           input.getType().isSignlessInteger(32)))
169       return nullptr;
170     return builder.create<TestCastOp>(conversionLoc, resultType, input);
171   }
172 };
173 } // end anonymous namespace
174 
175 //===----------------------------------------------------------------------===//
176 // TestDialect
177 //===----------------------------------------------------------------------===//
178 
179 static void testSideEffectOpGetEffect(
180     Operation *op,
181     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects);
182 
183 // This is the implementation of a dialect fallback for `TestEffectOpInterface`.
184 struct TestOpEffectInterfaceFallback
185     : public TestEffectOpInterface::FallbackModel<
186           TestOpEffectInterfaceFallback> {
187   static bool classof(Operation *op) {
188     bool isSupportedOp =
189         op->getName().getStringRef() == "test.unregistered_side_effect_op";
190     assert(isSupportedOp && "Unexpected dispatch");
191     return isSupportedOp;
192   }
193 
194   void
195   getEffects(Operation *op,
196              SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
197                  &effects) const {
198     testSideEffectOpGetEffect(op, effects);
199   }
200 };
201 
202 void TestDialect::initialize() {
203   registerAttributes();
204   registerTypes();
205   addOperations<
206 #define GET_OP_LIST
207 #include "TestOps.cpp.inc"
208       >();
209   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
210                 TestInlinerInterface>();
211   allowUnknownOperations();
212 
213   // Instantiate our fallback op interface that we'll use on specific
214   // unregistered op.
215   fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
216 }
217 TestDialect::~TestDialect() {
218   delete static_cast<TestOpEffectInterfaceFallback *>(
219       fallbackEffectOpInterfaces);
220 }
221 
222 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
223                                             Type type, Location loc) {
224   return builder.create<TestOpConstant>(loc, type, value);
225 }
226 
227 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
228                                                OperationName opName) {
229   if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
230       typeID == TypeID::get<TestEffectOpInterface>())
231     return fallbackEffectOpInterfaces;
232   return nullptr;
233 }
234 
235 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
236                                                     NamedAttribute namedAttr) {
237   if (namedAttr.first == "test.invalid_attr")
238     return op->emitError() << "invalid to use 'test.invalid_attr'";
239   return success();
240 }
241 
242 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
243                                                     unsigned regionIndex,
244                                                     unsigned argIndex,
245                                                     NamedAttribute namedAttr) {
246   if (namedAttr.first == "test.invalid_attr")
247     return op->emitError() << "invalid to use 'test.invalid_attr'";
248   return success();
249 }
250 
251 LogicalResult
252 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
253                                          unsigned resultIndex,
254                                          NamedAttribute namedAttr) {
255   if (namedAttr.first == "test.invalid_attr")
256     return op->emitError() << "invalid to use 'test.invalid_attr'";
257   return success();
258 }
259 
260 Optional<Dialect::ParseOpHook>
261 TestDialect::getParseOperationHook(StringRef opName) const {
262   if (opName == "test.dialect_custom_printer") {
263     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
264       return parser.parseKeyword("custom_format");
265     }};
266   }
267   return None;
268 }
269 
270 LogicalResult TestDialect::printOperation(Operation *op,
271                                           OpAsmPrinter &printer) const {
272   StringRef opName = op->getName().getStringRef();
273   if (opName == "test.dialect_custom_printer") {
274     printer.getStream() << opName << " custom_format";
275     return success();
276   }
277   return failure();
278 }
279 
280 //===----------------------------------------------------------------------===//
281 // TestBranchOp
282 //===----------------------------------------------------------------------===//
283 
284 Optional<MutableOperandRange>
285 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
286   assert(index == 0 && "invalid successor index");
287   return targetOperandsMutable();
288 }
289 
290 //===----------------------------------------------------------------------===//
291 // TestDialectCanonicalizerOp
292 //===----------------------------------------------------------------------===//
293 
294 static LogicalResult
295 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
296                                PatternRewriter &rewriter) {
297   rewriter.replaceOpWithNewOp<ConstantOp>(op, rewriter.getI32Type(),
298                                           rewriter.getI32IntegerAttr(42));
299   return success();
300 }
301 
302 void TestDialect::getCanonicalizationPatterns(
303     RewritePatternSet &results) const {
304   results.add(&dialectCanonicalizationPattern);
305 }
306 
307 //===----------------------------------------------------------------------===//
308 // TestFoldToCallOp
309 //===----------------------------------------------------------------------===//
310 
311 namespace {
312 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
313   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
314 
315   LogicalResult matchAndRewrite(FoldToCallOp op,
316                                 PatternRewriter &rewriter) const override {
317     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(),
318                                         ValueRange());
319     return success();
320   }
321 };
322 } // end anonymous namespace
323 
324 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
325                                                MLIRContext *context) {
326   results.add<FoldToCallOpPattern>(context);
327 }
328 
329 //===----------------------------------------------------------------------===//
330 // Test Format* operations
331 //===----------------------------------------------------------------------===//
332 
333 //===----------------------------------------------------------------------===//
334 // Parsing
335 
336 static ParseResult parseCustomDirectiveOperands(
337     OpAsmParser &parser, OpAsmParser::OperandType &operand,
338     Optional<OpAsmParser::OperandType> &optOperand,
339     SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
340   if (parser.parseOperand(operand))
341     return failure();
342   if (succeeded(parser.parseOptionalComma())) {
343     optOperand.emplace();
344     if (parser.parseOperand(*optOperand))
345       return failure();
346   }
347   if (parser.parseArrow() || parser.parseLParen() ||
348       parser.parseOperandList(varOperands) || parser.parseRParen())
349     return failure();
350   return success();
351 }
352 static ParseResult
353 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
354                             Type &optOperandType,
355                             SmallVectorImpl<Type> &varOperandTypes) {
356   if (parser.parseColon())
357     return failure();
358 
359   if (parser.parseType(operandType))
360     return failure();
361   if (succeeded(parser.parseOptionalComma())) {
362     if (parser.parseType(optOperandType))
363       return failure();
364   }
365   if (parser.parseArrow() || parser.parseLParen() ||
366       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
367     return failure();
368   return success();
369 }
370 static ParseResult
371 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
372                                  Type optOperandType,
373                                  const SmallVectorImpl<Type> &varOperandTypes) {
374   if (parser.parseKeyword("type_refs_capture"))
375     return failure();
376 
377   Type operandType2, optOperandType2;
378   SmallVector<Type, 1> varOperandTypes2;
379   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
380                                   varOperandTypes2))
381     return failure();
382 
383   if (operandType != operandType2 || optOperandType != optOperandType2 ||
384       varOperandTypes != varOperandTypes2)
385     return failure();
386 
387   return success();
388 }
389 static ParseResult parseCustomDirectiveOperandsAndTypes(
390     OpAsmParser &parser, OpAsmParser::OperandType &operand,
391     Optional<OpAsmParser::OperandType> &optOperand,
392     SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
393     Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
394   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
395       parseCustomDirectiveResults(parser, operandType, optOperandType,
396                                   varOperandTypes))
397     return failure();
398   return success();
399 }
400 static ParseResult parseCustomDirectiveRegions(
401     OpAsmParser &parser, Region &region,
402     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
403   if (parser.parseRegion(region))
404     return failure();
405   if (failed(parser.parseOptionalComma()))
406     return success();
407   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
408   if (parser.parseRegion(*varRegion))
409     return failure();
410   varRegions.emplace_back(std::move(varRegion));
411   return success();
412 }
413 static ParseResult
414 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
415                                SmallVectorImpl<Block *> &varSuccessors) {
416   if (parser.parseSuccessor(successor))
417     return failure();
418   if (failed(parser.parseOptionalComma()))
419     return success();
420   Block *varSuccessor;
421   if (parser.parseSuccessor(varSuccessor))
422     return failure();
423   varSuccessors.append(2, varSuccessor);
424   return success();
425 }
426 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
427                                                   IntegerAttr &attr,
428                                                   IntegerAttr &optAttr) {
429   if (parser.parseAttribute(attr))
430     return failure();
431   if (succeeded(parser.parseOptionalComma())) {
432     if (parser.parseAttribute(optAttr))
433       return failure();
434   }
435   return success();
436 }
437 
438 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
439                                                 NamedAttrList &attrs) {
440   return parser.parseOptionalAttrDict(attrs);
441 }
442 static ParseResult parseCustomDirectiveOptionalOperandRef(
443     OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) {
444   int64_t operandCount = 0;
445   if (parser.parseInteger(operandCount))
446     return failure();
447   bool expectedOptionalOperand = operandCount == 0;
448   return success(expectedOptionalOperand != optOperand.hasValue());
449 }
450 
451 //===----------------------------------------------------------------------===//
452 // Printing
453 
454 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
455                                          Value operand, Value optOperand,
456                                          OperandRange varOperands) {
457   printer << operand;
458   if (optOperand)
459     printer << ", " << optOperand;
460   printer << " -> (" << varOperands << ")";
461 }
462 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
463                                         Type operandType, Type optOperandType,
464                                         TypeRange varOperandTypes) {
465   printer << " : " << operandType;
466   if (optOperandType)
467     printer << ", " << optOperandType;
468   printer << " -> (" << varOperandTypes << ")";
469 }
470 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
471                                              Operation *op, Type operandType,
472                                              Type optOperandType,
473                                              TypeRange varOperandTypes) {
474   printer << " type_refs_capture ";
475   printCustomDirectiveResults(printer, op, operandType, optOperandType,
476                               varOperandTypes);
477 }
478 static void printCustomDirectiveOperandsAndTypes(
479     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
480     OperandRange varOperands, Type operandType, Type optOperandType,
481     TypeRange varOperandTypes) {
482   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
483   printCustomDirectiveResults(printer, op, operandType, optOperandType,
484                               varOperandTypes);
485 }
486 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
487                                         Region &region,
488                                         MutableArrayRef<Region> varRegions) {
489   printer.printRegion(region);
490   if (!varRegions.empty()) {
491     printer << ", ";
492     for (Region &region : varRegions)
493       printer.printRegion(region);
494   }
495 }
496 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
497                                            Block *successor,
498                                            SuccessorRange varSuccessors) {
499   printer << successor;
500   if (!varSuccessors.empty())
501     printer << ", " << varSuccessors.front();
502 }
503 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
504                                            Attribute attribute,
505                                            Attribute optAttribute) {
506   printer << attribute;
507   if (optAttribute)
508     printer << ", " << optAttribute;
509 }
510 
511 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
512                                          DictionaryAttr attrs) {
513   printer.printOptionalAttrDict(attrs.getValue());
514 }
515 
516 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
517                                                    Operation *op,
518                                                    Value optOperand) {
519   printer << (optOperand ? "1" : "0");
520 }
521 
522 //===----------------------------------------------------------------------===//
523 // Test IsolatedRegionOp - parse passthrough region arguments.
524 //===----------------------------------------------------------------------===//
525 
526 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
527                                          OperationState &result) {
528   OpAsmParser::OperandType argInfo;
529   Type argType = parser.getBuilder().getIndexType();
530 
531   // Parse the input operand.
532   if (parser.parseOperand(argInfo) ||
533       parser.resolveOperand(argInfo, argType, result.operands))
534     return failure();
535 
536   // Parse the body region, and reuse the operand info as the argument info.
537   Region *body = result.addRegion();
538   return parser.parseRegion(*body, argInfo, argType,
539                             /*enableNameShadowing=*/true);
540 }
541 
542 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
543   p << "test.isolated_region ";
544   p.printOperand(op.getOperand());
545   p.shadowRegionArgs(op.region(), op.getOperand());
546   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
547 }
548 
549 //===----------------------------------------------------------------------===//
550 // Test SSACFGRegionOp
551 //===----------------------------------------------------------------------===//
552 
553 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
554   return RegionKind::SSACFG;
555 }
556 
557 //===----------------------------------------------------------------------===//
558 // Test GraphRegionOp
559 //===----------------------------------------------------------------------===//
560 
561 static ParseResult parseGraphRegionOp(OpAsmParser &parser,
562                                       OperationState &result) {
563   // Parse the body region, and reuse the operand info as the argument info.
564   Region *body = result.addRegion();
565   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
566 }
567 
568 static void print(OpAsmPrinter &p, GraphRegionOp op) {
569   p << "test.graph_region ";
570   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
571 }
572 
573 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
574   return RegionKind::Graph;
575 }
576 
577 //===----------------------------------------------------------------------===//
578 // Test AffineScopeOp
579 //===----------------------------------------------------------------------===//
580 
581 static ParseResult parseAffineScopeOp(OpAsmParser &parser,
582                                       OperationState &result) {
583   // Parse the body region, and reuse the operand info as the argument info.
584   Region *body = result.addRegion();
585   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
586 }
587 
588 static void print(OpAsmPrinter &p, AffineScopeOp op) {
589   p << "test.affine_scope ";
590   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
591 }
592 
593 //===----------------------------------------------------------------------===//
594 // Test parser.
595 //===----------------------------------------------------------------------===//
596 
597 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser,
598                                               OperationState &result) {
599   if (parser.parseOptionalColon())
600     return success();
601   uint64_t numResults;
602   if (parser.parseInteger(numResults))
603     return failure();
604 
605   IndexType type = parser.getBuilder().getIndexType();
606   for (unsigned i = 0; i < numResults; ++i)
607     result.addTypes(type);
608   return success();
609 }
610 
611 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) {
612   p << ParseIntegerLiteralOp::getOperationName();
613   if (unsigned numResults = op->getNumResults())
614     p << " : " << numResults;
615 }
616 
617 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser,
618                                               OperationState &result) {
619   StringRef keyword;
620   if (parser.parseKeyword(&keyword))
621     return failure();
622   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
623   return success();
624 }
625 
626 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) {
627   p << ParseWrappedKeywordOp::getOperationName() << " " << op.keyword();
628 }
629 
630 //===----------------------------------------------------------------------===//
631 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
632 
633 static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
634                                          OperationState &result) {
635   if (parser.parseKeyword("wraps"))
636     return failure();
637 
638   // Parse the wrapped op in a region
639   Region &body = *result.addRegion();
640   body.push_back(new Block);
641   Block &block = body.back();
642   Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
643   if (!wrapped_op)
644     return failure();
645 
646   // Create a return terminator in the inner region, pass as operand to the
647   // terminator the returned values from the wrapped operation.
648   SmallVector<Value, 8> return_operands(wrapped_op->getResults());
649   OpBuilder builder(parser.getBuilder().getContext());
650   builder.setInsertionPointToEnd(&block);
651   builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
652 
653   // Get the results type for the wrapping op from the terminator operands.
654   Operation &return_op = body.back().back();
655   result.types.append(return_op.operand_type_begin(),
656                       return_op.operand_type_end());
657 
658   // Use the location of the wrapped op for the "test.wrapping_region" op.
659   result.location = wrapped_op->getLoc();
660 
661   return success();
662 }
663 
664 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
665   p << op.getOperationName() << " wraps ";
666   p.printGenericOp(&op.region().front().front());
667 }
668 
669 //===----------------------------------------------------------------------===//
670 // Test PolyForOp - parse list of region arguments.
671 //===----------------------------------------------------------------------===//
672 
673 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
674   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
675   // Parse list of region arguments without a delimiter.
676   if (parser.parseRegionArgumentList(ivsInfo))
677     return failure();
678 
679   // Parse the body region.
680   Region *body = result.addRegion();
681   auto &builder = parser.getBuilder();
682   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
683   return parser.parseRegion(*body, ivsInfo, argTypes);
684 }
685 
686 //===----------------------------------------------------------------------===//
687 // Test removing op with inner ops.
688 //===----------------------------------------------------------------------===//
689 
690 namespace {
691 struct TestRemoveOpWithInnerOps
692     : public OpRewritePattern<TestOpWithRegionPattern> {
693   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
694 
695   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
696                                 PatternRewriter &rewriter) const override {
697     rewriter.eraseOp(op);
698     return success();
699   }
700 };
701 } // end anonymous namespace
702 
703 void TestOpWithRegionPattern::getCanonicalizationPatterns(
704     RewritePatternSet &results, MLIRContext *context) {
705   results.add<TestRemoveOpWithInnerOps>(context);
706 }
707 
708 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
709   return operand();
710 }
711 
712 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
713   return getValue();
714 }
715 
716 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
717     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
718   for (Value input : this->operands()) {
719     results.push_back(input);
720   }
721   return success();
722 }
723 
724 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
725   assert(operands.size() == 1);
726   if (operands.front()) {
727     (*this)->setAttr("attr", operands.front());
728     return getResult();
729   }
730   return {};
731 }
732 
733 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
734   return getOperand();
735 }
736 
737 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
738     MLIRContext *, Optional<Location> location, ValueRange operands,
739     DictionaryAttr attributes, RegionRange regions,
740     SmallVectorImpl<Type> &inferredReturnTypes) {
741   if (operands[0].getType() != operands[1].getType()) {
742     return emitOptionalError(location, "operand type mismatch ",
743                              operands[0].getType(), " vs ",
744                              operands[1].getType());
745   }
746   inferredReturnTypes.assign({operands[0].getType()});
747   return success();
748 }
749 
750 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
751     MLIRContext *context, Optional<Location> location, ValueRange operands,
752     DictionaryAttr attributes, RegionRange regions,
753     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
754   // Create return type consisting of the last element of the first operand.
755   auto operandType = *operands.getTypes().begin();
756   auto sval = operandType.dyn_cast<ShapedType>();
757   if (!sval) {
758     return emitOptionalError(location, "only shaped type operands allowed");
759   }
760   int64_t dim =
761       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
762   auto type = IntegerType::get(context, 17);
763   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
764   return success();
765 }
766 
767 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
768     OpBuilder &builder, ValueRange operands,
769     llvm::SmallVectorImpl<Value> &shapes) {
770   shapes = SmallVector<Value, 1>{
771       builder.createOrFold<memref::DimOp>(getLoc(), operands.front(), 0)};
772   return success();
773 }
774 
775 LogicalResult
776 OpWithResultShapePerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim(
777     OpBuilder &builder,
778     llvm::SmallVectorImpl<llvm::SmallVector<Value>> &shapes) {
779   SmallVector<Value> operand1Shape, operand2Shape;
780   Location loc = getLoc();
781   for (auto i :
782        llvm::seq<int>(0, operand1().getType().cast<ShapedType>().getRank())) {
783     operand1Shape.push_back(builder.create<memref::DimOp>(loc, operand1(), i));
784   }
785   for (auto i :
786        llvm::seq<int>(0, operand2().getType().cast<ShapedType>().getRank())) {
787     operand2Shape.push_back(builder.create<memref::DimOp>(loc, operand2(), i));
788   }
789   shapes.emplace_back(std::move(operand2Shape));
790   shapes.emplace_back(std::move(operand1Shape));
791   return success();
792 }
793 
794 //===----------------------------------------------------------------------===//
795 // Test SideEffect interfaces
796 //===----------------------------------------------------------------------===//
797 
798 namespace {
799 /// A test resource for side effects.
800 struct TestResource : public SideEffects::Resource::Base<TestResource> {
801   StringRef getName() final { return "<Test>"; }
802 };
803 } // end anonymous namespace
804 
805 static void testSideEffectOpGetEffect(
806     Operation *op,
807     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
808         &effects) {
809   auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
810   if (!effectsAttr)
811     return;
812 
813   effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
814 }
815 
816 void SideEffectOp::getEffects(
817     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
818   // Check for an effects attribute on the op instance.
819   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
820   if (!effectsAttr)
821     return;
822 
823   // If there is one, it is an array of dictionary attributes that hold
824   // information on the effects of this operation.
825   for (Attribute element : effectsAttr) {
826     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
827 
828     // Get the specific memory effect.
829     MemoryEffects::Effect *effect =
830         StringSwitch<MemoryEffects::Effect *>(
831             effectElement.get("effect").cast<StringAttr>().getValue())
832             .Case("allocate", MemoryEffects::Allocate::get())
833             .Case("free", MemoryEffects::Free::get())
834             .Case("read", MemoryEffects::Read::get())
835             .Case("write", MemoryEffects::Write::get());
836 
837     // Check for a non-default resource to use.
838     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
839     if (effectElement.get("test_resource"))
840       resource = TestResource::get();
841 
842     // Check for a result to affect.
843     if (effectElement.get("on_result"))
844       effects.emplace_back(effect, getResult(), resource);
845     else if (Attribute ref = effectElement.get("on_reference"))
846       effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
847     else
848       effects.emplace_back(effect, resource);
849   }
850 }
851 
852 void SideEffectOp::getEffects(
853     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
854   testSideEffectOpGetEffect(getOperation(), effects);
855 }
856 
857 //===----------------------------------------------------------------------===//
858 // StringAttrPrettyNameOp
859 //===----------------------------------------------------------------------===//
860 
861 // This op has fancy handling of its SSA result name.
862 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
863                                                OperationState &result) {
864   // Add the result types.
865   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
866     result.addTypes(parser.getBuilder().getIntegerType(32));
867 
868   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
869     return failure();
870 
871   // If the attribute dictionary contains no 'names' attribute, infer it from
872   // the SSA name (if specified).
873   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
874     return attr.first == "names";
875   });
876 
877   // If there was no name specified, check to see if there was a useful name
878   // specified in the asm file.
879   if (hadNames || parser.getNumResults() == 0)
880     return success();
881 
882   SmallVector<StringRef, 4> names;
883   auto *context = result.getContext();
884 
885   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
886     auto resultName = parser.getResultName(i);
887     StringRef nameStr;
888     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
889       nameStr = resultName.first;
890 
891     names.push_back(nameStr);
892   }
893 
894   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
895   result.attributes.push_back({Identifier::get("names", context), namesAttr});
896   return success();
897 }
898 
899 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
900   p << "test.string_attr_pretty_name";
901 
902   // Note that we only need to print the "name" attribute if the asmprinter
903   // result name disagrees with it.  This can happen in strange cases, e.g.
904   // when there are conflicts.
905   bool namesDisagree = op.names().size() != op.getNumResults();
906 
907   SmallString<32> resultNameStr;
908   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
909     resultNameStr.clear();
910     llvm::raw_svector_ostream tmpStream(resultNameStr);
911     p.printOperand(op.getResult(i), tmpStream);
912 
913     auto expectedName = op.names()[i].dyn_cast<StringAttr>();
914     if (!expectedName ||
915         tmpStream.str().drop_front() != expectedName.getValue()) {
916       namesDisagree = true;
917     }
918   }
919 
920   if (namesDisagree)
921     p.printOptionalAttrDictWithKeyword(op->getAttrs());
922   else
923     p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"});
924 }
925 
926 // We set the SSA name in the asm syntax to the contents of the name
927 // attribute.
928 void StringAttrPrettyNameOp::getAsmResultNames(
929     function_ref<void(Value, StringRef)> setNameFn) {
930 
931   auto value = names();
932   for (size_t i = 0, e = value.size(); i != e; ++i)
933     if (auto str = value[i].dyn_cast<StringAttr>())
934       if (!str.getValue().empty())
935         setNameFn(getResult(i), str.getValue());
936 }
937 
938 //===----------------------------------------------------------------------===//
939 // RegionIfOp
940 //===----------------------------------------------------------------------===//
941 
942 static void print(OpAsmPrinter &p, RegionIfOp op) {
943   p << RegionIfOp::getOperationName() << " ";
944   p.printOperands(op.getOperands());
945   p << ": " << op.getOperandTypes();
946   p.printArrowTypeList(op.getResultTypes());
947   p << " then";
948   p.printRegion(op.thenRegion(),
949                 /*printEntryBlockArgs=*/true,
950                 /*printBlockTerminators=*/true);
951   p << " else";
952   p.printRegion(op.elseRegion(),
953                 /*printEntryBlockArgs=*/true,
954                 /*printBlockTerminators=*/true);
955   p << " join";
956   p.printRegion(op.joinRegion(),
957                 /*printEntryBlockArgs=*/true,
958                 /*printBlockTerminators=*/true);
959 }
960 
961 static ParseResult parseRegionIfOp(OpAsmParser &parser,
962                                    OperationState &result) {
963   SmallVector<OpAsmParser::OperandType, 2> operandInfos;
964   SmallVector<Type, 2> operandTypes;
965 
966   result.regions.reserve(3);
967   Region *thenRegion = result.addRegion();
968   Region *elseRegion = result.addRegion();
969   Region *joinRegion = result.addRegion();
970 
971   // Parse operand, type and arrow type lists.
972   if (parser.parseOperandList(operandInfos) ||
973       parser.parseColonTypeList(operandTypes) ||
974       parser.parseArrowTypeList(result.types))
975     return failure();
976 
977   // Parse all attached regions.
978   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
979       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
980       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
981     return failure();
982 
983   return parser.resolveOperands(operandInfos, operandTypes,
984                                 parser.getCurrentLocation(), result.operands);
985 }
986 
987 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
988   assert(index < 2 && "invalid region index");
989   return getOperands();
990 }
991 
992 void RegionIfOp::getSuccessorRegions(
993     Optional<unsigned> index, ArrayRef<Attribute> operands,
994     SmallVectorImpl<RegionSuccessor> &regions) {
995   // We always branch to the join region.
996   if (index.hasValue()) {
997     if (index.getValue() < 2)
998       regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs()));
999     else
1000       regions.push_back(RegionSuccessor(getResults()));
1001     return;
1002   }
1003 
1004   // The then and else regions are the entry regions of this op.
1005   regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs()));
1006   regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs()));
1007 }
1008 
1009 #include "TestOpEnums.cpp.inc"
1010 #include "TestOpInterfaces.cpp.inc"
1011 #include "TestOpStructs.cpp.inc"
1012 #include "TestTypeInterfaces.cpp.inc"
1013 
1014 #define GET_OP_CLASSES
1015 #include "TestOps.cpp.inc"
1016