1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestAttributes.h"
11 #include "TestInterfaces.h"
12 #include "TestTypes.h"
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/DLTI/DLTI.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/DialectImplementation.h"
19 #include "mlir/IR/ExtensibleDialect.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/IR/TypeUtilities.h"
22 #include "mlir/Reducer/ReductionPatternInterface.h"
23 #include "mlir/Transforms/FoldUtils.h"
24 #include "mlir/Transforms/InliningUtils.h"
25 #include "llvm/ADT/StringSwitch.h"
26 
27 // Include this before the using namespace lines below to
28 // test that we don't have namespace dependencies.
29 #include "TestOpsDialect.cpp.inc"
30 
31 using namespace mlir;
32 using namespace test;
33 
34 void test::registerTestDialect(DialectRegistry &registry) {
35   registry.insert<TestDialect>();
36 }
37 
38 //===----------------------------------------------------------------------===//
39 // TestDialect Interfaces
40 //===----------------------------------------------------------------------===//
41 
42 namespace {
43 
44 /// Testing the correctness of some traits.
45 static_assert(
46     llvm::is_detected<OpTrait::has_implicit_terminator_t,
47                       SingleBlockImplicitTerminatorOp>::value,
48     "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
49 static_assert(OpTrait::hasSingleBlockImplicitTerminator<
50                   SingleBlockImplicitTerminatorOp>::value,
51               "hasSingleBlockImplicitTerminator does not match "
52               "SingleBlockImplicitTerminatorOp");
53 
54 // Test support for interacting with the AsmPrinter.
55 struct TestOpAsmInterface : public OpAsmDialectInterface {
56   using OpAsmDialectInterface::OpAsmDialectInterface;
57 
58   AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
59     StringAttr strAttr = attr.dyn_cast<StringAttr>();
60     if (!strAttr)
61       return AliasResult::NoAlias;
62 
63     // Check the contents of the string attribute to see what the test alias
64     // should be named.
65     Optional<StringRef> aliasName =
66         StringSwitch<Optional<StringRef>>(strAttr.getValue())
67             .Case("alias_test:dot_in_name", StringRef("test.alias"))
68             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
69             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
70             .Case("alias_test:sanitize_conflict_a",
71                   StringRef("test_alias_conflict0"))
72             .Case("alias_test:sanitize_conflict_b",
73                   StringRef("test_alias_conflict0_"))
74             .Case("alias_test:tensor_encoding", StringRef("test_encoding"))
75             .Default(llvm::None);
76     if (!aliasName)
77       return AliasResult::NoAlias;
78 
79     os << *aliasName;
80     return AliasResult::FinalAlias;
81   }
82 
83   AliasResult getAlias(Type type, raw_ostream &os) const final {
84     if (auto tupleType = type.dyn_cast<TupleType>()) {
85       if (tupleType.size() > 0 &&
86           llvm::all_of(tupleType.getTypes(), [](Type elemType) {
87             return elemType.isa<SimpleAType>();
88           })) {
89         os << "test_tuple";
90         return AliasResult::FinalAlias;
91       }
92     }
93     if (auto intType = type.dyn_cast<TestIntegerType>()) {
94       if (intType.getSignedness() ==
95               TestIntegerType::SignednessSemantics::Unsigned &&
96           intType.getWidth() == 8) {
97         os << "test_ui8";
98         return AliasResult::FinalAlias;
99       }
100     }
101     return AliasResult::NoAlias;
102   }
103 };
104 
105 struct TestDialectFoldInterface : public DialectFoldInterface {
106   using DialectFoldInterface::DialectFoldInterface;
107 
108   /// Registered hook to check if the given region, which is attached to an
109   /// operation that is *not* isolated from above, should be used when
110   /// materializing constants.
111   bool shouldMaterializeInto(Region *region) const final {
112     // If this is a one region operation, then insert into it.
113     return isa<OneRegionOp>(region->getParentOp());
114   }
115 };
116 
117 /// This class defines the interface for handling inlining with standard
118 /// operations.
119 struct TestInlinerInterface : public DialectInlinerInterface {
120   using DialectInlinerInterface::DialectInlinerInterface;
121 
122   //===--------------------------------------------------------------------===//
123   // Analysis Hooks
124   //===--------------------------------------------------------------------===//
125 
126   bool isLegalToInline(Operation *call, Operation *callable,
127                        bool wouldBeCloned) const final {
128     // Don't allow inlining calls that are marked `noinline`.
129     return !call->hasAttr("noinline");
130   }
131   bool isLegalToInline(Region *, Region *, bool,
132                        BlockAndValueMapping &) const final {
133     // Inlining into test dialect regions is legal.
134     return true;
135   }
136   bool isLegalToInline(Operation *, Region *, bool,
137                        BlockAndValueMapping &) const final {
138     return true;
139   }
140 
141   bool shouldAnalyzeRecursively(Operation *op) const final {
142     // Analyze recursively if this is not a functional region operation, it
143     // froms a separate functional scope.
144     return !isa<FunctionalRegionOp>(op);
145   }
146 
147   //===--------------------------------------------------------------------===//
148   // Transformation Hooks
149   //===--------------------------------------------------------------------===//
150 
151   /// Handle the given inlined terminator by replacing it with a new operation
152   /// as necessary.
153   void handleTerminator(Operation *op,
154                         ArrayRef<Value> valuesToRepl) const final {
155     // Only handle "test.return" here.
156     auto returnOp = dyn_cast<TestReturnOp>(op);
157     if (!returnOp)
158       return;
159 
160     // Replace the values directly with the return operands.
161     assert(returnOp.getNumOperands() == valuesToRepl.size());
162     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
163       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
164   }
165 
166   /// Attempt to materialize a conversion for a type mismatch between a call
167   /// from this dialect, and a callable region. This method should generate an
168   /// operation that takes 'input' as the only operand, and produces a single
169   /// result of 'resultType'. If a conversion can not be generated, nullptr
170   /// should be returned.
171   Operation *materializeCallConversion(OpBuilder &builder, Value input,
172                                        Type resultType,
173                                        Location conversionLoc) const final {
174     // Only allow conversion for i16/i32 types.
175     if (!(resultType.isSignlessInteger(16) ||
176           resultType.isSignlessInteger(32)) ||
177         !(input.getType().isSignlessInteger(16) ||
178           input.getType().isSignlessInteger(32)))
179       return nullptr;
180     return builder.create<TestCastOp>(conversionLoc, resultType, input);
181   }
182 
183   void processInlinedCallBlocks(
184       Operation *call,
185       iterator_range<Region::iterator> inlinedBlocks) const final {
186     if (!isa<ConversionCallOp>(call))
187       return;
188 
189     // Set attributed on all ops in the inlined blocks.
190     for (Block &block : inlinedBlocks) {
191       block.walk([&](Operation *op) {
192         op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
193       });
194     }
195   }
196 };
197 
198 struct TestReductionPatternInterface : public DialectReductionPatternInterface {
199 public:
200   TestReductionPatternInterface(Dialect *dialect)
201       : DialectReductionPatternInterface(dialect) {}
202 
203   void populateReductionPatterns(RewritePatternSet &patterns) const final {
204     populateTestReductionPatterns(patterns);
205   }
206 };
207 
208 } // namespace
209 
210 //===----------------------------------------------------------------------===//
211 // Dynamic operations
212 //===----------------------------------------------------------------------===//
213 
214 std::unique_ptr<DynamicOpDefinition> getGenericDynamicOp(TestDialect *dialect) {
215   return DynamicOpDefinition::get(
216       "generic_dynamic_op", dialect, [](Operation *op) { return success(); },
217       [](Operation *op) { return success(); });
218 }
219 
220 std::unique_ptr<DynamicOpDefinition>
221 getOneOperandTwoResultsDynamicOp(TestDialect *dialect) {
222   return DynamicOpDefinition::get(
223       "one_operand_two_results", dialect,
224       [](Operation *op) {
225         if (op->getNumOperands() != 1) {
226           op->emitOpError()
227               << "expected 1 operand, but had " << op->getNumOperands();
228           return failure();
229         }
230         if (op->getNumResults() != 2) {
231           op->emitOpError()
232               << "expected 2 results, but had " << op->getNumResults();
233           return failure();
234         }
235         return success();
236       },
237       [](Operation *op) { return success(); });
238 }
239 
240 std::unique_ptr<DynamicOpDefinition>
241 getCustomParserPrinterDynamicOp(TestDialect *dialect) {
242   auto verifier = [](Operation *op) {
243     if (op->getNumOperands() == 0 && op->getNumResults() == 0)
244       return success();
245     op->emitError() << "operation should have no operands and no results";
246     return failure();
247   };
248   auto regionVerifier = [](Operation *op) { return success(); };
249 
250   auto parser = [](OpAsmParser &parser, OperationState &state) {
251     return parser.parseKeyword("custom_keyword");
252   };
253 
254   auto printer = [](Operation *op, OpAsmPrinter &printer, llvm::StringRef) {
255     printer << op->getName() << " custom_keyword";
256   };
257 
258   return DynamicOpDefinition::get("custom_parser_printer_dynamic_op", dialect,
259                                   verifier, regionVerifier, parser, printer);
260 }
261 
262 //===----------------------------------------------------------------------===//
263 // TestDialect
264 //===----------------------------------------------------------------------===//
265 
266 static void testSideEffectOpGetEffect(
267     Operation *op,
268     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects);
269 
270 // This is the implementation of a dialect fallback for `TestEffectOpInterface`.
271 struct TestOpEffectInterfaceFallback
272     : public TestEffectOpInterface::FallbackModel<
273           TestOpEffectInterfaceFallback> {
274   static bool classof(Operation *op) {
275     bool isSupportedOp =
276         op->getName().getStringRef() == "test.unregistered_side_effect_op";
277     assert(isSupportedOp && "Unexpected dispatch");
278     return isSupportedOp;
279   }
280 
281   void
282   getEffects(Operation *op,
283              SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
284                  &effects) const {
285     testSideEffectOpGetEffect(op, effects);
286   }
287 };
288 
289 void TestDialect::initialize() {
290   registerAttributes();
291   registerTypes();
292   addOperations<
293 #define GET_OP_LIST
294 #include "TestOps.cpp.inc"
295       >();
296   registerDynamicOp(getGenericDynamicOp(this));
297   registerDynamicOp(getOneOperandTwoResultsDynamicOp(this));
298   registerDynamicOp(getCustomParserPrinterDynamicOp(this));
299 
300   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
301                 TestInlinerInterface, TestReductionPatternInterface>();
302   allowUnknownOperations();
303 
304   // Instantiate our fallback op interface that we'll use on specific
305   // unregistered op.
306   fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
307 }
308 TestDialect::~TestDialect() {
309   delete static_cast<TestOpEffectInterfaceFallback *>(
310       fallbackEffectOpInterfaces);
311 }
312 
313 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
314                                             Type type, Location loc) {
315   return builder.create<TestOpConstant>(loc, type, value);
316 }
317 
318 ::mlir::LogicalResult FormatInferType2Op::inferReturnTypes(
319     ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location,
320     ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
321     ::mlir::RegionRange regions,
322     ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
323   inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)});
324   return ::mlir::success();
325 }
326 
327 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
328                                                OperationName opName) {
329   if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
330       typeID == TypeID::get<TestEffectOpInterface>())
331     return fallbackEffectOpInterfaces;
332   return nullptr;
333 }
334 
335 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
336                                                     NamedAttribute namedAttr) {
337   if (namedAttr.getName() == "test.invalid_attr")
338     return op->emitError() << "invalid to use 'test.invalid_attr'";
339   return success();
340 }
341 
342 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
343                                                     unsigned regionIndex,
344                                                     unsigned argIndex,
345                                                     NamedAttribute namedAttr) {
346   if (namedAttr.getName() == "test.invalid_attr")
347     return op->emitError() << "invalid to use 'test.invalid_attr'";
348   return success();
349 }
350 
351 LogicalResult
352 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
353                                          unsigned resultIndex,
354                                          NamedAttribute namedAttr) {
355   if (namedAttr.getName() == "test.invalid_attr")
356     return op->emitError() << "invalid to use 'test.invalid_attr'";
357   return success();
358 }
359 
360 Optional<Dialect::ParseOpHook>
361 TestDialect::getParseOperationHook(StringRef opName) const {
362   if (opName == "test.dialect_custom_printer") {
363     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
364       return parser.parseKeyword("custom_format");
365     }};
366   }
367   if (opName == "test.dialect_custom_format_fallback") {
368     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
369       return parser.parseKeyword("custom_format_fallback");
370     }};
371   }
372   return None;
373 }
374 
375 llvm::unique_function<void(Operation *, OpAsmPrinter &)>
376 TestDialect::getOperationPrinter(Operation *op) const {
377   StringRef opName = op->getName().getStringRef();
378   if (opName == "test.dialect_custom_printer") {
379     return [](Operation *op, OpAsmPrinter &printer) {
380       printer.getStream() << " custom_format";
381     };
382   }
383   if (opName == "test.dialect_custom_format_fallback") {
384     return [](Operation *op, OpAsmPrinter &printer) {
385       printer.getStream() << " custom_format_fallback";
386     };
387   }
388   return {};
389 }
390 
391 //===----------------------------------------------------------------------===//
392 // TestBranchOp
393 //===----------------------------------------------------------------------===//
394 
395 Optional<MutableOperandRange>
396 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
397   assert(index == 0 && "invalid successor index");
398   return getTargetOperandsMutable();
399 }
400 
401 //===----------------------------------------------------------------------===//
402 // TestDialectCanonicalizerOp
403 //===----------------------------------------------------------------------===//
404 
405 static LogicalResult
406 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
407                                PatternRewriter &rewriter) {
408   rewriter.replaceOpWithNewOp<arith::ConstantOp>(
409       op, rewriter.getI32IntegerAttr(42));
410   return success();
411 }
412 
413 void TestDialect::getCanonicalizationPatterns(
414     RewritePatternSet &results) const {
415   results.add(&dialectCanonicalizationPattern);
416 }
417 
418 //===----------------------------------------------------------------------===//
419 // TestFoldToCallOp
420 //===----------------------------------------------------------------------===//
421 
422 namespace {
423 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
424   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
425 
426   LogicalResult matchAndRewrite(FoldToCallOp op,
427                                 PatternRewriter &rewriter) const override {
428     rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(),
429                                               op.getCalleeAttr(), ValueRange());
430     return success();
431   }
432 };
433 } // namespace
434 
435 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
436                                                MLIRContext *context) {
437   results.add<FoldToCallOpPattern>(context);
438 }
439 
440 //===----------------------------------------------------------------------===//
441 // Test Format* operations
442 //===----------------------------------------------------------------------===//
443 
444 //===----------------------------------------------------------------------===//
445 // Parsing
446 
447 static ParseResult
448 parseCustomOptionalOperand(OpAsmParser &parser,
449                            Optional<OpAsmParser::OperandType> &optOperand) {
450   if (succeeded(parser.parseOptionalLParen())) {
451     optOperand.emplace();
452     if (parser.parseOperand(*optOperand) || parser.parseRParen())
453       return failure();
454   }
455   return success();
456 }
457 
458 static ParseResult parseCustomDirectiveOperands(
459     OpAsmParser &parser, OpAsmParser::OperandType &operand,
460     Optional<OpAsmParser::OperandType> &optOperand,
461     SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
462   if (parser.parseOperand(operand))
463     return failure();
464   if (succeeded(parser.parseOptionalComma())) {
465     optOperand.emplace();
466     if (parser.parseOperand(*optOperand))
467       return failure();
468   }
469   if (parser.parseArrow() || parser.parseLParen() ||
470       parser.parseOperandList(varOperands) || parser.parseRParen())
471     return failure();
472   return success();
473 }
474 static ParseResult
475 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
476                             Type &optOperandType,
477                             SmallVectorImpl<Type> &varOperandTypes) {
478   if (parser.parseColon())
479     return failure();
480 
481   if (parser.parseType(operandType))
482     return failure();
483   if (succeeded(parser.parseOptionalComma())) {
484     if (parser.parseType(optOperandType))
485       return failure();
486   }
487   if (parser.parseArrow() || parser.parseLParen() ||
488       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
489     return failure();
490   return success();
491 }
492 static ParseResult
493 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
494                                  Type optOperandType,
495                                  const SmallVectorImpl<Type> &varOperandTypes) {
496   if (parser.parseKeyword("type_refs_capture"))
497     return failure();
498 
499   Type operandType2, optOperandType2;
500   SmallVector<Type, 1> varOperandTypes2;
501   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
502                                   varOperandTypes2))
503     return failure();
504 
505   if (operandType != operandType2 || optOperandType != optOperandType2 ||
506       varOperandTypes != varOperandTypes2)
507     return failure();
508 
509   return success();
510 }
511 static ParseResult parseCustomDirectiveOperandsAndTypes(
512     OpAsmParser &parser, OpAsmParser::OperandType &operand,
513     Optional<OpAsmParser::OperandType> &optOperand,
514     SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
515     Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
516   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
517       parseCustomDirectiveResults(parser, operandType, optOperandType,
518                                   varOperandTypes))
519     return failure();
520   return success();
521 }
522 static ParseResult parseCustomDirectiveRegions(
523     OpAsmParser &parser, Region &region,
524     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
525   if (parser.parseRegion(region))
526     return failure();
527   if (failed(parser.parseOptionalComma()))
528     return success();
529   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
530   if (parser.parseRegion(*varRegion))
531     return failure();
532   varRegions.emplace_back(std::move(varRegion));
533   return success();
534 }
535 static ParseResult
536 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
537                                SmallVectorImpl<Block *> &varSuccessors) {
538   if (parser.parseSuccessor(successor))
539     return failure();
540   if (failed(parser.parseOptionalComma()))
541     return success();
542   Block *varSuccessor;
543   if (parser.parseSuccessor(varSuccessor))
544     return failure();
545   varSuccessors.append(2, varSuccessor);
546   return success();
547 }
548 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
549                                                   IntegerAttr &attr,
550                                                   IntegerAttr &optAttr) {
551   if (parser.parseAttribute(attr))
552     return failure();
553   if (succeeded(parser.parseOptionalComma())) {
554     if (parser.parseAttribute(optAttr))
555       return failure();
556   }
557   return success();
558 }
559 
560 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
561                                                 NamedAttrList &attrs) {
562   return parser.parseOptionalAttrDict(attrs);
563 }
564 static ParseResult parseCustomDirectiveOptionalOperandRef(
565     OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) {
566   int64_t operandCount = 0;
567   if (parser.parseInteger(operandCount))
568     return failure();
569   bool expectedOptionalOperand = operandCount == 0;
570   return success(expectedOptionalOperand != optOperand.hasValue());
571 }
572 
573 //===----------------------------------------------------------------------===//
574 // Printing
575 
576 static void printCustomOptionalOperand(OpAsmPrinter &printer, Operation *,
577                                        Value optOperand) {
578   if (optOperand)
579     printer << "(" << optOperand << ") ";
580 }
581 
582 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
583                                          Value operand, Value optOperand,
584                                          OperandRange varOperands) {
585   printer << operand;
586   if (optOperand)
587     printer << ", " << optOperand;
588   printer << " -> (" << varOperands << ")";
589 }
590 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
591                                         Type operandType, Type optOperandType,
592                                         TypeRange varOperandTypes) {
593   printer << " : " << operandType;
594   if (optOperandType)
595     printer << ", " << optOperandType;
596   printer << " -> (" << varOperandTypes << ")";
597 }
598 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
599                                              Operation *op, Type operandType,
600                                              Type optOperandType,
601                                              TypeRange varOperandTypes) {
602   printer << " type_refs_capture ";
603   printCustomDirectiveResults(printer, op, operandType, optOperandType,
604                               varOperandTypes);
605 }
606 static void printCustomDirectiveOperandsAndTypes(
607     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
608     OperandRange varOperands, Type operandType, Type optOperandType,
609     TypeRange varOperandTypes) {
610   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
611   printCustomDirectiveResults(printer, op, operandType, optOperandType,
612                               varOperandTypes);
613 }
614 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
615                                         Region &region,
616                                         MutableArrayRef<Region> varRegions) {
617   printer.printRegion(region);
618   if (!varRegions.empty()) {
619     printer << ", ";
620     for (Region &region : varRegions)
621       printer.printRegion(region);
622   }
623 }
624 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
625                                            Block *successor,
626                                            SuccessorRange varSuccessors) {
627   printer << successor;
628   if (!varSuccessors.empty())
629     printer << ", " << varSuccessors.front();
630 }
631 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
632                                            Attribute attribute,
633                                            Attribute optAttribute) {
634   printer << attribute;
635   if (optAttribute)
636     printer << ", " << optAttribute;
637 }
638 
639 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
640                                          DictionaryAttr attrs) {
641   printer.printOptionalAttrDict(attrs.getValue());
642 }
643 
644 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
645                                                    Operation *op,
646                                                    Value optOperand) {
647   printer << (optOperand ? "1" : "0");
648 }
649 
650 //===----------------------------------------------------------------------===//
651 // Test IsolatedRegionOp - parse passthrough region arguments.
652 //===----------------------------------------------------------------------===//
653 
654 ParseResult IsolatedRegionOp::parse(OpAsmParser &parser,
655                                     OperationState &result) {
656   OpAsmParser::OperandType argInfo;
657   Type argType = parser.getBuilder().getIndexType();
658 
659   // Parse the input operand.
660   if (parser.parseOperand(argInfo) ||
661       parser.resolveOperand(argInfo, argType, result.operands))
662     return failure();
663 
664   // Parse the body region, and reuse the operand info as the argument info.
665   Region *body = result.addRegion();
666   return parser.parseRegion(*body, argInfo, argType, /*argLocations=*/{},
667                             /*enableNameShadowing=*/true);
668 }
669 
670 void IsolatedRegionOp::print(OpAsmPrinter &p) {
671   p << "test.isolated_region ";
672   p.printOperand(getOperand());
673   p.shadowRegionArgs(getRegion(), getOperand());
674   p << ' ';
675   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
676 }
677 
678 //===----------------------------------------------------------------------===//
679 // Test SSACFGRegionOp
680 //===----------------------------------------------------------------------===//
681 
682 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
683   return RegionKind::SSACFG;
684 }
685 
686 //===----------------------------------------------------------------------===//
687 // Test GraphRegionOp
688 //===----------------------------------------------------------------------===//
689 
690 ParseResult GraphRegionOp::parse(OpAsmParser &parser, OperationState &result) {
691   // Parse the body region, and reuse the operand info as the argument info.
692   Region *body = result.addRegion();
693   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
694 }
695 
696 void GraphRegionOp::print(OpAsmPrinter &p) {
697   p << "test.graph_region ";
698   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
699 }
700 
701 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
702   return RegionKind::Graph;
703 }
704 
705 //===----------------------------------------------------------------------===//
706 // Test AffineScopeOp
707 //===----------------------------------------------------------------------===//
708 
709 ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) {
710   // Parse the body region, and reuse the operand info as the argument info.
711   Region *body = result.addRegion();
712   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
713 }
714 
715 void AffineScopeOp::print(OpAsmPrinter &p) {
716   p << "test.affine_scope ";
717   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
718 }
719 
720 //===----------------------------------------------------------------------===//
721 // Test parser.
722 //===----------------------------------------------------------------------===//
723 
724 ParseResult ParseIntegerLiteralOp::parse(OpAsmParser &parser,
725                                          OperationState &result) {
726   if (parser.parseOptionalColon())
727     return success();
728   uint64_t numResults;
729   if (parser.parseInteger(numResults))
730     return failure();
731 
732   IndexType type = parser.getBuilder().getIndexType();
733   for (unsigned i = 0; i < numResults; ++i)
734     result.addTypes(type);
735   return success();
736 }
737 
738 void ParseIntegerLiteralOp::print(OpAsmPrinter &p) {
739   if (unsigned numResults = getNumResults())
740     p << " : " << numResults;
741 }
742 
743 ParseResult ParseWrappedKeywordOp::parse(OpAsmParser &parser,
744                                          OperationState &result) {
745   StringRef keyword;
746   if (parser.parseKeyword(&keyword))
747     return failure();
748   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
749   return success();
750 }
751 
752 void ParseWrappedKeywordOp::print(OpAsmPrinter &p) { p << " " << getKeyword(); }
753 
754 //===----------------------------------------------------------------------===//
755 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
756 
757 ParseResult WrappingRegionOp::parse(OpAsmParser &parser,
758                                     OperationState &result) {
759   if (parser.parseKeyword("wraps"))
760     return failure();
761 
762   // Parse the wrapped op in a region
763   Region &body = *result.addRegion();
764   body.push_back(new Block);
765   Block &block = body.back();
766   Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
767   if (!wrappedOp)
768     return failure();
769 
770   // Create a return terminator in the inner region, pass as operand to the
771   // terminator the returned values from the wrapped operation.
772   SmallVector<Value, 8> returnOperands(wrappedOp->getResults());
773   OpBuilder builder(parser.getContext());
774   builder.setInsertionPointToEnd(&block);
775   builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands);
776 
777   // Get the results type for the wrapping op from the terminator operands.
778   Operation &returnOp = body.back().back();
779   result.types.append(returnOp.operand_type_begin(),
780                       returnOp.operand_type_end());
781 
782   // Use the location of the wrapped op for the "test.wrapping_region" op.
783   result.location = wrappedOp->getLoc();
784 
785   return success();
786 }
787 
788 void WrappingRegionOp::print(OpAsmPrinter &p) {
789   p << " wraps ";
790   p.printGenericOp(&getRegion().front().front());
791 }
792 
793 //===----------------------------------------------------------------------===//
794 // Test PrettyPrintedRegionOp -  exercising the following parser APIs
795 //   parseGenericOperationAfterOpName
796 //   parseCustomOperationName
797 //===----------------------------------------------------------------------===//
798 
799 ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser,
800                                          OperationState &result) {
801 
802   SMLoc loc = parser.getCurrentLocation();
803   Location currLocation = parser.getEncodedSourceLoc(loc);
804 
805   // Parse the operands.
806   SmallVector<OpAsmParser::OperandType, 2> operands;
807   if (parser.parseOperandList(operands))
808     return failure();
809 
810   // Check if we are parsing the pretty-printed version
811   //  test.pretty_printed_region start <inner-op> end : <functional-type>
812   // Else fallback to parsing the "non pretty-printed" version.
813   if (!succeeded(parser.parseOptionalKeyword("start")))
814     return parser.parseGenericOperationAfterOpName(
815         result, llvm::makeArrayRef(operands));
816 
817   FailureOr<OperationName> parseOpNameInfo = parser.parseCustomOperationName();
818   if (failed(parseOpNameInfo))
819     return failure();
820 
821   StringRef innerOpName = parseOpNameInfo->getStringRef();
822 
823   FunctionType opFntype;
824   Optional<Location> explicitLoc;
825   if (parser.parseKeyword("end") || parser.parseColon() ||
826       parser.parseType(opFntype) ||
827       parser.parseOptionalLocationSpecifier(explicitLoc))
828     return failure();
829 
830   // If location of the op is explicitly provided, then use it; Else use
831   // the parser's current location.
832   Location opLoc = explicitLoc.getValueOr(currLocation);
833 
834   // Derive the SSA-values for op's operands.
835   if (parser.resolveOperands(operands, opFntype.getInputs(), loc,
836                              result.operands))
837     return failure();
838 
839   // Add a region for op.
840   Region &region = *result.addRegion();
841 
842   // Create a basic-block inside op's region.
843   Block &block = region.emplaceBlock();
844 
845   // Create and insert an "inner-op" operation in the block.
846   // Just for testing purposes, we can assume that inner op is a binary op with
847   // result and operand types all same as the test-op's first operand.
848   Type innerOpType = opFntype.getInput(0);
849   Value lhs = block.addArgument(innerOpType, opLoc);
850   Value rhs = block.addArgument(innerOpType, opLoc);
851 
852   OpBuilder builder(parser.getBuilder().getContext());
853   builder.setInsertionPointToStart(&block);
854 
855   OperationState innerOpState(opLoc, innerOpName);
856   innerOpState.operands.push_back(lhs);
857   innerOpState.operands.push_back(rhs);
858   innerOpState.addTypes(innerOpType);
859 
860   Operation *innerOp = builder.createOperation(innerOpState);
861 
862   // Insert a return statement in the block returning the inner-op's result.
863   builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults());
864 
865   // Populate the op operation-state with result-type and location.
866   result.addTypes(opFntype.getResults());
867   result.location = innerOp->getLoc();
868 
869   return success();
870 }
871 
872 void PrettyPrintedRegionOp::print(OpAsmPrinter &p) {
873   p << ' ';
874   p.printOperands(getOperands());
875 
876   Operation &innerOp = getRegion().front().front();
877   // Assuming that region has a single non-terminator inner-op, if the inner-op
878   // meets some criteria (which in this case is a simple one  based on the name
879   // of inner-op), then we can print the entire region in a succinct way.
880   // Here we assume that the prototype of "special.op" can be trivially derived
881   // while parsing it back.
882   if (innerOp.getName().getStringRef().equals("special.op")) {
883     p << " start special.op end";
884   } else {
885     p << " (";
886     p.printRegion(getRegion());
887     p << ")";
888   }
889 
890   p << " : ";
891   p.printFunctionalType(*this);
892 }
893 
894 //===----------------------------------------------------------------------===//
895 // Test PolyForOp - parse list of region arguments.
896 //===----------------------------------------------------------------------===//
897 
898 ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) {
899   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
900   // Parse list of region arguments without a delimiter.
901   if (parser.parseRegionArgumentList(ivsInfo))
902     return failure();
903 
904   // Parse the body region.
905   Region *body = result.addRegion();
906   auto &builder = parser.getBuilder();
907   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
908   return parser.parseRegion(*body, ivsInfo, argTypes);
909 }
910 
911 void PolyForOp::print(OpAsmPrinter &p) { p.printGenericOp(*this); }
912 
913 void PolyForOp::getAsmBlockArgumentNames(Region &region,
914                                          OpAsmSetValueNameFn setNameFn) {
915   auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names");
916   if (!arrayAttr)
917     return;
918   auto args = getRegion().front().getArguments();
919   auto e = std::min(arrayAttr.size(), args.size());
920   for (unsigned i = 0; i < e; ++i) {
921     if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
922       setNameFn(args[i], strAttr.getValue());
923   }
924 }
925 
926 //===----------------------------------------------------------------------===//
927 // Test removing op with inner ops.
928 //===----------------------------------------------------------------------===//
929 
930 namespace {
931 struct TestRemoveOpWithInnerOps
932     : public OpRewritePattern<TestOpWithRegionPattern> {
933   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
934 
935   void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
936 
937   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
938                                 PatternRewriter &rewriter) const override {
939     rewriter.eraseOp(op);
940     return success();
941   }
942 };
943 } // namespace
944 
945 void TestOpWithRegionPattern::getCanonicalizationPatterns(
946     RewritePatternSet &results, MLIRContext *context) {
947   results.add<TestRemoveOpWithInnerOps>(context);
948 }
949 
950 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
951   return getOperand();
952 }
953 
954 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
955   return getValue();
956 }
957 
958 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
959     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
960   for (Value input : this->getOperands()) {
961     results.push_back(input);
962   }
963   return success();
964 }
965 
966 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
967   assert(operands.size() == 1);
968   if (operands.front()) {
969     (*this)->setAttr("attr", operands.front());
970     return getResult();
971   }
972   return {};
973 }
974 
975 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
976   return getOperand();
977 }
978 
979 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
980     MLIRContext *, Optional<Location> location, ValueRange operands,
981     DictionaryAttr attributes, RegionRange regions,
982     SmallVectorImpl<Type> &inferredReturnTypes) {
983   if (operands[0].getType() != operands[1].getType()) {
984     return emitOptionalError(location, "operand type mismatch ",
985                              operands[0].getType(), " vs ",
986                              operands[1].getType());
987   }
988   inferredReturnTypes.assign({operands[0].getType()});
989   return success();
990 }
991 
992 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
993     MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
994     DictionaryAttr attributes, RegionRange regions,
995     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
996   // Create return type consisting of the last element of the first operand.
997   auto operandType = operands.front().getType();
998   auto sval = operandType.dyn_cast<ShapedType>();
999   if (!sval) {
1000     return emitOptionalError(location, "only shaped type operands allowed");
1001   }
1002   int64_t dim =
1003       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
1004   auto type = IntegerType::get(context, 17);
1005   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
1006   return success();
1007 }
1008 
1009 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
1010     OpBuilder &builder, ValueRange operands,
1011     llvm::SmallVectorImpl<Value> &shapes) {
1012   shapes = SmallVector<Value, 1>{
1013       builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
1014   return success();
1015 }
1016 
1017 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
1018     OpBuilder &builder, ValueRange operands,
1019     llvm::SmallVectorImpl<Value> &shapes) {
1020   Location loc = getLoc();
1021   shapes.reserve(operands.size());
1022   for (Value operand : llvm::reverse(operands)) {
1023     auto rank = operand.getType().cast<RankedTensorType>().getRank();
1024     auto currShape = llvm::to_vector<4>(
1025         llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value {
1026           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
1027         }));
1028     shapes.push_back(builder.create<tensor::FromElementsOp>(
1029         getLoc(), RankedTensorType::get({rank}, builder.getIndexType()),
1030         currShape));
1031   }
1032   return success();
1033 }
1034 
1035 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
1036     OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
1037   Location loc = getLoc();
1038   shapes.reserve(getNumOperands());
1039   for (Value operand : llvm::reverse(getOperands())) {
1040     auto currShape = llvm::to_vector<4>(llvm::map_range(
1041         llvm::seq<int64_t>(
1042             0, operand.getType().cast<RankedTensorType>().getRank()),
1043         [&](int64_t dim) -> Value {
1044           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
1045         }));
1046     shapes.emplace_back(std::move(currShape));
1047   }
1048   return success();
1049 }
1050 
1051 //===----------------------------------------------------------------------===//
1052 // Test SideEffect interfaces
1053 //===----------------------------------------------------------------------===//
1054 
1055 namespace {
1056 /// A test resource for side effects.
1057 struct TestResource : public SideEffects::Resource::Base<TestResource> {
1058   StringRef getName() final { return "<Test>"; }
1059 };
1060 } // namespace
1061 
1062 static void testSideEffectOpGetEffect(
1063     Operation *op,
1064     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
1065         &effects) {
1066   auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
1067   if (!effectsAttr)
1068     return;
1069 
1070   effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
1071 }
1072 
1073 void SideEffectOp::getEffects(
1074     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1075   // Check for an effects attribute on the op instance.
1076   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
1077   if (!effectsAttr)
1078     return;
1079 
1080   // If there is one, it is an array of dictionary attributes that hold
1081   // information on the effects of this operation.
1082   for (Attribute element : effectsAttr) {
1083     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
1084 
1085     // Get the specific memory effect.
1086     MemoryEffects::Effect *effect =
1087         StringSwitch<MemoryEffects::Effect *>(
1088             effectElement.get("effect").cast<StringAttr>().getValue())
1089             .Case("allocate", MemoryEffects::Allocate::get())
1090             .Case("free", MemoryEffects::Free::get())
1091             .Case("read", MemoryEffects::Read::get())
1092             .Case("write", MemoryEffects::Write::get());
1093 
1094     // Check for a non-default resource to use.
1095     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
1096     if (effectElement.get("test_resource"))
1097       resource = TestResource::get();
1098 
1099     // Check for a result to affect.
1100     if (effectElement.get("on_result"))
1101       effects.emplace_back(effect, getResult(), resource);
1102     else if (Attribute ref = effectElement.get("on_reference"))
1103       effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
1104     else
1105       effects.emplace_back(effect, resource);
1106   }
1107 }
1108 
1109 void SideEffectOp::getEffects(
1110     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
1111   testSideEffectOpGetEffect(getOperation(), effects);
1112 }
1113 
1114 //===----------------------------------------------------------------------===//
1115 // StringAttrPrettyNameOp
1116 //===----------------------------------------------------------------------===//
1117 
1118 // This op has fancy handling of its SSA result name.
1119 ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser,
1120                                           OperationState &result) {
1121   // Add the result types.
1122   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
1123     result.addTypes(parser.getBuilder().getIntegerType(32));
1124 
1125   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1126     return failure();
1127 
1128   // If the attribute dictionary contains no 'names' attribute, infer it from
1129   // the SSA name (if specified).
1130   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
1131     return attr.getName() == "names";
1132   });
1133 
1134   // If there was no name specified, check to see if there was a useful name
1135   // specified in the asm file.
1136   if (hadNames || parser.getNumResults() == 0)
1137     return success();
1138 
1139   SmallVector<StringRef, 4> names;
1140   auto *context = result.getContext();
1141 
1142   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
1143     auto resultName = parser.getResultName(i);
1144     StringRef nameStr;
1145     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
1146       nameStr = resultName.first;
1147 
1148     names.push_back(nameStr);
1149   }
1150 
1151   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
1152   result.attributes.push_back({StringAttr::get(context, "names"), namesAttr});
1153   return success();
1154 }
1155 
1156 void StringAttrPrettyNameOp::print(OpAsmPrinter &p) {
1157   // Note that we only need to print the "name" attribute if the asmprinter
1158   // result name disagrees with it.  This can happen in strange cases, e.g.
1159   // when there are conflicts.
1160   bool namesDisagree = getNames().size() != getNumResults();
1161 
1162   SmallString<32> resultNameStr;
1163   for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) {
1164     resultNameStr.clear();
1165     llvm::raw_svector_ostream tmpStream(resultNameStr);
1166     p.printOperand(getResult(i), tmpStream);
1167 
1168     auto expectedName = getNames()[i].dyn_cast<StringAttr>();
1169     if (!expectedName ||
1170         tmpStream.str().drop_front() != expectedName.getValue()) {
1171       namesDisagree = true;
1172     }
1173   }
1174 
1175   if (namesDisagree)
1176     p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
1177   else
1178     p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"});
1179 }
1180 
1181 // We set the SSA name in the asm syntax to the contents of the name
1182 // attribute.
1183 void StringAttrPrettyNameOp::getAsmResultNames(
1184     function_ref<void(Value, StringRef)> setNameFn) {
1185 
1186   auto value = getNames();
1187   for (size_t i = 0, e = value.size(); i != e; ++i)
1188     if (auto str = value[i].dyn_cast<StringAttr>())
1189       if (!str.getValue().empty())
1190         setNameFn(getResult(i), str.getValue());
1191 }
1192 
1193 //===----------------------------------------------------------------------===//
1194 // ResultTypeWithTraitOp
1195 //===----------------------------------------------------------------------===//
1196 
1197 LogicalResult ResultTypeWithTraitOp::verify() {
1198   if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>())
1199     return success();
1200   return emitError("result type should have trait 'TestTypeTrait'");
1201 }
1202 
1203 //===----------------------------------------------------------------------===//
1204 // AttrWithTraitOp
1205 //===----------------------------------------------------------------------===//
1206 
1207 LogicalResult AttrWithTraitOp::verify() {
1208   if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>())
1209     return success();
1210   return emitError("'attr' attribute should have trait 'TestAttrTrait'");
1211 }
1212 
1213 //===----------------------------------------------------------------------===//
1214 // RegionIfOp
1215 //===----------------------------------------------------------------------===//
1216 
1217 void RegionIfOp::print(OpAsmPrinter &p) {
1218   p << " ";
1219   p.printOperands(getOperands());
1220   p << ": " << getOperandTypes();
1221   p.printArrowTypeList(getResultTypes());
1222   p << " then ";
1223   p.printRegion(getThenRegion(),
1224                 /*printEntryBlockArgs=*/true,
1225                 /*printBlockTerminators=*/true);
1226   p << " else ";
1227   p.printRegion(getElseRegion(),
1228                 /*printEntryBlockArgs=*/true,
1229                 /*printBlockTerminators=*/true);
1230   p << " join ";
1231   p.printRegion(getJoinRegion(),
1232                 /*printEntryBlockArgs=*/true,
1233                 /*printBlockTerminators=*/true);
1234 }
1235 
1236 ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
1237   SmallVector<OpAsmParser::OperandType, 2> operandInfos;
1238   SmallVector<Type, 2> operandTypes;
1239 
1240   result.regions.reserve(3);
1241   Region *thenRegion = result.addRegion();
1242   Region *elseRegion = result.addRegion();
1243   Region *joinRegion = result.addRegion();
1244 
1245   // Parse operand, type and arrow type lists.
1246   if (parser.parseOperandList(operandInfos) ||
1247       parser.parseColonTypeList(operandTypes) ||
1248       parser.parseArrowTypeList(result.types))
1249     return failure();
1250 
1251   // Parse all attached regions.
1252   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
1253       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
1254       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
1255     return failure();
1256 
1257   return parser.resolveOperands(operandInfos, operandTypes,
1258                                 parser.getCurrentLocation(), result.operands);
1259 }
1260 
1261 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
1262   assert(index < 2 && "invalid region index");
1263   return getOperands();
1264 }
1265 
1266 void RegionIfOp::getSuccessorRegions(
1267     Optional<unsigned> index, ArrayRef<Attribute> operands,
1268     SmallVectorImpl<RegionSuccessor> &regions) {
1269   // We always branch to the join region.
1270   if (index.hasValue()) {
1271     if (index.getValue() < 2)
1272       regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
1273     else
1274       regions.push_back(RegionSuccessor(getResults()));
1275     return;
1276   }
1277 
1278   // The then and else regions are the entry regions of this op.
1279   regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs()));
1280   regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs()));
1281 }
1282 
1283 void RegionIfOp::getRegionInvocationBounds(
1284     ArrayRef<Attribute> operands,
1285     SmallVectorImpl<InvocationBounds> &invocationBounds) {
1286   // Each region is invoked at most once.
1287   invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1});
1288 }
1289 
1290 //===----------------------------------------------------------------------===//
1291 // AnyCondOp
1292 //===----------------------------------------------------------------------===//
1293 
1294 void AnyCondOp::getSuccessorRegions(Optional<unsigned> index,
1295                                     ArrayRef<Attribute> operands,
1296                                     SmallVectorImpl<RegionSuccessor> &regions) {
1297   // The parent op branches into the only region, and the region branches back
1298   // to the parent op.
1299   if (index)
1300     regions.emplace_back(&getRegion());
1301   else
1302     regions.emplace_back(getResults());
1303 }
1304 
1305 void AnyCondOp::getRegionInvocationBounds(
1306     ArrayRef<Attribute> operands,
1307     SmallVectorImpl<InvocationBounds> &invocationBounds) {
1308   invocationBounds.emplace_back(1, 1);
1309 }
1310 
1311 //===----------------------------------------------------------------------===//
1312 // SingleNoTerminatorCustomAsmOp
1313 //===----------------------------------------------------------------------===//
1314 
1315 ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser,
1316                                                  OperationState &state) {
1317   Region *body = state.addRegion();
1318   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
1319     return failure();
1320   return success();
1321 }
1322 
1323 void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) {
1324   printer.printRegion(
1325       getRegion(), /*printEntryBlockArgs=*/false,
1326       // This op has a single block without terminators. But explicitly mark
1327       // as not printing block terminators for testing.
1328       /*printBlockTerminators=*/false);
1329 }
1330 
1331 #include "TestOpEnums.cpp.inc"
1332 #include "TestOpInterfaces.cpp.inc"
1333 #include "TestOpStructs.cpp.inc"
1334 #include "TestTypeInterfaces.cpp.inc"
1335 
1336 #define GET_OP_CLASSES
1337 #include "TestOps.cpp.inc"
1338