1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestTypes.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/IR/DialectImplementation.h" 13 #include "mlir/IR/Function.h" 14 #include "mlir/IR/Module.h" 15 #include "mlir/IR/PatternMatch.h" 16 #include "mlir/IR/TypeUtilities.h" 17 #include "mlir/Transforms/FoldUtils.h" 18 #include "mlir/Transforms/InliningUtils.h" 19 #include "llvm/ADT/SetVector.h" 20 #include "llvm/ADT/StringSwitch.h" 21 22 using namespace mlir; 23 24 //===----------------------------------------------------------------------===// 25 // TestDialect Interfaces 26 //===----------------------------------------------------------------------===// 27 28 namespace { 29 30 // Test support for interacting with the AsmPrinter. 31 struct TestOpAsmInterface : public OpAsmDialectInterface { 32 using OpAsmDialectInterface::OpAsmDialectInterface; 33 34 void getAsmResultNames(Operation *op, 35 OpAsmSetValueNameFn setNameFn) const final { 36 if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) 37 setNameFn(asmOp, "result"); 38 } 39 40 void getAsmBlockArgumentNames(Block *block, 41 OpAsmSetValueNameFn setNameFn) const final { 42 auto op = block->getParentOp(); 43 auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names"); 44 if (!arrayAttr) 45 return; 46 auto args = block->getArguments(); 47 auto e = std::min(arrayAttr.size(), args.size()); 48 for (unsigned i = 0; i < e; ++i) { 49 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 50 setNameFn(args[i], strAttr.getValue()); 51 } 52 } 53 }; 54 55 struct TestDialectFoldInterface : public DialectFoldInterface { 56 using DialectFoldInterface::DialectFoldInterface; 57 58 /// Registered hook to check if the given region, which is attached to an 59 /// operation that is *not* isolated from above, should be used when 60 /// materializing constants. 61 bool shouldMaterializeInto(Region *region) const final { 62 // If this is a one region operation, then insert into it. 63 return isa<OneRegionOp>(region->getParentOp()); 64 } 65 }; 66 67 /// This class defines the interface for handling inlining with standard 68 /// operations. 69 struct TestInlinerInterface : public DialectInlinerInterface { 70 using DialectInlinerInterface::DialectInlinerInterface; 71 72 //===--------------------------------------------------------------------===// 73 // Analysis Hooks 74 //===--------------------------------------------------------------------===// 75 76 bool isLegalToInline(Region *, Region *, BlockAndValueMapping &) const final { 77 // Inlining into test dialect regions is legal. 78 return true; 79 } 80 bool isLegalToInline(Operation *, Region *, 81 BlockAndValueMapping &) const final { 82 return true; 83 } 84 85 bool shouldAnalyzeRecursively(Operation *op) const final { 86 // Analyze recursively if this is not a functional region operation, it 87 // froms a separate functional scope. 88 return !isa<FunctionalRegionOp>(op); 89 } 90 91 //===--------------------------------------------------------------------===// 92 // Transformation Hooks 93 //===--------------------------------------------------------------------===// 94 95 /// Handle the given inlined terminator by replacing it with a new operation 96 /// as necessary. 97 void handleTerminator(Operation *op, 98 ArrayRef<Value> valuesToRepl) const final { 99 // Only handle "test.return" here. 100 auto returnOp = dyn_cast<TestReturnOp>(op); 101 if (!returnOp) 102 return; 103 104 // Replace the values directly with the return operands. 105 assert(returnOp.getNumOperands() == valuesToRepl.size()); 106 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 107 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 108 } 109 110 /// Attempt to materialize a conversion for a type mismatch between a call 111 /// from this dialect, and a callable region. This method should generate an 112 /// operation that takes 'input' as the only operand, and produces a single 113 /// result of 'resultType'. If a conversion can not be generated, nullptr 114 /// should be returned. 115 Operation *materializeCallConversion(OpBuilder &builder, Value input, 116 Type resultType, 117 Location conversionLoc) const final { 118 // Only allow conversion for i16/i32 types. 119 if (!(resultType.isSignlessInteger(16) || 120 resultType.isSignlessInteger(32)) || 121 !(input.getType().isSignlessInteger(16) || 122 input.getType().isSignlessInteger(32))) 123 return nullptr; 124 return builder.create<TestCastOp>(conversionLoc, resultType, input); 125 } 126 }; 127 } // end anonymous namespace 128 129 //===----------------------------------------------------------------------===// 130 // TestDialect 131 //===----------------------------------------------------------------------===// 132 133 void TestDialect::initialize() { 134 addOperations< 135 #define GET_OP_LIST 136 #include "TestOps.cpp.inc" 137 >(); 138 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 139 TestInlinerInterface>(); 140 addTypes<TestType, TestRecursiveType>(); 141 allowUnknownOperations(); 142 } 143 144 static Type parseTestType(DialectAsmParser &parser, 145 llvm::SetVector<Type> &stack) { 146 StringRef typeTag; 147 if (failed(parser.parseKeyword(&typeTag))) 148 return Type(); 149 150 if (typeTag == "test_type") 151 return TestType::get(parser.getBuilder().getContext()); 152 153 if (typeTag != "test_rec") 154 return Type(); 155 156 StringRef name; 157 if (parser.parseLess() || parser.parseKeyword(&name)) 158 return Type(); 159 auto rec = TestRecursiveType::create(parser.getBuilder().getContext(), name); 160 161 // If this type already has been parsed above in the stack, expect just the 162 // name. 163 if (stack.contains(rec)) { 164 if (failed(parser.parseGreater())) 165 return Type(); 166 return rec; 167 } 168 169 // Otherwise, parse the body and update the type. 170 if (failed(parser.parseComma())) 171 return Type(); 172 stack.insert(rec); 173 Type subtype = parseTestType(parser, stack); 174 stack.pop_back(); 175 if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype))) 176 return Type(); 177 178 return rec; 179 } 180 181 Type TestDialect::parseType(DialectAsmParser &parser) const { 182 llvm::SetVector<Type> stack; 183 return parseTestType(parser, stack); 184 } 185 186 static void printTestType(Type type, DialectAsmPrinter &printer, 187 llvm::SetVector<Type> &stack) { 188 if (type.isa<TestType>()) { 189 printer << "test_type"; 190 return; 191 } 192 193 auto rec = type.cast<TestRecursiveType>(); 194 printer << "test_rec<" << rec.getName(); 195 if (!stack.contains(rec)) { 196 printer << ", "; 197 stack.insert(rec); 198 printTestType(rec.getBody(), printer, stack); 199 stack.pop_back(); 200 } 201 printer << ">"; 202 } 203 204 void TestDialect::printType(Type type, DialectAsmPrinter &printer) const { 205 llvm::SetVector<Type> stack; 206 printTestType(type, printer, stack); 207 } 208 209 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 210 NamedAttribute namedAttr) { 211 if (namedAttr.first == "test.invalid_attr") 212 return op->emitError() << "invalid to use 'test.invalid_attr'"; 213 return success(); 214 } 215 216 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 217 unsigned regionIndex, 218 unsigned argIndex, 219 NamedAttribute namedAttr) { 220 if (namedAttr.first == "test.invalid_attr") 221 return op->emitError() << "invalid to use 'test.invalid_attr'"; 222 return success(); 223 } 224 225 LogicalResult 226 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 227 unsigned resultIndex, 228 NamedAttribute namedAttr) { 229 if (namedAttr.first == "test.invalid_attr") 230 return op->emitError() << "invalid to use 'test.invalid_attr'"; 231 return success(); 232 } 233 234 //===----------------------------------------------------------------------===// 235 // TestBranchOp 236 //===----------------------------------------------------------------------===// 237 238 Optional<MutableOperandRange> 239 TestBranchOp::getMutableSuccessorOperands(unsigned index) { 240 assert(index == 0 && "invalid successor index"); 241 return targetOperandsMutable(); 242 } 243 244 //===----------------------------------------------------------------------===// 245 // TestFoldToCallOp 246 //===----------------------------------------------------------------------===// 247 248 namespace { 249 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 250 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 251 252 LogicalResult matchAndRewrite(FoldToCallOp op, 253 PatternRewriter &rewriter) const override { 254 rewriter.replaceOpWithNewOp<CallOp>(op, ArrayRef<Type>(), op.calleeAttr(), 255 ValueRange()); 256 return success(); 257 } 258 }; 259 } // end anonymous namespace 260 261 void FoldToCallOp::getCanonicalizationPatterns( 262 OwningRewritePatternList &results, MLIRContext *context) { 263 results.insert<FoldToCallOpPattern>(context); 264 } 265 266 //===----------------------------------------------------------------------===// 267 // Test IsolatedRegionOp - parse passthrough region arguments. 268 //===----------------------------------------------------------------------===// 269 270 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, 271 OperationState &result) { 272 OpAsmParser::OperandType argInfo; 273 Type argType = parser.getBuilder().getIndexType(); 274 275 // Parse the input operand. 276 if (parser.parseOperand(argInfo) || 277 parser.resolveOperand(argInfo, argType, result.operands)) 278 return failure(); 279 280 // Parse the body region, and reuse the operand info as the argument info. 281 Region *body = result.addRegion(); 282 return parser.parseRegion(*body, argInfo, argType, 283 /*enableNameShadowing=*/true); 284 } 285 286 static void print(OpAsmPrinter &p, IsolatedRegionOp op) { 287 p << "test.isolated_region "; 288 p.printOperand(op.getOperand()); 289 p.shadowRegionArgs(op.region(), op.getOperand()); 290 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 291 } 292 293 //===----------------------------------------------------------------------===// 294 // Test SSACFGRegionOp 295 //===----------------------------------------------------------------------===// 296 297 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 298 return RegionKind::SSACFG; 299 } 300 301 //===----------------------------------------------------------------------===// 302 // Test GraphRegionOp 303 //===----------------------------------------------------------------------===// 304 305 static ParseResult parseGraphRegionOp(OpAsmParser &parser, 306 OperationState &result) { 307 // Parse the body region, and reuse the operand info as the argument info. 308 Region *body = result.addRegion(); 309 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 310 } 311 312 static void print(OpAsmPrinter &p, GraphRegionOp op) { 313 p << "test.graph_region "; 314 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 315 } 316 317 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 318 return RegionKind::Graph; 319 } 320 321 //===----------------------------------------------------------------------===// 322 // Test AffineScopeOp 323 //===----------------------------------------------------------------------===// 324 325 static ParseResult parseAffineScopeOp(OpAsmParser &parser, 326 OperationState &result) { 327 // Parse the body region, and reuse the operand info as the argument info. 328 Region *body = result.addRegion(); 329 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 330 } 331 332 static void print(OpAsmPrinter &p, AffineScopeOp op) { 333 p << "test.affine_scope "; 334 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 335 } 336 337 //===----------------------------------------------------------------------===// 338 // Test parser. 339 //===----------------------------------------------------------------------===// 340 341 static ParseResult parseWrappedKeywordOp(OpAsmParser &parser, 342 OperationState &result) { 343 StringRef keyword; 344 if (parser.parseKeyword(&keyword)) 345 return failure(); 346 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 347 return success(); 348 } 349 350 static void print(OpAsmPrinter &p, WrappedKeywordOp op) { 351 p << WrappedKeywordOp::getOperationName() << " " << op.keyword(); 352 } 353 354 //===----------------------------------------------------------------------===// 355 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 356 357 static ParseResult parseWrappingRegionOp(OpAsmParser &parser, 358 OperationState &result) { 359 if (parser.parseKeyword("wraps")) 360 return failure(); 361 362 // Parse the wrapped op in a region 363 Region &body = *result.addRegion(); 364 body.push_back(new Block); 365 Block &block = body.back(); 366 Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); 367 if (!wrapped_op) 368 return failure(); 369 370 // Create a return terminator in the inner region, pass as operand to the 371 // terminator the returned values from the wrapped operation. 372 SmallVector<Value, 8> return_operands(wrapped_op->getResults()); 373 OpBuilder builder(parser.getBuilder().getContext()); 374 builder.setInsertionPointToEnd(&block); 375 builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands); 376 377 // Get the results type for the wrapping op from the terminator operands. 378 Operation &return_op = body.back().back(); 379 result.types.append(return_op.operand_type_begin(), 380 return_op.operand_type_end()); 381 382 // Use the location of the wrapped op for the "test.wrapping_region" op. 383 result.location = wrapped_op->getLoc(); 384 385 return success(); 386 } 387 388 static void print(OpAsmPrinter &p, WrappingRegionOp op) { 389 p << op.getOperationName() << " wraps "; 390 p.printGenericOp(&op.region().front().front()); 391 } 392 393 //===----------------------------------------------------------------------===// 394 // Test PolyForOp - parse list of region arguments. 395 //===----------------------------------------------------------------------===// 396 397 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { 398 SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 399 // Parse list of region arguments without a delimiter. 400 if (parser.parseRegionArgumentList(ivsInfo)) 401 return failure(); 402 403 // Parse the body region. 404 Region *body = result.addRegion(); 405 auto &builder = parser.getBuilder(); 406 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 407 return parser.parseRegion(*body, ivsInfo, argTypes); 408 } 409 410 //===----------------------------------------------------------------------===// 411 // Test removing op with inner ops. 412 //===----------------------------------------------------------------------===// 413 414 namespace { 415 struct TestRemoveOpWithInnerOps 416 : public OpRewritePattern<TestOpWithRegionPattern> { 417 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 418 419 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 420 PatternRewriter &rewriter) const override { 421 rewriter.eraseOp(op); 422 return success(); 423 } 424 }; 425 } // end anonymous namespace 426 427 void TestOpWithRegionPattern::getCanonicalizationPatterns( 428 OwningRewritePatternList &results, MLIRContext *context) { 429 results.insert<TestRemoveOpWithInnerOps>(context); 430 } 431 432 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 433 return operand(); 434 } 435 436 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 437 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 438 for (Value input : this->operands()) { 439 results.push_back(input); 440 } 441 return success(); 442 } 443 444 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 445 assert(operands.size() == 1); 446 if (operands.front()) { 447 setAttr("attr", operands.front()); 448 return getResult(); 449 } 450 return {}; 451 } 452 453 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 454 MLIRContext *, Optional<Location> location, ValueRange operands, 455 DictionaryAttr attributes, RegionRange regions, 456 SmallVectorImpl<Type> &inferredReturnTypes) { 457 if (operands[0].getType() != operands[1].getType()) { 458 return emitOptionalError(location, "operand type mismatch ", 459 operands[0].getType(), " vs ", 460 operands[1].getType()); 461 } 462 inferredReturnTypes.assign({operands[0].getType()}); 463 return success(); 464 } 465 466 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 467 MLIRContext *context, Optional<Location> location, ValueRange operands, 468 DictionaryAttr attributes, RegionRange regions, 469 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 470 // Create return type consisting of the last element of the first operand. 471 auto operandType = *operands.getTypes().begin(); 472 auto sval = operandType.dyn_cast<ShapedType>(); 473 if (!sval) { 474 return emitOptionalError(location, "only shaped type operands allowed"); 475 } 476 int64_t dim = 477 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 478 auto type = IntegerType::get(17, context); 479 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 480 return success(); 481 } 482 483 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 484 OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) { 485 shapes = SmallVector<Value, 1>{ 486 builder.createOrFold<DimOp>(getLoc(), getOperand(0), 0)}; 487 return success(); 488 } 489 490 //===----------------------------------------------------------------------===// 491 // Test SideEffect interfaces 492 //===----------------------------------------------------------------------===// 493 494 namespace { 495 /// A test resource for side effects. 496 struct TestResource : public SideEffects::Resource::Base<TestResource> { 497 StringRef getName() final { return "<Test>"; } 498 }; 499 } // end anonymous namespace 500 501 void SideEffectOp::getEffects( 502 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 503 // Check for an effects attribute on the op instance. 504 ArrayAttr effectsAttr = getAttrOfType<ArrayAttr>("effects"); 505 if (!effectsAttr) 506 return; 507 508 // If there is one, it is an array of dictionary attributes that hold 509 // information on the effects of this operation. 510 for (Attribute element : effectsAttr) { 511 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 512 513 // Get the specific memory effect. 514 MemoryEffects::Effect *effect = 515 llvm::StringSwitch<MemoryEffects::Effect *>( 516 effectElement.get("effect").cast<StringAttr>().getValue()) 517 .Case("allocate", MemoryEffects::Allocate::get()) 518 .Case("free", MemoryEffects::Free::get()) 519 .Case("read", MemoryEffects::Read::get()) 520 .Case("write", MemoryEffects::Write::get()); 521 522 // Check for a result to affect. 523 Value value; 524 if (effectElement.get("on_result")) 525 value = getResult(); 526 527 // Check for a non-default resource to use. 528 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 529 if (effectElement.get("test_resource")) 530 resource = TestResource::get(); 531 532 effects.emplace_back(effect, value, resource); 533 } 534 } 535 536 //===----------------------------------------------------------------------===// 537 // StringAttrPrettyNameOp 538 //===----------------------------------------------------------------------===// 539 540 // This op has fancy handling of its SSA result name. 541 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, 542 OperationState &result) { 543 // Add the result types. 544 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 545 result.addTypes(parser.getBuilder().getIntegerType(32)); 546 547 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 548 return failure(); 549 550 // If the attribute dictionary contains no 'names' attribute, infer it from 551 // the SSA name (if specified). 552 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 553 return attr.first == "names"; 554 }); 555 556 // If there was no name specified, check to see if there was a useful name 557 // specified in the asm file. 558 if (hadNames || parser.getNumResults() == 0) 559 return success(); 560 561 SmallVector<StringRef, 4> names; 562 auto *context = result.getContext(); 563 564 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 565 auto resultName = parser.getResultName(i); 566 StringRef nameStr; 567 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 568 nameStr = resultName.first; 569 570 names.push_back(nameStr); 571 } 572 573 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 574 result.attributes.push_back({Identifier::get("names", context), namesAttr}); 575 return success(); 576 } 577 578 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { 579 p << "test.string_attr_pretty_name"; 580 581 // Note that we only need to print the "name" attribute if the asmprinter 582 // result name disagrees with it. This can happen in strange cases, e.g. 583 // when there are conflicts. 584 bool namesDisagree = op.names().size() != op.getNumResults(); 585 586 SmallString<32> resultNameStr; 587 for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { 588 resultNameStr.clear(); 589 llvm::raw_svector_ostream tmpStream(resultNameStr); 590 p.printOperand(op.getResult(i), tmpStream); 591 592 auto expectedName = op.names()[i].dyn_cast<StringAttr>(); 593 if (!expectedName || 594 tmpStream.str().drop_front() != expectedName.getValue()) { 595 namesDisagree = true; 596 } 597 } 598 599 if (namesDisagree) 600 p.printOptionalAttrDictWithKeyword(op.getAttrs()); 601 else 602 p.printOptionalAttrDictWithKeyword(op.getAttrs(), {"names"}); 603 } 604 605 // We set the SSA name in the asm syntax to the contents of the name 606 // attribute. 607 void StringAttrPrettyNameOp::getAsmResultNames( 608 function_ref<void(Value, StringRef)> setNameFn) { 609 610 auto value = names(); 611 for (size_t i = 0, e = value.size(); i != e; ++i) 612 if (auto str = value[i].dyn_cast<StringAttr>()) 613 if (!str.getValue().empty()) 614 setNameFn(getResult(i), str.getValue()); 615 } 616 617 //===----------------------------------------------------------------------===// 618 // RegionIfOp 619 //===----------------------------------------------------------------------===// 620 621 static void print(OpAsmPrinter &p, RegionIfOp op) { 622 p << RegionIfOp::getOperationName() << " "; 623 p.printOperands(op.getOperands()); 624 p << ": " << op.getOperandTypes(); 625 p.printArrowTypeList(op.getResultTypes()); 626 p << " then"; 627 p.printRegion(op.thenRegion(), 628 /*printEntryBlockArgs=*/true, 629 /*printBlockTerminators=*/true); 630 p << " else"; 631 p.printRegion(op.elseRegion(), 632 /*printEntryBlockArgs=*/true, 633 /*printBlockTerminators=*/true); 634 p << " join"; 635 p.printRegion(op.joinRegion(), 636 /*printEntryBlockArgs=*/true, 637 /*printBlockTerminators=*/true); 638 } 639 640 static ParseResult parseRegionIfOp(OpAsmParser &parser, 641 OperationState &result) { 642 SmallVector<OpAsmParser::OperandType, 2> operandInfos; 643 SmallVector<Type, 2> operandTypes; 644 645 result.regions.reserve(3); 646 Region *thenRegion = result.addRegion(); 647 Region *elseRegion = result.addRegion(); 648 Region *joinRegion = result.addRegion(); 649 650 // Parse operand, type and arrow type lists. 651 if (parser.parseOperandList(operandInfos) || 652 parser.parseColonTypeList(operandTypes) || 653 parser.parseArrowTypeList(result.types)) 654 return failure(); 655 656 // Parse all attached regions. 657 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 658 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 659 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 660 return failure(); 661 662 return parser.resolveOperands(operandInfos, operandTypes, 663 parser.getCurrentLocation(), result.operands); 664 } 665 666 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 667 assert(index < 2 && "invalid region index"); 668 return getOperands(); 669 } 670 671 void RegionIfOp::getSuccessorRegions( 672 Optional<unsigned> index, ArrayRef<Attribute> operands, 673 SmallVectorImpl<RegionSuccessor> ®ions) { 674 // We always branch to the join region. 675 if (index.hasValue()) { 676 if (index.getValue() < 2) 677 regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs())); 678 else 679 regions.push_back(RegionSuccessor(getResults())); 680 return; 681 } 682 683 // The then and else regions are the entry regions of this op. 684 regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs())); 685 regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs())); 686 } 687 688 //===----------------------------------------------------------------------===// 689 // Dialect Registration 690 //===----------------------------------------------------------------------===// 691 692 // Static initialization for Test dialect registration. 693 static DialectRegistration<TestDialect> testDialect; 694 695 #include "TestOpEnums.cpp.inc" 696 #include "TestOpStructs.cpp.inc" 697 #include "TestTypeInterfaces.cpp.inc" 698 699 #define GET_OP_CLASSES 700 #include "TestOps.cpp.inc" 701