1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestTypes.h"
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/Function.h"
14 #include "mlir/IR/Module.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/IR/TypeUtilities.h"
17 #include "mlir/Transforms/FoldUtils.h"
18 #include "mlir/Transforms/InliningUtils.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/StringSwitch.h"
21 
22 using namespace mlir;
23 
24 //===----------------------------------------------------------------------===//
25 // TestDialect Interfaces
26 //===----------------------------------------------------------------------===//
27 
28 namespace {
29 
30 // Test support for interacting with the AsmPrinter.
31 struct TestOpAsmInterface : public OpAsmDialectInterface {
32   using OpAsmDialectInterface::OpAsmDialectInterface;
33 
34   void getAsmResultNames(Operation *op,
35                          OpAsmSetValueNameFn setNameFn) const final {
36     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
37       setNameFn(asmOp, "result");
38   }
39 
40   void getAsmBlockArgumentNames(Block *block,
41                                 OpAsmSetValueNameFn setNameFn) const final {
42     auto op = block->getParentOp();
43     auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
44     if (!arrayAttr)
45       return;
46     auto args = block->getArguments();
47     auto e = std::min(arrayAttr.size(), args.size());
48     for (unsigned i = 0; i < e; ++i) {
49       if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
50         setNameFn(args[i], strAttr.getValue());
51     }
52   }
53 };
54 
55 struct TestDialectFoldInterface : public DialectFoldInterface {
56   using DialectFoldInterface::DialectFoldInterface;
57 
58   /// Registered hook to check if the given region, which is attached to an
59   /// operation that is *not* isolated from above, should be used when
60   /// materializing constants.
61   bool shouldMaterializeInto(Region *region) const final {
62     // If this is a one region operation, then insert into it.
63     return isa<OneRegionOp>(region->getParentOp());
64   }
65 };
66 
67 /// This class defines the interface for handling inlining with standard
68 /// operations.
69 struct TestInlinerInterface : public DialectInlinerInterface {
70   using DialectInlinerInterface::DialectInlinerInterface;
71 
72   //===--------------------------------------------------------------------===//
73   // Analysis Hooks
74   //===--------------------------------------------------------------------===//
75 
76   bool isLegalToInline(Region *, Region *, BlockAndValueMapping &) const final {
77     // Inlining into test dialect regions is legal.
78     return true;
79   }
80   bool isLegalToInline(Operation *, Region *,
81                        BlockAndValueMapping &) const final {
82     return true;
83   }
84 
85   bool shouldAnalyzeRecursively(Operation *op) const final {
86     // Analyze recursively if this is not a functional region operation, it
87     // froms a separate functional scope.
88     return !isa<FunctionalRegionOp>(op);
89   }
90 
91   //===--------------------------------------------------------------------===//
92   // Transformation Hooks
93   //===--------------------------------------------------------------------===//
94 
95   /// Handle the given inlined terminator by replacing it with a new operation
96   /// as necessary.
97   void handleTerminator(Operation *op,
98                         ArrayRef<Value> valuesToRepl) const final {
99     // Only handle "test.return" here.
100     auto returnOp = dyn_cast<TestReturnOp>(op);
101     if (!returnOp)
102       return;
103 
104     // Replace the values directly with the return operands.
105     assert(returnOp.getNumOperands() == valuesToRepl.size());
106     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
107       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
108   }
109 
110   /// Attempt to materialize a conversion for a type mismatch between a call
111   /// from this dialect, and a callable region. This method should generate an
112   /// operation that takes 'input' as the only operand, and produces a single
113   /// result of 'resultType'. If a conversion can not be generated, nullptr
114   /// should be returned.
115   Operation *materializeCallConversion(OpBuilder &builder, Value input,
116                                        Type resultType,
117                                        Location conversionLoc) const final {
118     // Only allow conversion for i16/i32 types.
119     if (!(resultType.isSignlessInteger(16) ||
120           resultType.isSignlessInteger(32)) ||
121         !(input.getType().isSignlessInteger(16) ||
122           input.getType().isSignlessInteger(32)))
123       return nullptr;
124     return builder.create<TestCastOp>(conversionLoc, resultType, input);
125   }
126 };
127 } // end anonymous namespace
128 
129 //===----------------------------------------------------------------------===//
130 // TestDialect
131 //===----------------------------------------------------------------------===//
132 
133 void TestDialect::initialize() {
134   addOperations<
135 #define GET_OP_LIST
136 #include "TestOps.cpp.inc"
137       >();
138   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
139                 TestInlinerInterface>();
140   addTypes<TestType, TestRecursiveType>();
141   allowUnknownOperations();
142 }
143 
144 static Type parseTestType(DialectAsmParser &parser,
145                           llvm::SetVector<Type> &stack) {
146   StringRef typeTag;
147   if (failed(parser.parseKeyword(&typeTag)))
148     return Type();
149 
150   if (typeTag == "test_type")
151     return TestType::get(parser.getBuilder().getContext());
152 
153   if (typeTag != "test_rec")
154     return Type();
155 
156   StringRef name;
157   if (parser.parseLess() || parser.parseKeyword(&name))
158     return Type();
159   auto rec = TestRecursiveType::create(parser.getBuilder().getContext(), name);
160 
161   // If this type already has been parsed above in the stack, expect just the
162   // name.
163   if (stack.contains(rec)) {
164     if (failed(parser.parseGreater()))
165       return Type();
166     return rec;
167   }
168 
169   // Otherwise, parse the body and update the type.
170   if (failed(parser.parseComma()))
171     return Type();
172   stack.insert(rec);
173   Type subtype = parseTestType(parser, stack);
174   stack.pop_back();
175   if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype)))
176     return Type();
177 
178   return rec;
179 }
180 
181 Type TestDialect::parseType(DialectAsmParser &parser) const {
182   llvm::SetVector<Type> stack;
183   return parseTestType(parser, stack);
184 }
185 
186 static void printTestType(Type type, DialectAsmPrinter &printer,
187                           llvm::SetVector<Type> &stack) {
188   if (type.isa<TestType>()) {
189     printer << "test_type";
190     return;
191   }
192 
193   auto rec = type.cast<TestRecursiveType>();
194   printer << "test_rec<" << rec.getName();
195   if (!stack.contains(rec)) {
196     printer << ", ";
197     stack.insert(rec);
198     printTestType(rec.getBody(), printer, stack);
199     stack.pop_back();
200   }
201   printer << ">";
202 }
203 
204 void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
205   llvm::SetVector<Type> stack;
206   printTestType(type, printer, stack);
207 }
208 
209 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
210                                                     NamedAttribute namedAttr) {
211   if (namedAttr.first == "test.invalid_attr")
212     return op->emitError() << "invalid to use 'test.invalid_attr'";
213   return success();
214 }
215 
216 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
217                                                     unsigned regionIndex,
218                                                     unsigned argIndex,
219                                                     NamedAttribute namedAttr) {
220   if (namedAttr.first == "test.invalid_attr")
221     return op->emitError() << "invalid to use 'test.invalid_attr'";
222   return success();
223 }
224 
225 LogicalResult
226 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
227                                          unsigned resultIndex,
228                                          NamedAttribute namedAttr) {
229   if (namedAttr.first == "test.invalid_attr")
230     return op->emitError() << "invalid to use 'test.invalid_attr'";
231   return success();
232 }
233 
234 //===----------------------------------------------------------------------===//
235 // TestBranchOp
236 //===----------------------------------------------------------------------===//
237 
238 Optional<MutableOperandRange>
239 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
240   assert(index == 0 && "invalid successor index");
241   return targetOperandsMutable();
242 }
243 
244 //===----------------------------------------------------------------------===//
245 // TestFoldToCallOp
246 //===----------------------------------------------------------------------===//
247 
248 namespace {
249 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
250   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
251 
252   LogicalResult matchAndRewrite(FoldToCallOp op,
253                                 PatternRewriter &rewriter) const override {
254     rewriter.replaceOpWithNewOp<CallOp>(op, ArrayRef<Type>(), op.calleeAttr(),
255                                         ValueRange());
256     return success();
257   }
258 };
259 } // end anonymous namespace
260 
261 void FoldToCallOp::getCanonicalizationPatterns(
262     OwningRewritePatternList &results, MLIRContext *context) {
263   results.insert<FoldToCallOpPattern>(context);
264 }
265 
266 //===----------------------------------------------------------------------===//
267 // Test IsolatedRegionOp - parse passthrough region arguments.
268 //===----------------------------------------------------------------------===//
269 
270 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
271                                          OperationState &result) {
272   OpAsmParser::OperandType argInfo;
273   Type argType = parser.getBuilder().getIndexType();
274 
275   // Parse the input operand.
276   if (parser.parseOperand(argInfo) ||
277       parser.resolveOperand(argInfo, argType, result.operands))
278     return failure();
279 
280   // Parse the body region, and reuse the operand info as the argument info.
281   Region *body = result.addRegion();
282   return parser.parseRegion(*body, argInfo, argType,
283                             /*enableNameShadowing=*/true);
284 }
285 
286 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
287   p << "test.isolated_region ";
288   p.printOperand(op.getOperand());
289   p.shadowRegionArgs(op.region(), op.getOperand());
290   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
291 }
292 
293 //===----------------------------------------------------------------------===//
294 // Test SSACFGRegionOp
295 //===----------------------------------------------------------------------===//
296 
297 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
298   return RegionKind::SSACFG;
299 }
300 
301 //===----------------------------------------------------------------------===//
302 // Test GraphRegionOp
303 //===----------------------------------------------------------------------===//
304 
305 static ParseResult parseGraphRegionOp(OpAsmParser &parser,
306                                       OperationState &result) {
307   // Parse the body region, and reuse the operand info as the argument info.
308   Region *body = result.addRegion();
309   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
310 }
311 
312 static void print(OpAsmPrinter &p, GraphRegionOp op) {
313   p << "test.graph_region ";
314   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
315 }
316 
317 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
318   return RegionKind::Graph;
319 }
320 
321 //===----------------------------------------------------------------------===//
322 // Test AffineScopeOp
323 //===----------------------------------------------------------------------===//
324 
325 static ParseResult parseAffineScopeOp(OpAsmParser &parser,
326                                       OperationState &result) {
327   // Parse the body region, and reuse the operand info as the argument info.
328   Region *body = result.addRegion();
329   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
330 }
331 
332 static void print(OpAsmPrinter &p, AffineScopeOp op) {
333   p << "test.affine_scope ";
334   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
335 }
336 
337 //===----------------------------------------------------------------------===//
338 // Test parser.
339 //===----------------------------------------------------------------------===//
340 
341 static ParseResult parseWrappedKeywordOp(OpAsmParser &parser,
342                                          OperationState &result) {
343   StringRef keyword;
344   if (parser.parseKeyword(&keyword))
345     return failure();
346   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
347   return success();
348 }
349 
350 static void print(OpAsmPrinter &p, WrappedKeywordOp op) {
351   p << WrappedKeywordOp::getOperationName() << " " << op.keyword();
352 }
353 
354 //===----------------------------------------------------------------------===//
355 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
356 
357 static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
358                                          OperationState &result) {
359   if (parser.parseKeyword("wraps"))
360     return failure();
361 
362   // Parse the wrapped op in a region
363   Region &body = *result.addRegion();
364   body.push_back(new Block);
365   Block &block = body.back();
366   Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
367   if (!wrapped_op)
368     return failure();
369 
370   // Create a return terminator in the inner region, pass as operand to the
371   // terminator the returned values from the wrapped operation.
372   SmallVector<Value, 8> return_operands(wrapped_op->getResults());
373   OpBuilder builder(parser.getBuilder().getContext());
374   builder.setInsertionPointToEnd(&block);
375   builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
376 
377   // Get the results type for the wrapping op from the terminator operands.
378   Operation &return_op = body.back().back();
379   result.types.append(return_op.operand_type_begin(),
380                       return_op.operand_type_end());
381 
382   // Use the location of the wrapped op for the "test.wrapping_region" op.
383   result.location = wrapped_op->getLoc();
384 
385   return success();
386 }
387 
388 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
389   p << op.getOperationName() << " wraps ";
390   p.printGenericOp(&op.region().front().front());
391 }
392 
393 //===----------------------------------------------------------------------===//
394 // Test PolyForOp - parse list of region arguments.
395 //===----------------------------------------------------------------------===//
396 
397 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
398   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
399   // Parse list of region arguments without a delimiter.
400   if (parser.parseRegionArgumentList(ivsInfo))
401     return failure();
402 
403   // Parse the body region.
404   Region *body = result.addRegion();
405   auto &builder = parser.getBuilder();
406   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
407   return parser.parseRegion(*body, ivsInfo, argTypes);
408 }
409 
410 //===----------------------------------------------------------------------===//
411 // Test removing op with inner ops.
412 //===----------------------------------------------------------------------===//
413 
414 namespace {
415 struct TestRemoveOpWithInnerOps
416     : public OpRewritePattern<TestOpWithRegionPattern> {
417   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
418 
419   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
420                                 PatternRewriter &rewriter) const override {
421     rewriter.eraseOp(op);
422     return success();
423   }
424 };
425 } // end anonymous namespace
426 
427 void TestOpWithRegionPattern::getCanonicalizationPatterns(
428     OwningRewritePatternList &results, MLIRContext *context) {
429   results.insert<TestRemoveOpWithInnerOps>(context);
430 }
431 
432 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
433   return operand();
434 }
435 
436 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
437     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
438   for (Value input : this->operands()) {
439     results.push_back(input);
440   }
441   return success();
442 }
443 
444 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
445   assert(operands.size() == 1);
446   if (operands.front()) {
447     setAttr("attr", operands.front());
448     return getResult();
449   }
450   return {};
451 }
452 
453 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
454     MLIRContext *, Optional<Location> location, ValueRange operands,
455     DictionaryAttr attributes, RegionRange regions,
456     SmallVectorImpl<Type> &inferredReturnTypes) {
457   if (operands[0].getType() != operands[1].getType()) {
458     return emitOptionalError(location, "operand type mismatch ",
459                              operands[0].getType(), " vs ",
460                              operands[1].getType());
461   }
462   inferredReturnTypes.assign({operands[0].getType()});
463   return success();
464 }
465 
466 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
467     MLIRContext *context, Optional<Location> location, ValueRange operands,
468     DictionaryAttr attributes, RegionRange regions,
469     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
470   // Create return type consisting of the last element of the first operand.
471   auto operandType = *operands.getTypes().begin();
472   auto sval = operandType.dyn_cast<ShapedType>();
473   if (!sval) {
474     return emitOptionalError(location, "only shaped type operands allowed");
475   }
476   int64_t dim =
477       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
478   auto type = IntegerType::get(17, context);
479   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
480   return success();
481 }
482 
483 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
484     OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) {
485   shapes = SmallVector<Value, 1>{
486       builder.createOrFold<DimOp>(getLoc(), getOperand(0), 0)};
487   return success();
488 }
489 
490 //===----------------------------------------------------------------------===//
491 // Test SideEffect interfaces
492 //===----------------------------------------------------------------------===//
493 
494 namespace {
495 /// A test resource for side effects.
496 struct TestResource : public SideEffects::Resource::Base<TestResource> {
497   StringRef getName() final { return "<Test>"; }
498 };
499 } // end anonymous namespace
500 
501 void SideEffectOp::getEffects(
502     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
503   // Check for an effects attribute on the op instance.
504   ArrayAttr effectsAttr = getAttrOfType<ArrayAttr>("effects");
505   if (!effectsAttr)
506     return;
507 
508   // If there is one, it is an array of dictionary attributes that hold
509   // information on the effects of this operation.
510   for (Attribute element : effectsAttr) {
511     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
512 
513     // Get the specific memory effect.
514     MemoryEffects::Effect *effect =
515         llvm::StringSwitch<MemoryEffects::Effect *>(
516             effectElement.get("effect").cast<StringAttr>().getValue())
517             .Case("allocate", MemoryEffects::Allocate::get())
518             .Case("free", MemoryEffects::Free::get())
519             .Case("read", MemoryEffects::Read::get())
520             .Case("write", MemoryEffects::Write::get());
521 
522     // Check for a result to affect.
523     Value value;
524     if (effectElement.get("on_result"))
525       value = getResult();
526 
527     // Check for a non-default resource to use.
528     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
529     if (effectElement.get("test_resource"))
530       resource = TestResource::get();
531 
532     effects.emplace_back(effect, value, resource);
533   }
534 }
535 
536 //===----------------------------------------------------------------------===//
537 // StringAttrPrettyNameOp
538 //===----------------------------------------------------------------------===//
539 
540 // This op has fancy handling of its SSA result name.
541 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
542                                                OperationState &result) {
543   // Add the result types.
544   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
545     result.addTypes(parser.getBuilder().getIntegerType(32));
546 
547   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
548     return failure();
549 
550   // If the attribute dictionary contains no 'names' attribute, infer it from
551   // the SSA name (if specified).
552   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
553     return attr.first == "names";
554   });
555 
556   // If there was no name specified, check to see if there was a useful name
557   // specified in the asm file.
558   if (hadNames || parser.getNumResults() == 0)
559     return success();
560 
561   SmallVector<StringRef, 4> names;
562   auto *context = result.getContext();
563 
564   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
565     auto resultName = parser.getResultName(i);
566     StringRef nameStr;
567     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
568       nameStr = resultName.first;
569 
570     names.push_back(nameStr);
571   }
572 
573   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
574   result.attributes.push_back({Identifier::get("names", context), namesAttr});
575   return success();
576 }
577 
578 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
579   p << "test.string_attr_pretty_name";
580 
581   // Note that we only need to print the "name" attribute if the asmprinter
582   // result name disagrees with it.  This can happen in strange cases, e.g.
583   // when there are conflicts.
584   bool namesDisagree = op.names().size() != op.getNumResults();
585 
586   SmallString<32> resultNameStr;
587   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
588     resultNameStr.clear();
589     llvm::raw_svector_ostream tmpStream(resultNameStr);
590     p.printOperand(op.getResult(i), tmpStream);
591 
592     auto expectedName = op.names()[i].dyn_cast<StringAttr>();
593     if (!expectedName ||
594         tmpStream.str().drop_front() != expectedName.getValue()) {
595       namesDisagree = true;
596     }
597   }
598 
599   if (namesDisagree)
600     p.printOptionalAttrDictWithKeyword(op.getAttrs());
601   else
602     p.printOptionalAttrDictWithKeyword(op.getAttrs(), {"names"});
603 }
604 
605 // We set the SSA name in the asm syntax to the contents of the name
606 // attribute.
607 void StringAttrPrettyNameOp::getAsmResultNames(
608     function_ref<void(Value, StringRef)> setNameFn) {
609 
610   auto value = names();
611   for (size_t i = 0, e = value.size(); i != e; ++i)
612     if (auto str = value[i].dyn_cast<StringAttr>())
613       if (!str.getValue().empty())
614         setNameFn(getResult(i), str.getValue());
615 }
616 
617 //===----------------------------------------------------------------------===//
618 // RegionIfOp
619 //===----------------------------------------------------------------------===//
620 
621 static void print(OpAsmPrinter &p, RegionIfOp op) {
622   p << RegionIfOp::getOperationName() << " ";
623   p.printOperands(op.getOperands());
624   p << ": " << op.getOperandTypes();
625   p.printArrowTypeList(op.getResultTypes());
626   p << " then";
627   p.printRegion(op.thenRegion(),
628                 /*printEntryBlockArgs=*/true,
629                 /*printBlockTerminators=*/true);
630   p << " else";
631   p.printRegion(op.elseRegion(),
632                 /*printEntryBlockArgs=*/true,
633                 /*printBlockTerminators=*/true);
634   p << " join";
635   p.printRegion(op.joinRegion(),
636                 /*printEntryBlockArgs=*/true,
637                 /*printBlockTerminators=*/true);
638 }
639 
640 static ParseResult parseRegionIfOp(OpAsmParser &parser,
641                                    OperationState &result) {
642   SmallVector<OpAsmParser::OperandType, 2> operandInfos;
643   SmallVector<Type, 2> operandTypes;
644 
645   result.regions.reserve(3);
646   Region *thenRegion = result.addRegion();
647   Region *elseRegion = result.addRegion();
648   Region *joinRegion = result.addRegion();
649 
650   // Parse operand, type and arrow type lists.
651   if (parser.parseOperandList(operandInfos) ||
652       parser.parseColonTypeList(operandTypes) ||
653       parser.parseArrowTypeList(result.types))
654     return failure();
655 
656   // Parse all attached regions.
657   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
658       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
659       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
660     return failure();
661 
662   return parser.resolveOperands(operandInfos, operandTypes,
663                                 parser.getCurrentLocation(), result.operands);
664 }
665 
666 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
667   assert(index < 2 && "invalid region index");
668   return getOperands();
669 }
670 
671 void RegionIfOp::getSuccessorRegions(
672     Optional<unsigned> index, ArrayRef<Attribute> operands,
673     SmallVectorImpl<RegionSuccessor> &regions) {
674   // We always branch to the join region.
675   if (index.hasValue()) {
676     if (index.getValue() < 2)
677       regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs()));
678     else
679       regions.push_back(RegionSuccessor(getResults()));
680     return;
681   }
682 
683   // The then and else regions are the entry regions of this op.
684   regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs()));
685   regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs()));
686 }
687 
688 //===----------------------------------------------------------------------===//
689 // Dialect Registration
690 //===----------------------------------------------------------------------===//
691 
692 // Static initialization for Test dialect registration.
693 static DialectRegistration<TestDialect> testDialect;
694 
695 #include "TestOpEnums.cpp.inc"
696 #include "TestOpStructs.cpp.inc"
697 #include "TestTypeInterfaces.cpp.inc"
698 
699 #define GET_OP_CLASSES
700 #include "TestOps.cpp.inc"
701