1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestAttributes.h" 11 #include "TestInterfaces.h" 12 #include "TestTypes.h" 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/DLTI/DLTI.h" 15 #include "mlir/Dialect/Func/IR/FuncOps.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/DialectImplementation.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/IR/TypeUtilities.h" 21 #include "mlir/Reducer/ReductionPatternInterface.h" 22 #include "mlir/Transforms/FoldUtils.h" 23 #include "mlir/Transforms/InliningUtils.h" 24 #include "llvm/ADT/StringSwitch.h" 25 26 // Include this before the using namespace lines below to 27 // test that we don't have namespace dependencies. 28 #include "TestOpsDialect.cpp.inc" 29 30 using namespace mlir; 31 using namespace test; 32 33 void test::registerTestDialect(DialectRegistry ®istry) { 34 registry.insert<TestDialect>(); 35 } 36 37 //===----------------------------------------------------------------------===// 38 // TestDialect Interfaces 39 //===----------------------------------------------------------------------===// 40 41 namespace { 42 43 /// Testing the correctness of some traits. 44 static_assert( 45 llvm::is_detected<OpTrait::has_implicit_terminator_t, 46 SingleBlockImplicitTerminatorOp>::value, 47 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp"); 48 static_assert(OpTrait::hasSingleBlockImplicitTerminator< 49 SingleBlockImplicitTerminatorOp>::value, 50 "hasSingleBlockImplicitTerminator does not match " 51 "SingleBlockImplicitTerminatorOp"); 52 53 // Test support for interacting with the AsmPrinter. 54 struct TestOpAsmInterface : public OpAsmDialectInterface { 55 using OpAsmDialectInterface::OpAsmDialectInterface; 56 57 AliasResult getAlias(Attribute attr, raw_ostream &os) const final { 58 StringAttr strAttr = attr.dyn_cast<StringAttr>(); 59 if (!strAttr) 60 return AliasResult::NoAlias; 61 62 // Check the contents of the string attribute to see what the test alias 63 // should be named. 64 Optional<StringRef> aliasName = 65 StringSwitch<Optional<StringRef>>(strAttr.getValue()) 66 .Case("alias_test:dot_in_name", StringRef("test.alias")) 67 .Case("alias_test:trailing_digit", StringRef("test_alias0")) 68 .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) 69 .Case("alias_test:sanitize_conflict_a", 70 StringRef("test_alias_conflict0")) 71 .Case("alias_test:sanitize_conflict_b", 72 StringRef("test_alias_conflict0_")) 73 .Case("alias_test:tensor_encoding", StringRef("test_encoding")) 74 .Default(llvm::None); 75 if (!aliasName) 76 return AliasResult::NoAlias; 77 78 os << *aliasName; 79 return AliasResult::FinalAlias; 80 } 81 82 AliasResult getAlias(Type type, raw_ostream &os) const final { 83 if (auto tupleType = type.dyn_cast<TupleType>()) { 84 if (tupleType.size() > 0 && 85 llvm::all_of(tupleType.getTypes(), [](Type elemType) { 86 return elemType.isa<SimpleAType>(); 87 })) { 88 os << "test_tuple"; 89 return AliasResult::FinalAlias; 90 } 91 } 92 if (auto intType = type.dyn_cast<TestIntegerType>()) { 93 if (intType.getSignedness() == 94 TestIntegerType::SignednessSemantics::Unsigned && 95 intType.getWidth() == 8) { 96 os << "test_ui8"; 97 return AliasResult::FinalAlias; 98 } 99 } 100 return AliasResult::NoAlias; 101 } 102 }; 103 104 struct TestDialectFoldInterface : public DialectFoldInterface { 105 using DialectFoldInterface::DialectFoldInterface; 106 107 /// Registered hook to check if the given region, which is attached to an 108 /// operation that is *not* isolated from above, should be used when 109 /// materializing constants. 110 bool shouldMaterializeInto(Region *region) const final { 111 // If this is a one region operation, then insert into it. 112 return isa<OneRegionOp>(region->getParentOp()); 113 } 114 }; 115 116 /// This class defines the interface for handling inlining with standard 117 /// operations. 118 struct TestInlinerInterface : public DialectInlinerInterface { 119 using DialectInlinerInterface::DialectInlinerInterface; 120 121 //===--------------------------------------------------------------------===// 122 // Analysis Hooks 123 //===--------------------------------------------------------------------===// 124 125 bool isLegalToInline(Operation *call, Operation *callable, 126 bool wouldBeCloned) const final { 127 // Don't allow inlining calls that are marked `noinline`. 128 return !call->hasAttr("noinline"); 129 } 130 bool isLegalToInline(Region *, Region *, bool, 131 BlockAndValueMapping &) const final { 132 // Inlining into test dialect regions is legal. 133 return true; 134 } 135 bool isLegalToInline(Operation *, Region *, bool, 136 BlockAndValueMapping &) const final { 137 return true; 138 } 139 140 bool shouldAnalyzeRecursively(Operation *op) const final { 141 // Analyze recursively if this is not a functional region operation, it 142 // froms a separate functional scope. 143 return !isa<FunctionalRegionOp>(op); 144 } 145 146 //===--------------------------------------------------------------------===// 147 // Transformation Hooks 148 //===--------------------------------------------------------------------===// 149 150 /// Handle the given inlined terminator by replacing it with a new operation 151 /// as necessary. 152 void handleTerminator(Operation *op, 153 ArrayRef<Value> valuesToRepl) const final { 154 // Only handle "test.return" here. 155 auto returnOp = dyn_cast<TestReturnOp>(op); 156 if (!returnOp) 157 return; 158 159 // Replace the values directly with the return operands. 160 assert(returnOp.getNumOperands() == valuesToRepl.size()); 161 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 162 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 163 } 164 165 /// Attempt to materialize a conversion for a type mismatch between a call 166 /// from this dialect, and a callable region. This method should generate an 167 /// operation that takes 'input' as the only operand, and produces a single 168 /// result of 'resultType'. If a conversion can not be generated, nullptr 169 /// should be returned. 170 Operation *materializeCallConversion(OpBuilder &builder, Value input, 171 Type resultType, 172 Location conversionLoc) const final { 173 // Only allow conversion for i16/i32 types. 174 if (!(resultType.isSignlessInteger(16) || 175 resultType.isSignlessInteger(32)) || 176 !(input.getType().isSignlessInteger(16) || 177 input.getType().isSignlessInteger(32))) 178 return nullptr; 179 return builder.create<TestCastOp>(conversionLoc, resultType, input); 180 } 181 182 void processInlinedCallBlocks( 183 Operation *call, 184 iterator_range<Region::iterator> inlinedBlocks) const final { 185 if (!isa<ConversionCallOp>(call)) 186 return; 187 188 // Set attributed on all ops in the inlined blocks. 189 for (Block &block : inlinedBlocks) { 190 block.walk([&](Operation *op) { 191 op->setAttr("inlined_conversion", UnitAttr::get(call->getContext())); 192 }); 193 } 194 } 195 }; 196 197 struct TestReductionPatternInterface : public DialectReductionPatternInterface { 198 public: 199 TestReductionPatternInterface(Dialect *dialect) 200 : DialectReductionPatternInterface(dialect) {} 201 202 void populateReductionPatterns(RewritePatternSet &patterns) const final { 203 populateTestReductionPatterns(patterns); 204 } 205 }; 206 207 } // namespace 208 209 //===----------------------------------------------------------------------===// 210 // TestDialect 211 //===----------------------------------------------------------------------===// 212 213 static void testSideEffectOpGetEffect( 214 Operation *op, 215 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects); 216 217 // This is the implementation of a dialect fallback for `TestEffectOpInterface`. 218 struct TestOpEffectInterfaceFallback 219 : public TestEffectOpInterface::FallbackModel< 220 TestOpEffectInterfaceFallback> { 221 static bool classof(Operation *op) { 222 bool isSupportedOp = 223 op->getName().getStringRef() == "test.unregistered_side_effect_op"; 224 assert(isSupportedOp && "Unexpected dispatch"); 225 return isSupportedOp; 226 } 227 228 void 229 getEffects(Operation *op, 230 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 231 &effects) const { 232 testSideEffectOpGetEffect(op, effects); 233 } 234 }; 235 236 void TestDialect::initialize() { 237 registerAttributes(); 238 registerTypes(); 239 addOperations< 240 #define GET_OP_LIST 241 #include "TestOps.cpp.inc" 242 >(); 243 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 244 TestInlinerInterface, TestReductionPatternInterface>(); 245 allowUnknownOperations(); 246 247 // Instantiate our fallback op interface that we'll use on specific 248 // unregistered op. 249 fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback; 250 } 251 TestDialect::~TestDialect() { 252 delete static_cast<TestOpEffectInterfaceFallback *>( 253 fallbackEffectOpInterfaces); 254 } 255 256 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, 257 Type type, Location loc) { 258 return builder.create<TestOpConstant>(loc, type, value); 259 } 260 261 ::mlir::LogicalResult FormatInferType2Op::inferReturnTypes( 262 ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, 263 ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, 264 ::mlir::RegionRange regions, 265 ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { 266 inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)}); 267 return ::mlir::success(); 268 } 269 270 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, 271 OperationName opName) { 272 if (opName.getIdentifier() == "test.unregistered_side_effect_op" && 273 typeID == TypeID::get<TestEffectOpInterface>()) 274 return fallbackEffectOpInterfaces; 275 return nullptr; 276 } 277 278 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 279 NamedAttribute namedAttr) { 280 if (namedAttr.getName() == "test.invalid_attr") 281 return op->emitError() << "invalid to use 'test.invalid_attr'"; 282 return success(); 283 } 284 285 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 286 unsigned regionIndex, 287 unsigned argIndex, 288 NamedAttribute namedAttr) { 289 if (namedAttr.getName() == "test.invalid_attr") 290 return op->emitError() << "invalid to use 'test.invalid_attr'"; 291 return success(); 292 } 293 294 LogicalResult 295 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 296 unsigned resultIndex, 297 NamedAttribute namedAttr) { 298 if (namedAttr.getName() == "test.invalid_attr") 299 return op->emitError() << "invalid to use 'test.invalid_attr'"; 300 return success(); 301 } 302 303 Optional<Dialect::ParseOpHook> 304 TestDialect::getParseOperationHook(StringRef opName) const { 305 if (opName == "test.dialect_custom_printer") { 306 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 307 return parser.parseKeyword("custom_format"); 308 }}; 309 } 310 if (opName == "test.dialect_custom_format_fallback") { 311 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 312 return parser.parseKeyword("custom_format_fallback"); 313 }}; 314 } 315 return None; 316 } 317 318 llvm::unique_function<void(Operation *, OpAsmPrinter &)> 319 TestDialect::getOperationPrinter(Operation *op) const { 320 StringRef opName = op->getName().getStringRef(); 321 if (opName == "test.dialect_custom_printer") { 322 return [](Operation *op, OpAsmPrinter &printer) { 323 printer.getStream() << " custom_format"; 324 }; 325 } 326 if (opName == "test.dialect_custom_format_fallback") { 327 return [](Operation *op, OpAsmPrinter &printer) { 328 printer.getStream() << " custom_format_fallback"; 329 }; 330 } 331 return {}; 332 } 333 334 //===----------------------------------------------------------------------===// 335 // TestBranchOp 336 //===----------------------------------------------------------------------===// 337 338 SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) { 339 assert(index == 0 && "invalid successor index"); 340 return SuccessorOperands(getTargetOperandsMutable()); 341 } 342 343 //===----------------------------------------------------------------------===// 344 // TestProducingBranchOp 345 //===----------------------------------------------------------------------===// 346 347 SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) { 348 assert(index <= 1 && "invalid successor index"); 349 if (index == 1) 350 return SuccessorOperands(getFirstOperandsMutable()); 351 return SuccessorOperands(getSecondOperandsMutable()); 352 } 353 354 //===----------------------------------------------------------------------===// 355 // TestProducingBranchOp 356 //===----------------------------------------------------------------------===// 357 358 SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) { 359 assert(index <= 1 && "invalid successor index"); 360 if (index == 0) 361 return SuccessorOperands(0, getSuccessOperandsMutable()); 362 return SuccessorOperands(1, getErrorOperandsMutable()); 363 } 364 365 //===----------------------------------------------------------------------===// 366 // TestDialectCanonicalizerOp 367 //===----------------------------------------------------------------------===// 368 369 static LogicalResult 370 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op, 371 PatternRewriter &rewriter) { 372 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 373 op, rewriter.getI32IntegerAttr(42)); 374 return success(); 375 } 376 377 void TestDialect::getCanonicalizationPatterns( 378 RewritePatternSet &results) const { 379 results.add(&dialectCanonicalizationPattern); 380 } 381 382 //===----------------------------------------------------------------------===// 383 // TestCallOp 384 //===----------------------------------------------------------------------===// 385 386 LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 387 // Check that the callee attribute was specified. 388 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee"); 389 if (!fnAttr) 390 return emitOpError("requires a 'callee' symbol reference attribute"); 391 if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr)) 392 return emitOpError() << "'" << fnAttr.getValue() 393 << "' does not reference a valid function"; 394 return success(); 395 } 396 397 //===----------------------------------------------------------------------===// 398 // TestFoldToCallOp 399 //===----------------------------------------------------------------------===// 400 401 namespace { 402 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 403 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 404 405 LogicalResult matchAndRewrite(FoldToCallOp op, 406 PatternRewriter &rewriter) const override { 407 rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(), 408 op.getCalleeAttr(), ValueRange()); 409 return success(); 410 } 411 }; 412 } // namespace 413 414 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results, 415 MLIRContext *context) { 416 results.add<FoldToCallOpPattern>(context); 417 } 418 419 //===----------------------------------------------------------------------===// 420 // Test Format* operations 421 //===----------------------------------------------------------------------===// 422 423 //===----------------------------------------------------------------------===// 424 // Parsing 425 426 static ParseResult parseCustomOptionalOperand( 427 OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &optOperand) { 428 if (succeeded(parser.parseOptionalLParen())) { 429 optOperand.emplace(); 430 if (parser.parseOperand(*optOperand) || parser.parseRParen()) 431 return failure(); 432 } 433 return success(); 434 } 435 436 static ParseResult parseCustomDirectiveOperands( 437 OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, 438 Optional<OpAsmParser::UnresolvedOperand> &optOperand, 439 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands) { 440 if (parser.parseOperand(operand)) 441 return failure(); 442 if (succeeded(parser.parseOptionalComma())) { 443 optOperand.emplace(); 444 if (parser.parseOperand(*optOperand)) 445 return failure(); 446 } 447 if (parser.parseArrow() || parser.parseLParen() || 448 parser.parseOperandList(varOperands) || parser.parseRParen()) 449 return failure(); 450 return success(); 451 } 452 static ParseResult 453 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 454 Type &optOperandType, 455 SmallVectorImpl<Type> &varOperandTypes) { 456 if (parser.parseColon()) 457 return failure(); 458 459 if (parser.parseType(operandType)) 460 return failure(); 461 if (succeeded(parser.parseOptionalComma())) { 462 if (parser.parseType(optOperandType)) 463 return failure(); 464 } 465 if (parser.parseArrow() || parser.parseLParen() || 466 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 467 return failure(); 468 return success(); 469 } 470 static ParseResult 471 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 472 Type optOperandType, 473 const SmallVectorImpl<Type> &varOperandTypes) { 474 if (parser.parseKeyword("type_refs_capture")) 475 return failure(); 476 477 Type operandType2, optOperandType2; 478 SmallVector<Type, 1> varOperandTypes2; 479 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 480 varOperandTypes2)) 481 return failure(); 482 483 if (operandType != operandType2 || optOperandType != optOperandType2 || 484 varOperandTypes != varOperandTypes2) 485 return failure(); 486 487 return success(); 488 } 489 static ParseResult parseCustomDirectiveOperandsAndTypes( 490 OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, 491 Optional<OpAsmParser::UnresolvedOperand> &optOperand, 492 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands, 493 Type &operandType, Type &optOperandType, 494 SmallVectorImpl<Type> &varOperandTypes) { 495 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 496 parseCustomDirectiveResults(parser, operandType, optOperandType, 497 varOperandTypes)) 498 return failure(); 499 return success(); 500 } 501 static ParseResult parseCustomDirectiveRegions( 502 OpAsmParser &parser, Region ®ion, 503 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 504 if (parser.parseRegion(region)) 505 return failure(); 506 if (failed(parser.parseOptionalComma())) 507 return success(); 508 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 509 if (parser.parseRegion(*varRegion)) 510 return failure(); 511 varRegions.emplace_back(std::move(varRegion)); 512 return success(); 513 } 514 static ParseResult 515 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 516 SmallVectorImpl<Block *> &varSuccessors) { 517 if (parser.parseSuccessor(successor)) 518 return failure(); 519 if (failed(parser.parseOptionalComma())) 520 return success(); 521 Block *varSuccessor; 522 if (parser.parseSuccessor(varSuccessor)) 523 return failure(); 524 varSuccessors.append(2, varSuccessor); 525 return success(); 526 } 527 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 528 IntegerAttr &attr, 529 IntegerAttr &optAttr) { 530 if (parser.parseAttribute(attr)) 531 return failure(); 532 if (succeeded(parser.parseOptionalComma())) { 533 if (parser.parseAttribute(optAttr)) 534 return failure(); 535 } 536 return success(); 537 } 538 539 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, 540 NamedAttrList &attrs) { 541 return parser.parseOptionalAttrDict(attrs); 542 } 543 static ParseResult parseCustomDirectiveOptionalOperandRef( 544 OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &optOperand) { 545 int64_t operandCount = 0; 546 if (parser.parseInteger(operandCount)) 547 return failure(); 548 bool expectedOptionalOperand = operandCount == 0; 549 return success(expectedOptionalOperand != optOperand.hasValue()); 550 } 551 552 //===----------------------------------------------------------------------===// 553 // Printing 554 555 static void printCustomOptionalOperand(OpAsmPrinter &printer, Operation *, 556 Value optOperand) { 557 if (optOperand) 558 printer << "(" << optOperand << ") "; 559 } 560 561 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 562 Value operand, Value optOperand, 563 OperandRange varOperands) { 564 printer << operand; 565 if (optOperand) 566 printer << ", " << optOperand; 567 printer << " -> (" << varOperands << ")"; 568 } 569 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 570 Type operandType, Type optOperandType, 571 TypeRange varOperandTypes) { 572 printer << " : " << operandType; 573 if (optOperandType) 574 printer << ", " << optOperandType; 575 printer << " -> (" << varOperandTypes << ")"; 576 } 577 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 578 Operation *op, Type operandType, 579 Type optOperandType, 580 TypeRange varOperandTypes) { 581 printer << " type_refs_capture "; 582 printCustomDirectiveResults(printer, op, operandType, optOperandType, 583 varOperandTypes); 584 } 585 static void printCustomDirectiveOperandsAndTypes( 586 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 587 OperandRange varOperands, Type operandType, Type optOperandType, 588 TypeRange varOperandTypes) { 589 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 590 printCustomDirectiveResults(printer, op, operandType, optOperandType, 591 varOperandTypes); 592 } 593 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 594 Region ®ion, 595 MutableArrayRef<Region> varRegions) { 596 printer.printRegion(region); 597 if (!varRegions.empty()) { 598 printer << ", "; 599 for (Region ®ion : varRegions) 600 printer.printRegion(region); 601 } 602 } 603 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 604 Block *successor, 605 SuccessorRange varSuccessors) { 606 printer << successor; 607 if (!varSuccessors.empty()) 608 printer << ", " << varSuccessors.front(); 609 } 610 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 611 Attribute attribute, 612 Attribute optAttribute) { 613 printer << attribute; 614 if (optAttribute) 615 printer << ", " << optAttribute; 616 } 617 618 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 619 DictionaryAttr attrs) { 620 printer.printOptionalAttrDict(attrs.getValue()); 621 } 622 623 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, 624 Operation *op, 625 Value optOperand) { 626 printer << (optOperand ? "1" : "0"); 627 } 628 629 //===----------------------------------------------------------------------===// 630 // Test IsolatedRegionOp - parse passthrough region arguments. 631 //===----------------------------------------------------------------------===// 632 633 ParseResult IsolatedRegionOp::parse(OpAsmParser &parser, 634 OperationState &result) { 635 OpAsmParser::UnresolvedOperand argInfo; 636 Type argType = parser.getBuilder().getIndexType(); 637 638 // Parse the input operand. 639 if (parser.parseOperand(argInfo) || 640 parser.resolveOperand(argInfo, argType, result.operands)) 641 return failure(); 642 643 // Parse the body region, and reuse the operand info as the argument info. 644 Region *body = result.addRegion(); 645 return parser.parseRegion(*body, argInfo, argType, /*argLocations=*/{}, 646 /*enableNameShadowing=*/true); 647 } 648 649 void IsolatedRegionOp::print(OpAsmPrinter &p) { 650 p << "test.isolated_region "; 651 p.printOperand(getOperand()); 652 p.shadowRegionArgs(getRegion(), getOperand()); 653 p << ' '; 654 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); 655 } 656 657 //===----------------------------------------------------------------------===// 658 // Test SSACFGRegionOp 659 //===----------------------------------------------------------------------===// 660 661 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 662 return RegionKind::SSACFG; 663 } 664 665 //===----------------------------------------------------------------------===// 666 // Test GraphRegionOp 667 //===----------------------------------------------------------------------===// 668 669 ParseResult GraphRegionOp::parse(OpAsmParser &parser, OperationState &result) { 670 // Parse the body region, and reuse the operand info as the argument info. 671 Region *body = result.addRegion(); 672 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 673 } 674 675 void GraphRegionOp::print(OpAsmPrinter &p) { 676 p << "test.graph_region "; 677 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); 678 } 679 680 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 681 return RegionKind::Graph; 682 } 683 684 //===----------------------------------------------------------------------===// 685 // Test AffineScopeOp 686 //===----------------------------------------------------------------------===// 687 688 ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) { 689 // Parse the body region, and reuse the operand info as the argument info. 690 Region *body = result.addRegion(); 691 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 692 } 693 694 void AffineScopeOp::print(OpAsmPrinter &p) { 695 p << "test.affine_scope "; 696 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); 697 } 698 699 //===----------------------------------------------------------------------===// 700 // Test parser. 701 //===----------------------------------------------------------------------===// 702 703 ParseResult ParseIntegerLiteralOp::parse(OpAsmParser &parser, 704 OperationState &result) { 705 if (parser.parseOptionalColon()) 706 return success(); 707 uint64_t numResults; 708 if (parser.parseInteger(numResults)) 709 return failure(); 710 711 IndexType type = parser.getBuilder().getIndexType(); 712 for (unsigned i = 0; i < numResults; ++i) 713 result.addTypes(type); 714 return success(); 715 } 716 717 void ParseIntegerLiteralOp::print(OpAsmPrinter &p) { 718 if (unsigned numResults = getNumResults()) 719 p << " : " << numResults; 720 } 721 722 ParseResult ParseWrappedKeywordOp::parse(OpAsmParser &parser, 723 OperationState &result) { 724 StringRef keyword; 725 if (parser.parseKeyword(&keyword)) 726 return failure(); 727 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 728 return success(); 729 } 730 731 void ParseWrappedKeywordOp::print(OpAsmPrinter &p) { p << " " << getKeyword(); } 732 733 //===----------------------------------------------------------------------===// 734 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 735 736 ParseResult WrappingRegionOp::parse(OpAsmParser &parser, 737 OperationState &result) { 738 if (parser.parseKeyword("wraps")) 739 return failure(); 740 741 // Parse the wrapped op in a region 742 Region &body = *result.addRegion(); 743 body.push_back(new Block); 744 Block &block = body.back(); 745 Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin()); 746 if (!wrappedOp) 747 return failure(); 748 749 // Create a return terminator in the inner region, pass as operand to the 750 // terminator the returned values from the wrapped operation. 751 SmallVector<Value, 8> returnOperands(wrappedOp->getResults()); 752 OpBuilder builder(parser.getContext()); 753 builder.setInsertionPointToEnd(&block); 754 builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands); 755 756 // Get the results type for the wrapping op from the terminator operands. 757 Operation &returnOp = body.back().back(); 758 result.types.append(returnOp.operand_type_begin(), 759 returnOp.operand_type_end()); 760 761 // Use the location of the wrapped op for the "test.wrapping_region" op. 762 result.location = wrappedOp->getLoc(); 763 764 return success(); 765 } 766 767 void WrappingRegionOp::print(OpAsmPrinter &p) { 768 p << " wraps "; 769 p.printGenericOp(&getRegion().front().front()); 770 } 771 772 //===----------------------------------------------------------------------===// 773 // Test PrettyPrintedRegionOp - exercising the following parser APIs 774 // parseGenericOperationAfterOpName 775 // parseCustomOperationName 776 //===----------------------------------------------------------------------===// 777 778 ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser, 779 OperationState &result) { 780 781 SMLoc loc = parser.getCurrentLocation(); 782 Location currLocation = parser.getEncodedSourceLoc(loc); 783 784 // Parse the operands. 785 SmallVector<OpAsmParser::UnresolvedOperand, 2> operands; 786 if (parser.parseOperandList(operands)) 787 return failure(); 788 789 // Check if we are parsing the pretty-printed version 790 // test.pretty_printed_region start <inner-op> end : <functional-type> 791 // Else fallback to parsing the "non pretty-printed" version. 792 if (!succeeded(parser.parseOptionalKeyword("start"))) 793 return parser.parseGenericOperationAfterOpName( 794 result, llvm::makeArrayRef(operands)); 795 796 FailureOr<OperationName> parseOpNameInfo = parser.parseCustomOperationName(); 797 if (failed(parseOpNameInfo)) 798 return failure(); 799 800 StringAttr innerOpName = parseOpNameInfo->getIdentifier(); 801 802 FunctionType opFntype; 803 Optional<Location> explicitLoc; 804 if (parser.parseKeyword("end") || parser.parseColon() || 805 parser.parseType(opFntype) || 806 parser.parseOptionalLocationSpecifier(explicitLoc)) 807 return failure(); 808 809 // If location of the op is explicitly provided, then use it; Else use 810 // the parser's current location. 811 Location opLoc = explicitLoc.getValueOr(currLocation); 812 813 // Derive the SSA-values for op's operands. 814 if (parser.resolveOperands(operands, opFntype.getInputs(), loc, 815 result.operands)) 816 return failure(); 817 818 // Add a region for op. 819 Region ®ion = *result.addRegion(); 820 821 // Create a basic-block inside op's region. 822 Block &block = region.emplaceBlock(); 823 824 // Create and insert an "inner-op" operation in the block. 825 // Just for testing purposes, we can assume that inner op is a binary op with 826 // result and operand types all same as the test-op's first operand. 827 Type innerOpType = opFntype.getInput(0); 828 Value lhs = block.addArgument(innerOpType, opLoc); 829 Value rhs = block.addArgument(innerOpType, opLoc); 830 831 OpBuilder builder(parser.getBuilder().getContext()); 832 builder.setInsertionPointToStart(&block); 833 834 Operation *innerOp = 835 builder.create(opLoc, innerOpName, /*operands=*/{lhs, rhs}, innerOpType); 836 837 // Insert a return statement in the block returning the inner-op's result. 838 builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults()); 839 840 // Populate the op operation-state with result-type and location. 841 result.addTypes(opFntype.getResults()); 842 result.location = innerOp->getLoc(); 843 844 return success(); 845 } 846 847 void PrettyPrintedRegionOp::print(OpAsmPrinter &p) { 848 p << ' '; 849 p.printOperands(getOperands()); 850 851 Operation &innerOp = getRegion().front().front(); 852 // Assuming that region has a single non-terminator inner-op, if the inner-op 853 // meets some criteria (which in this case is a simple one based on the name 854 // of inner-op), then we can print the entire region in a succinct way. 855 // Here we assume that the prototype of "special.op" can be trivially derived 856 // while parsing it back. 857 if (innerOp.getName().getStringRef().equals("special.op")) { 858 p << " start special.op end"; 859 } else { 860 p << " ("; 861 p.printRegion(getRegion()); 862 p << ")"; 863 } 864 865 p << " : "; 866 p.printFunctionalType(*this); 867 } 868 869 //===----------------------------------------------------------------------===// 870 // Test PolyForOp - parse list of region arguments. 871 //===----------------------------------------------------------------------===// 872 873 ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) { 874 SmallVector<OpAsmParser::UnresolvedOperand, 4> ivsInfo; 875 // Parse list of region arguments without a delimiter. 876 if (parser.parseRegionArgumentList(ivsInfo)) 877 return failure(); 878 879 // Parse the body region. 880 Region *body = result.addRegion(); 881 auto &builder = parser.getBuilder(); 882 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 883 return parser.parseRegion(*body, ivsInfo, argTypes); 884 } 885 886 void PolyForOp::print(OpAsmPrinter &p) { p.printGenericOp(*this); } 887 888 void PolyForOp::getAsmBlockArgumentNames(Region ®ion, 889 OpAsmSetValueNameFn setNameFn) { 890 auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names"); 891 if (!arrayAttr) 892 return; 893 auto args = getRegion().front().getArguments(); 894 auto e = std::min(arrayAttr.size(), args.size()); 895 for (unsigned i = 0; i < e; ++i) { 896 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 897 setNameFn(args[i], strAttr.getValue()); 898 } 899 } 900 901 //===----------------------------------------------------------------------===// 902 // Test removing op with inner ops. 903 //===----------------------------------------------------------------------===// 904 905 namespace { 906 struct TestRemoveOpWithInnerOps 907 : public OpRewritePattern<TestOpWithRegionPattern> { 908 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 909 910 void initialize() { setDebugName("TestRemoveOpWithInnerOps"); } 911 912 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 913 PatternRewriter &rewriter) const override { 914 rewriter.eraseOp(op); 915 return success(); 916 } 917 }; 918 } // namespace 919 920 void TestOpWithRegionPattern::getCanonicalizationPatterns( 921 RewritePatternSet &results, MLIRContext *context) { 922 results.add<TestRemoveOpWithInnerOps>(context); 923 } 924 925 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 926 return getOperand(); 927 } 928 929 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) { 930 return getValue(); 931 } 932 933 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 934 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 935 for (Value input : this->getOperands()) { 936 results.push_back(input); 937 } 938 return success(); 939 } 940 941 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 942 assert(operands.size() == 1); 943 if (operands.front()) { 944 (*this)->setAttr("attr", operands.front()); 945 return getResult(); 946 } 947 return {}; 948 } 949 950 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) { 951 return getOperand(); 952 } 953 954 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 955 MLIRContext *, Optional<Location> location, ValueRange operands, 956 DictionaryAttr attributes, RegionRange regions, 957 SmallVectorImpl<Type> &inferredReturnTypes) { 958 if (operands[0].getType() != operands[1].getType()) { 959 return emitOptionalError(location, "operand type mismatch ", 960 operands[0].getType(), " vs ", 961 operands[1].getType()); 962 } 963 inferredReturnTypes.assign({operands[0].getType()}); 964 return success(); 965 } 966 967 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 968 MLIRContext *context, Optional<Location> location, ValueShapeRange operands, 969 DictionaryAttr attributes, RegionRange regions, 970 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 971 // Create return type consisting of the last element of the first operand. 972 auto operandType = operands.front().getType(); 973 auto sval = operandType.dyn_cast<ShapedType>(); 974 if (!sval) { 975 return emitOptionalError(location, "only shaped type operands allowed"); 976 } 977 int64_t dim = 978 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 979 auto type = IntegerType::get(context, 17); 980 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 981 return success(); 982 } 983 984 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 985 OpBuilder &builder, ValueRange operands, 986 llvm::SmallVectorImpl<Value> &shapes) { 987 shapes = SmallVector<Value, 1>{ 988 builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)}; 989 return success(); 990 } 991 992 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( 993 OpBuilder &builder, ValueRange operands, 994 llvm::SmallVectorImpl<Value> &shapes) { 995 Location loc = getLoc(); 996 shapes.reserve(operands.size()); 997 for (Value operand : llvm::reverse(operands)) { 998 auto rank = operand.getType().cast<RankedTensorType>().getRank(); 999 auto currShape = llvm::to_vector<4>( 1000 llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value { 1001 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 1002 })); 1003 shapes.push_back(builder.create<tensor::FromElementsOp>( 1004 getLoc(), RankedTensorType::get({rank}, builder.getIndexType()), 1005 currShape)); 1006 } 1007 return success(); 1008 } 1009 1010 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( 1011 OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { 1012 Location loc = getLoc(); 1013 shapes.reserve(getNumOperands()); 1014 for (Value operand : llvm::reverse(getOperands())) { 1015 auto currShape = llvm::to_vector<4>(llvm::map_range( 1016 llvm::seq<int64_t>( 1017 0, operand.getType().cast<RankedTensorType>().getRank()), 1018 [&](int64_t dim) -> Value { 1019 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 1020 })); 1021 shapes.emplace_back(std::move(currShape)); 1022 } 1023 return success(); 1024 } 1025 1026 //===----------------------------------------------------------------------===// 1027 // Test SideEffect interfaces 1028 //===----------------------------------------------------------------------===// 1029 1030 namespace { 1031 /// A test resource for side effects. 1032 struct TestResource : public SideEffects::Resource::Base<TestResource> { 1033 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource) 1034 1035 StringRef getName() final { return "<Test>"; } 1036 }; 1037 } // namespace 1038 1039 static void testSideEffectOpGetEffect( 1040 Operation *op, 1041 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 1042 &effects) { 1043 auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter"); 1044 if (!effectsAttr) 1045 return; 1046 1047 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); 1048 } 1049 1050 void SideEffectOp::getEffects( 1051 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1052 // Check for an effects attribute on the op instance. 1053 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects"); 1054 if (!effectsAttr) 1055 return; 1056 1057 // If there is one, it is an array of dictionary attributes that hold 1058 // information on the effects of this operation. 1059 for (Attribute element : effectsAttr) { 1060 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 1061 1062 // Get the specific memory effect. 1063 MemoryEffects::Effect *effect = 1064 StringSwitch<MemoryEffects::Effect *>( 1065 effectElement.get("effect").cast<StringAttr>().getValue()) 1066 .Case("allocate", MemoryEffects::Allocate::get()) 1067 .Case("free", MemoryEffects::Free::get()) 1068 .Case("read", MemoryEffects::Read::get()) 1069 .Case("write", MemoryEffects::Write::get()); 1070 1071 // Check for a non-default resource to use. 1072 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 1073 if (effectElement.get("test_resource")) 1074 resource = TestResource::get(); 1075 1076 // Check for a result to affect. 1077 if (effectElement.get("on_result")) 1078 effects.emplace_back(effect, getResult(), resource); 1079 else if (Attribute ref = effectElement.get("on_reference")) 1080 effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource); 1081 else 1082 effects.emplace_back(effect, resource); 1083 } 1084 } 1085 1086 void SideEffectOp::getEffects( 1087 SmallVectorImpl<TestEffects::EffectInstance> &effects) { 1088 testSideEffectOpGetEffect(getOperation(), effects); 1089 } 1090 1091 //===----------------------------------------------------------------------===// 1092 // StringAttrPrettyNameOp 1093 //===----------------------------------------------------------------------===// 1094 1095 // This op has fancy handling of its SSA result name. 1096 ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser, 1097 OperationState &result) { 1098 // Add the result types. 1099 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 1100 result.addTypes(parser.getBuilder().getIntegerType(32)); 1101 1102 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 1103 return failure(); 1104 1105 // If the attribute dictionary contains no 'names' attribute, infer it from 1106 // the SSA name (if specified). 1107 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 1108 return attr.getName() == "names"; 1109 }); 1110 1111 // If there was no name specified, check to see if there was a useful name 1112 // specified in the asm file. 1113 if (hadNames || parser.getNumResults() == 0) 1114 return success(); 1115 1116 SmallVector<StringRef, 4> names; 1117 auto *context = result.getContext(); 1118 1119 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 1120 auto resultName = parser.getResultName(i); 1121 StringRef nameStr; 1122 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 1123 nameStr = resultName.first; 1124 1125 names.push_back(nameStr); 1126 } 1127 1128 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 1129 result.attributes.push_back({StringAttr::get(context, "names"), namesAttr}); 1130 return success(); 1131 } 1132 1133 void StringAttrPrettyNameOp::print(OpAsmPrinter &p) { 1134 // Note that we only need to print the "name" attribute if the asmprinter 1135 // result name disagrees with it. This can happen in strange cases, e.g. 1136 // when there are conflicts. 1137 bool namesDisagree = getNames().size() != getNumResults(); 1138 1139 SmallString<32> resultNameStr; 1140 for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) { 1141 resultNameStr.clear(); 1142 llvm::raw_svector_ostream tmpStream(resultNameStr); 1143 p.printOperand(getResult(i), tmpStream); 1144 1145 auto expectedName = getNames()[i].dyn_cast<StringAttr>(); 1146 if (!expectedName || 1147 tmpStream.str().drop_front() != expectedName.getValue()) { 1148 namesDisagree = true; 1149 } 1150 } 1151 1152 if (namesDisagree) 1153 p.printOptionalAttrDictWithKeyword((*this)->getAttrs()); 1154 else 1155 p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"}); 1156 } 1157 1158 // We set the SSA name in the asm syntax to the contents of the name 1159 // attribute. 1160 void StringAttrPrettyNameOp::getAsmResultNames( 1161 function_ref<void(Value, StringRef)> setNameFn) { 1162 1163 auto value = getNames(); 1164 for (size_t i = 0, e = value.size(); i != e; ++i) 1165 if (auto str = value[i].dyn_cast<StringAttr>()) 1166 if (!str.getValue().empty()) 1167 setNameFn(getResult(i), str.getValue()); 1168 } 1169 1170 //===----------------------------------------------------------------------===// 1171 // ResultTypeWithTraitOp 1172 //===----------------------------------------------------------------------===// 1173 1174 LogicalResult ResultTypeWithTraitOp::verify() { 1175 if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>()) 1176 return success(); 1177 return emitError("result type should have trait 'TestTypeTrait'"); 1178 } 1179 1180 //===----------------------------------------------------------------------===// 1181 // AttrWithTraitOp 1182 //===----------------------------------------------------------------------===// 1183 1184 LogicalResult AttrWithTraitOp::verify() { 1185 if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>()) 1186 return success(); 1187 return emitError("'attr' attribute should have trait 'TestAttrTrait'"); 1188 } 1189 1190 //===----------------------------------------------------------------------===// 1191 // RegionIfOp 1192 //===----------------------------------------------------------------------===// 1193 1194 void RegionIfOp::print(OpAsmPrinter &p) { 1195 p << " "; 1196 p.printOperands(getOperands()); 1197 p << ": " << getOperandTypes(); 1198 p.printArrowTypeList(getResultTypes()); 1199 p << " then "; 1200 p.printRegion(getThenRegion(), 1201 /*printEntryBlockArgs=*/true, 1202 /*printBlockTerminators=*/true); 1203 p << " else "; 1204 p.printRegion(getElseRegion(), 1205 /*printEntryBlockArgs=*/true, 1206 /*printBlockTerminators=*/true); 1207 p << " join "; 1208 p.printRegion(getJoinRegion(), 1209 /*printEntryBlockArgs=*/true, 1210 /*printBlockTerminators=*/true); 1211 } 1212 1213 ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) { 1214 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos; 1215 SmallVector<Type, 2> operandTypes; 1216 1217 result.regions.reserve(3); 1218 Region *thenRegion = result.addRegion(); 1219 Region *elseRegion = result.addRegion(); 1220 Region *joinRegion = result.addRegion(); 1221 1222 // Parse operand, type and arrow type lists. 1223 if (parser.parseOperandList(operandInfos) || 1224 parser.parseColonTypeList(operandTypes) || 1225 parser.parseArrowTypeList(result.types)) 1226 return failure(); 1227 1228 // Parse all attached regions. 1229 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 1230 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 1231 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 1232 return failure(); 1233 1234 return parser.resolveOperands(operandInfos, operandTypes, 1235 parser.getCurrentLocation(), result.operands); 1236 } 1237 1238 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 1239 assert(index < 2 && "invalid region index"); 1240 return getOperands(); 1241 } 1242 1243 void RegionIfOp::getSuccessorRegions( 1244 Optional<unsigned> index, ArrayRef<Attribute> operands, 1245 SmallVectorImpl<RegionSuccessor> ®ions) { 1246 // We always branch to the join region. 1247 if (index.hasValue()) { 1248 if (index.getValue() < 2) 1249 regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs())); 1250 else 1251 regions.push_back(RegionSuccessor(getResults())); 1252 return; 1253 } 1254 1255 // The then and else regions are the entry regions of this op. 1256 regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs())); 1257 regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs())); 1258 } 1259 1260 void RegionIfOp::getRegionInvocationBounds( 1261 ArrayRef<Attribute> operands, 1262 SmallVectorImpl<InvocationBounds> &invocationBounds) { 1263 // Each region is invoked at most once. 1264 invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1}); 1265 } 1266 1267 //===----------------------------------------------------------------------===// 1268 // AnyCondOp 1269 //===----------------------------------------------------------------------===// 1270 1271 void AnyCondOp::getSuccessorRegions(Optional<unsigned> index, 1272 ArrayRef<Attribute> operands, 1273 SmallVectorImpl<RegionSuccessor> ®ions) { 1274 // The parent op branches into the only region, and the region branches back 1275 // to the parent op. 1276 if (index) 1277 regions.emplace_back(&getRegion()); 1278 else 1279 regions.emplace_back(getResults()); 1280 } 1281 1282 void AnyCondOp::getRegionInvocationBounds( 1283 ArrayRef<Attribute> operands, 1284 SmallVectorImpl<InvocationBounds> &invocationBounds) { 1285 invocationBounds.emplace_back(1, 1); 1286 } 1287 1288 //===----------------------------------------------------------------------===// 1289 // SingleNoTerminatorCustomAsmOp 1290 //===----------------------------------------------------------------------===// 1291 1292 ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser, 1293 OperationState &state) { 1294 Region *body = state.addRegion(); 1295 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) 1296 return failure(); 1297 return success(); 1298 } 1299 1300 void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) { 1301 printer.printRegion( 1302 getRegion(), /*printEntryBlockArgs=*/false, 1303 // This op has a single block without terminators. But explicitly mark 1304 // as not printing block terminators for testing. 1305 /*printBlockTerminators=*/false); 1306 } 1307 1308 #include "TestOpEnums.cpp.inc" 1309 #include "TestOpInterfaces.cpp.inc" 1310 #include "TestOpStructs.cpp.inc" 1311 #include "TestTypeInterfaces.cpp.inc" 1312 1313 #define GET_OP_CLASSES 1314 #include "TestOps.cpp.inc" 1315