1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestTypes.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/IR/DialectImplementation.h" 13 #include "mlir/IR/Function.h" 14 #include "mlir/IR/Module.h" 15 #include "mlir/IR/PatternMatch.h" 16 #include "mlir/IR/TypeUtilities.h" 17 #include "mlir/Transforms/FoldUtils.h" 18 #include "mlir/Transforms/InliningUtils.h" 19 #include "llvm/ADT/SetVector.h" 20 #include "llvm/ADT/StringSwitch.h" 21 22 using namespace mlir; 23 24 //===----------------------------------------------------------------------===// 25 // TestDialect Interfaces 26 //===----------------------------------------------------------------------===// 27 28 namespace { 29 30 // Test support for interacting with the AsmPrinter. 31 struct TestOpAsmInterface : public OpAsmDialectInterface { 32 using OpAsmDialectInterface::OpAsmDialectInterface; 33 34 void getAsmResultNames(Operation *op, 35 OpAsmSetValueNameFn setNameFn) const final { 36 if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) 37 setNameFn(asmOp, "result"); 38 } 39 40 void getAsmBlockArgumentNames(Block *block, 41 OpAsmSetValueNameFn setNameFn) const final { 42 auto op = block->getParentOp(); 43 auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names"); 44 if (!arrayAttr) 45 return; 46 auto args = block->getArguments(); 47 auto e = std::min(arrayAttr.size(), args.size()); 48 for (unsigned i = 0; i < e; ++i) { 49 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 50 setNameFn(args[i], strAttr.getValue()); 51 } 52 } 53 }; 54 55 struct TestOpFolderDialectInterface : public OpFolderDialectInterface { 56 using OpFolderDialectInterface::OpFolderDialectInterface; 57 58 /// Registered hook to check if the given region, which is attached to an 59 /// operation that is *not* isolated from above, should be used when 60 /// materializing constants. 61 bool shouldMaterializeInto(Region *region) const final { 62 // If this is a one region operation, then insert into it. 63 return isa<OneRegionOp>(region->getParentOp()); 64 } 65 }; 66 67 /// This class defines the interface for handling inlining with standard 68 /// operations. 69 struct TestInlinerInterface : public DialectInlinerInterface { 70 using DialectInlinerInterface::DialectInlinerInterface; 71 72 //===--------------------------------------------------------------------===// 73 // Analysis Hooks 74 //===--------------------------------------------------------------------===// 75 76 bool isLegalToInline(Region *, Region *, BlockAndValueMapping &) const final { 77 // Inlining into test dialect regions is legal. 78 return true; 79 } 80 bool isLegalToInline(Operation *, Region *, 81 BlockAndValueMapping &) const final { 82 return true; 83 } 84 85 bool shouldAnalyzeRecursively(Operation *op) const final { 86 // Analyze recursively if this is not a functional region operation, it 87 // froms a separate functional scope. 88 return !isa<FunctionalRegionOp>(op); 89 } 90 91 //===--------------------------------------------------------------------===// 92 // Transformation Hooks 93 //===--------------------------------------------------------------------===// 94 95 /// Handle the given inlined terminator by replacing it with a new operation 96 /// as necessary. 97 void handleTerminator(Operation *op, 98 ArrayRef<Value> valuesToRepl) const final { 99 // Only handle "test.return" here. 100 auto returnOp = dyn_cast<TestReturnOp>(op); 101 if (!returnOp) 102 return; 103 104 // Replace the values directly with the return operands. 105 assert(returnOp.getNumOperands() == valuesToRepl.size()); 106 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 107 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 108 } 109 110 /// Attempt to materialize a conversion for a type mismatch between a call 111 /// from this dialect, and a callable region. This method should generate an 112 /// operation that takes 'input' as the only operand, and produces a single 113 /// result of 'resultType'. If a conversion can not be generated, nullptr 114 /// should be returned. 115 Operation *materializeCallConversion(OpBuilder &builder, Value input, 116 Type resultType, 117 Location conversionLoc) const final { 118 // Only allow conversion for i16/i32 types. 119 if (!(resultType.isSignlessInteger(16) || 120 resultType.isSignlessInteger(32)) || 121 !(input.getType().isSignlessInteger(16) || 122 input.getType().isSignlessInteger(32))) 123 return nullptr; 124 return builder.create<TestCastOp>(conversionLoc, resultType, input); 125 } 126 }; 127 } // end anonymous namespace 128 129 //===----------------------------------------------------------------------===// 130 // TestDialect 131 //===----------------------------------------------------------------------===// 132 133 TestDialect::TestDialect(MLIRContext *context) 134 : Dialect(getDialectNamespace(), context) { 135 addOperations< 136 #define GET_OP_LIST 137 #include "TestOps.cpp.inc" 138 >(); 139 addInterfaces<TestOpAsmInterface, TestOpFolderDialectInterface, 140 TestInlinerInterface>(); 141 addTypes<TestType, TestRecursiveType>(); 142 allowUnknownOperations(); 143 } 144 145 static Type parseTestType(DialectAsmParser &parser, 146 llvm::SetVector<Type> &stack) { 147 StringRef typeTag; 148 if (failed(parser.parseKeyword(&typeTag))) 149 return Type(); 150 151 if (typeTag == "test_type") 152 return TestType::get(parser.getBuilder().getContext()); 153 154 if (typeTag != "test_rec") 155 return Type(); 156 157 StringRef name; 158 if (parser.parseLess() || parser.parseKeyword(&name)) 159 return Type(); 160 auto rec = TestRecursiveType::create(parser.getBuilder().getContext(), name); 161 162 // If this type already has been parsed above in the stack, expect just the 163 // name. 164 if (stack.contains(rec)) { 165 if (failed(parser.parseGreater())) 166 return Type(); 167 return rec; 168 } 169 170 // Otherwise, parse the body and update the type. 171 if (failed(parser.parseComma())) 172 return Type(); 173 stack.insert(rec); 174 Type subtype = parseTestType(parser, stack); 175 stack.pop_back(); 176 if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype))) 177 return Type(); 178 179 return rec; 180 } 181 182 Type TestDialect::parseType(DialectAsmParser &parser) const { 183 llvm::SetVector<Type> stack; 184 return parseTestType(parser, stack); 185 } 186 187 static void printTestType(Type type, DialectAsmPrinter &printer, 188 llvm::SetVector<Type> &stack) { 189 if (type.isa<TestType>()) { 190 printer << "test_type"; 191 return; 192 } 193 194 auto rec = type.cast<TestRecursiveType>(); 195 printer << "test_rec<" << rec.getName(); 196 if (!stack.contains(rec)) { 197 printer << ", "; 198 stack.insert(rec); 199 printTestType(rec.getBody(), printer, stack); 200 stack.pop_back(); 201 } 202 printer << ">"; 203 } 204 205 void TestDialect::printType(Type type, DialectAsmPrinter &printer) const { 206 llvm::SetVector<Type> stack; 207 printTestType(type, printer, stack); 208 } 209 210 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 211 NamedAttribute namedAttr) { 212 if (namedAttr.first == "test.invalid_attr") 213 return op->emitError() << "invalid to use 'test.invalid_attr'"; 214 return success(); 215 } 216 217 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 218 unsigned regionIndex, 219 unsigned argIndex, 220 NamedAttribute namedAttr) { 221 if (namedAttr.first == "test.invalid_attr") 222 return op->emitError() << "invalid to use 'test.invalid_attr'"; 223 return success(); 224 } 225 226 LogicalResult 227 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 228 unsigned resultIndex, 229 NamedAttribute namedAttr) { 230 if (namedAttr.first == "test.invalid_attr") 231 return op->emitError() << "invalid to use 'test.invalid_attr'"; 232 return success(); 233 } 234 235 //===----------------------------------------------------------------------===// 236 // TestBranchOp 237 //===----------------------------------------------------------------------===// 238 239 Optional<MutableOperandRange> 240 TestBranchOp::getMutableSuccessorOperands(unsigned index) { 241 assert(index == 0 && "invalid successor index"); 242 return targetOperandsMutable(); 243 } 244 245 //===----------------------------------------------------------------------===// 246 // TestFoldToCallOp 247 //===----------------------------------------------------------------------===// 248 249 namespace { 250 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 251 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 252 253 LogicalResult matchAndRewrite(FoldToCallOp op, 254 PatternRewriter &rewriter) const override { 255 rewriter.replaceOpWithNewOp<CallOp>(op, ArrayRef<Type>(), op.calleeAttr(), 256 ValueRange()); 257 return success(); 258 } 259 }; 260 } // end anonymous namespace 261 262 void FoldToCallOp::getCanonicalizationPatterns( 263 OwningRewritePatternList &results, MLIRContext *context) { 264 results.insert<FoldToCallOpPattern>(context); 265 } 266 267 //===----------------------------------------------------------------------===// 268 // Test IsolatedRegionOp - parse passthrough region arguments. 269 //===----------------------------------------------------------------------===// 270 271 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, 272 OperationState &result) { 273 OpAsmParser::OperandType argInfo; 274 Type argType = parser.getBuilder().getIndexType(); 275 276 // Parse the input operand. 277 if (parser.parseOperand(argInfo) || 278 parser.resolveOperand(argInfo, argType, result.operands)) 279 return failure(); 280 281 // Parse the body region, and reuse the operand info as the argument info. 282 Region *body = result.addRegion(); 283 return parser.parseRegion(*body, argInfo, argType, 284 /*enableNameShadowing=*/true); 285 } 286 287 static void print(OpAsmPrinter &p, IsolatedRegionOp op) { 288 p << "test.isolated_region "; 289 p.printOperand(op.getOperand()); 290 p.shadowRegionArgs(op.region(), op.getOperand()); 291 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 292 } 293 294 //===----------------------------------------------------------------------===// 295 // Test SSACFGRegionOp 296 //===----------------------------------------------------------------------===// 297 298 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 299 return RegionKind::SSACFG; 300 } 301 302 //===----------------------------------------------------------------------===// 303 // Test GraphRegionOp 304 //===----------------------------------------------------------------------===// 305 306 static ParseResult parseGraphRegionOp(OpAsmParser &parser, 307 OperationState &result) { 308 // Parse the body region, and reuse the operand info as the argument info. 309 Region *body = result.addRegion(); 310 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 311 } 312 313 static void print(OpAsmPrinter &p, GraphRegionOp op) { 314 p << "test.graph_region "; 315 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 316 } 317 318 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 319 return RegionKind::Graph; 320 } 321 322 //===----------------------------------------------------------------------===// 323 // Test AffineScopeOp 324 //===----------------------------------------------------------------------===// 325 326 static ParseResult parseAffineScopeOp(OpAsmParser &parser, 327 OperationState &result) { 328 // Parse the body region, and reuse the operand info as the argument info. 329 Region *body = result.addRegion(); 330 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 331 } 332 333 static void print(OpAsmPrinter &p, AffineScopeOp op) { 334 p << "test.affine_scope "; 335 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 336 } 337 338 //===----------------------------------------------------------------------===// 339 // Test parser. 340 //===----------------------------------------------------------------------===// 341 342 static ParseResult parseWrappedKeywordOp(OpAsmParser &parser, 343 OperationState &result) { 344 StringRef keyword; 345 if (parser.parseKeyword(&keyword)) 346 return failure(); 347 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 348 return success(); 349 } 350 351 static void print(OpAsmPrinter &p, WrappedKeywordOp op) { 352 p << WrappedKeywordOp::getOperationName() << " " << op.keyword(); 353 } 354 355 //===----------------------------------------------------------------------===// 356 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 357 358 static ParseResult parseWrappingRegionOp(OpAsmParser &parser, 359 OperationState &result) { 360 if (parser.parseKeyword("wraps")) 361 return failure(); 362 363 // Parse the wrapped op in a region 364 Region &body = *result.addRegion(); 365 body.push_back(new Block); 366 Block &block = body.back(); 367 Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); 368 if (!wrapped_op) 369 return failure(); 370 371 // Create a return terminator in the inner region, pass as operand to the 372 // terminator the returned values from the wrapped operation. 373 SmallVector<Value, 8> return_operands(wrapped_op->getResults()); 374 OpBuilder builder(parser.getBuilder().getContext()); 375 builder.setInsertionPointToEnd(&block); 376 builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands); 377 378 // Get the results type for the wrapping op from the terminator operands. 379 Operation &return_op = body.back().back(); 380 result.types.append(return_op.operand_type_begin(), 381 return_op.operand_type_end()); 382 383 // Use the location of the wrapped op for the "test.wrapping_region" op. 384 result.location = wrapped_op->getLoc(); 385 386 return success(); 387 } 388 389 static void print(OpAsmPrinter &p, WrappingRegionOp op) { 390 p << op.getOperationName() << " wraps "; 391 p.printGenericOp(&op.region().front().front()); 392 } 393 394 //===----------------------------------------------------------------------===// 395 // Test PolyForOp - parse list of region arguments. 396 //===----------------------------------------------------------------------===// 397 398 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { 399 SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 400 // Parse list of region arguments without a delimiter. 401 if (parser.parseRegionArgumentList(ivsInfo)) 402 return failure(); 403 404 // Parse the body region. 405 Region *body = result.addRegion(); 406 auto &builder = parser.getBuilder(); 407 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 408 return parser.parseRegion(*body, ivsInfo, argTypes); 409 } 410 411 //===----------------------------------------------------------------------===// 412 // Test removing op with inner ops. 413 //===----------------------------------------------------------------------===// 414 415 namespace { 416 struct TestRemoveOpWithInnerOps 417 : public OpRewritePattern<TestOpWithRegionPattern> { 418 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 419 420 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 421 PatternRewriter &rewriter) const override { 422 rewriter.eraseOp(op); 423 return success(); 424 } 425 }; 426 } // end anonymous namespace 427 428 void TestOpWithRegionPattern::getCanonicalizationPatterns( 429 OwningRewritePatternList &results, MLIRContext *context) { 430 results.insert<TestRemoveOpWithInnerOps>(context); 431 } 432 433 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 434 return operand(); 435 } 436 437 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 438 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 439 for (Value input : this->operands()) { 440 results.push_back(input); 441 } 442 return success(); 443 } 444 445 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 446 assert(operands.size() == 1); 447 if (operands.front()) { 448 setAttr("attr", operands.front()); 449 return getResult(); 450 } 451 return {}; 452 } 453 454 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 455 MLIRContext *, Optional<Location> location, ValueRange operands, 456 DictionaryAttr attributes, RegionRange regions, 457 SmallVectorImpl<Type> &inferredReturnTypes) { 458 if (operands[0].getType() != operands[1].getType()) { 459 return emitOptionalError(location, "operand type mismatch ", 460 operands[0].getType(), " vs ", 461 operands[1].getType()); 462 } 463 inferredReturnTypes.assign({operands[0].getType()}); 464 return success(); 465 } 466 467 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 468 MLIRContext *context, Optional<Location> location, ValueRange operands, 469 DictionaryAttr attributes, RegionRange regions, 470 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 471 // Create return type consisting of the last element of the first operand. 472 auto operandType = *operands.getTypes().begin(); 473 auto sval = operandType.dyn_cast<ShapedType>(); 474 if (!sval) { 475 return emitOptionalError(location, "only shaped type operands allowed"); 476 } 477 int64_t dim = 478 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 479 auto type = IntegerType::get(17, context); 480 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 481 return success(); 482 } 483 484 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 485 OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) { 486 shapes = SmallVector<Value, 1>{ 487 builder.createOrFold<DimOp>(getLoc(), getOperand(0), 0)}; 488 return success(); 489 } 490 491 //===----------------------------------------------------------------------===// 492 // Test SideEffect interfaces 493 //===----------------------------------------------------------------------===// 494 495 namespace { 496 /// A test resource for side effects. 497 struct TestResource : public SideEffects::Resource::Base<TestResource> { 498 StringRef getName() final { return "<Test>"; } 499 }; 500 } // end anonymous namespace 501 502 void SideEffectOp::getEffects( 503 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 504 // Check for an effects attribute on the op instance. 505 ArrayAttr effectsAttr = getAttrOfType<ArrayAttr>("effects"); 506 if (!effectsAttr) 507 return; 508 509 // If there is one, it is an array of dictionary attributes that hold 510 // information on the effects of this operation. 511 for (Attribute element : effectsAttr) { 512 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 513 514 // Get the specific memory effect. 515 MemoryEffects::Effect *effect = 516 llvm::StringSwitch<MemoryEffects::Effect *>( 517 effectElement.get("effect").cast<StringAttr>().getValue()) 518 .Case("allocate", MemoryEffects::Allocate::get()) 519 .Case("free", MemoryEffects::Free::get()) 520 .Case("read", MemoryEffects::Read::get()) 521 .Case("write", MemoryEffects::Write::get()); 522 523 // Check for a result to affect. 524 Value value; 525 if (effectElement.get("on_result")) 526 value = getResult(); 527 528 // Check for a non-default resource to use. 529 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 530 if (effectElement.get("test_resource")) 531 resource = TestResource::get(); 532 533 effects.emplace_back(effect, value, resource); 534 } 535 } 536 537 //===----------------------------------------------------------------------===// 538 // StringAttrPrettyNameOp 539 //===----------------------------------------------------------------------===// 540 541 // This op has fancy handling of its SSA result name. 542 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, 543 OperationState &result) { 544 // Add the result types. 545 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 546 result.addTypes(parser.getBuilder().getIntegerType(32)); 547 548 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 549 return failure(); 550 551 // If the attribute dictionary contains no 'names' attribute, infer it from 552 // the SSA name (if specified). 553 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 554 return attr.first == "names"; 555 }); 556 557 // If there was no name specified, check to see if there was a useful name 558 // specified in the asm file. 559 if (hadNames || parser.getNumResults() == 0) 560 return success(); 561 562 SmallVector<StringRef, 4> names; 563 auto *context = result.getContext(); 564 565 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 566 auto resultName = parser.getResultName(i); 567 StringRef nameStr; 568 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 569 nameStr = resultName.first; 570 571 names.push_back(nameStr); 572 } 573 574 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 575 result.attributes.push_back({Identifier::get("names", context), namesAttr}); 576 return success(); 577 } 578 579 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { 580 p << "test.string_attr_pretty_name"; 581 582 // Note that we only need to print the "name" attribute if the asmprinter 583 // result name disagrees with it. This can happen in strange cases, e.g. 584 // when there are conflicts. 585 bool namesDisagree = op.names().size() != op.getNumResults(); 586 587 SmallString<32> resultNameStr; 588 for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { 589 resultNameStr.clear(); 590 llvm::raw_svector_ostream tmpStream(resultNameStr); 591 p.printOperand(op.getResult(i), tmpStream); 592 593 auto expectedName = op.names()[i].dyn_cast<StringAttr>(); 594 if (!expectedName || 595 tmpStream.str().drop_front() != expectedName.getValue()) { 596 namesDisagree = true; 597 } 598 } 599 600 if (namesDisagree) 601 p.printOptionalAttrDictWithKeyword(op.getAttrs()); 602 else 603 p.printOptionalAttrDictWithKeyword(op.getAttrs(), {"names"}); 604 } 605 606 // We set the SSA name in the asm syntax to the contents of the name 607 // attribute. 608 void StringAttrPrettyNameOp::getAsmResultNames( 609 function_ref<void(Value, StringRef)> setNameFn) { 610 611 auto value = names(); 612 for (size_t i = 0, e = value.size(); i != e; ++i) 613 if (auto str = value[i].dyn_cast<StringAttr>()) 614 if (!str.getValue().empty()) 615 setNameFn(getResult(i), str.getValue()); 616 } 617 618 //===----------------------------------------------------------------------===// 619 // RegionIfOp 620 //===----------------------------------------------------------------------===// 621 622 static void print(OpAsmPrinter &p, RegionIfOp op) { 623 p << RegionIfOp::getOperationName() << " "; 624 p.printOperands(op.getOperands()); 625 p << ": " << op.getOperandTypes(); 626 p.printArrowTypeList(op.getResultTypes()); 627 p << " then"; 628 p.printRegion(op.thenRegion(), 629 /*printEntryBlockArgs=*/true, 630 /*printBlockTerminators=*/true); 631 p << " else"; 632 p.printRegion(op.elseRegion(), 633 /*printEntryBlockArgs=*/true, 634 /*printBlockTerminators=*/true); 635 p << " join"; 636 p.printRegion(op.joinRegion(), 637 /*printEntryBlockArgs=*/true, 638 /*printBlockTerminators=*/true); 639 } 640 641 static ParseResult parseRegionIfOp(OpAsmParser &parser, 642 OperationState &result) { 643 SmallVector<OpAsmParser::OperandType, 2> operandInfos; 644 SmallVector<Type, 2> operandTypes; 645 646 result.regions.reserve(3); 647 Region *thenRegion = result.addRegion(); 648 Region *elseRegion = result.addRegion(); 649 Region *joinRegion = result.addRegion(); 650 651 // Parse operand, type and arrow type lists. 652 if (parser.parseOperandList(operandInfos) || 653 parser.parseColonTypeList(operandTypes) || 654 parser.parseArrowTypeList(result.types)) 655 return failure(); 656 657 // Parse all attached regions. 658 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 659 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 660 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 661 return failure(); 662 663 return parser.resolveOperands(operandInfos, operandTypes, 664 parser.getCurrentLocation(), result.operands); 665 } 666 667 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 668 assert(index < 2 && "invalid region index"); 669 return getOperands(); 670 } 671 672 void RegionIfOp::getSuccessorRegions( 673 Optional<unsigned> index, ArrayRef<Attribute> operands, 674 SmallVectorImpl<RegionSuccessor> ®ions) { 675 // We always branch to the join region. 676 if (index.hasValue()) { 677 if (index.getValue() < 2) 678 regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs())); 679 else 680 regions.push_back(RegionSuccessor(getResults())); 681 return; 682 } 683 684 // The then and else regions are the entry regions of this op. 685 regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs())); 686 regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs())); 687 } 688 689 //===----------------------------------------------------------------------===// 690 // Dialect Registration 691 //===----------------------------------------------------------------------===// 692 693 // Static initialization for Test dialect registration. 694 static DialectRegistration<TestDialect> testDialect; 695 696 #include "TestOpEnums.cpp.inc" 697 #include "TestOpStructs.cpp.inc" 698 #include "TestTypeInterfaces.cpp.inc" 699 700 #define GET_OP_CLASSES 701 #include "TestOps.cpp.inc" 702