1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestAttributes.h"
11 #include "TestInterfaces.h"
12 #include "TestTypes.h"
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/DLTI/DLTI.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/DialectImplementation.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/Reducer/ReductionPatternInterface.h"
22 #include "mlir/Transforms/FoldUtils.h"
23 #include "mlir/Transforms/InliningUtils.h"
24 #include "llvm/ADT/StringSwitch.h"
25 
26 // Include this before the using namespace lines below to
27 // test that we don't have namespace dependencies.
28 #include "TestOpsDialect.cpp.inc"
29 
30 using namespace mlir;
31 using namespace test;
32 
33 void test::registerTestDialect(DialectRegistry &registry) {
34   registry.insert<TestDialect>();
35 }
36 
37 //===----------------------------------------------------------------------===//
38 // TestDialect Interfaces
39 //===----------------------------------------------------------------------===//
40 
41 namespace {
42 
43 /// Testing the correctness of some traits.
44 static_assert(
45     llvm::is_detected<OpTrait::has_implicit_terminator_t,
46                       SingleBlockImplicitTerminatorOp>::value,
47     "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
48 static_assert(OpTrait::hasSingleBlockImplicitTerminator<
49                   SingleBlockImplicitTerminatorOp>::value,
50               "hasSingleBlockImplicitTerminator does not match "
51               "SingleBlockImplicitTerminatorOp");
52 
53 // Test support for interacting with the AsmPrinter.
54 struct TestOpAsmInterface : public OpAsmDialectInterface {
55   using OpAsmDialectInterface::OpAsmDialectInterface;
56 
57   AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
58     StringAttr strAttr = attr.dyn_cast<StringAttr>();
59     if (!strAttr)
60       return AliasResult::NoAlias;
61 
62     // Check the contents of the string attribute to see what the test alias
63     // should be named.
64     Optional<StringRef> aliasName =
65         StringSwitch<Optional<StringRef>>(strAttr.getValue())
66             .Case("alias_test:dot_in_name", StringRef("test.alias"))
67             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
68             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
69             .Case("alias_test:sanitize_conflict_a",
70                   StringRef("test_alias_conflict0"))
71             .Case("alias_test:sanitize_conflict_b",
72                   StringRef("test_alias_conflict0_"))
73             .Case("alias_test:tensor_encoding", StringRef("test_encoding"))
74             .Default(llvm::None);
75     if (!aliasName)
76       return AliasResult::NoAlias;
77 
78     os << *aliasName;
79     return AliasResult::FinalAlias;
80   }
81 
82   AliasResult getAlias(Type type, raw_ostream &os) const final {
83     if (auto tupleType = type.dyn_cast<TupleType>()) {
84       if (tupleType.size() > 0 &&
85           llvm::all_of(tupleType.getTypes(), [](Type elemType) {
86             return elemType.isa<SimpleAType>();
87           })) {
88         os << "test_tuple";
89         return AliasResult::FinalAlias;
90       }
91     }
92     if (auto intType = type.dyn_cast<TestIntegerType>()) {
93       if (intType.getSignedness() ==
94               TestIntegerType::SignednessSemantics::Unsigned &&
95           intType.getWidth() == 8) {
96         os << "test_ui8";
97         return AliasResult::FinalAlias;
98       }
99     }
100     return AliasResult::NoAlias;
101   }
102 };
103 
104 struct TestDialectFoldInterface : public DialectFoldInterface {
105   using DialectFoldInterface::DialectFoldInterface;
106 
107   /// Registered hook to check if the given region, which is attached to an
108   /// operation that is *not* isolated from above, should be used when
109   /// materializing constants.
110   bool shouldMaterializeInto(Region *region) const final {
111     // If this is a one region operation, then insert into it.
112     return isa<OneRegionOp>(region->getParentOp());
113   }
114 };
115 
116 /// This class defines the interface for handling inlining with standard
117 /// operations.
118 struct TestInlinerInterface : public DialectInlinerInterface {
119   using DialectInlinerInterface::DialectInlinerInterface;
120 
121   //===--------------------------------------------------------------------===//
122   // Analysis Hooks
123   //===--------------------------------------------------------------------===//
124 
125   bool isLegalToInline(Operation *call, Operation *callable,
126                        bool wouldBeCloned) const final {
127     // Don't allow inlining calls that are marked `noinline`.
128     return !call->hasAttr("noinline");
129   }
130   bool isLegalToInline(Region *, Region *, bool,
131                        BlockAndValueMapping &) const final {
132     // Inlining into test dialect regions is legal.
133     return true;
134   }
135   bool isLegalToInline(Operation *, Region *, bool,
136                        BlockAndValueMapping &) const final {
137     return true;
138   }
139 
140   bool shouldAnalyzeRecursively(Operation *op) const final {
141     // Analyze recursively if this is not a functional region operation, it
142     // froms a separate functional scope.
143     return !isa<FunctionalRegionOp>(op);
144   }
145 
146   //===--------------------------------------------------------------------===//
147   // Transformation Hooks
148   //===--------------------------------------------------------------------===//
149 
150   /// Handle the given inlined terminator by replacing it with a new operation
151   /// as necessary.
152   void handleTerminator(Operation *op,
153                         ArrayRef<Value> valuesToRepl) const final {
154     // Only handle "test.return" here.
155     auto returnOp = dyn_cast<TestReturnOp>(op);
156     if (!returnOp)
157       return;
158 
159     // Replace the values directly with the return operands.
160     assert(returnOp.getNumOperands() == valuesToRepl.size());
161     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
162       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
163   }
164 
165   /// Attempt to materialize a conversion for a type mismatch between a call
166   /// from this dialect, and a callable region. This method should generate an
167   /// operation that takes 'input' as the only operand, and produces a single
168   /// result of 'resultType'. If a conversion can not be generated, nullptr
169   /// should be returned.
170   Operation *materializeCallConversion(OpBuilder &builder, Value input,
171                                        Type resultType,
172                                        Location conversionLoc) const final {
173     // Only allow conversion for i16/i32 types.
174     if (!(resultType.isSignlessInteger(16) ||
175           resultType.isSignlessInteger(32)) ||
176         !(input.getType().isSignlessInteger(16) ||
177           input.getType().isSignlessInteger(32)))
178       return nullptr;
179     return builder.create<TestCastOp>(conversionLoc, resultType, input);
180   }
181 
182   void processInlinedCallBlocks(
183       Operation *call,
184       iterator_range<Region::iterator> inlinedBlocks) const final {
185     if (!isa<ConversionCallOp>(call))
186       return;
187 
188     // Set attributed on all ops in the inlined blocks.
189     for (Block &block : inlinedBlocks) {
190       block.walk([&](Operation *op) {
191         op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
192       });
193     }
194   }
195 };
196 
197 struct TestReductionPatternInterface : public DialectReductionPatternInterface {
198 public:
199   TestReductionPatternInterface(Dialect *dialect)
200       : DialectReductionPatternInterface(dialect) {}
201 
202   void populateReductionPatterns(RewritePatternSet &patterns) const final {
203     populateTestReductionPatterns(patterns);
204   }
205 };
206 
207 } // namespace
208 
209 //===----------------------------------------------------------------------===//
210 // TestDialect
211 //===----------------------------------------------------------------------===//
212 
213 static void testSideEffectOpGetEffect(
214     Operation *op,
215     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects);
216 
217 // This is the implementation of a dialect fallback for `TestEffectOpInterface`.
218 struct TestOpEffectInterfaceFallback
219     : public TestEffectOpInterface::FallbackModel<
220           TestOpEffectInterfaceFallback> {
221   static bool classof(Operation *op) {
222     bool isSupportedOp =
223         op->getName().getStringRef() == "test.unregistered_side_effect_op";
224     assert(isSupportedOp && "Unexpected dispatch");
225     return isSupportedOp;
226   }
227 
228   void
229   getEffects(Operation *op,
230              SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
231                  &effects) const {
232     testSideEffectOpGetEffect(op, effects);
233   }
234 };
235 
236 void TestDialect::initialize() {
237   registerAttributes();
238   registerTypes();
239   addOperations<
240 #define GET_OP_LIST
241 #include "TestOps.cpp.inc"
242       >();
243   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
244                 TestInlinerInterface, TestReductionPatternInterface>();
245   allowUnknownOperations();
246 
247   // Instantiate our fallback op interface that we'll use on specific
248   // unregistered op.
249   fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
250 }
251 TestDialect::~TestDialect() {
252   delete static_cast<TestOpEffectInterfaceFallback *>(
253       fallbackEffectOpInterfaces);
254 }
255 
256 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
257                                             Type type, Location loc) {
258   return builder.create<TestOpConstant>(loc, type, value);
259 }
260 
261 ::mlir::LogicalResult FormatInferType2Op::inferReturnTypes(
262     ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location,
263     ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
264     ::mlir::RegionRange regions,
265     ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
266   inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)});
267   return ::mlir::success();
268 }
269 
270 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
271                                                OperationName opName) {
272   if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
273       typeID == TypeID::get<TestEffectOpInterface>())
274     return fallbackEffectOpInterfaces;
275   return nullptr;
276 }
277 
278 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
279                                                     NamedAttribute namedAttr) {
280   if (namedAttr.getName() == "test.invalid_attr")
281     return op->emitError() << "invalid to use 'test.invalid_attr'";
282   return success();
283 }
284 
285 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
286                                                     unsigned regionIndex,
287                                                     unsigned argIndex,
288                                                     NamedAttribute namedAttr) {
289   if (namedAttr.getName() == "test.invalid_attr")
290     return op->emitError() << "invalid to use 'test.invalid_attr'";
291   return success();
292 }
293 
294 LogicalResult
295 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
296                                          unsigned resultIndex,
297                                          NamedAttribute namedAttr) {
298   if (namedAttr.getName() == "test.invalid_attr")
299     return op->emitError() << "invalid to use 'test.invalid_attr'";
300   return success();
301 }
302 
303 Optional<Dialect::ParseOpHook>
304 TestDialect::getParseOperationHook(StringRef opName) const {
305   if (opName == "test.dialect_custom_printer") {
306     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
307       return parser.parseKeyword("custom_format");
308     }};
309   }
310   if (opName == "test.dialect_custom_format_fallback") {
311     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
312       return parser.parseKeyword("custom_format_fallback");
313     }};
314   }
315   return None;
316 }
317 
318 llvm::unique_function<void(Operation *, OpAsmPrinter &)>
319 TestDialect::getOperationPrinter(Operation *op) const {
320   StringRef opName = op->getName().getStringRef();
321   if (opName == "test.dialect_custom_printer") {
322     return [](Operation *op, OpAsmPrinter &printer) {
323       printer.getStream() << " custom_format";
324     };
325   }
326   if (opName == "test.dialect_custom_format_fallback") {
327     return [](Operation *op, OpAsmPrinter &printer) {
328       printer.getStream() << " custom_format_fallback";
329     };
330   }
331   return {};
332 }
333 
334 //===----------------------------------------------------------------------===//
335 // TestBranchOp
336 //===----------------------------------------------------------------------===//
337 
338 Optional<MutableOperandRange>
339 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
340   assert(index == 0 && "invalid successor index");
341   return getTargetOperandsMutable();
342 }
343 
344 //===----------------------------------------------------------------------===//
345 // TestProducingBranchOp
346 //===----------------------------------------------------------------------===//
347 
348 Optional<MutableOperandRange>
349 TestProducingBranchOp::getMutableSuccessorOperands(unsigned index) {
350   assert(index <= 1 && "invalid successor index");
351   if (index == 1)
352     return getFirstOperandsMutable();
353   return getSecondOperandsMutable();
354 }
355 
356 //===----------------------------------------------------------------------===//
357 // TestDialectCanonicalizerOp
358 //===----------------------------------------------------------------------===//
359 
360 static LogicalResult
361 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
362                                PatternRewriter &rewriter) {
363   rewriter.replaceOpWithNewOp<arith::ConstantOp>(
364       op, rewriter.getI32IntegerAttr(42));
365   return success();
366 }
367 
368 void TestDialect::getCanonicalizationPatterns(
369     RewritePatternSet &results) const {
370   results.add(&dialectCanonicalizationPattern);
371 }
372 
373 //===----------------------------------------------------------------------===//
374 // TestCallOp
375 //===----------------------------------------------------------------------===//
376 
377 LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
378   // Check that the callee attribute was specified.
379   auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
380   if (!fnAttr)
381     return emitOpError("requires a 'callee' symbol reference attribute");
382   if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr))
383     return emitOpError() << "'" << fnAttr.getValue()
384                          << "' does not reference a valid function";
385   return success();
386 }
387 
388 //===----------------------------------------------------------------------===//
389 // TestFoldToCallOp
390 //===----------------------------------------------------------------------===//
391 
392 namespace {
393 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
394   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
395 
396   LogicalResult matchAndRewrite(FoldToCallOp op,
397                                 PatternRewriter &rewriter) const override {
398     rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(),
399                                               op.getCalleeAttr(), ValueRange());
400     return success();
401   }
402 };
403 } // namespace
404 
405 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
406                                                MLIRContext *context) {
407   results.add<FoldToCallOpPattern>(context);
408 }
409 
410 //===----------------------------------------------------------------------===//
411 // Test Format* operations
412 //===----------------------------------------------------------------------===//
413 
414 //===----------------------------------------------------------------------===//
415 // Parsing
416 
417 static ParseResult parseCustomOptionalOperand(
418     OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &optOperand) {
419   if (succeeded(parser.parseOptionalLParen())) {
420     optOperand.emplace();
421     if (parser.parseOperand(*optOperand) || parser.parseRParen())
422       return failure();
423   }
424   return success();
425 }
426 
427 static ParseResult parseCustomDirectiveOperands(
428     OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
429     Optional<OpAsmParser::UnresolvedOperand> &optOperand,
430     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands) {
431   if (parser.parseOperand(operand))
432     return failure();
433   if (succeeded(parser.parseOptionalComma())) {
434     optOperand.emplace();
435     if (parser.parseOperand(*optOperand))
436       return failure();
437   }
438   if (parser.parseArrow() || parser.parseLParen() ||
439       parser.parseOperandList(varOperands) || parser.parseRParen())
440     return failure();
441   return success();
442 }
443 static ParseResult
444 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
445                             Type &optOperandType,
446                             SmallVectorImpl<Type> &varOperandTypes) {
447   if (parser.parseColon())
448     return failure();
449 
450   if (parser.parseType(operandType))
451     return failure();
452   if (succeeded(parser.parseOptionalComma())) {
453     if (parser.parseType(optOperandType))
454       return failure();
455   }
456   if (parser.parseArrow() || parser.parseLParen() ||
457       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
458     return failure();
459   return success();
460 }
461 static ParseResult
462 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
463                                  Type optOperandType,
464                                  const SmallVectorImpl<Type> &varOperandTypes) {
465   if (parser.parseKeyword("type_refs_capture"))
466     return failure();
467 
468   Type operandType2, optOperandType2;
469   SmallVector<Type, 1> varOperandTypes2;
470   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
471                                   varOperandTypes2))
472     return failure();
473 
474   if (operandType != operandType2 || optOperandType != optOperandType2 ||
475       varOperandTypes != varOperandTypes2)
476     return failure();
477 
478   return success();
479 }
480 static ParseResult parseCustomDirectiveOperandsAndTypes(
481     OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
482     Optional<OpAsmParser::UnresolvedOperand> &optOperand,
483     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands,
484     Type &operandType, Type &optOperandType,
485     SmallVectorImpl<Type> &varOperandTypes) {
486   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
487       parseCustomDirectiveResults(parser, operandType, optOperandType,
488                                   varOperandTypes))
489     return failure();
490   return success();
491 }
492 static ParseResult parseCustomDirectiveRegions(
493     OpAsmParser &parser, Region &region,
494     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
495   if (parser.parseRegion(region))
496     return failure();
497   if (failed(parser.parseOptionalComma()))
498     return success();
499   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
500   if (parser.parseRegion(*varRegion))
501     return failure();
502   varRegions.emplace_back(std::move(varRegion));
503   return success();
504 }
505 static ParseResult
506 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
507                                SmallVectorImpl<Block *> &varSuccessors) {
508   if (parser.parseSuccessor(successor))
509     return failure();
510   if (failed(parser.parseOptionalComma()))
511     return success();
512   Block *varSuccessor;
513   if (parser.parseSuccessor(varSuccessor))
514     return failure();
515   varSuccessors.append(2, varSuccessor);
516   return success();
517 }
518 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
519                                                   IntegerAttr &attr,
520                                                   IntegerAttr &optAttr) {
521   if (parser.parseAttribute(attr))
522     return failure();
523   if (succeeded(parser.parseOptionalComma())) {
524     if (parser.parseAttribute(optAttr))
525       return failure();
526   }
527   return success();
528 }
529 
530 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
531                                                 NamedAttrList &attrs) {
532   return parser.parseOptionalAttrDict(attrs);
533 }
534 static ParseResult parseCustomDirectiveOptionalOperandRef(
535     OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &optOperand) {
536   int64_t operandCount = 0;
537   if (parser.parseInteger(operandCount))
538     return failure();
539   bool expectedOptionalOperand = operandCount == 0;
540   return success(expectedOptionalOperand != optOperand.hasValue());
541 }
542 
543 //===----------------------------------------------------------------------===//
544 // Printing
545 
546 static void printCustomOptionalOperand(OpAsmPrinter &printer, Operation *,
547                                        Value optOperand) {
548   if (optOperand)
549     printer << "(" << optOperand << ") ";
550 }
551 
552 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
553                                          Value operand, Value optOperand,
554                                          OperandRange varOperands) {
555   printer << operand;
556   if (optOperand)
557     printer << ", " << optOperand;
558   printer << " -> (" << varOperands << ")";
559 }
560 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
561                                         Type operandType, Type optOperandType,
562                                         TypeRange varOperandTypes) {
563   printer << " : " << operandType;
564   if (optOperandType)
565     printer << ", " << optOperandType;
566   printer << " -> (" << varOperandTypes << ")";
567 }
568 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
569                                              Operation *op, Type operandType,
570                                              Type optOperandType,
571                                              TypeRange varOperandTypes) {
572   printer << " type_refs_capture ";
573   printCustomDirectiveResults(printer, op, operandType, optOperandType,
574                               varOperandTypes);
575 }
576 static void printCustomDirectiveOperandsAndTypes(
577     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
578     OperandRange varOperands, Type operandType, Type optOperandType,
579     TypeRange varOperandTypes) {
580   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
581   printCustomDirectiveResults(printer, op, operandType, optOperandType,
582                               varOperandTypes);
583 }
584 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
585                                         Region &region,
586                                         MutableArrayRef<Region> varRegions) {
587   printer.printRegion(region);
588   if (!varRegions.empty()) {
589     printer << ", ";
590     for (Region &region : varRegions)
591       printer.printRegion(region);
592   }
593 }
594 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
595                                            Block *successor,
596                                            SuccessorRange varSuccessors) {
597   printer << successor;
598   if (!varSuccessors.empty())
599     printer << ", " << varSuccessors.front();
600 }
601 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
602                                            Attribute attribute,
603                                            Attribute optAttribute) {
604   printer << attribute;
605   if (optAttribute)
606     printer << ", " << optAttribute;
607 }
608 
609 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
610                                          DictionaryAttr attrs) {
611   printer.printOptionalAttrDict(attrs.getValue());
612 }
613 
614 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
615                                                    Operation *op,
616                                                    Value optOperand) {
617   printer << (optOperand ? "1" : "0");
618 }
619 
620 //===----------------------------------------------------------------------===//
621 // Test IsolatedRegionOp - parse passthrough region arguments.
622 //===----------------------------------------------------------------------===//
623 
624 ParseResult IsolatedRegionOp::parse(OpAsmParser &parser,
625                                     OperationState &result) {
626   OpAsmParser::UnresolvedOperand argInfo;
627   Type argType = parser.getBuilder().getIndexType();
628 
629   // Parse the input operand.
630   if (parser.parseOperand(argInfo) ||
631       parser.resolveOperand(argInfo, argType, result.operands))
632     return failure();
633 
634   // Parse the body region, and reuse the operand info as the argument info.
635   Region *body = result.addRegion();
636   return parser.parseRegion(*body, argInfo, argType, /*argLocations=*/{},
637                             /*enableNameShadowing=*/true);
638 }
639 
640 void IsolatedRegionOp::print(OpAsmPrinter &p) {
641   p << "test.isolated_region ";
642   p.printOperand(getOperand());
643   p.shadowRegionArgs(getRegion(), getOperand());
644   p << ' ';
645   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
646 }
647 
648 //===----------------------------------------------------------------------===//
649 // Test SSACFGRegionOp
650 //===----------------------------------------------------------------------===//
651 
652 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
653   return RegionKind::SSACFG;
654 }
655 
656 //===----------------------------------------------------------------------===//
657 // Test GraphRegionOp
658 //===----------------------------------------------------------------------===//
659 
660 ParseResult GraphRegionOp::parse(OpAsmParser &parser, OperationState &result) {
661   // Parse the body region, and reuse the operand info as the argument info.
662   Region *body = result.addRegion();
663   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
664 }
665 
666 void GraphRegionOp::print(OpAsmPrinter &p) {
667   p << "test.graph_region ";
668   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
669 }
670 
671 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
672   return RegionKind::Graph;
673 }
674 
675 //===----------------------------------------------------------------------===//
676 // Test AffineScopeOp
677 //===----------------------------------------------------------------------===//
678 
679 ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) {
680   // Parse the body region, and reuse the operand info as the argument info.
681   Region *body = result.addRegion();
682   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
683 }
684 
685 void AffineScopeOp::print(OpAsmPrinter &p) {
686   p << "test.affine_scope ";
687   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
688 }
689 
690 //===----------------------------------------------------------------------===//
691 // Test parser.
692 //===----------------------------------------------------------------------===//
693 
694 ParseResult ParseIntegerLiteralOp::parse(OpAsmParser &parser,
695                                          OperationState &result) {
696   if (parser.parseOptionalColon())
697     return success();
698   uint64_t numResults;
699   if (parser.parseInteger(numResults))
700     return failure();
701 
702   IndexType type = parser.getBuilder().getIndexType();
703   for (unsigned i = 0; i < numResults; ++i)
704     result.addTypes(type);
705   return success();
706 }
707 
708 void ParseIntegerLiteralOp::print(OpAsmPrinter &p) {
709   if (unsigned numResults = getNumResults())
710     p << " : " << numResults;
711 }
712 
713 ParseResult ParseWrappedKeywordOp::parse(OpAsmParser &parser,
714                                          OperationState &result) {
715   StringRef keyword;
716   if (parser.parseKeyword(&keyword))
717     return failure();
718   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
719   return success();
720 }
721 
722 void ParseWrappedKeywordOp::print(OpAsmPrinter &p) { p << " " << getKeyword(); }
723 
724 //===----------------------------------------------------------------------===//
725 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
726 
727 ParseResult WrappingRegionOp::parse(OpAsmParser &parser,
728                                     OperationState &result) {
729   if (parser.parseKeyword("wraps"))
730     return failure();
731 
732   // Parse the wrapped op in a region
733   Region &body = *result.addRegion();
734   body.push_back(new Block);
735   Block &block = body.back();
736   Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
737   if (!wrappedOp)
738     return failure();
739 
740   // Create a return terminator in the inner region, pass as operand to the
741   // terminator the returned values from the wrapped operation.
742   SmallVector<Value, 8> returnOperands(wrappedOp->getResults());
743   OpBuilder builder(parser.getContext());
744   builder.setInsertionPointToEnd(&block);
745   builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands);
746 
747   // Get the results type for the wrapping op from the terminator operands.
748   Operation &returnOp = body.back().back();
749   result.types.append(returnOp.operand_type_begin(),
750                       returnOp.operand_type_end());
751 
752   // Use the location of the wrapped op for the "test.wrapping_region" op.
753   result.location = wrappedOp->getLoc();
754 
755   return success();
756 }
757 
758 void WrappingRegionOp::print(OpAsmPrinter &p) {
759   p << " wraps ";
760   p.printGenericOp(&getRegion().front().front());
761 }
762 
763 //===----------------------------------------------------------------------===//
764 // Test PrettyPrintedRegionOp -  exercising the following parser APIs
765 //   parseGenericOperationAfterOpName
766 //   parseCustomOperationName
767 //===----------------------------------------------------------------------===//
768 
769 ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser,
770                                          OperationState &result) {
771 
772   SMLoc loc = parser.getCurrentLocation();
773   Location currLocation = parser.getEncodedSourceLoc(loc);
774 
775   // Parse the operands.
776   SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
777   if (parser.parseOperandList(operands))
778     return failure();
779 
780   // Check if we are parsing the pretty-printed version
781   //  test.pretty_printed_region start <inner-op> end : <functional-type>
782   // Else fallback to parsing the "non pretty-printed" version.
783   if (!succeeded(parser.parseOptionalKeyword("start")))
784     return parser.parseGenericOperationAfterOpName(
785         result, llvm::makeArrayRef(operands));
786 
787   FailureOr<OperationName> parseOpNameInfo = parser.parseCustomOperationName();
788   if (failed(parseOpNameInfo))
789     return failure();
790 
791   StringAttr innerOpName = parseOpNameInfo->getIdentifier();
792 
793   FunctionType opFntype;
794   Optional<Location> explicitLoc;
795   if (parser.parseKeyword("end") || parser.parseColon() ||
796       parser.parseType(opFntype) ||
797       parser.parseOptionalLocationSpecifier(explicitLoc))
798     return failure();
799 
800   // If location of the op is explicitly provided, then use it; Else use
801   // the parser's current location.
802   Location opLoc = explicitLoc.getValueOr(currLocation);
803 
804   // Derive the SSA-values for op's operands.
805   if (parser.resolveOperands(operands, opFntype.getInputs(), loc,
806                              result.operands))
807     return failure();
808 
809   // Add a region for op.
810   Region &region = *result.addRegion();
811 
812   // Create a basic-block inside op's region.
813   Block &block = region.emplaceBlock();
814 
815   // Create and insert an "inner-op" operation in the block.
816   // Just for testing purposes, we can assume that inner op is a binary op with
817   // result and operand types all same as the test-op's first operand.
818   Type innerOpType = opFntype.getInput(0);
819   Value lhs = block.addArgument(innerOpType, opLoc);
820   Value rhs = block.addArgument(innerOpType, opLoc);
821 
822   OpBuilder builder(parser.getBuilder().getContext());
823   builder.setInsertionPointToStart(&block);
824 
825   Operation *innerOp =
826       builder.create(opLoc, innerOpName, /*operands=*/{lhs, rhs}, innerOpType);
827 
828   // Insert a return statement in the block returning the inner-op's result.
829   builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults());
830 
831   // Populate the op operation-state with result-type and location.
832   result.addTypes(opFntype.getResults());
833   result.location = innerOp->getLoc();
834 
835   return success();
836 }
837 
838 void PrettyPrintedRegionOp::print(OpAsmPrinter &p) {
839   p << ' ';
840   p.printOperands(getOperands());
841 
842   Operation &innerOp = getRegion().front().front();
843   // Assuming that region has a single non-terminator inner-op, if the inner-op
844   // meets some criteria (which in this case is a simple one  based on the name
845   // of inner-op), then we can print the entire region in a succinct way.
846   // Here we assume that the prototype of "special.op" can be trivially derived
847   // while parsing it back.
848   if (innerOp.getName().getStringRef().equals("special.op")) {
849     p << " start special.op end";
850   } else {
851     p << " (";
852     p.printRegion(getRegion());
853     p << ")";
854   }
855 
856   p << " : ";
857   p.printFunctionalType(*this);
858 }
859 
860 //===----------------------------------------------------------------------===//
861 // Test PolyForOp - parse list of region arguments.
862 //===----------------------------------------------------------------------===//
863 
864 ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) {
865   SmallVector<OpAsmParser::UnresolvedOperand, 4> ivsInfo;
866   // Parse list of region arguments without a delimiter.
867   if (parser.parseRegionArgumentList(ivsInfo))
868     return failure();
869 
870   // Parse the body region.
871   Region *body = result.addRegion();
872   auto &builder = parser.getBuilder();
873   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
874   return parser.parseRegion(*body, ivsInfo, argTypes);
875 }
876 
877 void PolyForOp::print(OpAsmPrinter &p) { p.printGenericOp(*this); }
878 
879 void PolyForOp::getAsmBlockArgumentNames(Region &region,
880                                          OpAsmSetValueNameFn setNameFn) {
881   auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names");
882   if (!arrayAttr)
883     return;
884   auto args = getRegion().front().getArguments();
885   auto e = std::min(arrayAttr.size(), args.size());
886   for (unsigned i = 0; i < e; ++i) {
887     if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
888       setNameFn(args[i], strAttr.getValue());
889   }
890 }
891 
892 //===----------------------------------------------------------------------===//
893 // Test removing op with inner ops.
894 //===----------------------------------------------------------------------===//
895 
896 namespace {
897 struct TestRemoveOpWithInnerOps
898     : public OpRewritePattern<TestOpWithRegionPattern> {
899   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
900 
901   void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
902 
903   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
904                                 PatternRewriter &rewriter) const override {
905     rewriter.eraseOp(op);
906     return success();
907   }
908 };
909 } // namespace
910 
911 void TestOpWithRegionPattern::getCanonicalizationPatterns(
912     RewritePatternSet &results, MLIRContext *context) {
913   results.add<TestRemoveOpWithInnerOps>(context);
914 }
915 
916 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
917   return getOperand();
918 }
919 
920 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
921   return getValue();
922 }
923 
924 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
925     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
926   for (Value input : this->getOperands()) {
927     results.push_back(input);
928   }
929   return success();
930 }
931 
932 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
933   assert(operands.size() == 1);
934   if (operands.front()) {
935     (*this)->setAttr("attr", operands.front());
936     return getResult();
937   }
938   return {};
939 }
940 
941 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
942   return getOperand();
943 }
944 
945 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
946     MLIRContext *, Optional<Location> location, ValueRange operands,
947     DictionaryAttr attributes, RegionRange regions,
948     SmallVectorImpl<Type> &inferredReturnTypes) {
949   if (operands[0].getType() != operands[1].getType()) {
950     return emitOptionalError(location, "operand type mismatch ",
951                              operands[0].getType(), " vs ",
952                              operands[1].getType());
953   }
954   inferredReturnTypes.assign({operands[0].getType()});
955   return success();
956 }
957 
958 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
959     MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
960     DictionaryAttr attributes, RegionRange regions,
961     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
962   // Create return type consisting of the last element of the first operand.
963   auto operandType = operands.front().getType();
964   auto sval = operandType.dyn_cast<ShapedType>();
965   if (!sval) {
966     return emitOptionalError(location, "only shaped type operands allowed");
967   }
968   int64_t dim =
969       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
970   auto type = IntegerType::get(context, 17);
971   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
972   return success();
973 }
974 
975 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
976     OpBuilder &builder, ValueRange operands,
977     llvm::SmallVectorImpl<Value> &shapes) {
978   shapes = SmallVector<Value, 1>{
979       builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
980   return success();
981 }
982 
983 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
984     OpBuilder &builder, ValueRange operands,
985     llvm::SmallVectorImpl<Value> &shapes) {
986   Location loc = getLoc();
987   shapes.reserve(operands.size());
988   for (Value operand : llvm::reverse(operands)) {
989     auto rank = operand.getType().cast<RankedTensorType>().getRank();
990     auto currShape = llvm::to_vector<4>(
991         llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value {
992           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
993         }));
994     shapes.push_back(builder.create<tensor::FromElementsOp>(
995         getLoc(), RankedTensorType::get({rank}, builder.getIndexType()),
996         currShape));
997   }
998   return success();
999 }
1000 
1001 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
1002     OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
1003   Location loc = getLoc();
1004   shapes.reserve(getNumOperands());
1005   for (Value operand : llvm::reverse(getOperands())) {
1006     auto currShape = llvm::to_vector<4>(llvm::map_range(
1007         llvm::seq<int64_t>(
1008             0, operand.getType().cast<RankedTensorType>().getRank()),
1009         [&](int64_t dim) -> Value {
1010           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
1011         }));
1012     shapes.emplace_back(std::move(currShape));
1013   }
1014   return success();
1015 }
1016 
1017 //===----------------------------------------------------------------------===//
1018 // Test SideEffect interfaces
1019 //===----------------------------------------------------------------------===//
1020 
1021 namespace {
1022 /// A test resource for side effects.
1023 struct TestResource : public SideEffects::Resource::Base<TestResource> {
1024   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource)
1025 
1026   StringRef getName() final { return "<Test>"; }
1027 };
1028 } // namespace
1029 
1030 static void testSideEffectOpGetEffect(
1031     Operation *op,
1032     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
1033         &effects) {
1034   auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
1035   if (!effectsAttr)
1036     return;
1037 
1038   effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
1039 }
1040 
1041 void SideEffectOp::getEffects(
1042     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1043   // Check for an effects attribute on the op instance.
1044   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
1045   if (!effectsAttr)
1046     return;
1047 
1048   // If there is one, it is an array of dictionary attributes that hold
1049   // information on the effects of this operation.
1050   for (Attribute element : effectsAttr) {
1051     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
1052 
1053     // Get the specific memory effect.
1054     MemoryEffects::Effect *effect =
1055         StringSwitch<MemoryEffects::Effect *>(
1056             effectElement.get("effect").cast<StringAttr>().getValue())
1057             .Case("allocate", MemoryEffects::Allocate::get())
1058             .Case("free", MemoryEffects::Free::get())
1059             .Case("read", MemoryEffects::Read::get())
1060             .Case("write", MemoryEffects::Write::get());
1061 
1062     // Check for a non-default resource to use.
1063     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
1064     if (effectElement.get("test_resource"))
1065       resource = TestResource::get();
1066 
1067     // Check for a result to affect.
1068     if (effectElement.get("on_result"))
1069       effects.emplace_back(effect, getResult(), resource);
1070     else if (Attribute ref = effectElement.get("on_reference"))
1071       effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
1072     else
1073       effects.emplace_back(effect, resource);
1074   }
1075 }
1076 
1077 void SideEffectOp::getEffects(
1078     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
1079   testSideEffectOpGetEffect(getOperation(), effects);
1080 }
1081 
1082 //===----------------------------------------------------------------------===//
1083 // StringAttrPrettyNameOp
1084 //===----------------------------------------------------------------------===//
1085 
1086 // This op has fancy handling of its SSA result name.
1087 ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser,
1088                                           OperationState &result) {
1089   // Add the result types.
1090   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
1091     result.addTypes(parser.getBuilder().getIntegerType(32));
1092 
1093   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1094     return failure();
1095 
1096   // If the attribute dictionary contains no 'names' attribute, infer it from
1097   // the SSA name (if specified).
1098   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
1099     return attr.getName() == "names";
1100   });
1101 
1102   // If there was no name specified, check to see if there was a useful name
1103   // specified in the asm file.
1104   if (hadNames || parser.getNumResults() == 0)
1105     return success();
1106 
1107   SmallVector<StringRef, 4> names;
1108   auto *context = result.getContext();
1109 
1110   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
1111     auto resultName = parser.getResultName(i);
1112     StringRef nameStr;
1113     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
1114       nameStr = resultName.first;
1115 
1116     names.push_back(nameStr);
1117   }
1118 
1119   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
1120   result.attributes.push_back({StringAttr::get(context, "names"), namesAttr});
1121   return success();
1122 }
1123 
1124 void StringAttrPrettyNameOp::print(OpAsmPrinter &p) {
1125   // Note that we only need to print the "name" attribute if the asmprinter
1126   // result name disagrees with it.  This can happen in strange cases, e.g.
1127   // when there are conflicts.
1128   bool namesDisagree = getNames().size() != getNumResults();
1129 
1130   SmallString<32> resultNameStr;
1131   for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) {
1132     resultNameStr.clear();
1133     llvm::raw_svector_ostream tmpStream(resultNameStr);
1134     p.printOperand(getResult(i), tmpStream);
1135 
1136     auto expectedName = getNames()[i].dyn_cast<StringAttr>();
1137     if (!expectedName ||
1138         tmpStream.str().drop_front() != expectedName.getValue()) {
1139       namesDisagree = true;
1140     }
1141   }
1142 
1143   if (namesDisagree)
1144     p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
1145   else
1146     p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"});
1147 }
1148 
1149 // We set the SSA name in the asm syntax to the contents of the name
1150 // attribute.
1151 void StringAttrPrettyNameOp::getAsmResultNames(
1152     function_ref<void(Value, StringRef)> setNameFn) {
1153 
1154   auto value = getNames();
1155   for (size_t i = 0, e = value.size(); i != e; ++i)
1156     if (auto str = value[i].dyn_cast<StringAttr>())
1157       if (!str.getValue().empty())
1158         setNameFn(getResult(i), str.getValue());
1159 }
1160 
1161 //===----------------------------------------------------------------------===//
1162 // ResultTypeWithTraitOp
1163 //===----------------------------------------------------------------------===//
1164 
1165 LogicalResult ResultTypeWithTraitOp::verify() {
1166   if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>())
1167     return success();
1168   return emitError("result type should have trait 'TestTypeTrait'");
1169 }
1170 
1171 //===----------------------------------------------------------------------===//
1172 // AttrWithTraitOp
1173 //===----------------------------------------------------------------------===//
1174 
1175 LogicalResult AttrWithTraitOp::verify() {
1176   if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>())
1177     return success();
1178   return emitError("'attr' attribute should have trait 'TestAttrTrait'");
1179 }
1180 
1181 //===----------------------------------------------------------------------===//
1182 // RegionIfOp
1183 //===----------------------------------------------------------------------===//
1184 
1185 void RegionIfOp::print(OpAsmPrinter &p) {
1186   p << " ";
1187   p.printOperands(getOperands());
1188   p << ": " << getOperandTypes();
1189   p.printArrowTypeList(getResultTypes());
1190   p << " then ";
1191   p.printRegion(getThenRegion(),
1192                 /*printEntryBlockArgs=*/true,
1193                 /*printBlockTerminators=*/true);
1194   p << " else ";
1195   p.printRegion(getElseRegion(),
1196                 /*printEntryBlockArgs=*/true,
1197                 /*printBlockTerminators=*/true);
1198   p << " join ";
1199   p.printRegion(getJoinRegion(),
1200                 /*printEntryBlockArgs=*/true,
1201                 /*printBlockTerminators=*/true);
1202 }
1203 
1204 ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
1205   SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos;
1206   SmallVector<Type, 2> operandTypes;
1207 
1208   result.regions.reserve(3);
1209   Region *thenRegion = result.addRegion();
1210   Region *elseRegion = result.addRegion();
1211   Region *joinRegion = result.addRegion();
1212 
1213   // Parse operand, type and arrow type lists.
1214   if (parser.parseOperandList(operandInfos) ||
1215       parser.parseColonTypeList(operandTypes) ||
1216       parser.parseArrowTypeList(result.types))
1217     return failure();
1218 
1219   // Parse all attached regions.
1220   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
1221       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
1222       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
1223     return failure();
1224 
1225   return parser.resolveOperands(operandInfos, operandTypes,
1226                                 parser.getCurrentLocation(), result.operands);
1227 }
1228 
1229 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
1230   assert(index < 2 && "invalid region index");
1231   return getOperands();
1232 }
1233 
1234 void RegionIfOp::getSuccessorRegions(
1235     Optional<unsigned> index, ArrayRef<Attribute> operands,
1236     SmallVectorImpl<RegionSuccessor> &regions) {
1237   // We always branch to the join region.
1238   if (index.hasValue()) {
1239     if (index.getValue() < 2)
1240       regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
1241     else
1242       regions.push_back(RegionSuccessor(getResults()));
1243     return;
1244   }
1245 
1246   // The then and else regions are the entry regions of this op.
1247   regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs()));
1248   regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs()));
1249 }
1250 
1251 void RegionIfOp::getRegionInvocationBounds(
1252     ArrayRef<Attribute> operands,
1253     SmallVectorImpl<InvocationBounds> &invocationBounds) {
1254   // Each region is invoked at most once.
1255   invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1});
1256 }
1257 
1258 //===----------------------------------------------------------------------===//
1259 // AnyCondOp
1260 //===----------------------------------------------------------------------===//
1261 
1262 void AnyCondOp::getSuccessorRegions(Optional<unsigned> index,
1263                                     ArrayRef<Attribute> operands,
1264                                     SmallVectorImpl<RegionSuccessor> &regions) {
1265   // The parent op branches into the only region, and the region branches back
1266   // to the parent op.
1267   if (index)
1268     regions.emplace_back(&getRegion());
1269   else
1270     regions.emplace_back(getResults());
1271 }
1272 
1273 void AnyCondOp::getRegionInvocationBounds(
1274     ArrayRef<Attribute> operands,
1275     SmallVectorImpl<InvocationBounds> &invocationBounds) {
1276   invocationBounds.emplace_back(1, 1);
1277 }
1278 
1279 //===----------------------------------------------------------------------===//
1280 // SingleNoTerminatorCustomAsmOp
1281 //===----------------------------------------------------------------------===//
1282 
1283 ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser,
1284                                                  OperationState &state) {
1285   Region *body = state.addRegion();
1286   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
1287     return failure();
1288   return success();
1289 }
1290 
1291 void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) {
1292   printer.printRegion(
1293       getRegion(), /*printEntryBlockArgs=*/false,
1294       // This op has a single block without terminators. But explicitly mark
1295       // as not printing block terminators for testing.
1296       /*printBlockTerminators=*/false);
1297 }
1298 
1299 #include "TestOpEnums.cpp.inc"
1300 #include "TestOpInterfaces.cpp.inc"
1301 #include "TestOpStructs.cpp.inc"
1302 #include "TestTypeInterfaces.cpp.inc"
1303 
1304 #define GET_OP_CLASSES
1305 #include "TestOps.cpp.inc"
1306