1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestTypes.h"
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/DialectImplementation.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/IR/TypeUtilities.h"
16 #include "mlir/Transforms/FoldUtils.h"
17 #include "mlir/Transforms/InliningUtils.h"
18 #include "llvm/ADT/SetVector.h"
19 #include "llvm/ADT/StringSwitch.h"
20 
21 using namespace mlir;
22 using namespace mlir::test;
23 
24 void mlir::test::registerTestDialect(DialectRegistry &registry) {
25   registry.insert<TestDialect>();
26 }
27 
28 //===----------------------------------------------------------------------===//
29 // TestDialect Interfaces
30 //===----------------------------------------------------------------------===//
31 
32 namespace {
33 
34 // Test support for interacting with the AsmPrinter.
35 struct TestOpAsmInterface : public OpAsmDialectInterface {
36   using OpAsmDialectInterface::OpAsmDialectInterface;
37 
38   LogicalResult getAlias(Attribute attr, raw_ostream &os) const final {
39     StringAttr strAttr = attr.dyn_cast<StringAttr>();
40     if (!strAttr)
41       return failure();
42 
43     // Check the contents of the string attribute to see what the test alias
44     // should be named.
45     Optional<StringRef> aliasName =
46         StringSwitch<Optional<StringRef>>(strAttr.getValue())
47             .Case("alias_test:dot_in_name", StringRef("test.alias"))
48             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
49             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
50             .Case("alias_test:sanitize_conflict_a",
51                   StringRef("test_alias_conflict0"))
52             .Case("alias_test:sanitize_conflict_b",
53                   StringRef("test_alias_conflict0_"))
54             .Default(llvm::None);
55     if (!aliasName)
56       return failure();
57 
58     os << *aliasName;
59     return success();
60   }
61 
62   void getAsmResultNames(Operation *op,
63                          OpAsmSetValueNameFn setNameFn) const final {
64     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
65       setNameFn(asmOp, "result");
66   }
67 
68   void getAsmBlockArgumentNames(Block *block,
69                                 OpAsmSetValueNameFn setNameFn) const final {
70     auto op = block->getParentOp();
71     auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
72     if (!arrayAttr)
73       return;
74     auto args = block->getArguments();
75     auto e = std::min(arrayAttr.size(), args.size());
76     for (unsigned i = 0; i < e; ++i) {
77       if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
78         setNameFn(args[i], strAttr.getValue());
79     }
80   }
81 };
82 
83 struct TestDialectFoldInterface : public DialectFoldInterface {
84   using DialectFoldInterface::DialectFoldInterface;
85 
86   /// Registered hook to check if the given region, which is attached to an
87   /// operation that is *not* isolated from above, should be used when
88   /// materializing constants.
89   bool shouldMaterializeInto(Region *region) const final {
90     // If this is a one region operation, then insert into it.
91     return isa<OneRegionOp>(region->getParentOp());
92   }
93 };
94 
95 /// This class defines the interface for handling inlining with standard
96 /// operations.
97 struct TestInlinerInterface : public DialectInlinerInterface {
98   using DialectInlinerInterface::DialectInlinerInterface;
99 
100   //===--------------------------------------------------------------------===//
101   // Analysis Hooks
102   //===--------------------------------------------------------------------===//
103 
104   bool isLegalToInline(Operation *call, Operation *callable,
105                        bool wouldBeCloned) const final {
106     // Don't allow inlining calls that are marked `noinline`.
107     return !call->hasAttr("noinline");
108   }
109   bool isLegalToInline(Region *, Region *, bool,
110                        BlockAndValueMapping &) const final {
111     // Inlining into test dialect regions is legal.
112     return true;
113   }
114   bool isLegalToInline(Operation *, Region *, bool,
115                        BlockAndValueMapping &) const final {
116     return true;
117   }
118 
119   bool shouldAnalyzeRecursively(Operation *op) const final {
120     // Analyze recursively if this is not a functional region operation, it
121     // froms a separate functional scope.
122     return !isa<FunctionalRegionOp>(op);
123   }
124 
125   //===--------------------------------------------------------------------===//
126   // Transformation Hooks
127   //===--------------------------------------------------------------------===//
128 
129   /// Handle the given inlined terminator by replacing it with a new operation
130   /// as necessary.
131   void handleTerminator(Operation *op,
132                         ArrayRef<Value> valuesToRepl) const final {
133     // Only handle "test.return" here.
134     auto returnOp = dyn_cast<TestReturnOp>(op);
135     if (!returnOp)
136       return;
137 
138     // Replace the values directly with the return operands.
139     assert(returnOp.getNumOperands() == valuesToRepl.size());
140     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
141       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
142   }
143 
144   /// Attempt to materialize a conversion for a type mismatch between a call
145   /// from this dialect, and a callable region. This method should generate an
146   /// operation that takes 'input' as the only operand, and produces a single
147   /// result of 'resultType'. If a conversion can not be generated, nullptr
148   /// should be returned.
149   Operation *materializeCallConversion(OpBuilder &builder, Value input,
150                                        Type resultType,
151                                        Location conversionLoc) const final {
152     // Only allow conversion for i16/i32 types.
153     if (!(resultType.isSignlessInteger(16) ||
154           resultType.isSignlessInteger(32)) ||
155         !(input.getType().isSignlessInteger(16) ||
156           input.getType().isSignlessInteger(32)))
157       return nullptr;
158     return builder.create<TestCastOp>(conversionLoc, resultType, input);
159   }
160 };
161 } // end anonymous namespace
162 
163 //===----------------------------------------------------------------------===//
164 // TestDialect
165 //===----------------------------------------------------------------------===//
166 
167 void TestDialect::initialize() {
168   addOperations<
169 #define GET_OP_LIST
170 #include "TestOps.cpp.inc"
171       >();
172   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
173                 TestInlinerInterface>();
174   addTypes<TestType, TestRecursiveType,
175 #define GET_TYPEDEF_LIST
176 #include "TestTypeDefs.cpp.inc"
177            >();
178   allowUnknownOperations();
179 }
180 
181 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
182                                             Type type, Location loc) {
183   return builder.create<TestOpConstant>(loc, type, value);
184 }
185 
186 static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser,
187                           llvm::SetVector<Type> &stack) {
188   StringRef typeTag;
189   if (failed(parser.parseKeyword(&typeTag)))
190     return Type();
191 
192   auto genType = generatedTypeParser(ctxt, parser, typeTag);
193   if (genType != Type())
194     return genType;
195 
196   if (typeTag == "test_type")
197     return TestType::get(parser.getBuilder().getContext());
198 
199   if (typeTag != "test_rec")
200     return Type();
201 
202   StringRef name;
203   if (parser.parseLess() || parser.parseKeyword(&name))
204     return Type();
205   auto rec = TestRecursiveType::get(parser.getBuilder().getContext(), name);
206 
207   // If this type already has been parsed above in the stack, expect just the
208   // name.
209   if (stack.contains(rec)) {
210     if (failed(parser.parseGreater()))
211       return Type();
212     return rec;
213   }
214 
215   // Otherwise, parse the body and update the type.
216   if (failed(parser.parseComma()))
217     return Type();
218   stack.insert(rec);
219   Type subtype = parseTestType(ctxt, parser, stack);
220   stack.pop_back();
221   if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype)))
222     return Type();
223 
224   return rec;
225 }
226 
227 Type TestDialect::parseType(DialectAsmParser &parser) const {
228   llvm::SetVector<Type> stack;
229   return parseTestType(getContext(), parser, stack);
230 }
231 
232 static void printTestType(Type type, DialectAsmPrinter &printer,
233                           llvm::SetVector<Type> &stack) {
234   if (succeeded(generatedTypePrinter(type, printer)))
235     return;
236   if (type.isa<TestType>()) {
237     printer << "test_type";
238     return;
239   }
240 
241   auto rec = type.cast<TestRecursiveType>();
242   printer << "test_rec<" << rec.getName();
243   if (!stack.contains(rec)) {
244     printer << ", ";
245     stack.insert(rec);
246     printTestType(rec.getBody(), printer, stack);
247     stack.pop_back();
248   }
249   printer << ">";
250 }
251 
252 void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
253   llvm::SetVector<Type> stack;
254   printTestType(type, printer, stack);
255 }
256 
257 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
258                                                     NamedAttribute namedAttr) {
259   if (namedAttr.first == "test.invalid_attr")
260     return op->emitError() << "invalid to use 'test.invalid_attr'";
261   return success();
262 }
263 
264 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
265                                                     unsigned regionIndex,
266                                                     unsigned argIndex,
267                                                     NamedAttribute namedAttr) {
268   if (namedAttr.first == "test.invalid_attr")
269     return op->emitError() << "invalid to use 'test.invalid_attr'";
270   return success();
271 }
272 
273 LogicalResult
274 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
275                                          unsigned resultIndex,
276                                          NamedAttribute namedAttr) {
277   if (namedAttr.first == "test.invalid_attr")
278     return op->emitError() << "invalid to use 'test.invalid_attr'";
279   return success();
280 }
281 
282 //===----------------------------------------------------------------------===//
283 // TestBranchOp
284 //===----------------------------------------------------------------------===//
285 
286 Optional<MutableOperandRange>
287 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
288   assert(index == 0 && "invalid successor index");
289   return targetOperandsMutable();
290 }
291 
292 //===----------------------------------------------------------------------===//
293 // TestFoldToCallOp
294 //===----------------------------------------------------------------------===//
295 
296 namespace {
297 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
298   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
299 
300   LogicalResult matchAndRewrite(FoldToCallOp op,
301                                 PatternRewriter &rewriter) const override {
302     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(),
303                                         ValueRange());
304     return success();
305   }
306 };
307 } // end anonymous namespace
308 
309 void FoldToCallOp::getCanonicalizationPatterns(
310     OwningRewritePatternList &results, MLIRContext *context) {
311   results.insert<FoldToCallOpPattern>(context);
312 }
313 
314 //===----------------------------------------------------------------------===//
315 // Test Format* operations
316 //===----------------------------------------------------------------------===//
317 
318 //===----------------------------------------------------------------------===//
319 // Parsing
320 
321 static ParseResult parseCustomDirectiveOperands(
322     OpAsmParser &parser, OpAsmParser::OperandType &operand,
323     Optional<OpAsmParser::OperandType> &optOperand,
324     SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
325   if (parser.parseOperand(operand))
326     return failure();
327   if (succeeded(parser.parseOptionalComma())) {
328     optOperand.emplace();
329     if (parser.parseOperand(*optOperand))
330       return failure();
331   }
332   if (parser.parseArrow() || parser.parseLParen() ||
333       parser.parseOperandList(varOperands) || parser.parseRParen())
334     return failure();
335   return success();
336 }
337 static ParseResult
338 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
339                             Type &optOperandType,
340                             SmallVectorImpl<Type> &varOperandTypes) {
341   if (parser.parseColon())
342     return failure();
343 
344   if (parser.parseType(operandType))
345     return failure();
346   if (succeeded(parser.parseOptionalComma())) {
347     if (parser.parseType(optOperandType))
348       return failure();
349   }
350   if (parser.parseArrow() || parser.parseLParen() ||
351       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
352     return failure();
353   return success();
354 }
355 static ParseResult
356 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
357                                  Type optOperandType,
358                                  const SmallVectorImpl<Type> &varOperandTypes) {
359   if (parser.parseKeyword("type_refs_capture"))
360     return failure();
361 
362   Type operandType2, optOperandType2;
363   SmallVector<Type, 1> varOperandTypes2;
364   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
365                                   varOperandTypes2))
366     return failure();
367 
368   if (operandType != operandType2 || optOperandType != optOperandType2 ||
369       varOperandTypes != varOperandTypes2)
370     return failure();
371 
372   return success();
373 }
374 static ParseResult parseCustomDirectiveOperandsAndTypes(
375     OpAsmParser &parser, OpAsmParser::OperandType &operand,
376     Optional<OpAsmParser::OperandType> &optOperand,
377     SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
378     Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
379   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
380       parseCustomDirectiveResults(parser, operandType, optOperandType,
381                                   varOperandTypes))
382     return failure();
383   return success();
384 }
385 static ParseResult parseCustomDirectiveRegions(
386     OpAsmParser &parser, Region &region,
387     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
388   if (parser.parseRegion(region))
389     return failure();
390   if (failed(parser.parseOptionalComma()))
391     return success();
392   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
393   if (parser.parseRegion(*varRegion))
394     return failure();
395   varRegions.emplace_back(std::move(varRegion));
396   return success();
397 }
398 static ParseResult
399 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
400                                SmallVectorImpl<Block *> &varSuccessors) {
401   if (parser.parseSuccessor(successor))
402     return failure();
403   if (failed(parser.parseOptionalComma()))
404     return success();
405   Block *varSuccessor;
406   if (parser.parseSuccessor(varSuccessor))
407     return failure();
408   varSuccessors.append(2, varSuccessor);
409   return success();
410 }
411 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
412                                                   IntegerAttr &attr,
413                                                   IntegerAttr &optAttr) {
414   if (parser.parseAttribute(attr))
415     return failure();
416   if (succeeded(parser.parseOptionalComma())) {
417     if (parser.parseAttribute(optAttr))
418       return failure();
419   }
420   return success();
421 }
422 
423 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
424                                                 NamedAttrList &attrs) {
425   return parser.parseOptionalAttrDict(attrs);
426 }
427 
428 //===----------------------------------------------------------------------===//
429 // Printing
430 
431 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
432                                          Value operand, Value optOperand,
433                                          OperandRange varOperands) {
434   printer << operand;
435   if (optOperand)
436     printer << ", " << optOperand;
437   printer << " -> (" << varOperands << ")";
438 }
439 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
440                                         Type operandType, Type optOperandType,
441                                         TypeRange varOperandTypes) {
442   printer << " : " << operandType;
443   if (optOperandType)
444     printer << ", " << optOperandType;
445   printer << " -> (" << varOperandTypes << ")";
446 }
447 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
448                                              Operation *op, Type operandType,
449                                              Type optOperandType,
450                                              TypeRange varOperandTypes) {
451   printer << " type_refs_capture ";
452   printCustomDirectiveResults(printer, op, operandType, optOperandType,
453                               varOperandTypes);
454 }
455 static void printCustomDirectiveOperandsAndTypes(
456     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
457     OperandRange varOperands, Type operandType, Type optOperandType,
458     TypeRange varOperandTypes) {
459   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
460   printCustomDirectiveResults(printer, op, operandType, optOperandType,
461                               varOperandTypes);
462 }
463 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
464                                         Region &region,
465                                         MutableArrayRef<Region> varRegions) {
466   printer.printRegion(region);
467   if (!varRegions.empty()) {
468     printer << ", ";
469     for (Region &region : varRegions)
470       printer.printRegion(region);
471   }
472 }
473 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
474                                            Block *successor,
475                                            SuccessorRange varSuccessors) {
476   printer << successor;
477   if (!varSuccessors.empty())
478     printer << ", " << varSuccessors.front();
479 }
480 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
481                                            Attribute attribute,
482                                            Attribute optAttribute) {
483   printer << attribute;
484   if (optAttribute)
485     printer << ", " << optAttribute;
486 }
487 
488 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
489                                          MutableDictionaryAttr attrs) {
490   printer.printOptionalAttrDict(attrs.getAttrs());
491 }
492 //===----------------------------------------------------------------------===//
493 // Test IsolatedRegionOp - parse passthrough region arguments.
494 //===----------------------------------------------------------------------===//
495 
496 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
497                                          OperationState &result) {
498   OpAsmParser::OperandType argInfo;
499   Type argType = parser.getBuilder().getIndexType();
500 
501   // Parse the input operand.
502   if (parser.parseOperand(argInfo) ||
503       parser.resolveOperand(argInfo, argType, result.operands))
504     return failure();
505 
506   // Parse the body region, and reuse the operand info as the argument info.
507   Region *body = result.addRegion();
508   return parser.parseRegion(*body, argInfo, argType,
509                             /*enableNameShadowing=*/true);
510 }
511 
512 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
513   p << "test.isolated_region ";
514   p.printOperand(op.getOperand());
515   p.shadowRegionArgs(op.region(), op.getOperand());
516   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
517 }
518 
519 //===----------------------------------------------------------------------===//
520 // Test SSACFGRegionOp
521 //===----------------------------------------------------------------------===//
522 
523 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
524   return RegionKind::SSACFG;
525 }
526 
527 //===----------------------------------------------------------------------===//
528 // Test GraphRegionOp
529 //===----------------------------------------------------------------------===//
530 
531 static ParseResult parseGraphRegionOp(OpAsmParser &parser,
532                                       OperationState &result) {
533   // Parse the body region, and reuse the operand info as the argument info.
534   Region *body = result.addRegion();
535   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
536 }
537 
538 static void print(OpAsmPrinter &p, GraphRegionOp op) {
539   p << "test.graph_region ";
540   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
541 }
542 
543 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
544   return RegionKind::Graph;
545 }
546 
547 //===----------------------------------------------------------------------===//
548 // Test AffineScopeOp
549 //===----------------------------------------------------------------------===//
550 
551 static ParseResult parseAffineScopeOp(OpAsmParser &parser,
552                                       OperationState &result) {
553   // Parse the body region, and reuse the operand info as the argument info.
554   Region *body = result.addRegion();
555   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
556 }
557 
558 static void print(OpAsmPrinter &p, AffineScopeOp op) {
559   p << "test.affine_scope ";
560   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
561 }
562 
563 //===----------------------------------------------------------------------===//
564 // Test parser.
565 //===----------------------------------------------------------------------===//
566 
567 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser,
568                                               OperationState &result) {
569   if (parser.parseOptionalColon())
570     return success();
571   uint64_t numResults;
572   if (parser.parseInteger(numResults))
573     return failure();
574 
575   IndexType type = parser.getBuilder().getIndexType();
576   for (unsigned i = 0; i < numResults; ++i)
577     result.addTypes(type);
578   return success();
579 }
580 
581 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) {
582   p << ParseIntegerLiteralOp::getOperationName();
583   if (unsigned numResults = op->getNumResults())
584     p << " : " << numResults;
585 }
586 
587 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser,
588                                               OperationState &result) {
589   StringRef keyword;
590   if (parser.parseKeyword(&keyword))
591     return failure();
592   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
593   return success();
594 }
595 
596 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) {
597   p << ParseWrappedKeywordOp::getOperationName() << " " << op.keyword();
598 }
599 
600 //===----------------------------------------------------------------------===//
601 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
602 
603 static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
604                                          OperationState &result) {
605   if (parser.parseKeyword("wraps"))
606     return failure();
607 
608   // Parse the wrapped op in a region
609   Region &body = *result.addRegion();
610   body.push_back(new Block);
611   Block &block = body.back();
612   Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
613   if (!wrapped_op)
614     return failure();
615 
616   // Create a return terminator in the inner region, pass as operand to the
617   // terminator the returned values from the wrapped operation.
618   SmallVector<Value, 8> return_operands(wrapped_op->getResults());
619   OpBuilder builder(parser.getBuilder().getContext());
620   builder.setInsertionPointToEnd(&block);
621   builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
622 
623   // Get the results type for the wrapping op from the terminator operands.
624   Operation &return_op = body.back().back();
625   result.types.append(return_op.operand_type_begin(),
626                       return_op.operand_type_end());
627 
628   // Use the location of the wrapped op for the "test.wrapping_region" op.
629   result.location = wrapped_op->getLoc();
630 
631   return success();
632 }
633 
634 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
635   p << op.getOperationName() << " wraps ";
636   p.printGenericOp(&op.region().front().front());
637 }
638 
639 //===----------------------------------------------------------------------===//
640 // Test PolyForOp - parse list of region arguments.
641 //===----------------------------------------------------------------------===//
642 
643 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
644   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
645   // Parse list of region arguments without a delimiter.
646   if (parser.parseRegionArgumentList(ivsInfo))
647     return failure();
648 
649   // Parse the body region.
650   Region *body = result.addRegion();
651   auto &builder = parser.getBuilder();
652   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
653   return parser.parseRegion(*body, ivsInfo, argTypes);
654 }
655 
656 //===----------------------------------------------------------------------===//
657 // Test removing op with inner ops.
658 //===----------------------------------------------------------------------===//
659 
660 namespace {
661 struct TestRemoveOpWithInnerOps
662     : public OpRewritePattern<TestOpWithRegionPattern> {
663   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
664 
665   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
666                                 PatternRewriter &rewriter) const override {
667     rewriter.eraseOp(op);
668     return success();
669   }
670 };
671 } // end anonymous namespace
672 
673 void TestOpWithRegionPattern::getCanonicalizationPatterns(
674     OwningRewritePatternList &results, MLIRContext *context) {
675   results.insert<TestRemoveOpWithInnerOps>(context);
676 }
677 
678 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
679   return operand();
680 }
681 
682 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
683   return getValue();
684 }
685 
686 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
687     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
688   for (Value input : this->operands()) {
689     results.push_back(input);
690   }
691   return success();
692 }
693 
694 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
695   assert(operands.size() == 1);
696   if (operands.front()) {
697     (*this)->setAttr("attr", operands.front());
698     return getResult();
699   }
700   return {};
701 }
702 
703 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
704     MLIRContext *, Optional<Location> location, ValueRange operands,
705     DictionaryAttr attributes, RegionRange regions,
706     SmallVectorImpl<Type> &inferredReturnTypes) {
707   if (operands[0].getType() != operands[1].getType()) {
708     return emitOptionalError(location, "operand type mismatch ",
709                              operands[0].getType(), " vs ",
710                              operands[1].getType());
711   }
712   inferredReturnTypes.assign({operands[0].getType()});
713   return success();
714 }
715 
716 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
717     MLIRContext *context, Optional<Location> location, ValueRange operands,
718     DictionaryAttr attributes, RegionRange regions,
719     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
720   // Create return type consisting of the last element of the first operand.
721   auto operandType = *operands.getTypes().begin();
722   auto sval = operandType.dyn_cast<ShapedType>();
723   if (!sval) {
724     return emitOptionalError(location, "only shaped type operands allowed");
725   }
726   int64_t dim =
727       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
728   auto type = IntegerType::get(17, context);
729   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
730   return success();
731 }
732 
733 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
734     OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) {
735   shapes = SmallVector<Value, 1>{
736       builder.createOrFold<DimOp>(getLoc(), getOperand(0), 0)};
737   return success();
738 }
739 
740 //===----------------------------------------------------------------------===//
741 // Test SideEffect interfaces
742 //===----------------------------------------------------------------------===//
743 
744 namespace {
745 /// A test resource for side effects.
746 struct TestResource : public SideEffects::Resource::Base<TestResource> {
747   StringRef getName() final { return "<Test>"; }
748 };
749 } // end anonymous namespace
750 
751 void SideEffectOp::getEffects(
752     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
753   // Check for an effects attribute on the op instance.
754   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
755   if (!effectsAttr)
756     return;
757 
758   // If there is one, it is an array of dictionary attributes that hold
759   // information on the effects of this operation.
760   for (Attribute element : effectsAttr) {
761     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
762 
763     // Get the specific memory effect.
764     MemoryEffects::Effect *effect =
765         StringSwitch<MemoryEffects::Effect *>(
766             effectElement.get("effect").cast<StringAttr>().getValue())
767             .Case("allocate", MemoryEffects::Allocate::get())
768             .Case("free", MemoryEffects::Free::get())
769             .Case("read", MemoryEffects::Read::get())
770             .Case("write", MemoryEffects::Write::get());
771 
772     // Check for a non-default resource to use.
773     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
774     if (effectElement.get("test_resource"))
775       resource = TestResource::get();
776 
777     // Check for a result to affect.
778     if (effectElement.get("on_result"))
779       effects.emplace_back(effect, getResult(), resource);
780     else if (Attribute ref = effectElement.get("on_reference"))
781       effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
782     else
783       effects.emplace_back(effect, resource);
784   }
785 }
786 
787 void SideEffectOp::getEffects(
788     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
789   auto effectsAttr = (*this)->getAttrOfType<AffineMapAttr>("effect_parameter");
790   if (!effectsAttr)
791     return;
792 
793   effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
794 }
795 
796 //===----------------------------------------------------------------------===//
797 // StringAttrPrettyNameOp
798 //===----------------------------------------------------------------------===//
799 
800 // This op has fancy handling of its SSA result name.
801 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
802                                                OperationState &result) {
803   // Add the result types.
804   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
805     result.addTypes(parser.getBuilder().getIntegerType(32));
806 
807   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
808     return failure();
809 
810   // If the attribute dictionary contains no 'names' attribute, infer it from
811   // the SSA name (if specified).
812   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
813     return attr.first == "names";
814   });
815 
816   // If there was no name specified, check to see if there was a useful name
817   // specified in the asm file.
818   if (hadNames || parser.getNumResults() == 0)
819     return success();
820 
821   SmallVector<StringRef, 4> names;
822   auto *context = result.getContext();
823 
824   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
825     auto resultName = parser.getResultName(i);
826     StringRef nameStr;
827     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
828       nameStr = resultName.first;
829 
830     names.push_back(nameStr);
831   }
832 
833   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
834   result.attributes.push_back({Identifier::get("names", context), namesAttr});
835   return success();
836 }
837 
838 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
839   p << "test.string_attr_pretty_name";
840 
841   // Note that we only need to print the "name" attribute if the asmprinter
842   // result name disagrees with it.  This can happen in strange cases, e.g.
843   // when there are conflicts.
844   bool namesDisagree = op.names().size() != op.getNumResults();
845 
846   SmallString<32> resultNameStr;
847   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
848     resultNameStr.clear();
849     llvm::raw_svector_ostream tmpStream(resultNameStr);
850     p.printOperand(op.getResult(i), tmpStream);
851 
852     auto expectedName = op.names()[i].dyn_cast<StringAttr>();
853     if (!expectedName ||
854         tmpStream.str().drop_front() != expectedName.getValue()) {
855       namesDisagree = true;
856     }
857   }
858 
859   if (namesDisagree)
860     p.printOptionalAttrDictWithKeyword(op.getAttrs());
861   else
862     p.printOptionalAttrDictWithKeyword(op.getAttrs(), {"names"});
863 }
864 
865 // We set the SSA name in the asm syntax to the contents of the name
866 // attribute.
867 void StringAttrPrettyNameOp::getAsmResultNames(
868     function_ref<void(Value, StringRef)> setNameFn) {
869 
870   auto value = names();
871   for (size_t i = 0, e = value.size(); i != e; ++i)
872     if (auto str = value[i].dyn_cast<StringAttr>())
873       if (!str.getValue().empty())
874         setNameFn(getResult(i), str.getValue());
875 }
876 
877 //===----------------------------------------------------------------------===//
878 // RegionIfOp
879 //===----------------------------------------------------------------------===//
880 
881 static void print(OpAsmPrinter &p, RegionIfOp op) {
882   p << RegionIfOp::getOperationName() << " ";
883   p.printOperands(op.getOperands());
884   p << ": " << op.getOperandTypes();
885   p.printArrowTypeList(op.getResultTypes());
886   p << " then";
887   p.printRegion(op.thenRegion(),
888                 /*printEntryBlockArgs=*/true,
889                 /*printBlockTerminators=*/true);
890   p << " else";
891   p.printRegion(op.elseRegion(),
892                 /*printEntryBlockArgs=*/true,
893                 /*printBlockTerminators=*/true);
894   p << " join";
895   p.printRegion(op.joinRegion(),
896                 /*printEntryBlockArgs=*/true,
897                 /*printBlockTerminators=*/true);
898 }
899 
900 static ParseResult parseRegionIfOp(OpAsmParser &parser,
901                                    OperationState &result) {
902   SmallVector<OpAsmParser::OperandType, 2> operandInfos;
903   SmallVector<Type, 2> operandTypes;
904 
905   result.regions.reserve(3);
906   Region *thenRegion = result.addRegion();
907   Region *elseRegion = result.addRegion();
908   Region *joinRegion = result.addRegion();
909 
910   // Parse operand, type and arrow type lists.
911   if (parser.parseOperandList(operandInfos) ||
912       parser.parseColonTypeList(operandTypes) ||
913       parser.parseArrowTypeList(result.types))
914     return failure();
915 
916   // Parse all attached regions.
917   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
918       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
919       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
920     return failure();
921 
922   return parser.resolveOperands(operandInfos, operandTypes,
923                                 parser.getCurrentLocation(), result.operands);
924 }
925 
926 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
927   assert(index < 2 && "invalid region index");
928   return getOperands();
929 }
930 
931 void RegionIfOp::getSuccessorRegions(
932     Optional<unsigned> index, ArrayRef<Attribute> operands,
933     SmallVectorImpl<RegionSuccessor> &regions) {
934   // We always branch to the join region.
935   if (index.hasValue()) {
936     if (index.getValue() < 2)
937       regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs()));
938     else
939       regions.push_back(RegionSuccessor(getResults()));
940     return;
941   }
942 
943   // The then and else regions are the entry regions of this op.
944   regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs()));
945   regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs()));
946 }
947 
948 #include "TestOpEnums.cpp.inc"
949 #include "TestOpInterfaces.cpp.inc"
950 #include "TestOpStructs.cpp.inc"
951 #include "TestTypeInterfaces.cpp.inc"
952 
953 #define GET_OP_CLASSES
954 #include "TestOps.cpp.inc"
955