1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestAttributes.h" 11 #include "TestInterfaces.h" 12 #include "TestTypes.h" 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/DLTI/DLTI.h" 15 #include "mlir/Dialect/Func/IR/FuncOps.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/DialectImplementation.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/IR/TypeUtilities.h" 21 #include "mlir/IR/Verifier.h" 22 #include "mlir/Reducer/ReductionPatternInterface.h" 23 #include "mlir/Transforms/FoldUtils.h" 24 #include "mlir/Transforms/InliningUtils.h" 25 #include "llvm/ADT/StringSwitch.h" 26 27 // Include this before the using namespace lines below to 28 // test that we don't have namespace dependencies. 29 #include "TestOpsDialect.cpp.inc" 30 31 using namespace mlir; 32 using namespace test; 33 34 void test::registerTestDialect(DialectRegistry ®istry) { 35 registry.insert<TestDialect>(); 36 } 37 38 //===----------------------------------------------------------------------===// 39 // TestDialect Interfaces 40 //===----------------------------------------------------------------------===// 41 42 namespace { 43 44 /// Testing the correctness of some traits. 45 static_assert( 46 llvm::is_detected<OpTrait::has_implicit_terminator_t, 47 SingleBlockImplicitTerminatorOp>::value, 48 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp"); 49 static_assert(OpTrait::hasSingleBlockImplicitTerminator< 50 SingleBlockImplicitTerminatorOp>::value, 51 "hasSingleBlockImplicitTerminator does not match " 52 "SingleBlockImplicitTerminatorOp"); 53 54 // Test support for interacting with the AsmPrinter. 55 struct TestOpAsmInterface : public OpAsmDialectInterface { 56 using OpAsmDialectInterface::OpAsmDialectInterface; 57 58 AliasResult getAlias(Attribute attr, raw_ostream &os) const final { 59 StringAttr strAttr = attr.dyn_cast<StringAttr>(); 60 if (!strAttr) 61 return AliasResult::NoAlias; 62 63 // Check the contents of the string attribute to see what the test alias 64 // should be named. 65 Optional<StringRef> aliasName = 66 StringSwitch<Optional<StringRef>>(strAttr.getValue()) 67 .Case("alias_test:dot_in_name", StringRef("test.alias")) 68 .Case("alias_test:trailing_digit", StringRef("test_alias0")) 69 .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) 70 .Case("alias_test:sanitize_conflict_a", 71 StringRef("test_alias_conflict0")) 72 .Case("alias_test:sanitize_conflict_b", 73 StringRef("test_alias_conflict0_")) 74 .Case("alias_test:tensor_encoding", StringRef("test_encoding")) 75 .Default(llvm::None); 76 if (!aliasName) 77 return AliasResult::NoAlias; 78 79 os << *aliasName; 80 return AliasResult::FinalAlias; 81 } 82 83 AliasResult getAlias(Type type, raw_ostream &os) const final { 84 if (auto tupleType = type.dyn_cast<TupleType>()) { 85 if (tupleType.size() > 0 && 86 llvm::all_of(tupleType.getTypes(), [](Type elemType) { 87 return elemType.isa<SimpleAType>(); 88 })) { 89 os << "test_tuple"; 90 return AliasResult::FinalAlias; 91 } 92 } 93 if (auto intType = type.dyn_cast<TestIntegerType>()) { 94 if (intType.getSignedness() == 95 TestIntegerType::SignednessSemantics::Unsigned && 96 intType.getWidth() == 8) { 97 os << "test_ui8"; 98 return AliasResult::FinalAlias; 99 } 100 } 101 return AliasResult::NoAlias; 102 } 103 }; 104 105 struct TestDialectFoldInterface : public DialectFoldInterface { 106 using DialectFoldInterface::DialectFoldInterface; 107 108 /// Registered hook to check if the given region, which is attached to an 109 /// operation that is *not* isolated from above, should be used when 110 /// materializing constants. 111 bool shouldMaterializeInto(Region *region) const final { 112 // If this is a one region operation, then insert into it. 113 return isa<OneRegionOp>(region->getParentOp()); 114 } 115 }; 116 117 /// This class defines the interface for handling inlining with standard 118 /// operations. 119 struct TestInlinerInterface : public DialectInlinerInterface { 120 using DialectInlinerInterface::DialectInlinerInterface; 121 122 //===--------------------------------------------------------------------===// 123 // Analysis Hooks 124 //===--------------------------------------------------------------------===// 125 126 bool isLegalToInline(Operation *call, Operation *callable, 127 bool wouldBeCloned) const final { 128 // Don't allow inlining calls that are marked `noinline`. 129 return !call->hasAttr("noinline"); 130 } 131 bool isLegalToInline(Region *, Region *, bool, 132 BlockAndValueMapping &) const final { 133 // Inlining into test dialect regions is legal. 134 return true; 135 } 136 bool isLegalToInline(Operation *, Region *, bool, 137 BlockAndValueMapping &) const final { 138 return true; 139 } 140 141 bool shouldAnalyzeRecursively(Operation *op) const final { 142 // Analyze recursively if this is not a functional region operation, it 143 // froms a separate functional scope. 144 return !isa<FunctionalRegionOp>(op); 145 } 146 147 //===--------------------------------------------------------------------===// 148 // Transformation Hooks 149 //===--------------------------------------------------------------------===// 150 151 /// Handle the given inlined terminator by replacing it with a new operation 152 /// as necessary. 153 void handleTerminator(Operation *op, 154 ArrayRef<Value> valuesToRepl) const final { 155 // Only handle "test.return" here. 156 auto returnOp = dyn_cast<TestReturnOp>(op); 157 if (!returnOp) 158 return; 159 160 // Replace the values directly with the return operands. 161 assert(returnOp.getNumOperands() == valuesToRepl.size()); 162 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 163 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 164 } 165 166 /// Attempt to materialize a conversion for a type mismatch between a call 167 /// from this dialect, and a callable region. This method should generate an 168 /// operation that takes 'input' as the only operand, and produces a single 169 /// result of 'resultType'. If a conversion can not be generated, nullptr 170 /// should be returned. 171 Operation *materializeCallConversion(OpBuilder &builder, Value input, 172 Type resultType, 173 Location conversionLoc) const final { 174 // Only allow conversion for i16/i32 types. 175 if (!(resultType.isSignlessInteger(16) || 176 resultType.isSignlessInteger(32)) || 177 !(input.getType().isSignlessInteger(16) || 178 input.getType().isSignlessInteger(32))) 179 return nullptr; 180 return builder.create<TestCastOp>(conversionLoc, resultType, input); 181 } 182 183 void processInlinedCallBlocks( 184 Operation *call, 185 iterator_range<Region::iterator> inlinedBlocks) const final { 186 if (!isa<ConversionCallOp>(call)) 187 return; 188 189 // Set attributed on all ops in the inlined blocks. 190 for (Block &block : inlinedBlocks) { 191 block.walk([&](Operation *op) { 192 op->setAttr("inlined_conversion", UnitAttr::get(call->getContext())); 193 }); 194 } 195 } 196 }; 197 198 struct TestReductionPatternInterface : public DialectReductionPatternInterface { 199 public: 200 TestReductionPatternInterface(Dialect *dialect) 201 : DialectReductionPatternInterface(dialect) {} 202 203 void populateReductionPatterns(RewritePatternSet &patterns) const final { 204 populateTestReductionPatterns(patterns); 205 } 206 }; 207 208 } // namespace 209 210 //===----------------------------------------------------------------------===// 211 // TestDialect 212 //===----------------------------------------------------------------------===// 213 214 static void testSideEffectOpGetEffect( 215 Operation *op, 216 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects); 217 218 // This is the implementation of a dialect fallback for `TestEffectOpInterface`. 219 struct TestOpEffectInterfaceFallback 220 : public TestEffectOpInterface::FallbackModel< 221 TestOpEffectInterfaceFallback> { 222 static bool classof(Operation *op) { 223 bool isSupportedOp = 224 op->getName().getStringRef() == "test.unregistered_side_effect_op"; 225 assert(isSupportedOp && "Unexpected dispatch"); 226 return isSupportedOp; 227 } 228 229 void 230 getEffects(Operation *op, 231 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 232 &effects) const { 233 testSideEffectOpGetEffect(op, effects); 234 } 235 }; 236 237 void TestDialect::initialize() { 238 registerAttributes(); 239 registerTypes(); 240 addOperations< 241 #define GET_OP_LIST 242 #include "TestOps.cpp.inc" 243 >(); 244 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 245 TestInlinerInterface, TestReductionPatternInterface>(); 246 allowUnknownOperations(); 247 248 // Instantiate our fallback op interface that we'll use on specific 249 // unregistered op. 250 fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback; 251 } 252 TestDialect::~TestDialect() { 253 delete static_cast<TestOpEffectInterfaceFallback *>( 254 fallbackEffectOpInterfaces); 255 } 256 257 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, 258 Type type, Location loc) { 259 return builder.create<TestOpConstant>(loc, type, value); 260 } 261 262 ::mlir::LogicalResult FormatInferType2Op::inferReturnTypes( 263 ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, 264 ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, 265 ::mlir::RegionRange regions, 266 ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { 267 inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)}); 268 return ::mlir::success(); 269 } 270 271 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, 272 OperationName opName) { 273 if (opName.getIdentifier() == "test.unregistered_side_effect_op" && 274 typeID == TypeID::get<TestEffectOpInterface>()) 275 return fallbackEffectOpInterfaces; 276 return nullptr; 277 } 278 279 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 280 NamedAttribute namedAttr) { 281 if (namedAttr.getName() == "test.invalid_attr") 282 return op->emitError() << "invalid to use 'test.invalid_attr'"; 283 return success(); 284 } 285 286 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 287 unsigned regionIndex, 288 unsigned argIndex, 289 NamedAttribute namedAttr) { 290 if (namedAttr.getName() == "test.invalid_attr") 291 return op->emitError() << "invalid to use 'test.invalid_attr'"; 292 return success(); 293 } 294 295 LogicalResult 296 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 297 unsigned resultIndex, 298 NamedAttribute namedAttr) { 299 if (namedAttr.getName() == "test.invalid_attr") 300 return op->emitError() << "invalid to use 'test.invalid_attr'"; 301 return success(); 302 } 303 304 Optional<Dialect::ParseOpHook> 305 TestDialect::getParseOperationHook(StringRef opName) const { 306 if (opName == "test.dialect_custom_printer") { 307 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 308 return parser.parseKeyword("custom_format"); 309 }}; 310 } 311 if (opName == "test.dialect_custom_format_fallback") { 312 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 313 return parser.parseKeyword("custom_format_fallback"); 314 }}; 315 } 316 return None; 317 } 318 319 llvm::unique_function<void(Operation *, OpAsmPrinter &)> 320 TestDialect::getOperationPrinter(Operation *op) const { 321 StringRef opName = op->getName().getStringRef(); 322 if (opName == "test.dialect_custom_printer") { 323 return [](Operation *op, OpAsmPrinter &printer) { 324 printer.getStream() << " custom_format"; 325 }; 326 } 327 if (opName == "test.dialect_custom_format_fallback") { 328 return [](Operation *op, OpAsmPrinter &printer) { 329 printer.getStream() << " custom_format_fallback"; 330 }; 331 } 332 return {}; 333 } 334 335 //===----------------------------------------------------------------------===// 336 // TestBranchOp 337 //===----------------------------------------------------------------------===// 338 339 SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) { 340 assert(index == 0 && "invalid successor index"); 341 return SuccessorOperands(getTargetOperandsMutable()); 342 } 343 344 //===----------------------------------------------------------------------===// 345 // TestProducingBranchOp 346 //===----------------------------------------------------------------------===// 347 348 SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) { 349 assert(index <= 1 && "invalid successor index"); 350 if (index == 1) 351 return SuccessorOperands(getFirstOperandsMutable()); 352 return SuccessorOperands(getSecondOperandsMutable()); 353 } 354 355 //===----------------------------------------------------------------------===// 356 // TestProducingBranchOp 357 //===----------------------------------------------------------------------===// 358 359 SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) { 360 assert(index <= 1 && "invalid successor index"); 361 if (index == 0) 362 return SuccessorOperands(0, getSuccessOperandsMutable()); 363 return SuccessorOperands(1, getErrorOperandsMutable()); 364 } 365 366 //===----------------------------------------------------------------------===// 367 // TestDialectCanonicalizerOp 368 //===----------------------------------------------------------------------===// 369 370 static LogicalResult 371 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op, 372 PatternRewriter &rewriter) { 373 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 374 op, rewriter.getI32IntegerAttr(42)); 375 return success(); 376 } 377 378 void TestDialect::getCanonicalizationPatterns( 379 RewritePatternSet &results) const { 380 results.add(&dialectCanonicalizationPattern); 381 } 382 383 //===----------------------------------------------------------------------===// 384 // TestCallOp 385 //===----------------------------------------------------------------------===// 386 387 LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 388 // Check that the callee attribute was specified. 389 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee"); 390 if (!fnAttr) 391 return emitOpError("requires a 'callee' symbol reference attribute"); 392 if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr)) 393 return emitOpError() << "'" << fnAttr.getValue() 394 << "' does not reference a valid function"; 395 return success(); 396 } 397 398 //===----------------------------------------------------------------------===// 399 // TestFoldToCallOp 400 //===----------------------------------------------------------------------===// 401 402 namespace { 403 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 404 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 405 406 LogicalResult matchAndRewrite(FoldToCallOp op, 407 PatternRewriter &rewriter) const override { 408 rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(), 409 op.getCalleeAttr(), ValueRange()); 410 return success(); 411 } 412 }; 413 } // namespace 414 415 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results, 416 MLIRContext *context) { 417 results.add<FoldToCallOpPattern>(context); 418 } 419 420 //===----------------------------------------------------------------------===// 421 // Test Format* operations 422 //===----------------------------------------------------------------------===// 423 424 //===----------------------------------------------------------------------===// 425 // Parsing 426 427 static ParseResult parseCustomOptionalOperand( 428 OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &optOperand) { 429 if (succeeded(parser.parseOptionalLParen())) { 430 optOperand.emplace(); 431 if (parser.parseOperand(*optOperand) || parser.parseRParen()) 432 return failure(); 433 } 434 return success(); 435 } 436 437 static ParseResult parseCustomDirectiveOperands( 438 OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, 439 Optional<OpAsmParser::UnresolvedOperand> &optOperand, 440 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands) { 441 if (parser.parseOperand(operand)) 442 return failure(); 443 if (succeeded(parser.parseOptionalComma())) { 444 optOperand.emplace(); 445 if (parser.parseOperand(*optOperand)) 446 return failure(); 447 } 448 if (parser.parseArrow() || parser.parseLParen() || 449 parser.parseOperandList(varOperands) || parser.parseRParen()) 450 return failure(); 451 return success(); 452 } 453 static ParseResult 454 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 455 Type &optOperandType, 456 SmallVectorImpl<Type> &varOperandTypes) { 457 if (parser.parseColon()) 458 return failure(); 459 460 if (parser.parseType(operandType)) 461 return failure(); 462 if (succeeded(parser.parseOptionalComma())) { 463 if (parser.parseType(optOperandType)) 464 return failure(); 465 } 466 if (parser.parseArrow() || parser.parseLParen() || 467 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 468 return failure(); 469 return success(); 470 } 471 static ParseResult 472 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 473 Type optOperandType, 474 const SmallVectorImpl<Type> &varOperandTypes) { 475 if (parser.parseKeyword("type_refs_capture")) 476 return failure(); 477 478 Type operandType2, optOperandType2; 479 SmallVector<Type, 1> varOperandTypes2; 480 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 481 varOperandTypes2)) 482 return failure(); 483 484 if (operandType != operandType2 || optOperandType != optOperandType2 || 485 varOperandTypes != varOperandTypes2) 486 return failure(); 487 488 return success(); 489 } 490 static ParseResult parseCustomDirectiveOperandsAndTypes( 491 OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, 492 Optional<OpAsmParser::UnresolvedOperand> &optOperand, 493 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands, 494 Type &operandType, Type &optOperandType, 495 SmallVectorImpl<Type> &varOperandTypes) { 496 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 497 parseCustomDirectiveResults(parser, operandType, optOperandType, 498 varOperandTypes)) 499 return failure(); 500 return success(); 501 } 502 static ParseResult parseCustomDirectiveRegions( 503 OpAsmParser &parser, Region ®ion, 504 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 505 if (parser.parseRegion(region)) 506 return failure(); 507 if (failed(parser.parseOptionalComma())) 508 return success(); 509 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 510 if (parser.parseRegion(*varRegion)) 511 return failure(); 512 varRegions.emplace_back(std::move(varRegion)); 513 return success(); 514 } 515 static ParseResult 516 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 517 SmallVectorImpl<Block *> &varSuccessors) { 518 if (parser.parseSuccessor(successor)) 519 return failure(); 520 if (failed(parser.parseOptionalComma())) 521 return success(); 522 Block *varSuccessor; 523 if (parser.parseSuccessor(varSuccessor)) 524 return failure(); 525 varSuccessors.append(2, varSuccessor); 526 return success(); 527 } 528 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 529 IntegerAttr &attr, 530 IntegerAttr &optAttr) { 531 if (parser.parseAttribute(attr)) 532 return failure(); 533 if (succeeded(parser.parseOptionalComma())) { 534 if (parser.parseAttribute(optAttr)) 535 return failure(); 536 } 537 return success(); 538 } 539 540 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, 541 NamedAttrList &attrs) { 542 return parser.parseOptionalAttrDict(attrs); 543 } 544 static ParseResult parseCustomDirectiveOptionalOperandRef( 545 OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &optOperand) { 546 int64_t operandCount = 0; 547 if (parser.parseInteger(operandCount)) 548 return failure(); 549 bool expectedOptionalOperand = operandCount == 0; 550 return success(expectedOptionalOperand != optOperand.hasValue()); 551 } 552 553 //===----------------------------------------------------------------------===// 554 // Printing 555 556 static void printCustomOptionalOperand(OpAsmPrinter &printer, Operation *, 557 Value optOperand) { 558 if (optOperand) 559 printer << "(" << optOperand << ") "; 560 } 561 562 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 563 Value operand, Value optOperand, 564 OperandRange varOperands) { 565 printer << operand; 566 if (optOperand) 567 printer << ", " << optOperand; 568 printer << " -> (" << varOperands << ")"; 569 } 570 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 571 Type operandType, Type optOperandType, 572 TypeRange varOperandTypes) { 573 printer << " : " << operandType; 574 if (optOperandType) 575 printer << ", " << optOperandType; 576 printer << " -> (" << varOperandTypes << ")"; 577 } 578 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 579 Operation *op, Type operandType, 580 Type optOperandType, 581 TypeRange varOperandTypes) { 582 printer << " type_refs_capture "; 583 printCustomDirectiveResults(printer, op, operandType, optOperandType, 584 varOperandTypes); 585 } 586 static void printCustomDirectiveOperandsAndTypes( 587 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 588 OperandRange varOperands, Type operandType, Type optOperandType, 589 TypeRange varOperandTypes) { 590 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 591 printCustomDirectiveResults(printer, op, operandType, optOperandType, 592 varOperandTypes); 593 } 594 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 595 Region ®ion, 596 MutableArrayRef<Region> varRegions) { 597 printer.printRegion(region); 598 if (!varRegions.empty()) { 599 printer << ", "; 600 for (Region ®ion : varRegions) 601 printer.printRegion(region); 602 } 603 } 604 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 605 Block *successor, 606 SuccessorRange varSuccessors) { 607 printer << successor; 608 if (!varSuccessors.empty()) 609 printer << ", " << varSuccessors.front(); 610 } 611 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 612 Attribute attribute, 613 Attribute optAttribute) { 614 printer << attribute; 615 if (optAttribute) 616 printer << ", " << optAttribute; 617 } 618 619 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 620 DictionaryAttr attrs) { 621 printer.printOptionalAttrDict(attrs.getValue()); 622 } 623 624 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, 625 Operation *op, 626 Value optOperand) { 627 printer << (optOperand ? "1" : "0"); 628 } 629 630 //===----------------------------------------------------------------------===// 631 // Test IsolatedRegionOp - parse passthrough region arguments. 632 //===----------------------------------------------------------------------===// 633 634 ParseResult IsolatedRegionOp::parse(OpAsmParser &parser, 635 OperationState &result) { 636 OpAsmParser::UnresolvedOperand argInfo; 637 Type argType = parser.getBuilder().getIndexType(); 638 639 // Parse the input operand. 640 if (parser.parseOperand(argInfo) || 641 parser.resolveOperand(argInfo, argType, result.operands)) 642 return failure(); 643 644 // Parse the body region, and reuse the operand info as the argument info. 645 Region *body = result.addRegion(); 646 return parser.parseRegion(*body, argInfo, argType, /*argLocations=*/{}, 647 /*enableNameShadowing=*/true); 648 } 649 650 void IsolatedRegionOp::print(OpAsmPrinter &p) { 651 p << "test.isolated_region "; 652 p.printOperand(getOperand()); 653 p.shadowRegionArgs(getRegion(), getOperand()); 654 p << ' '; 655 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); 656 } 657 658 //===----------------------------------------------------------------------===// 659 // Test SSACFGRegionOp 660 //===----------------------------------------------------------------------===// 661 662 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 663 return RegionKind::SSACFG; 664 } 665 666 //===----------------------------------------------------------------------===// 667 // Test GraphRegionOp 668 //===----------------------------------------------------------------------===// 669 670 ParseResult GraphRegionOp::parse(OpAsmParser &parser, OperationState &result) { 671 // Parse the body region, and reuse the operand info as the argument info. 672 Region *body = result.addRegion(); 673 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 674 } 675 676 void GraphRegionOp::print(OpAsmPrinter &p) { 677 p << "test.graph_region "; 678 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); 679 } 680 681 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 682 return RegionKind::Graph; 683 } 684 685 //===----------------------------------------------------------------------===// 686 // Test AffineScopeOp 687 //===----------------------------------------------------------------------===// 688 689 ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) { 690 // Parse the body region, and reuse the operand info as the argument info. 691 Region *body = result.addRegion(); 692 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 693 } 694 695 void AffineScopeOp::print(OpAsmPrinter &p) { 696 p << "test.affine_scope "; 697 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); 698 } 699 700 //===----------------------------------------------------------------------===// 701 // Test parser. 702 //===----------------------------------------------------------------------===// 703 704 ParseResult ParseIntegerLiteralOp::parse(OpAsmParser &parser, 705 OperationState &result) { 706 if (parser.parseOptionalColon()) 707 return success(); 708 uint64_t numResults; 709 if (parser.parseInteger(numResults)) 710 return failure(); 711 712 IndexType type = parser.getBuilder().getIndexType(); 713 for (unsigned i = 0; i < numResults; ++i) 714 result.addTypes(type); 715 return success(); 716 } 717 718 void ParseIntegerLiteralOp::print(OpAsmPrinter &p) { 719 if (unsigned numResults = getNumResults()) 720 p << " : " << numResults; 721 } 722 723 ParseResult ParseWrappedKeywordOp::parse(OpAsmParser &parser, 724 OperationState &result) { 725 StringRef keyword; 726 if (parser.parseKeyword(&keyword)) 727 return failure(); 728 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 729 return success(); 730 } 731 732 void ParseWrappedKeywordOp::print(OpAsmPrinter &p) { p << " " << getKeyword(); } 733 734 //===----------------------------------------------------------------------===// 735 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 736 737 ParseResult WrappingRegionOp::parse(OpAsmParser &parser, 738 OperationState &result) { 739 if (parser.parseKeyword("wraps")) 740 return failure(); 741 742 // Parse the wrapped op in a region 743 Region &body = *result.addRegion(); 744 body.push_back(new Block); 745 Block &block = body.back(); 746 Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin()); 747 if (!wrappedOp) 748 return failure(); 749 750 // Create a return terminator in the inner region, pass as operand to the 751 // terminator the returned values from the wrapped operation. 752 SmallVector<Value, 8> returnOperands(wrappedOp->getResults()); 753 OpBuilder builder(parser.getContext()); 754 builder.setInsertionPointToEnd(&block); 755 builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands); 756 757 // Get the results type for the wrapping op from the terminator operands. 758 Operation &returnOp = body.back().back(); 759 result.types.append(returnOp.operand_type_begin(), 760 returnOp.operand_type_end()); 761 762 // Use the location of the wrapped op for the "test.wrapping_region" op. 763 result.location = wrappedOp->getLoc(); 764 765 return success(); 766 } 767 768 void WrappingRegionOp::print(OpAsmPrinter &p) { 769 p << " wraps "; 770 p.printGenericOp(&getRegion().front().front()); 771 } 772 773 //===----------------------------------------------------------------------===// 774 // Test PrettyPrintedRegionOp - exercising the following parser APIs 775 // parseGenericOperationAfterOpName 776 // parseCustomOperationName 777 //===----------------------------------------------------------------------===// 778 779 ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser, 780 OperationState &result) { 781 782 SMLoc loc = parser.getCurrentLocation(); 783 Location currLocation = parser.getEncodedSourceLoc(loc); 784 785 // Parse the operands. 786 SmallVector<OpAsmParser::UnresolvedOperand, 2> operands; 787 if (parser.parseOperandList(operands)) 788 return failure(); 789 790 // Check if we are parsing the pretty-printed version 791 // test.pretty_printed_region start <inner-op> end : <functional-type> 792 // Else fallback to parsing the "non pretty-printed" version. 793 if (!succeeded(parser.parseOptionalKeyword("start"))) 794 return parser.parseGenericOperationAfterOpName( 795 result, llvm::makeArrayRef(operands)); 796 797 FailureOr<OperationName> parseOpNameInfo = parser.parseCustomOperationName(); 798 if (failed(parseOpNameInfo)) 799 return failure(); 800 801 StringAttr innerOpName = parseOpNameInfo->getIdentifier(); 802 803 FunctionType opFntype; 804 Optional<Location> explicitLoc; 805 if (parser.parseKeyword("end") || parser.parseColon() || 806 parser.parseType(opFntype) || 807 parser.parseOptionalLocationSpecifier(explicitLoc)) 808 return failure(); 809 810 // If location of the op is explicitly provided, then use it; Else use 811 // the parser's current location. 812 Location opLoc = explicitLoc.getValueOr(currLocation); 813 814 // Derive the SSA-values for op's operands. 815 if (parser.resolveOperands(operands, opFntype.getInputs(), loc, 816 result.operands)) 817 return failure(); 818 819 // Add a region for op. 820 Region ®ion = *result.addRegion(); 821 822 // Create a basic-block inside op's region. 823 Block &block = region.emplaceBlock(); 824 825 // Create and insert an "inner-op" operation in the block. 826 // Just for testing purposes, we can assume that inner op is a binary op with 827 // result and operand types all same as the test-op's first operand. 828 Type innerOpType = opFntype.getInput(0); 829 Value lhs = block.addArgument(innerOpType, opLoc); 830 Value rhs = block.addArgument(innerOpType, opLoc); 831 832 OpBuilder builder(parser.getBuilder().getContext()); 833 builder.setInsertionPointToStart(&block); 834 835 Operation *innerOp = 836 builder.create(opLoc, innerOpName, /*operands=*/{lhs, rhs}, innerOpType); 837 838 // Insert a return statement in the block returning the inner-op's result. 839 builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults()); 840 841 // Populate the op operation-state with result-type and location. 842 result.addTypes(opFntype.getResults()); 843 result.location = innerOp->getLoc(); 844 845 return success(); 846 } 847 848 void PrettyPrintedRegionOp::print(OpAsmPrinter &p) { 849 p << ' '; 850 p.printOperands(getOperands()); 851 852 Operation &innerOp = getRegion().front().front(); 853 // Assuming that region has a single non-terminator inner-op, if the inner-op 854 // meets some criteria (which in this case is a simple one based on the name 855 // of inner-op), then we can print the entire region in a succinct way. 856 // Here we assume that the prototype of "special.op" can be trivially derived 857 // while parsing it back. 858 if (innerOp.getName().getStringRef().equals("special.op")) { 859 p << " start special.op end"; 860 } else { 861 p << " ("; 862 p.printRegion(getRegion()); 863 p << ")"; 864 } 865 866 p << " : "; 867 p.printFunctionalType(*this); 868 } 869 870 //===----------------------------------------------------------------------===// 871 // Test PolyForOp - parse list of region arguments. 872 //===----------------------------------------------------------------------===// 873 874 ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) { 875 SmallVector<OpAsmParser::UnresolvedOperand, 4> ivsInfo; 876 // Parse list of region arguments without a delimiter. 877 if (parser.parseRegionArgumentList(ivsInfo)) 878 return failure(); 879 880 // Parse the body region. 881 Region *body = result.addRegion(); 882 auto &builder = parser.getBuilder(); 883 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 884 return parser.parseRegion(*body, ivsInfo, argTypes); 885 } 886 887 void PolyForOp::print(OpAsmPrinter &p) { p.printGenericOp(*this); } 888 889 void PolyForOp::getAsmBlockArgumentNames(Region ®ion, 890 OpAsmSetValueNameFn setNameFn) { 891 auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names"); 892 if (!arrayAttr) 893 return; 894 auto args = getRegion().front().getArguments(); 895 auto e = std::min(arrayAttr.size(), args.size()); 896 for (unsigned i = 0; i < e; ++i) { 897 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 898 setNameFn(args[i], strAttr.getValue()); 899 } 900 } 901 902 //===----------------------------------------------------------------------===// 903 // Test removing op with inner ops. 904 //===----------------------------------------------------------------------===// 905 906 namespace { 907 struct TestRemoveOpWithInnerOps 908 : public OpRewritePattern<TestOpWithRegionPattern> { 909 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 910 911 void initialize() { setDebugName("TestRemoveOpWithInnerOps"); } 912 913 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 914 PatternRewriter &rewriter) const override { 915 rewriter.eraseOp(op); 916 return success(); 917 } 918 }; 919 } // namespace 920 921 void TestOpWithRegionPattern::getCanonicalizationPatterns( 922 RewritePatternSet &results, MLIRContext *context) { 923 results.add<TestRemoveOpWithInnerOps>(context); 924 } 925 926 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 927 return getOperand(); 928 } 929 930 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) { 931 return getValue(); 932 } 933 934 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 935 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 936 for (Value input : this->getOperands()) { 937 results.push_back(input); 938 } 939 return success(); 940 } 941 942 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 943 assert(operands.size() == 1); 944 if (operands.front()) { 945 (*this)->setAttr("attr", operands.front()); 946 return getResult(); 947 } 948 return {}; 949 } 950 951 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) { 952 return getOperand(); 953 } 954 955 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 956 MLIRContext *, Optional<Location> location, ValueRange operands, 957 DictionaryAttr attributes, RegionRange regions, 958 SmallVectorImpl<Type> &inferredReturnTypes) { 959 if (operands[0].getType() != operands[1].getType()) { 960 return emitOptionalError(location, "operand type mismatch ", 961 operands[0].getType(), " vs ", 962 operands[1].getType()); 963 } 964 inferredReturnTypes.assign({operands[0].getType()}); 965 return success(); 966 } 967 968 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 969 MLIRContext *context, Optional<Location> location, ValueShapeRange operands, 970 DictionaryAttr attributes, RegionRange regions, 971 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 972 // Create return type consisting of the last element of the first operand. 973 auto operandType = operands.front().getType(); 974 auto sval = operandType.dyn_cast<ShapedType>(); 975 if (!sval) { 976 return emitOptionalError(location, "only shaped type operands allowed"); 977 } 978 int64_t dim = 979 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 980 auto type = IntegerType::get(context, 17); 981 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 982 return success(); 983 } 984 985 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 986 OpBuilder &builder, ValueRange operands, 987 llvm::SmallVectorImpl<Value> &shapes) { 988 shapes = SmallVector<Value, 1>{ 989 builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)}; 990 return success(); 991 } 992 993 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( 994 OpBuilder &builder, ValueRange operands, 995 llvm::SmallVectorImpl<Value> &shapes) { 996 Location loc = getLoc(); 997 shapes.reserve(operands.size()); 998 for (Value operand : llvm::reverse(operands)) { 999 auto rank = operand.getType().cast<RankedTensorType>().getRank(); 1000 auto currShape = llvm::to_vector<4>( 1001 llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value { 1002 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 1003 })); 1004 shapes.push_back(builder.create<tensor::FromElementsOp>( 1005 getLoc(), RankedTensorType::get({rank}, builder.getIndexType()), 1006 currShape)); 1007 } 1008 return success(); 1009 } 1010 1011 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( 1012 OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { 1013 Location loc = getLoc(); 1014 shapes.reserve(getNumOperands()); 1015 for (Value operand : llvm::reverse(getOperands())) { 1016 auto currShape = llvm::to_vector<4>(llvm::map_range( 1017 llvm::seq<int64_t>( 1018 0, operand.getType().cast<RankedTensorType>().getRank()), 1019 [&](int64_t dim) -> Value { 1020 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 1021 })); 1022 shapes.emplace_back(std::move(currShape)); 1023 } 1024 return success(); 1025 } 1026 1027 //===----------------------------------------------------------------------===// 1028 // Test SideEffect interfaces 1029 //===----------------------------------------------------------------------===// 1030 1031 namespace { 1032 /// A test resource for side effects. 1033 struct TestResource : public SideEffects::Resource::Base<TestResource> { 1034 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource) 1035 1036 StringRef getName() final { return "<Test>"; } 1037 }; 1038 } // namespace 1039 1040 static void testSideEffectOpGetEffect( 1041 Operation *op, 1042 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 1043 &effects) { 1044 auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter"); 1045 if (!effectsAttr) 1046 return; 1047 1048 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); 1049 } 1050 1051 void SideEffectOp::getEffects( 1052 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1053 // Check for an effects attribute on the op instance. 1054 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects"); 1055 if (!effectsAttr) 1056 return; 1057 1058 // If there is one, it is an array of dictionary attributes that hold 1059 // information on the effects of this operation. 1060 for (Attribute element : effectsAttr) { 1061 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 1062 1063 // Get the specific memory effect. 1064 MemoryEffects::Effect *effect = 1065 StringSwitch<MemoryEffects::Effect *>( 1066 effectElement.get("effect").cast<StringAttr>().getValue()) 1067 .Case("allocate", MemoryEffects::Allocate::get()) 1068 .Case("free", MemoryEffects::Free::get()) 1069 .Case("read", MemoryEffects::Read::get()) 1070 .Case("write", MemoryEffects::Write::get()); 1071 1072 // Check for a non-default resource to use. 1073 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 1074 if (effectElement.get("test_resource")) 1075 resource = TestResource::get(); 1076 1077 // Check for a result to affect. 1078 if (effectElement.get("on_result")) 1079 effects.emplace_back(effect, getResult(), resource); 1080 else if (Attribute ref = effectElement.get("on_reference")) 1081 effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource); 1082 else 1083 effects.emplace_back(effect, resource); 1084 } 1085 } 1086 1087 void SideEffectOp::getEffects( 1088 SmallVectorImpl<TestEffects::EffectInstance> &effects) { 1089 testSideEffectOpGetEffect(getOperation(), effects); 1090 } 1091 1092 //===----------------------------------------------------------------------===// 1093 // StringAttrPrettyNameOp 1094 //===----------------------------------------------------------------------===// 1095 1096 // This op has fancy handling of its SSA result name. 1097 ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser, 1098 OperationState &result) { 1099 // Add the result types. 1100 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 1101 result.addTypes(parser.getBuilder().getIntegerType(32)); 1102 1103 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 1104 return failure(); 1105 1106 // If the attribute dictionary contains no 'names' attribute, infer it from 1107 // the SSA name (if specified). 1108 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 1109 return attr.getName() == "names"; 1110 }); 1111 1112 // If there was no name specified, check to see if there was a useful name 1113 // specified in the asm file. 1114 if (hadNames || parser.getNumResults() == 0) 1115 return success(); 1116 1117 SmallVector<StringRef, 4> names; 1118 auto *context = result.getContext(); 1119 1120 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 1121 auto resultName = parser.getResultName(i); 1122 StringRef nameStr; 1123 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 1124 nameStr = resultName.first; 1125 1126 names.push_back(nameStr); 1127 } 1128 1129 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 1130 result.attributes.push_back({StringAttr::get(context, "names"), namesAttr}); 1131 return success(); 1132 } 1133 1134 void StringAttrPrettyNameOp::print(OpAsmPrinter &p) { 1135 // Note that we only need to print the "name" attribute if the asmprinter 1136 // result name disagrees with it. This can happen in strange cases, e.g. 1137 // when there are conflicts. 1138 bool namesDisagree = getNames().size() != getNumResults(); 1139 1140 SmallString<32> resultNameStr; 1141 for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) { 1142 resultNameStr.clear(); 1143 llvm::raw_svector_ostream tmpStream(resultNameStr); 1144 p.printOperand(getResult(i), tmpStream); 1145 1146 auto expectedName = getNames()[i].dyn_cast<StringAttr>(); 1147 if (!expectedName || 1148 tmpStream.str().drop_front() != expectedName.getValue()) { 1149 namesDisagree = true; 1150 } 1151 } 1152 1153 if (namesDisagree) 1154 p.printOptionalAttrDictWithKeyword((*this)->getAttrs()); 1155 else 1156 p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"}); 1157 } 1158 1159 // We set the SSA name in the asm syntax to the contents of the name 1160 // attribute. 1161 void StringAttrPrettyNameOp::getAsmResultNames( 1162 function_ref<void(Value, StringRef)> setNameFn) { 1163 1164 auto value = getNames(); 1165 for (size_t i = 0, e = value.size(); i != e; ++i) 1166 if (auto str = value[i].dyn_cast<StringAttr>()) 1167 if (!str.getValue().empty()) 1168 setNameFn(getResult(i), str.getValue()); 1169 } 1170 1171 //===----------------------------------------------------------------------===// 1172 // ResultTypeWithTraitOp 1173 //===----------------------------------------------------------------------===// 1174 1175 LogicalResult ResultTypeWithTraitOp::verify() { 1176 if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>()) 1177 return success(); 1178 return emitError("result type should have trait 'TestTypeTrait'"); 1179 } 1180 1181 //===----------------------------------------------------------------------===// 1182 // AttrWithTraitOp 1183 //===----------------------------------------------------------------------===// 1184 1185 LogicalResult AttrWithTraitOp::verify() { 1186 if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>()) 1187 return success(); 1188 return emitError("'attr' attribute should have trait 'TestAttrTrait'"); 1189 } 1190 1191 //===----------------------------------------------------------------------===// 1192 // RegionIfOp 1193 //===----------------------------------------------------------------------===// 1194 1195 void RegionIfOp::print(OpAsmPrinter &p) { 1196 p << " "; 1197 p.printOperands(getOperands()); 1198 p << ": " << getOperandTypes(); 1199 p.printArrowTypeList(getResultTypes()); 1200 p << " then "; 1201 p.printRegion(getThenRegion(), 1202 /*printEntryBlockArgs=*/true, 1203 /*printBlockTerminators=*/true); 1204 p << " else "; 1205 p.printRegion(getElseRegion(), 1206 /*printEntryBlockArgs=*/true, 1207 /*printBlockTerminators=*/true); 1208 p << " join "; 1209 p.printRegion(getJoinRegion(), 1210 /*printEntryBlockArgs=*/true, 1211 /*printBlockTerminators=*/true); 1212 } 1213 1214 ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) { 1215 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos; 1216 SmallVector<Type, 2> operandTypes; 1217 1218 result.regions.reserve(3); 1219 Region *thenRegion = result.addRegion(); 1220 Region *elseRegion = result.addRegion(); 1221 Region *joinRegion = result.addRegion(); 1222 1223 // Parse operand, type and arrow type lists. 1224 if (parser.parseOperandList(operandInfos) || 1225 parser.parseColonTypeList(operandTypes) || 1226 parser.parseArrowTypeList(result.types)) 1227 return failure(); 1228 1229 // Parse all attached regions. 1230 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 1231 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 1232 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 1233 return failure(); 1234 1235 return parser.resolveOperands(operandInfos, operandTypes, 1236 parser.getCurrentLocation(), result.operands); 1237 } 1238 1239 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 1240 assert(index < 2 && "invalid region index"); 1241 return getOperands(); 1242 } 1243 1244 void RegionIfOp::getSuccessorRegions( 1245 Optional<unsigned> index, ArrayRef<Attribute> operands, 1246 SmallVectorImpl<RegionSuccessor> ®ions) { 1247 // We always branch to the join region. 1248 if (index.hasValue()) { 1249 if (index.getValue() < 2) 1250 regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs())); 1251 else 1252 regions.push_back(RegionSuccessor(getResults())); 1253 return; 1254 } 1255 1256 // The then and else regions are the entry regions of this op. 1257 regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs())); 1258 regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs())); 1259 } 1260 1261 void RegionIfOp::getRegionInvocationBounds( 1262 ArrayRef<Attribute> operands, 1263 SmallVectorImpl<InvocationBounds> &invocationBounds) { 1264 // Each region is invoked at most once. 1265 invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1}); 1266 } 1267 1268 //===----------------------------------------------------------------------===// 1269 // AnyCondOp 1270 //===----------------------------------------------------------------------===// 1271 1272 void AnyCondOp::getSuccessorRegions(Optional<unsigned> index, 1273 ArrayRef<Attribute> operands, 1274 SmallVectorImpl<RegionSuccessor> ®ions) { 1275 // The parent op branches into the only region, and the region branches back 1276 // to the parent op. 1277 if (index) 1278 regions.emplace_back(&getRegion()); 1279 else 1280 regions.emplace_back(getResults()); 1281 } 1282 1283 void AnyCondOp::getRegionInvocationBounds( 1284 ArrayRef<Attribute> operands, 1285 SmallVectorImpl<InvocationBounds> &invocationBounds) { 1286 invocationBounds.emplace_back(1, 1); 1287 } 1288 1289 //===----------------------------------------------------------------------===// 1290 // SingleNoTerminatorCustomAsmOp 1291 //===----------------------------------------------------------------------===// 1292 1293 ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser, 1294 OperationState &state) { 1295 Region *body = state.addRegion(); 1296 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) 1297 return failure(); 1298 return success(); 1299 } 1300 1301 void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) { 1302 printer.printRegion( 1303 getRegion(), /*printEntryBlockArgs=*/false, 1304 // This op has a single block without terminators. But explicitly mark 1305 // as not printing block terminators for testing. 1306 /*printBlockTerminators=*/false); 1307 } 1308 1309 //===----------------------------------------------------------------------===// 1310 // TestVerifiersOp 1311 //===----------------------------------------------------------------------===// 1312 1313 LogicalResult TestVerifiersOp::verify() { 1314 if (!getRegion().hasOneBlock()) 1315 return emitOpError("`hasOneBlock` trait hasn't been verified"); 1316 1317 Operation *definingOp = getInput().getDefiningOp(); 1318 if (definingOp && failed(mlir::verify(definingOp))) 1319 return emitOpError("operand hasn't been verified"); 1320 1321 emitRemark("success run of verifier"); 1322 1323 return success(); 1324 } 1325 1326 LogicalResult TestVerifiersOp::verifyRegions() { 1327 if (!getRegion().hasOneBlock()) 1328 return emitOpError("`hasOneBlock` trait hasn't been verified"); 1329 1330 for (Block &block : getRegion()) 1331 for (Operation &op : block) 1332 if (failed(mlir::verify(&op))) 1333 return emitOpError("nested op hasn't been verified"); 1334 1335 emitRemark("success run of region verifier"); 1336 1337 return success(); 1338 } 1339 1340 #include "TestOpEnums.cpp.inc" 1341 #include "TestOpInterfaces.cpp.inc" 1342 #include "TestOpStructs.cpp.inc" 1343 #include "TestTypeInterfaces.cpp.inc" 1344 1345 #define GET_OP_CLASSES 1346 #include "TestOps.cpp.inc" 1347