1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestAttributes.h" 11 #include "TestInterfaces.h" 12 #include "TestTypes.h" 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/DLTI/DLTI.h" 15 #include "mlir/Dialect/Func/IR/FuncOps.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/DialectImplementation.h" 19 #include "mlir/IR/ExtensibleDialect.h" 20 #include "mlir/IR/PatternMatch.h" 21 #include "mlir/IR/TypeUtilities.h" 22 #include "mlir/Reducer/ReductionPatternInterface.h" 23 #include "mlir/Transforms/FoldUtils.h" 24 #include "mlir/Transforms/InliningUtils.h" 25 #include "llvm/ADT/StringSwitch.h" 26 27 // Include this before the using namespace lines below to 28 // test that we don't have namespace dependencies. 29 #include "TestOpsDialect.cpp.inc" 30 31 using namespace mlir; 32 using namespace test; 33 34 void test::registerTestDialect(DialectRegistry ®istry) { 35 registry.insert<TestDialect>(); 36 } 37 38 //===----------------------------------------------------------------------===// 39 // TestDialect Interfaces 40 //===----------------------------------------------------------------------===// 41 42 namespace { 43 44 /// Testing the correctness of some traits. 45 static_assert( 46 llvm::is_detected<OpTrait::has_implicit_terminator_t, 47 SingleBlockImplicitTerminatorOp>::value, 48 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp"); 49 static_assert(OpTrait::hasSingleBlockImplicitTerminator< 50 SingleBlockImplicitTerminatorOp>::value, 51 "hasSingleBlockImplicitTerminator does not match " 52 "SingleBlockImplicitTerminatorOp"); 53 54 // Test support for interacting with the AsmPrinter. 55 struct TestOpAsmInterface : public OpAsmDialectInterface { 56 using OpAsmDialectInterface::OpAsmDialectInterface; 57 58 AliasResult getAlias(Attribute attr, raw_ostream &os) const final { 59 StringAttr strAttr = attr.dyn_cast<StringAttr>(); 60 if (!strAttr) 61 return AliasResult::NoAlias; 62 63 // Check the contents of the string attribute to see what the test alias 64 // should be named. 65 Optional<StringRef> aliasName = 66 StringSwitch<Optional<StringRef>>(strAttr.getValue()) 67 .Case("alias_test:dot_in_name", StringRef("test.alias")) 68 .Case("alias_test:trailing_digit", StringRef("test_alias0")) 69 .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) 70 .Case("alias_test:sanitize_conflict_a", 71 StringRef("test_alias_conflict0")) 72 .Case("alias_test:sanitize_conflict_b", 73 StringRef("test_alias_conflict0_")) 74 .Case("alias_test:tensor_encoding", StringRef("test_encoding")) 75 .Default(llvm::None); 76 if (!aliasName) 77 return AliasResult::NoAlias; 78 79 os << *aliasName; 80 return AliasResult::FinalAlias; 81 } 82 83 AliasResult getAlias(Type type, raw_ostream &os) const final { 84 if (auto tupleType = type.dyn_cast<TupleType>()) { 85 if (tupleType.size() > 0 && 86 llvm::all_of(tupleType.getTypes(), [](Type elemType) { 87 return elemType.isa<SimpleAType>(); 88 })) { 89 os << "test_tuple"; 90 return AliasResult::FinalAlias; 91 } 92 } 93 if (auto intType = type.dyn_cast<TestIntegerType>()) { 94 if (intType.getSignedness() == 95 TestIntegerType::SignednessSemantics::Unsigned && 96 intType.getWidth() == 8) { 97 os << "test_ui8"; 98 return AliasResult::FinalAlias; 99 } 100 } 101 return AliasResult::NoAlias; 102 } 103 }; 104 105 struct TestDialectFoldInterface : public DialectFoldInterface { 106 using DialectFoldInterface::DialectFoldInterface; 107 108 /// Registered hook to check if the given region, which is attached to an 109 /// operation that is *not* isolated from above, should be used when 110 /// materializing constants. 111 bool shouldMaterializeInto(Region *region) const final { 112 // If this is a one region operation, then insert into it. 113 return isa<OneRegionOp>(region->getParentOp()); 114 } 115 }; 116 117 /// This class defines the interface for handling inlining with standard 118 /// operations. 119 struct TestInlinerInterface : public DialectInlinerInterface { 120 using DialectInlinerInterface::DialectInlinerInterface; 121 122 //===--------------------------------------------------------------------===// 123 // Analysis Hooks 124 //===--------------------------------------------------------------------===// 125 126 bool isLegalToInline(Operation *call, Operation *callable, 127 bool wouldBeCloned) const final { 128 // Don't allow inlining calls that are marked `noinline`. 129 return !call->hasAttr("noinline"); 130 } 131 bool isLegalToInline(Region *, Region *, bool, 132 BlockAndValueMapping &) const final { 133 // Inlining into test dialect regions is legal. 134 return true; 135 } 136 bool isLegalToInline(Operation *, Region *, bool, 137 BlockAndValueMapping &) const final { 138 return true; 139 } 140 141 bool shouldAnalyzeRecursively(Operation *op) const final { 142 // Analyze recursively if this is not a functional region operation, it 143 // froms a separate functional scope. 144 return !isa<FunctionalRegionOp>(op); 145 } 146 147 //===--------------------------------------------------------------------===// 148 // Transformation Hooks 149 //===--------------------------------------------------------------------===// 150 151 /// Handle the given inlined terminator by replacing it with a new operation 152 /// as necessary. 153 void handleTerminator(Operation *op, 154 ArrayRef<Value> valuesToRepl) const final { 155 // Only handle "test.return" here. 156 auto returnOp = dyn_cast<TestReturnOp>(op); 157 if (!returnOp) 158 return; 159 160 // Replace the values directly with the return operands. 161 assert(returnOp.getNumOperands() == valuesToRepl.size()); 162 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 163 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 164 } 165 166 /// Attempt to materialize a conversion for a type mismatch between a call 167 /// from this dialect, and a callable region. This method should generate an 168 /// operation that takes 'input' as the only operand, and produces a single 169 /// result of 'resultType'. If a conversion can not be generated, nullptr 170 /// should be returned. 171 Operation *materializeCallConversion(OpBuilder &builder, Value input, 172 Type resultType, 173 Location conversionLoc) const final { 174 // Only allow conversion for i16/i32 types. 175 if (!(resultType.isSignlessInteger(16) || 176 resultType.isSignlessInteger(32)) || 177 !(input.getType().isSignlessInteger(16) || 178 input.getType().isSignlessInteger(32))) 179 return nullptr; 180 return builder.create<TestCastOp>(conversionLoc, resultType, input); 181 } 182 183 void processInlinedCallBlocks( 184 Operation *call, 185 iterator_range<Region::iterator> inlinedBlocks) const final { 186 if (!isa<ConversionCallOp>(call)) 187 return; 188 189 // Set attributed on all ops in the inlined blocks. 190 for (Block &block : inlinedBlocks) { 191 block.walk([&](Operation *op) { 192 op->setAttr("inlined_conversion", UnitAttr::get(call->getContext())); 193 }); 194 } 195 } 196 }; 197 198 struct TestReductionPatternInterface : public DialectReductionPatternInterface { 199 public: 200 TestReductionPatternInterface(Dialect *dialect) 201 : DialectReductionPatternInterface(dialect) {} 202 203 void populateReductionPatterns(RewritePatternSet &patterns) const final { 204 populateTestReductionPatterns(patterns); 205 } 206 }; 207 208 } // namespace 209 210 //===----------------------------------------------------------------------===// 211 // Dynamic operations 212 //===----------------------------------------------------------------------===// 213 214 std::unique_ptr<DynamicOpDefinition> getGenericDynamicOp(TestDialect *dialect) { 215 return DynamicOpDefinition::get( 216 "generic_dynamic_op", dialect, [](Operation *op) { return success(); }, 217 [](Operation *op) { return success(); }); 218 } 219 220 std::unique_ptr<DynamicOpDefinition> 221 getOneOperandTwoResultsDynamicOp(TestDialect *dialect) { 222 return DynamicOpDefinition::get( 223 "one_operand_two_results", dialect, 224 [](Operation *op) { 225 if (op->getNumOperands() != 1) { 226 op->emitOpError() 227 << "expected 1 operand, but had " << op->getNumOperands(); 228 return failure(); 229 } 230 if (op->getNumResults() != 2) { 231 op->emitOpError() 232 << "expected 2 results, but had " << op->getNumResults(); 233 return failure(); 234 } 235 return success(); 236 }, 237 [](Operation *op) { return success(); }); 238 } 239 240 std::unique_ptr<DynamicOpDefinition> 241 getCustomParserPrinterDynamicOp(TestDialect *dialect) { 242 auto verifier = [](Operation *op) { 243 if (op->getNumOperands() == 0 && op->getNumResults() == 0) 244 return success(); 245 op->emitError() << "operation should have no operands and no results"; 246 return failure(); 247 }; 248 auto regionVerifier = [](Operation *op) { return success(); }; 249 250 auto parser = [](OpAsmParser &parser, OperationState &state) { 251 return parser.parseKeyword("custom_keyword"); 252 }; 253 254 auto printer = [](Operation *op, OpAsmPrinter &printer, llvm::StringRef) { 255 printer << op->getName() << " custom_keyword"; 256 }; 257 258 return DynamicOpDefinition::get("custom_parser_printer_dynamic_op", dialect, 259 verifier, regionVerifier, parser, printer); 260 } 261 262 //===----------------------------------------------------------------------===// 263 // TestDialect 264 //===----------------------------------------------------------------------===// 265 266 static void testSideEffectOpGetEffect( 267 Operation *op, 268 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects); 269 270 // This is the implementation of a dialect fallback for `TestEffectOpInterface`. 271 struct TestOpEffectInterfaceFallback 272 : public TestEffectOpInterface::FallbackModel< 273 TestOpEffectInterfaceFallback> { 274 static bool classof(Operation *op) { 275 bool isSupportedOp = 276 op->getName().getStringRef() == "test.unregistered_side_effect_op"; 277 assert(isSupportedOp && "Unexpected dispatch"); 278 return isSupportedOp; 279 } 280 281 void 282 getEffects(Operation *op, 283 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 284 &effects) const { 285 testSideEffectOpGetEffect(op, effects); 286 } 287 }; 288 289 void TestDialect::initialize() { 290 registerAttributes(); 291 registerTypes(); 292 addOperations< 293 #define GET_OP_LIST 294 #include "TestOps.cpp.inc" 295 >(); 296 registerDynamicOp(getGenericDynamicOp(this)); 297 registerDynamicOp(getOneOperandTwoResultsDynamicOp(this)); 298 registerDynamicOp(getCustomParserPrinterDynamicOp(this)); 299 300 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 301 TestInlinerInterface, TestReductionPatternInterface>(); 302 allowUnknownOperations(); 303 304 // Instantiate our fallback op interface that we'll use on specific 305 // unregistered op. 306 fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback; 307 } 308 TestDialect::~TestDialect() { 309 delete static_cast<TestOpEffectInterfaceFallback *>( 310 fallbackEffectOpInterfaces); 311 } 312 313 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, 314 Type type, Location loc) { 315 return builder.create<TestOpConstant>(loc, type, value); 316 } 317 318 ::mlir::LogicalResult FormatInferType2Op::inferReturnTypes( 319 ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, 320 ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, 321 ::mlir::RegionRange regions, 322 ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { 323 inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)}); 324 return ::mlir::success(); 325 } 326 327 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, 328 OperationName opName) { 329 if (opName.getIdentifier() == "test.unregistered_side_effect_op" && 330 typeID == TypeID::get<TestEffectOpInterface>()) 331 return fallbackEffectOpInterfaces; 332 return nullptr; 333 } 334 335 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 336 NamedAttribute namedAttr) { 337 if (namedAttr.getName() == "test.invalid_attr") 338 return op->emitError() << "invalid to use 'test.invalid_attr'"; 339 return success(); 340 } 341 342 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 343 unsigned regionIndex, 344 unsigned argIndex, 345 NamedAttribute namedAttr) { 346 if (namedAttr.getName() == "test.invalid_attr") 347 return op->emitError() << "invalid to use 'test.invalid_attr'"; 348 return success(); 349 } 350 351 LogicalResult 352 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 353 unsigned resultIndex, 354 NamedAttribute namedAttr) { 355 if (namedAttr.getName() == "test.invalid_attr") 356 return op->emitError() << "invalid to use 'test.invalid_attr'"; 357 return success(); 358 } 359 360 Optional<Dialect::ParseOpHook> 361 TestDialect::getParseOperationHook(StringRef opName) const { 362 if (opName == "test.dialect_custom_printer") { 363 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 364 return parser.parseKeyword("custom_format"); 365 }}; 366 } 367 if (opName == "test.dialect_custom_format_fallback") { 368 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 369 return parser.parseKeyword("custom_format_fallback"); 370 }}; 371 } 372 return None; 373 } 374 375 llvm::unique_function<void(Operation *, OpAsmPrinter &)> 376 TestDialect::getOperationPrinter(Operation *op) const { 377 StringRef opName = op->getName().getStringRef(); 378 if (opName == "test.dialect_custom_printer") { 379 return [](Operation *op, OpAsmPrinter &printer) { 380 printer.getStream() << " custom_format"; 381 }; 382 } 383 if (opName == "test.dialect_custom_format_fallback") { 384 return [](Operation *op, OpAsmPrinter &printer) { 385 printer.getStream() << " custom_format_fallback"; 386 }; 387 } 388 return {}; 389 } 390 391 //===----------------------------------------------------------------------===// 392 // TestBranchOp 393 //===----------------------------------------------------------------------===// 394 395 Optional<MutableOperandRange> 396 TestBranchOp::getMutableSuccessorOperands(unsigned index) { 397 assert(index == 0 && "invalid successor index"); 398 return getTargetOperandsMutable(); 399 } 400 401 //===----------------------------------------------------------------------===// 402 // TestDialectCanonicalizerOp 403 //===----------------------------------------------------------------------===// 404 405 static LogicalResult 406 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op, 407 PatternRewriter &rewriter) { 408 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 409 op, rewriter.getI32IntegerAttr(42)); 410 return success(); 411 } 412 413 void TestDialect::getCanonicalizationPatterns( 414 RewritePatternSet &results) const { 415 results.add(&dialectCanonicalizationPattern); 416 } 417 418 //===----------------------------------------------------------------------===// 419 // TestFoldToCallOp 420 //===----------------------------------------------------------------------===// 421 422 namespace { 423 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 424 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 425 426 LogicalResult matchAndRewrite(FoldToCallOp op, 427 PatternRewriter &rewriter) const override { 428 rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(), 429 op.getCalleeAttr(), ValueRange()); 430 return success(); 431 } 432 }; 433 } // namespace 434 435 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results, 436 MLIRContext *context) { 437 results.add<FoldToCallOpPattern>(context); 438 } 439 440 //===----------------------------------------------------------------------===// 441 // Test Format* operations 442 //===----------------------------------------------------------------------===// 443 444 //===----------------------------------------------------------------------===// 445 // Parsing 446 447 static ParseResult 448 parseCustomOptionalOperand(OpAsmParser &parser, 449 Optional<OpAsmParser::OperandType> &optOperand) { 450 if (succeeded(parser.parseOptionalLParen())) { 451 optOperand.emplace(); 452 if (parser.parseOperand(*optOperand) || parser.parseRParen()) 453 return failure(); 454 } 455 return success(); 456 } 457 458 static ParseResult parseCustomDirectiveOperands( 459 OpAsmParser &parser, OpAsmParser::OperandType &operand, 460 Optional<OpAsmParser::OperandType> &optOperand, 461 SmallVectorImpl<OpAsmParser::OperandType> &varOperands) { 462 if (parser.parseOperand(operand)) 463 return failure(); 464 if (succeeded(parser.parseOptionalComma())) { 465 optOperand.emplace(); 466 if (parser.parseOperand(*optOperand)) 467 return failure(); 468 } 469 if (parser.parseArrow() || parser.parseLParen() || 470 parser.parseOperandList(varOperands) || parser.parseRParen()) 471 return failure(); 472 return success(); 473 } 474 static ParseResult 475 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 476 Type &optOperandType, 477 SmallVectorImpl<Type> &varOperandTypes) { 478 if (parser.parseColon()) 479 return failure(); 480 481 if (parser.parseType(operandType)) 482 return failure(); 483 if (succeeded(parser.parseOptionalComma())) { 484 if (parser.parseType(optOperandType)) 485 return failure(); 486 } 487 if (parser.parseArrow() || parser.parseLParen() || 488 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 489 return failure(); 490 return success(); 491 } 492 static ParseResult 493 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 494 Type optOperandType, 495 const SmallVectorImpl<Type> &varOperandTypes) { 496 if (parser.parseKeyword("type_refs_capture")) 497 return failure(); 498 499 Type operandType2, optOperandType2; 500 SmallVector<Type, 1> varOperandTypes2; 501 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 502 varOperandTypes2)) 503 return failure(); 504 505 if (operandType != operandType2 || optOperandType != optOperandType2 || 506 varOperandTypes != varOperandTypes2) 507 return failure(); 508 509 return success(); 510 } 511 static ParseResult parseCustomDirectiveOperandsAndTypes( 512 OpAsmParser &parser, OpAsmParser::OperandType &operand, 513 Optional<OpAsmParser::OperandType> &optOperand, 514 SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType, 515 Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) { 516 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 517 parseCustomDirectiveResults(parser, operandType, optOperandType, 518 varOperandTypes)) 519 return failure(); 520 return success(); 521 } 522 static ParseResult parseCustomDirectiveRegions( 523 OpAsmParser &parser, Region ®ion, 524 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 525 if (parser.parseRegion(region)) 526 return failure(); 527 if (failed(parser.parseOptionalComma())) 528 return success(); 529 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 530 if (parser.parseRegion(*varRegion)) 531 return failure(); 532 varRegions.emplace_back(std::move(varRegion)); 533 return success(); 534 } 535 static ParseResult 536 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 537 SmallVectorImpl<Block *> &varSuccessors) { 538 if (parser.parseSuccessor(successor)) 539 return failure(); 540 if (failed(parser.parseOptionalComma())) 541 return success(); 542 Block *varSuccessor; 543 if (parser.parseSuccessor(varSuccessor)) 544 return failure(); 545 varSuccessors.append(2, varSuccessor); 546 return success(); 547 } 548 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 549 IntegerAttr &attr, 550 IntegerAttr &optAttr) { 551 if (parser.parseAttribute(attr)) 552 return failure(); 553 if (succeeded(parser.parseOptionalComma())) { 554 if (parser.parseAttribute(optAttr)) 555 return failure(); 556 } 557 return success(); 558 } 559 560 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, 561 NamedAttrList &attrs) { 562 return parser.parseOptionalAttrDict(attrs); 563 } 564 static ParseResult parseCustomDirectiveOptionalOperandRef( 565 OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) { 566 int64_t operandCount = 0; 567 if (parser.parseInteger(operandCount)) 568 return failure(); 569 bool expectedOptionalOperand = operandCount == 0; 570 return success(expectedOptionalOperand != optOperand.hasValue()); 571 } 572 573 //===----------------------------------------------------------------------===// 574 // Printing 575 576 static void printCustomOptionalOperand(OpAsmPrinter &printer, Operation *, 577 Value optOperand) { 578 if (optOperand) 579 printer << "(" << optOperand << ") "; 580 } 581 582 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 583 Value operand, Value optOperand, 584 OperandRange varOperands) { 585 printer << operand; 586 if (optOperand) 587 printer << ", " << optOperand; 588 printer << " -> (" << varOperands << ")"; 589 } 590 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 591 Type operandType, Type optOperandType, 592 TypeRange varOperandTypes) { 593 printer << " : " << operandType; 594 if (optOperandType) 595 printer << ", " << optOperandType; 596 printer << " -> (" << varOperandTypes << ")"; 597 } 598 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 599 Operation *op, Type operandType, 600 Type optOperandType, 601 TypeRange varOperandTypes) { 602 printer << " type_refs_capture "; 603 printCustomDirectiveResults(printer, op, operandType, optOperandType, 604 varOperandTypes); 605 } 606 static void printCustomDirectiveOperandsAndTypes( 607 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 608 OperandRange varOperands, Type operandType, Type optOperandType, 609 TypeRange varOperandTypes) { 610 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 611 printCustomDirectiveResults(printer, op, operandType, optOperandType, 612 varOperandTypes); 613 } 614 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 615 Region ®ion, 616 MutableArrayRef<Region> varRegions) { 617 printer.printRegion(region); 618 if (!varRegions.empty()) { 619 printer << ", "; 620 for (Region ®ion : varRegions) 621 printer.printRegion(region); 622 } 623 } 624 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 625 Block *successor, 626 SuccessorRange varSuccessors) { 627 printer << successor; 628 if (!varSuccessors.empty()) 629 printer << ", " << varSuccessors.front(); 630 } 631 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 632 Attribute attribute, 633 Attribute optAttribute) { 634 printer << attribute; 635 if (optAttribute) 636 printer << ", " << optAttribute; 637 } 638 639 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 640 DictionaryAttr attrs) { 641 printer.printOptionalAttrDict(attrs.getValue()); 642 } 643 644 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, 645 Operation *op, 646 Value optOperand) { 647 printer << (optOperand ? "1" : "0"); 648 } 649 650 //===----------------------------------------------------------------------===// 651 // Test IsolatedRegionOp - parse passthrough region arguments. 652 //===----------------------------------------------------------------------===// 653 654 ParseResult IsolatedRegionOp::parse(OpAsmParser &parser, 655 OperationState &result) { 656 OpAsmParser::OperandType argInfo; 657 Type argType = parser.getBuilder().getIndexType(); 658 659 // Parse the input operand. 660 if (parser.parseOperand(argInfo) || 661 parser.resolveOperand(argInfo, argType, result.operands)) 662 return failure(); 663 664 // Parse the body region, and reuse the operand info as the argument info. 665 Region *body = result.addRegion(); 666 return parser.parseRegion(*body, argInfo, argType, /*argLocations=*/{}, 667 /*enableNameShadowing=*/true); 668 } 669 670 void IsolatedRegionOp::print(OpAsmPrinter &p) { 671 p << "test.isolated_region "; 672 p.printOperand(getOperand()); 673 p.shadowRegionArgs(getRegion(), getOperand()); 674 p << ' '; 675 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); 676 } 677 678 //===----------------------------------------------------------------------===// 679 // Test SSACFGRegionOp 680 //===----------------------------------------------------------------------===// 681 682 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 683 return RegionKind::SSACFG; 684 } 685 686 //===----------------------------------------------------------------------===// 687 // Test GraphRegionOp 688 //===----------------------------------------------------------------------===// 689 690 ParseResult GraphRegionOp::parse(OpAsmParser &parser, OperationState &result) { 691 // Parse the body region, and reuse the operand info as the argument info. 692 Region *body = result.addRegion(); 693 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 694 } 695 696 void GraphRegionOp::print(OpAsmPrinter &p) { 697 p << "test.graph_region "; 698 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); 699 } 700 701 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 702 return RegionKind::Graph; 703 } 704 705 //===----------------------------------------------------------------------===// 706 // Test AffineScopeOp 707 //===----------------------------------------------------------------------===// 708 709 ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) { 710 // Parse the body region, and reuse the operand info as the argument info. 711 Region *body = result.addRegion(); 712 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 713 } 714 715 void AffineScopeOp::print(OpAsmPrinter &p) { 716 p << "test.affine_scope "; 717 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); 718 } 719 720 //===----------------------------------------------------------------------===// 721 // Test parser. 722 //===----------------------------------------------------------------------===// 723 724 ParseResult ParseIntegerLiteralOp::parse(OpAsmParser &parser, 725 OperationState &result) { 726 if (parser.parseOptionalColon()) 727 return success(); 728 uint64_t numResults; 729 if (parser.parseInteger(numResults)) 730 return failure(); 731 732 IndexType type = parser.getBuilder().getIndexType(); 733 for (unsigned i = 0; i < numResults; ++i) 734 result.addTypes(type); 735 return success(); 736 } 737 738 void ParseIntegerLiteralOp::print(OpAsmPrinter &p) { 739 if (unsigned numResults = getNumResults()) 740 p << " : " << numResults; 741 } 742 743 ParseResult ParseWrappedKeywordOp::parse(OpAsmParser &parser, 744 OperationState &result) { 745 StringRef keyword; 746 if (parser.parseKeyword(&keyword)) 747 return failure(); 748 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 749 return success(); 750 } 751 752 void ParseWrappedKeywordOp::print(OpAsmPrinter &p) { p << " " << getKeyword(); } 753 754 //===----------------------------------------------------------------------===// 755 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 756 757 ParseResult WrappingRegionOp::parse(OpAsmParser &parser, 758 OperationState &result) { 759 if (parser.parseKeyword("wraps")) 760 return failure(); 761 762 // Parse the wrapped op in a region 763 Region &body = *result.addRegion(); 764 body.push_back(new Block); 765 Block &block = body.back(); 766 Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin()); 767 if (!wrappedOp) 768 return failure(); 769 770 // Create a return terminator in the inner region, pass as operand to the 771 // terminator the returned values from the wrapped operation. 772 SmallVector<Value, 8> returnOperands(wrappedOp->getResults()); 773 OpBuilder builder(parser.getContext()); 774 builder.setInsertionPointToEnd(&block); 775 builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands); 776 777 // Get the results type for the wrapping op from the terminator operands. 778 Operation &returnOp = body.back().back(); 779 result.types.append(returnOp.operand_type_begin(), 780 returnOp.operand_type_end()); 781 782 // Use the location of the wrapped op for the "test.wrapping_region" op. 783 result.location = wrappedOp->getLoc(); 784 785 return success(); 786 } 787 788 void WrappingRegionOp::print(OpAsmPrinter &p) { 789 p << " wraps "; 790 p.printGenericOp(&getRegion().front().front()); 791 } 792 793 //===----------------------------------------------------------------------===// 794 // Test PrettyPrintedRegionOp - exercising the following parser APIs 795 // parseGenericOperationAfterOpName 796 // parseCustomOperationName 797 //===----------------------------------------------------------------------===// 798 799 ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser, 800 OperationState &result) { 801 802 SMLoc loc = parser.getCurrentLocation(); 803 Location currLocation = parser.getEncodedSourceLoc(loc); 804 805 // Parse the operands. 806 SmallVector<OpAsmParser::OperandType, 2> operands; 807 if (parser.parseOperandList(operands)) 808 return failure(); 809 810 // Check if we are parsing the pretty-printed version 811 // test.pretty_printed_region start <inner-op> end : <functional-type> 812 // Else fallback to parsing the "non pretty-printed" version. 813 if (!succeeded(parser.parseOptionalKeyword("start"))) 814 return parser.parseGenericOperationAfterOpName( 815 result, llvm::makeArrayRef(operands)); 816 817 FailureOr<OperationName> parseOpNameInfo = parser.parseCustomOperationName(); 818 if (failed(parseOpNameInfo)) 819 return failure(); 820 821 StringRef innerOpName = parseOpNameInfo->getStringRef(); 822 823 FunctionType opFntype; 824 Optional<Location> explicitLoc; 825 if (parser.parseKeyword("end") || parser.parseColon() || 826 parser.parseType(opFntype) || 827 parser.parseOptionalLocationSpecifier(explicitLoc)) 828 return failure(); 829 830 // If location of the op is explicitly provided, then use it; Else use 831 // the parser's current location. 832 Location opLoc = explicitLoc.getValueOr(currLocation); 833 834 // Derive the SSA-values for op's operands. 835 if (parser.resolveOperands(operands, opFntype.getInputs(), loc, 836 result.operands)) 837 return failure(); 838 839 // Add a region for op. 840 Region ®ion = *result.addRegion(); 841 842 // Create a basic-block inside op's region. 843 Block &block = region.emplaceBlock(); 844 845 // Create and insert an "inner-op" operation in the block. 846 // Just for testing purposes, we can assume that inner op is a binary op with 847 // result and operand types all same as the test-op's first operand. 848 Type innerOpType = opFntype.getInput(0); 849 Value lhs = block.addArgument(innerOpType, opLoc); 850 Value rhs = block.addArgument(innerOpType, opLoc); 851 852 OpBuilder builder(parser.getBuilder().getContext()); 853 builder.setInsertionPointToStart(&block); 854 855 OperationState innerOpState(opLoc, innerOpName); 856 innerOpState.operands.push_back(lhs); 857 innerOpState.operands.push_back(rhs); 858 innerOpState.addTypes(innerOpType); 859 860 Operation *innerOp = builder.createOperation(innerOpState); 861 862 // Insert a return statement in the block returning the inner-op's result. 863 builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults()); 864 865 // Populate the op operation-state with result-type and location. 866 result.addTypes(opFntype.getResults()); 867 result.location = innerOp->getLoc(); 868 869 return success(); 870 } 871 872 void PrettyPrintedRegionOp::print(OpAsmPrinter &p) { 873 p << ' '; 874 p.printOperands(getOperands()); 875 876 Operation &innerOp = getRegion().front().front(); 877 // Assuming that region has a single non-terminator inner-op, if the inner-op 878 // meets some criteria (which in this case is a simple one based on the name 879 // of inner-op), then we can print the entire region in a succinct way. 880 // Here we assume that the prototype of "special.op" can be trivially derived 881 // while parsing it back. 882 if (innerOp.getName().getStringRef().equals("special.op")) { 883 p << " start special.op end"; 884 } else { 885 p << " ("; 886 p.printRegion(getRegion()); 887 p << ")"; 888 } 889 890 p << " : "; 891 p.printFunctionalType(*this); 892 } 893 894 //===----------------------------------------------------------------------===// 895 // Test PolyForOp - parse list of region arguments. 896 //===----------------------------------------------------------------------===// 897 898 ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) { 899 SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 900 // Parse list of region arguments without a delimiter. 901 if (parser.parseRegionArgumentList(ivsInfo)) 902 return failure(); 903 904 // Parse the body region. 905 Region *body = result.addRegion(); 906 auto &builder = parser.getBuilder(); 907 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 908 return parser.parseRegion(*body, ivsInfo, argTypes); 909 } 910 911 void PolyForOp::print(OpAsmPrinter &p) { p.printGenericOp(*this); } 912 913 void PolyForOp::getAsmBlockArgumentNames(Region ®ion, 914 OpAsmSetValueNameFn setNameFn) { 915 auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names"); 916 if (!arrayAttr) 917 return; 918 auto args = getRegion().front().getArguments(); 919 auto e = std::min(arrayAttr.size(), args.size()); 920 for (unsigned i = 0; i < e; ++i) { 921 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 922 setNameFn(args[i], strAttr.getValue()); 923 } 924 } 925 926 //===----------------------------------------------------------------------===// 927 // Test removing op with inner ops. 928 //===----------------------------------------------------------------------===// 929 930 namespace { 931 struct TestRemoveOpWithInnerOps 932 : public OpRewritePattern<TestOpWithRegionPattern> { 933 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 934 935 void initialize() { setDebugName("TestRemoveOpWithInnerOps"); } 936 937 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 938 PatternRewriter &rewriter) const override { 939 rewriter.eraseOp(op); 940 return success(); 941 } 942 }; 943 } // namespace 944 945 void TestOpWithRegionPattern::getCanonicalizationPatterns( 946 RewritePatternSet &results, MLIRContext *context) { 947 results.add<TestRemoveOpWithInnerOps>(context); 948 } 949 950 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 951 return getOperand(); 952 } 953 954 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) { 955 return getValue(); 956 } 957 958 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 959 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 960 for (Value input : this->getOperands()) { 961 results.push_back(input); 962 } 963 return success(); 964 } 965 966 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 967 assert(operands.size() == 1); 968 if (operands.front()) { 969 (*this)->setAttr("attr", operands.front()); 970 return getResult(); 971 } 972 return {}; 973 } 974 975 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) { 976 return getOperand(); 977 } 978 979 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 980 MLIRContext *, Optional<Location> location, ValueRange operands, 981 DictionaryAttr attributes, RegionRange regions, 982 SmallVectorImpl<Type> &inferredReturnTypes) { 983 if (operands[0].getType() != operands[1].getType()) { 984 return emitOptionalError(location, "operand type mismatch ", 985 operands[0].getType(), " vs ", 986 operands[1].getType()); 987 } 988 inferredReturnTypes.assign({operands[0].getType()}); 989 return success(); 990 } 991 992 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 993 MLIRContext *context, Optional<Location> location, ValueShapeRange operands, 994 DictionaryAttr attributes, RegionRange regions, 995 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 996 // Create return type consisting of the last element of the first operand. 997 auto operandType = operands.front().getType(); 998 auto sval = operandType.dyn_cast<ShapedType>(); 999 if (!sval) { 1000 return emitOptionalError(location, "only shaped type operands allowed"); 1001 } 1002 int64_t dim = 1003 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 1004 auto type = IntegerType::get(context, 17); 1005 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 1006 return success(); 1007 } 1008 1009 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 1010 OpBuilder &builder, ValueRange operands, 1011 llvm::SmallVectorImpl<Value> &shapes) { 1012 shapes = SmallVector<Value, 1>{ 1013 builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)}; 1014 return success(); 1015 } 1016 1017 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( 1018 OpBuilder &builder, ValueRange operands, 1019 llvm::SmallVectorImpl<Value> &shapes) { 1020 Location loc = getLoc(); 1021 shapes.reserve(operands.size()); 1022 for (Value operand : llvm::reverse(operands)) { 1023 auto rank = operand.getType().cast<RankedTensorType>().getRank(); 1024 auto currShape = llvm::to_vector<4>( 1025 llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value { 1026 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 1027 })); 1028 shapes.push_back(builder.create<tensor::FromElementsOp>( 1029 getLoc(), RankedTensorType::get({rank}, builder.getIndexType()), 1030 currShape)); 1031 } 1032 return success(); 1033 } 1034 1035 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( 1036 OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { 1037 Location loc = getLoc(); 1038 shapes.reserve(getNumOperands()); 1039 for (Value operand : llvm::reverse(getOperands())) { 1040 auto currShape = llvm::to_vector<4>(llvm::map_range( 1041 llvm::seq<int64_t>( 1042 0, operand.getType().cast<RankedTensorType>().getRank()), 1043 [&](int64_t dim) -> Value { 1044 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 1045 })); 1046 shapes.emplace_back(std::move(currShape)); 1047 } 1048 return success(); 1049 } 1050 1051 //===----------------------------------------------------------------------===// 1052 // Test SideEffect interfaces 1053 //===----------------------------------------------------------------------===// 1054 1055 namespace { 1056 /// A test resource for side effects. 1057 struct TestResource : public SideEffects::Resource::Base<TestResource> { 1058 StringRef getName() final { return "<Test>"; } 1059 }; 1060 } // namespace 1061 1062 static void testSideEffectOpGetEffect( 1063 Operation *op, 1064 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 1065 &effects) { 1066 auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter"); 1067 if (!effectsAttr) 1068 return; 1069 1070 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); 1071 } 1072 1073 void SideEffectOp::getEffects( 1074 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1075 // Check for an effects attribute on the op instance. 1076 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects"); 1077 if (!effectsAttr) 1078 return; 1079 1080 // If there is one, it is an array of dictionary attributes that hold 1081 // information on the effects of this operation. 1082 for (Attribute element : effectsAttr) { 1083 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 1084 1085 // Get the specific memory effect. 1086 MemoryEffects::Effect *effect = 1087 StringSwitch<MemoryEffects::Effect *>( 1088 effectElement.get("effect").cast<StringAttr>().getValue()) 1089 .Case("allocate", MemoryEffects::Allocate::get()) 1090 .Case("free", MemoryEffects::Free::get()) 1091 .Case("read", MemoryEffects::Read::get()) 1092 .Case("write", MemoryEffects::Write::get()); 1093 1094 // Check for a non-default resource to use. 1095 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 1096 if (effectElement.get("test_resource")) 1097 resource = TestResource::get(); 1098 1099 // Check for a result to affect. 1100 if (effectElement.get("on_result")) 1101 effects.emplace_back(effect, getResult(), resource); 1102 else if (Attribute ref = effectElement.get("on_reference")) 1103 effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource); 1104 else 1105 effects.emplace_back(effect, resource); 1106 } 1107 } 1108 1109 void SideEffectOp::getEffects( 1110 SmallVectorImpl<TestEffects::EffectInstance> &effects) { 1111 testSideEffectOpGetEffect(getOperation(), effects); 1112 } 1113 1114 //===----------------------------------------------------------------------===// 1115 // StringAttrPrettyNameOp 1116 //===----------------------------------------------------------------------===// 1117 1118 // This op has fancy handling of its SSA result name. 1119 ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser, 1120 OperationState &result) { 1121 // Add the result types. 1122 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 1123 result.addTypes(parser.getBuilder().getIntegerType(32)); 1124 1125 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 1126 return failure(); 1127 1128 // If the attribute dictionary contains no 'names' attribute, infer it from 1129 // the SSA name (if specified). 1130 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 1131 return attr.getName() == "names"; 1132 }); 1133 1134 // If there was no name specified, check to see if there was a useful name 1135 // specified in the asm file. 1136 if (hadNames || parser.getNumResults() == 0) 1137 return success(); 1138 1139 SmallVector<StringRef, 4> names; 1140 auto *context = result.getContext(); 1141 1142 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 1143 auto resultName = parser.getResultName(i); 1144 StringRef nameStr; 1145 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 1146 nameStr = resultName.first; 1147 1148 names.push_back(nameStr); 1149 } 1150 1151 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 1152 result.attributes.push_back({StringAttr::get(context, "names"), namesAttr}); 1153 return success(); 1154 } 1155 1156 void StringAttrPrettyNameOp::print(OpAsmPrinter &p) { 1157 // Note that we only need to print the "name" attribute if the asmprinter 1158 // result name disagrees with it. This can happen in strange cases, e.g. 1159 // when there are conflicts. 1160 bool namesDisagree = getNames().size() != getNumResults(); 1161 1162 SmallString<32> resultNameStr; 1163 for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) { 1164 resultNameStr.clear(); 1165 llvm::raw_svector_ostream tmpStream(resultNameStr); 1166 p.printOperand(getResult(i), tmpStream); 1167 1168 auto expectedName = getNames()[i].dyn_cast<StringAttr>(); 1169 if (!expectedName || 1170 tmpStream.str().drop_front() != expectedName.getValue()) { 1171 namesDisagree = true; 1172 } 1173 } 1174 1175 if (namesDisagree) 1176 p.printOptionalAttrDictWithKeyword((*this)->getAttrs()); 1177 else 1178 p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"}); 1179 } 1180 1181 // We set the SSA name in the asm syntax to the contents of the name 1182 // attribute. 1183 void StringAttrPrettyNameOp::getAsmResultNames( 1184 function_ref<void(Value, StringRef)> setNameFn) { 1185 1186 auto value = getNames(); 1187 for (size_t i = 0, e = value.size(); i != e; ++i) 1188 if (auto str = value[i].dyn_cast<StringAttr>()) 1189 if (!str.getValue().empty()) 1190 setNameFn(getResult(i), str.getValue()); 1191 } 1192 1193 //===----------------------------------------------------------------------===// 1194 // ResultTypeWithTraitOp 1195 //===----------------------------------------------------------------------===// 1196 1197 LogicalResult ResultTypeWithTraitOp::verify() { 1198 if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>()) 1199 return success(); 1200 return emitError("result type should have trait 'TestTypeTrait'"); 1201 } 1202 1203 //===----------------------------------------------------------------------===// 1204 // AttrWithTraitOp 1205 //===----------------------------------------------------------------------===// 1206 1207 LogicalResult AttrWithTraitOp::verify() { 1208 if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>()) 1209 return success(); 1210 return emitError("'attr' attribute should have trait 'TestAttrTrait'"); 1211 } 1212 1213 //===----------------------------------------------------------------------===// 1214 // RegionIfOp 1215 //===----------------------------------------------------------------------===// 1216 1217 void RegionIfOp::print(OpAsmPrinter &p) { 1218 p << " "; 1219 p.printOperands(getOperands()); 1220 p << ": " << getOperandTypes(); 1221 p.printArrowTypeList(getResultTypes()); 1222 p << " then "; 1223 p.printRegion(getThenRegion(), 1224 /*printEntryBlockArgs=*/true, 1225 /*printBlockTerminators=*/true); 1226 p << " else "; 1227 p.printRegion(getElseRegion(), 1228 /*printEntryBlockArgs=*/true, 1229 /*printBlockTerminators=*/true); 1230 p << " join "; 1231 p.printRegion(getJoinRegion(), 1232 /*printEntryBlockArgs=*/true, 1233 /*printBlockTerminators=*/true); 1234 } 1235 1236 ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) { 1237 SmallVector<OpAsmParser::OperandType, 2> operandInfos; 1238 SmallVector<Type, 2> operandTypes; 1239 1240 result.regions.reserve(3); 1241 Region *thenRegion = result.addRegion(); 1242 Region *elseRegion = result.addRegion(); 1243 Region *joinRegion = result.addRegion(); 1244 1245 // Parse operand, type and arrow type lists. 1246 if (parser.parseOperandList(operandInfos) || 1247 parser.parseColonTypeList(operandTypes) || 1248 parser.parseArrowTypeList(result.types)) 1249 return failure(); 1250 1251 // Parse all attached regions. 1252 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 1253 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 1254 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 1255 return failure(); 1256 1257 return parser.resolveOperands(operandInfos, operandTypes, 1258 parser.getCurrentLocation(), result.operands); 1259 } 1260 1261 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 1262 assert(index < 2 && "invalid region index"); 1263 return getOperands(); 1264 } 1265 1266 void RegionIfOp::getSuccessorRegions( 1267 Optional<unsigned> index, ArrayRef<Attribute> operands, 1268 SmallVectorImpl<RegionSuccessor> ®ions) { 1269 // We always branch to the join region. 1270 if (index.hasValue()) { 1271 if (index.getValue() < 2) 1272 regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs())); 1273 else 1274 regions.push_back(RegionSuccessor(getResults())); 1275 return; 1276 } 1277 1278 // The then and else regions are the entry regions of this op. 1279 regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs())); 1280 regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs())); 1281 } 1282 1283 void RegionIfOp::getRegionInvocationBounds( 1284 ArrayRef<Attribute> operands, 1285 SmallVectorImpl<InvocationBounds> &invocationBounds) { 1286 // Each region is invoked at most once. 1287 invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1}); 1288 } 1289 1290 //===----------------------------------------------------------------------===// 1291 // AnyCondOp 1292 //===----------------------------------------------------------------------===// 1293 1294 void AnyCondOp::getSuccessorRegions(Optional<unsigned> index, 1295 ArrayRef<Attribute> operands, 1296 SmallVectorImpl<RegionSuccessor> ®ions) { 1297 // The parent op branches into the only region, and the region branches back 1298 // to the parent op. 1299 if (index) 1300 regions.emplace_back(&getRegion()); 1301 else 1302 regions.emplace_back(getResults()); 1303 } 1304 1305 void AnyCondOp::getRegionInvocationBounds( 1306 ArrayRef<Attribute> operands, 1307 SmallVectorImpl<InvocationBounds> &invocationBounds) { 1308 invocationBounds.emplace_back(1, 1); 1309 } 1310 1311 //===----------------------------------------------------------------------===// 1312 // SingleNoTerminatorCustomAsmOp 1313 //===----------------------------------------------------------------------===// 1314 1315 ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser, 1316 OperationState &state) { 1317 Region *body = state.addRegion(); 1318 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) 1319 return failure(); 1320 return success(); 1321 } 1322 1323 void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) { 1324 printer.printRegion( 1325 getRegion(), /*printEntryBlockArgs=*/false, 1326 // This op has a single block without terminators. But explicitly mark 1327 // as not printing block terminators for testing. 1328 /*printBlockTerminators=*/false); 1329 } 1330 1331 #include "TestOpEnums.cpp.inc" 1332 #include "TestOpInterfaces.cpp.inc" 1333 #include "TestOpStructs.cpp.inc" 1334 #include "TestTypeInterfaces.cpp.inc" 1335 1336 #define GET_OP_CLASSES 1337 #include "TestOps.cpp.inc" 1338