1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestTypes.h"
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/Function.h"
14 #include "mlir/IR/Module.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/IR/TypeUtilities.h"
17 #include "mlir/Transforms/FoldUtils.h"
18 #include "mlir/Transforms/InliningUtils.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/StringSwitch.h"
21 
22 using namespace mlir;
23 using namespace mlir::test;
24 
25 void mlir::test::registerTestDialect(DialectRegistry &registry) {
26   registry.insert<TestDialect>();
27 }
28 
29 //===----------------------------------------------------------------------===//
30 // TestDialect Interfaces
31 //===----------------------------------------------------------------------===//
32 
33 namespace {
34 
35 // Test support for interacting with the AsmPrinter.
36 struct TestOpAsmInterface : public OpAsmDialectInterface {
37   using OpAsmDialectInterface::OpAsmDialectInterface;
38 
39   LogicalResult getAlias(Attribute attr, raw_ostream &os) const final {
40     StringAttr strAttr = attr.dyn_cast<StringAttr>();
41     if (!strAttr)
42       return failure();
43 
44     // Check the contents of the string attribute to see what the test alias
45     // should be named.
46     Optional<StringRef> aliasName =
47         StringSwitch<Optional<StringRef>>(strAttr.getValue())
48             .Case("alias_test:dot_in_name", StringRef("test.alias"))
49             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
50             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
51             .Case("alias_test:sanitize_conflict_a",
52                   StringRef("test_alias_conflict0"))
53             .Case("alias_test:sanitize_conflict_b",
54                   StringRef("test_alias_conflict0_"))
55             .Default(llvm::None);
56     if (!aliasName)
57       return failure();
58 
59     os << *aliasName;
60     return success();
61   }
62 
63   void getAsmResultNames(Operation *op,
64                          OpAsmSetValueNameFn setNameFn) const final {
65     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
66       setNameFn(asmOp, "result");
67   }
68 
69   void getAsmBlockArgumentNames(Block *block,
70                                 OpAsmSetValueNameFn setNameFn) const final {
71     auto op = block->getParentOp();
72     auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
73     if (!arrayAttr)
74       return;
75     auto args = block->getArguments();
76     auto e = std::min(arrayAttr.size(), args.size());
77     for (unsigned i = 0; i < e; ++i) {
78       if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
79         setNameFn(args[i], strAttr.getValue());
80     }
81   }
82 };
83 
84 struct TestDialectFoldInterface : public DialectFoldInterface {
85   using DialectFoldInterface::DialectFoldInterface;
86 
87   /// Registered hook to check if the given region, which is attached to an
88   /// operation that is *not* isolated from above, should be used when
89   /// materializing constants.
90   bool shouldMaterializeInto(Region *region) const final {
91     // If this is a one region operation, then insert into it.
92     return isa<OneRegionOp>(region->getParentOp());
93   }
94 };
95 
96 /// This class defines the interface for handling inlining with standard
97 /// operations.
98 struct TestInlinerInterface : public DialectInlinerInterface {
99   using DialectInlinerInterface::DialectInlinerInterface;
100 
101   //===--------------------------------------------------------------------===//
102   // Analysis Hooks
103   //===--------------------------------------------------------------------===//
104 
105   bool isLegalToInline(Operation *call, Operation *callable,
106                        bool wouldBeCloned) const final {
107     // Don't allow inlining calls that are marked `noinline`.
108     return !call->hasAttr("noinline");
109   }
110   bool isLegalToInline(Region *, Region *, bool,
111                        BlockAndValueMapping &) const final {
112     // Inlining into test dialect regions is legal.
113     return true;
114   }
115   bool isLegalToInline(Operation *, Region *, bool,
116                        BlockAndValueMapping &) const final {
117     return true;
118   }
119 
120   bool shouldAnalyzeRecursively(Operation *op) const final {
121     // Analyze recursively if this is not a functional region operation, it
122     // froms a separate functional scope.
123     return !isa<FunctionalRegionOp>(op);
124   }
125 
126   //===--------------------------------------------------------------------===//
127   // Transformation Hooks
128   //===--------------------------------------------------------------------===//
129 
130   /// Handle the given inlined terminator by replacing it with a new operation
131   /// as necessary.
132   void handleTerminator(Operation *op,
133                         ArrayRef<Value> valuesToRepl) const final {
134     // Only handle "test.return" here.
135     auto returnOp = dyn_cast<TestReturnOp>(op);
136     if (!returnOp)
137       return;
138 
139     // Replace the values directly with the return operands.
140     assert(returnOp.getNumOperands() == valuesToRepl.size());
141     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
142       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
143   }
144 
145   /// Attempt to materialize a conversion for a type mismatch between a call
146   /// from this dialect, and a callable region. This method should generate an
147   /// operation that takes 'input' as the only operand, and produces a single
148   /// result of 'resultType'. If a conversion can not be generated, nullptr
149   /// should be returned.
150   Operation *materializeCallConversion(OpBuilder &builder, Value input,
151                                        Type resultType,
152                                        Location conversionLoc) const final {
153     // Only allow conversion for i16/i32 types.
154     if (!(resultType.isSignlessInteger(16) ||
155           resultType.isSignlessInteger(32)) ||
156         !(input.getType().isSignlessInteger(16) ||
157           input.getType().isSignlessInteger(32)))
158       return nullptr;
159     return builder.create<TestCastOp>(conversionLoc, resultType, input);
160   }
161 };
162 } // end anonymous namespace
163 
164 //===----------------------------------------------------------------------===//
165 // TestDialect
166 //===----------------------------------------------------------------------===//
167 
168 void TestDialect::initialize() {
169   addOperations<
170 #define GET_OP_LIST
171 #include "TestOps.cpp.inc"
172       >();
173   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
174                 TestInlinerInterface>();
175   addTypes<TestType, TestRecursiveType,
176 #define GET_TYPEDEF_LIST
177 #include "TestTypeDefs.cpp.inc"
178            >();
179   allowUnknownOperations();
180 }
181 
182 static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser,
183                           llvm::SetVector<Type> &stack) {
184   StringRef typeTag;
185   if (failed(parser.parseKeyword(&typeTag)))
186     return Type();
187 
188   auto genType = generatedTypeParser(ctxt, parser, typeTag);
189   if (genType != Type())
190     return genType;
191 
192   if (typeTag == "test_type")
193     return TestType::get(parser.getBuilder().getContext());
194 
195   if (typeTag != "test_rec")
196     return Type();
197 
198   StringRef name;
199   if (parser.parseLess() || parser.parseKeyword(&name))
200     return Type();
201   auto rec = TestRecursiveType::get(parser.getBuilder().getContext(), name);
202 
203   // If this type already has been parsed above in the stack, expect just the
204   // name.
205   if (stack.contains(rec)) {
206     if (failed(parser.parseGreater()))
207       return Type();
208     return rec;
209   }
210 
211   // Otherwise, parse the body and update the type.
212   if (failed(parser.parseComma()))
213     return Type();
214   stack.insert(rec);
215   Type subtype = parseTestType(ctxt, parser, stack);
216   stack.pop_back();
217   if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype)))
218     return Type();
219 
220   return rec;
221 }
222 
223 Type TestDialect::parseType(DialectAsmParser &parser) const {
224   llvm::SetVector<Type> stack;
225   return parseTestType(getContext(), parser, stack);
226 }
227 
228 static void printTestType(Type type, DialectAsmPrinter &printer,
229                           llvm::SetVector<Type> &stack) {
230   if (succeeded(generatedTypePrinter(type, printer)))
231     return;
232   if (type.isa<TestType>()) {
233     printer << "test_type";
234     return;
235   }
236 
237   auto rec = type.cast<TestRecursiveType>();
238   printer << "test_rec<" << rec.getName();
239   if (!stack.contains(rec)) {
240     printer << ", ";
241     stack.insert(rec);
242     printTestType(rec.getBody(), printer, stack);
243     stack.pop_back();
244   }
245   printer << ">";
246 }
247 
248 void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
249   llvm::SetVector<Type> stack;
250   printTestType(type, printer, stack);
251 }
252 
253 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
254                                                     NamedAttribute namedAttr) {
255   if (namedAttr.first == "test.invalid_attr")
256     return op->emitError() << "invalid to use 'test.invalid_attr'";
257   return success();
258 }
259 
260 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
261                                                     unsigned regionIndex,
262                                                     unsigned argIndex,
263                                                     NamedAttribute namedAttr) {
264   if (namedAttr.first == "test.invalid_attr")
265     return op->emitError() << "invalid to use 'test.invalid_attr'";
266   return success();
267 }
268 
269 LogicalResult
270 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
271                                          unsigned resultIndex,
272                                          NamedAttribute namedAttr) {
273   if (namedAttr.first == "test.invalid_attr")
274     return op->emitError() << "invalid to use 'test.invalid_attr'";
275   return success();
276 }
277 
278 //===----------------------------------------------------------------------===//
279 // TestBranchOp
280 //===----------------------------------------------------------------------===//
281 
282 Optional<MutableOperandRange>
283 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
284   assert(index == 0 && "invalid successor index");
285   return targetOperandsMutable();
286 }
287 
288 //===----------------------------------------------------------------------===//
289 // TestFoldToCallOp
290 //===----------------------------------------------------------------------===//
291 
292 namespace {
293 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
294   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
295 
296   LogicalResult matchAndRewrite(FoldToCallOp op,
297                                 PatternRewriter &rewriter) const override {
298     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(),
299                                         ValueRange());
300     return success();
301   }
302 };
303 } // end anonymous namespace
304 
305 void FoldToCallOp::getCanonicalizationPatterns(
306     OwningRewritePatternList &results, MLIRContext *context) {
307   results.insert<FoldToCallOpPattern>(context);
308 }
309 
310 //===----------------------------------------------------------------------===//
311 // Test Format* operations
312 //===----------------------------------------------------------------------===//
313 
314 //===----------------------------------------------------------------------===//
315 // Parsing
316 
317 static ParseResult parseCustomDirectiveOperands(
318     OpAsmParser &parser, OpAsmParser::OperandType &operand,
319     Optional<OpAsmParser::OperandType> &optOperand,
320     SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
321   if (parser.parseOperand(operand))
322     return failure();
323   if (succeeded(parser.parseOptionalComma())) {
324     optOperand.emplace();
325     if (parser.parseOperand(*optOperand))
326       return failure();
327   }
328   if (parser.parseArrow() || parser.parseLParen() ||
329       parser.parseOperandList(varOperands) || parser.parseRParen())
330     return failure();
331   return success();
332 }
333 static ParseResult
334 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
335                             Type &optOperandType,
336                             SmallVectorImpl<Type> &varOperandTypes) {
337   if (parser.parseColon())
338     return failure();
339 
340   if (parser.parseType(operandType))
341     return failure();
342   if (succeeded(parser.parseOptionalComma())) {
343     if (parser.parseType(optOperandType))
344       return failure();
345   }
346   if (parser.parseArrow() || parser.parseLParen() ||
347       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
348     return failure();
349   return success();
350 }
351 static ParseResult
352 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
353                                  Type optOperandType,
354                                  const SmallVectorImpl<Type> &varOperandTypes) {
355   if (parser.parseKeyword("type_refs_capture"))
356     return failure();
357 
358   Type operandType2, optOperandType2;
359   SmallVector<Type, 1> varOperandTypes2;
360   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
361                                   varOperandTypes2))
362     return failure();
363 
364   if (operandType != operandType2 || optOperandType != optOperandType2 ||
365       varOperandTypes != varOperandTypes2)
366     return failure();
367 
368   return success();
369 }
370 static ParseResult parseCustomDirectiveOperandsAndTypes(
371     OpAsmParser &parser, OpAsmParser::OperandType &operand,
372     Optional<OpAsmParser::OperandType> &optOperand,
373     SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
374     Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
375   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
376       parseCustomDirectiveResults(parser, operandType, optOperandType,
377                                   varOperandTypes))
378     return failure();
379   return success();
380 }
381 static ParseResult parseCustomDirectiveRegions(
382     OpAsmParser &parser, Region &region,
383     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
384   if (parser.parseRegion(region))
385     return failure();
386   if (failed(parser.parseOptionalComma()))
387     return success();
388   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
389   if (parser.parseRegion(*varRegion))
390     return failure();
391   varRegions.emplace_back(std::move(varRegion));
392   return success();
393 }
394 static ParseResult
395 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
396                                SmallVectorImpl<Block *> &varSuccessors) {
397   if (parser.parseSuccessor(successor))
398     return failure();
399   if (failed(parser.parseOptionalComma()))
400     return success();
401   Block *varSuccessor;
402   if (parser.parseSuccessor(varSuccessor))
403     return failure();
404   varSuccessors.append(2, varSuccessor);
405   return success();
406 }
407 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
408                                                   IntegerAttr &attr,
409                                                   IntegerAttr &optAttr) {
410   if (parser.parseAttribute(attr))
411     return failure();
412   if (succeeded(parser.parseOptionalComma())) {
413     if (parser.parseAttribute(optAttr))
414       return failure();
415   }
416   return success();
417 }
418 
419 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
420                                                 NamedAttrList &attrs) {
421   return parser.parseOptionalAttrDict(attrs);
422 }
423 
424 //===----------------------------------------------------------------------===//
425 // Printing
426 
427 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
428                                          Value operand, Value optOperand,
429                                          OperandRange varOperands) {
430   printer << operand;
431   if (optOperand)
432     printer << ", " << optOperand;
433   printer << " -> (" << varOperands << ")";
434 }
435 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
436                                         Type operandType, Type optOperandType,
437                                         TypeRange varOperandTypes) {
438   printer << " : " << operandType;
439   if (optOperandType)
440     printer << ", " << optOperandType;
441   printer << " -> (" << varOperandTypes << ")";
442 }
443 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
444                                              Operation *op, Type operandType,
445                                              Type optOperandType,
446                                              TypeRange varOperandTypes) {
447   printer << " type_refs_capture ";
448   printCustomDirectiveResults(printer, op, operandType, optOperandType,
449                               varOperandTypes);
450 }
451 static void printCustomDirectiveOperandsAndTypes(
452     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
453     OperandRange varOperands, Type operandType, Type optOperandType,
454     TypeRange varOperandTypes) {
455   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
456   printCustomDirectiveResults(printer, op, operandType, optOperandType,
457                               varOperandTypes);
458 }
459 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
460                                         Region &region,
461                                         MutableArrayRef<Region> varRegions) {
462   printer.printRegion(region);
463   if (!varRegions.empty()) {
464     printer << ", ";
465     for (Region &region : varRegions)
466       printer.printRegion(region);
467   }
468 }
469 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
470                                            Block *successor,
471                                            SuccessorRange varSuccessors) {
472   printer << successor;
473   if (!varSuccessors.empty())
474     printer << ", " << varSuccessors.front();
475 }
476 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
477                                            Attribute attribute,
478                                            Attribute optAttribute) {
479   printer << attribute;
480   if (optAttribute)
481     printer << ", " << optAttribute;
482 }
483 
484 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
485                                          MutableDictionaryAttr attrs) {
486   printer.printOptionalAttrDict(attrs.getAttrs());
487 }
488 //===----------------------------------------------------------------------===//
489 // Test IsolatedRegionOp - parse passthrough region arguments.
490 //===----------------------------------------------------------------------===//
491 
492 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
493                                          OperationState &result) {
494   OpAsmParser::OperandType argInfo;
495   Type argType = parser.getBuilder().getIndexType();
496 
497   // Parse the input operand.
498   if (parser.parseOperand(argInfo) ||
499       parser.resolveOperand(argInfo, argType, result.operands))
500     return failure();
501 
502   // Parse the body region, and reuse the operand info as the argument info.
503   Region *body = result.addRegion();
504   return parser.parseRegion(*body, argInfo, argType,
505                             /*enableNameShadowing=*/true);
506 }
507 
508 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
509   p << "test.isolated_region ";
510   p.printOperand(op.getOperand());
511   p.shadowRegionArgs(op.region(), op.getOperand());
512   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
513 }
514 
515 //===----------------------------------------------------------------------===//
516 // Test SSACFGRegionOp
517 //===----------------------------------------------------------------------===//
518 
519 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
520   return RegionKind::SSACFG;
521 }
522 
523 //===----------------------------------------------------------------------===//
524 // Test GraphRegionOp
525 //===----------------------------------------------------------------------===//
526 
527 static ParseResult parseGraphRegionOp(OpAsmParser &parser,
528                                       OperationState &result) {
529   // Parse the body region, and reuse the operand info as the argument info.
530   Region *body = result.addRegion();
531   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
532 }
533 
534 static void print(OpAsmPrinter &p, GraphRegionOp op) {
535   p << "test.graph_region ";
536   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
537 }
538 
539 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
540   return RegionKind::Graph;
541 }
542 
543 //===----------------------------------------------------------------------===//
544 // Test AffineScopeOp
545 //===----------------------------------------------------------------------===//
546 
547 static ParseResult parseAffineScopeOp(OpAsmParser &parser,
548                                       OperationState &result) {
549   // Parse the body region, and reuse the operand info as the argument info.
550   Region *body = result.addRegion();
551   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
552 }
553 
554 static void print(OpAsmPrinter &p, AffineScopeOp op) {
555   p << "test.affine_scope ";
556   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
557 }
558 
559 //===----------------------------------------------------------------------===//
560 // Test parser.
561 //===----------------------------------------------------------------------===//
562 
563 static ParseResult parseWrappedKeywordOp(OpAsmParser &parser,
564                                          OperationState &result) {
565   StringRef keyword;
566   if (parser.parseKeyword(&keyword))
567     return failure();
568   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
569   return success();
570 }
571 
572 static void print(OpAsmPrinter &p, WrappedKeywordOp op) {
573   p << WrappedKeywordOp::getOperationName() << " " << op.keyword();
574 }
575 
576 //===----------------------------------------------------------------------===//
577 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
578 
579 static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
580                                          OperationState &result) {
581   if (parser.parseKeyword("wraps"))
582     return failure();
583 
584   // Parse the wrapped op in a region
585   Region &body = *result.addRegion();
586   body.push_back(new Block);
587   Block &block = body.back();
588   Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
589   if (!wrapped_op)
590     return failure();
591 
592   // Create a return terminator in the inner region, pass as operand to the
593   // terminator the returned values from the wrapped operation.
594   SmallVector<Value, 8> return_operands(wrapped_op->getResults());
595   OpBuilder builder(parser.getBuilder().getContext());
596   builder.setInsertionPointToEnd(&block);
597   builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
598 
599   // Get the results type for the wrapping op from the terminator operands.
600   Operation &return_op = body.back().back();
601   result.types.append(return_op.operand_type_begin(),
602                       return_op.operand_type_end());
603 
604   // Use the location of the wrapped op for the "test.wrapping_region" op.
605   result.location = wrapped_op->getLoc();
606 
607   return success();
608 }
609 
610 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
611   p << op.getOperationName() << " wraps ";
612   p.printGenericOp(&op.region().front().front());
613 }
614 
615 //===----------------------------------------------------------------------===//
616 // Test PolyForOp - parse list of region arguments.
617 //===----------------------------------------------------------------------===//
618 
619 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
620   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
621   // Parse list of region arguments without a delimiter.
622   if (parser.parseRegionArgumentList(ivsInfo))
623     return failure();
624 
625   // Parse the body region.
626   Region *body = result.addRegion();
627   auto &builder = parser.getBuilder();
628   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
629   return parser.parseRegion(*body, ivsInfo, argTypes);
630 }
631 
632 //===----------------------------------------------------------------------===//
633 // Test removing op with inner ops.
634 //===----------------------------------------------------------------------===//
635 
636 namespace {
637 struct TestRemoveOpWithInnerOps
638     : public OpRewritePattern<TestOpWithRegionPattern> {
639   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
640 
641   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
642                                 PatternRewriter &rewriter) const override {
643     rewriter.eraseOp(op);
644     return success();
645   }
646 };
647 } // end anonymous namespace
648 
649 void TestOpWithRegionPattern::getCanonicalizationPatterns(
650     OwningRewritePatternList &results, MLIRContext *context) {
651   results.insert<TestRemoveOpWithInnerOps>(context);
652 }
653 
654 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
655   return operand();
656 }
657 
658 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
659   return getValue();
660 }
661 
662 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
663     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
664   for (Value input : this->operands()) {
665     results.push_back(input);
666   }
667   return success();
668 }
669 
670 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
671   assert(operands.size() == 1);
672   if (operands.front()) {
673     setAttr("attr", operands.front());
674     return getResult();
675   }
676   return {};
677 }
678 
679 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
680     MLIRContext *, Optional<Location> location, ValueRange operands,
681     DictionaryAttr attributes, RegionRange regions,
682     SmallVectorImpl<Type> &inferredReturnTypes) {
683   if (operands[0].getType() != operands[1].getType()) {
684     return emitOptionalError(location, "operand type mismatch ",
685                              operands[0].getType(), " vs ",
686                              operands[1].getType());
687   }
688   inferredReturnTypes.assign({operands[0].getType()});
689   return success();
690 }
691 
692 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
693     MLIRContext *context, Optional<Location> location, ValueRange operands,
694     DictionaryAttr attributes, RegionRange regions,
695     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
696   // Create return type consisting of the last element of the first operand.
697   auto operandType = *operands.getTypes().begin();
698   auto sval = operandType.dyn_cast<ShapedType>();
699   if (!sval) {
700     return emitOptionalError(location, "only shaped type operands allowed");
701   }
702   int64_t dim =
703       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
704   auto type = IntegerType::get(17, context);
705   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
706   return success();
707 }
708 
709 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
710     OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) {
711   shapes = SmallVector<Value, 1>{
712       builder.createOrFold<DimOp>(getLoc(), getOperand(0), 0)};
713   return success();
714 }
715 
716 //===----------------------------------------------------------------------===//
717 // Test SideEffect interfaces
718 //===----------------------------------------------------------------------===//
719 
720 namespace {
721 /// A test resource for side effects.
722 struct TestResource : public SideEffects::Resource::Base<TestResource> {
723   StringRef getName() final { return "<Test>"; }
724 };
725 } // end anonymous namespace
726 
727 void SideEffectOp::getEffects(
728     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
729   // Check for an effects attribute on the op instance.
730   ArrayAttr effectsAttr = getAttrOfType<ArrayAttr>("effects");
731   if (!effectsAttr)
732     return;
733 
734   // If there is one, it is an array of dictionary attributes that hold
735   // information on the effects of this operation.
736   for (Attribute element : effectsAttr) {
737     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
738 
739     // Get the specific memory effect.
740     MemoryEffects::Effect *effect =
741         StringSwitch<MemoryEffects::Effect *>(
742             effectElement.get("effect").cast<StringAttr>().getValue())
743             .Case("allocate", MemoryEffects::Allocate::get())
744             .Case("free", MemoryEffects::Free::get())
745             .Case("read", MemoryEffects::Read::get())
746             .Case("write", MemoryEffects::Write::get());
747 
748     // Check for a result to affect.
749     Value value;
750     if (effectElement.get("on_result"))
751       value = getResult();
752 
753     // Check for a non-default resource to use.
754     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
755     if (effectElement.get("test_resource"))
756       resource = TestResource::get();
757 
758     effects.emplace_back(effect, value, resource);
759   }
760 }
761 
762 //===----------------------------------------------------------------------===//
763 // StringAttrPrettyNameOp
764 //===----------------------------------------------------------------------===//
765 
766 // This op has fancy handling of its SSA result name.
767 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
768                                                OperationState &result) {
769   // Add the result types.
770   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
771     result.addTypes(parser.getBuilder().getIntegerType(32));
772 
773   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
774     return failure();
775 
776   // If the attribute dictionary contains no 'names' attribute, infer it from
777   // the SSA name (if specified).
778   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
779     return attr.first == "names";
780   });
781 
782   // If there was no name specified, check to see if there was a useful name
783   // specified in the asm file.
784   if (hadNames || parser.getNumResults() == 0)
785     return success();
786 
787   SmallVector<StringRef, 4> names;
788   auto *context = result.getContext();
789 
790   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
791     auto resultName = parser.getResultName(i);
792     StringRef nameStr;
793     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
794       nameStr = resultName.first;
795 
796     names.push_back(nameStr);
797   }
798 
799   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
800   result.attributes.push_back({Identifier::get("names", context), namesAttr});
801   return success();
802 }
803 
804 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
805   p << "test.string_attr_pretty_name";
806 
807   // Note that we only need to print the "name" attribute if the asmprinter
808   // result name disagrees with it.  This can happen in strange cases, e.g.
809   // when there are conflicts.
810   bool namesDisagree = op.names().size() != op.getNumResults();
811 
812   SmallString<32> resultNameStr;
813   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
814     resultNameStr.clear();
815     llvm::raw_svector_ostream tmpStream(resultNameStr);
816     p.printOperand(op.getResult(i), tmpStream);
817 
818     auto expectedName = op.names()[i].dyn_cast<StringAttr>();
819     if (!expectedName ||
820         tmpStream.str().drop_front() != expectedName.getValue()) {
821       namesDisagree = true;
822     }
823   }
824 
825   if (namesDisagree)
826     p.printOptionalAttrDictWithKeyword(op.getAttrs());
827   else
828     p.printOptionalAttrDictWithKeyword(op.getAttrs(), {"names"});
829 }
830 
831 // We set the SSA name in the asm syntax to the contents of the name
832 // attribute.
833 void StringAttrPrettyNameOp::getAsmResultNames(
834     function_ref<void(Value, StringRef)> setNameFn) {
835 
836   auto value = names();
837   for (size_t i = 0, e = value.size(); i != e; ++i)
838     if (auto str = value[i].dyn_cast<StringAttr>())
839       if (!str.getValue().empty())
840         setNameFn(getResult(i), str.getValue());
841 }
842 
843 //===----------------------------------------------------------------------===//
844 // RegionIfOp
845 //===----------------------------------------------------------------------===//
846 
847 static void print(OpAsmPrinter &p, RegionIfOp op) {
848   p << RegionIfOp::getOperationName() << " ";
849   p.printOperands(op.getOperands());
850   p << ": " << op.getOperandTypes();
851   p.printArrowTypeList(op.getResultTypes());
852   p << " then";
853   p.printRegion(op.thenRegion(),
854                 /*printEntryBlockArgs=*/true,
855                 /*printBlockTerminators=*/true);
856   p << " else";
857   p.printRegion(op.elseRegion(),
858                 /*printEntryBlockArgs=*/true,
859                 /*printBlockTerminators=*/true);
860   p << " join";
861   p.printRegion(op.joinRegion(),
862                 /*printEntryBlockArgs=*/true,
863                 /*printBlockTerminators=*/true);
864 }
865 
866 static ParseResult parseRegionIfOp(OpAsmParser &parser,
867                                    OperationState &result) {
868   SmallVector<OpAsmParser::OperandType, 2> operandInfos;
869   SmallVector<Type, 2> operandTypes;
870 
871   result.regions.reserve(3);
872   Region *thenRegion = result.addRegion();
873   Region *elseRegion = result.addRegion();
874   Region *joinRegion = result.addRegion();
875 
876   // Parse operand, type and arrow type lists.
877   if (parser.parseOperandList(operandInfos) ||
878       parser.parseColonTypeList(operandTypes) ||
879       parser.parseArrowTypeList(result.types))
880     return failure();
881 
882   // Parse all attached regions.
883   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
884       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
885       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
886     return failure();
887 
888   return parser.resolveOperands(operandInfos, operandTypes,
889                                 parser.getCurrentLocation(), result.operands);
890 }
891 
892 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
893   assert(index < 2 && "invalid region index");
894   return getOperands();
895 }
896 
897 void RegionIfOp::getSuccessorRegions(
898     Optional<unsigned> index, ArrayRef<Attribute> operands,
899     SmallVectorImpl<RegionSuccessor> &regions) {
900   // We always branch to the join region.
901   if (index.hasValue()) {
902     if (index.getValue() < 2)
903       regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs()));
904     else
905       regions.push_back(RegionSuccessor(getResults()));
906     return;
907   }
908 
909   // The then and else regions are the entry regions of this op.
910   regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs()));
911   regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs()));
912 }
913 
914 #include "TestOpEnums.cpp.inc"
915 #include "TestOpStructs.cpp.inc"
916 #include "TestTypeInterfaces.cpp.inc"
917 
918 #define GET_OP_CLASSES
919 #include "TestOps.cpp.inc"
920