1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestAttributes.h" 11 #include "TestInterfaces.h" 12 #include "TestTypes.h" 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/DLTI/DLTI.h" 15 #include "mlir/Dialect/Func/IR/FuncOps.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/DialectImplementation.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/IR/TypeUtilities.h" 21 #include "mlir/IR/Verifier.h" 22 #include "mlir/Reducer/ReductionPatternInterface.h" 23 #include "mlir/Transforms/FoldUtils.h" 24 #include "mlir/Transforms/InliningUtils.h" 25 #include "llvm/ADT/StringExtras.h" 26 #include "llvm/ADT/StringSwitch.h" 27 28 // Include this before the using namespace lines below to 29 // test that we don't have namespace dependencies. 30 #include "TestOpsDialect.cpp.inc" 31 32 using namespace mlir; 33 using namespace test; 34 35 void test::registerTestDialect(DialectRegistry ®istry) { 36 registry.insert<TestDialect>(); 37 } 38 39 //===----------------------------------------------------------------------===// 40 // TestDialect Interfaces 41 //===----------------------------------------------------------------------===// 42 43 namespace { 44 45 /// Testing the correctness of some traits. 46 static_assert( 47 llvm::is_detected<OpTrait::has_implicit_terminator_t, 48 SingleBlockImplicitTerminatorOp>::value, 49 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp"); 50 static_assert(OpTrait::hasSingleBlockImplicitTerminator< 51 SingleBlockImplicitTerminatorOp>::value, 52 "hasSingleBlockImplicitTerminator does not match " 53 "SingleBlockImplicitTerminatorOp"); 54 55 // Test support for interacting with the AsmPrinter. 56 struct TestOpAsmInterface : public OpAsmDialectInterface { 57 using OpAsmDialectInterface::OpAsmDialectInterface; 58 59 AliasResult getAlias(Attribute attr, raw_ostream &os) const final { 60 StringAttr strAttr = attr.dyn_cast<StringAttr>(); 61 if (!strAttr) 62 return AliasResult::NoAlias; 63 64 // Check the contents of the string attribute to see what the test alias 65 // should be named. 66 Optional<StringRef> aliasName = 67 StringSwitch<Optional<StringRef>>(strAttr.getValue()) 68 .Case("alias_test:dot_in_name", StringRef("test.alias")) 69 .Case("alias_test:trailing_digit", StringRef("test_alias0")) 70 .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) 71 .Case("alias_test:sanitize_conflict_a", 72 StringRef("test_alias_conflict0")) 73 .Case("alias_test:sanitize_conflict_b", 74 StringRef("test_alias_conflict0_")) 75 .Case("alias_test:tensor_encoding", StringRef("test_encoding")) 76 .Default(llvm::None); 77 if (!aliasName) 78 return AliasResult::NoAlias; 79 80 os << *aliasName; 81 return AliasResult::FinalAlias; 82 } 83 84 AliasResult getAlias(Type type, raw_ostream &os) const final { 85 if (auto tupleType = type.dyn_cast<TupleType>()) { 86 if (tupleType.size() > 0 && 87 llvm::all_of(tupleType.getTypes(), [](Type elemType) { 88 return elemType.isa<SimpleAType>(); 89 })) { 90 os << "test_tuple"; 91 return AliasResult::FinalAlias; 92 } 93 } 94 if (auto intType = type.dyn_cast<TestIntegerType>()) { 95 if (intType.getSignedness() == 96 TestIntegerType::SignednessSemantics::Unsigned && 97 intType.getWidth() == 8) { 98 os << "test_ui8"; 99 return AliasResult::FinalAlias; 100 } 101 } 102 return AliasResult::NoAlias; 103 } 104 }; 105 106 struct TestDialectFoldInterface : public DialectFoldInterface { 107 using DialectFoldInterface::DialectFoldInterface; 108 109 /// Registered hook to check if the given region, which is attached to an 110 /// operation that is *not* isolated from above, should be used when 111 /// materializing constants. 112 bool shouldMaterializeInto(Region *region) const final { 113 // If this is a one region operation, then insert into it. 114 return isa<OneRegionOp>(region->getParentOp()); 115 } 116 }; 117 118 /// This class defines the interface for handling inlining with standard 119 /// operations. 120 struct TestInlinerInterface : public DialectInlinerInterface { 121 using DialectInlinerInterface::DialectInlinerInterface; 122 123 //===--------------------------------------------------------------------===// 124 // Analysis Hooks 125 //===--------------------------------------------------------------------===// 126 127 bool isLegalToInline(Operation *call, Operation *callable, 128 bool wouldBeCloned) const final { 129 // Don't allow inlining calls that are marked `noinline`. 130 return !call->hasAttr("noinline"); 131 } 132 bool isLegalToInline(Region *, Region *, bool, 133 BlockAndValueMapping &) const final { 134 // Inlining into test dialect regions is legal. 135 return true; 136 } 137 bool isLegalToInline(Operation *, Region *, bool, 138 BlockAndValueMapping &) const final { 139 return true; 140 } 141 142 bool shouldAnalyzeRecursively(Operation *op) const final { 143 // Analyze recursively if this is not a functional region operation, it 144 // froms a separate functional scope. 145 return !isa<FunctionalRegionOp>(op); 146 } 147 148 //===--------------------------------------------------------------------===// 149 // Transformation Hooks 150 //===--------------------------------------------------------------------===// 151 152 /// Handle the given inlined terminator by replacing it with a new operation 153 /// as necessary. 154 void handleTerminator(Operation *op, 155 ArrayRef<Value> valuesToRepl) const final { 156 // Only handle "test.return" here. 157 auto returnOp = dyn_cast<TestReturnOp>(op); 158 if (!returnOp) 159 return; 160 161 // Replace the values directly with the return operands. 162 assert(returnOp.getNumOperands() == valuesToRepl.size()); 163 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 164 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 165 } 166 167 /// Attempt to materialize a conversion for a type mismatch between a call 168 /// from this dialect, and a callable region. This method should generate an 169 /// operation that takes 'input' as the only operand, and produces a single 170 /// result of 'resultType'. If a conversion can not be generated, nullptr 171 /// should be returned. 172 Operation *materializeCallConversion(OpBuilder &builder, Value input, 173 Type resultType, 174 Location conversionLoc) const final { 175 // Only allow conversion for i16/i32 types. 176 if (!(resultType.isSignlessInteger(16) || 177 resultType.isSignlessInteger(32)) || 178 !(input.getType().isSignlessInteger(16) || 179 input.getType().isSignlessInteger(32))) 180 return nullptr; 181 return builder.create<TestCastOp>(conversionLoc, resultType, input); 182 } 183 184 void processInlinedCallBlocks( 185 Operation *call, 186 iterator_range<Region::iterator> inlinedBlocks) const final { 187 if (!isa<ConversionCallOp>(call)) 188 return; 189 190 // Set attributed on all ops in the inlined blocks. 191 for (Block &block : inlinedBlocks) { 192 block.walk([&](Operation *op) { 193 op->setAttr("inlined_conversion", UnitAttr::get(call->getContext())); 194 }); 195 } 196 } 197 }; 198 199 struct TestReductionPatternInterface : public DialectReductionPatternInterface { 200 public: 201 TestReductionPatternInterface(Dialect *dialect) 202 : DialectReductionPatternInterface(dialect) {} 203 204 void populateReductionPatterns(RewritePatternSet &patterns) const final { 205 populateTestReductionPatterns(patterns); 206 } 207 }; 208 209 } // namespace 210 211 //===----------------------------------------------------------------------===// 212 // TestDialect 213 //===----------------------------------------------------------------------===// 214 215 static void testSideEffectOpGetEffect( 216 Operation *op, 217 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects); 218 219 // This is the implementation of a dialect fallback for `TestEffectOpInterface`. 220 struct TestOpEffectInterfaceFallback 221 : public TestEffectOpInterface::FallbackModel< 222 TestOpEffectInterfaceFallback> { 223 static bool classof(Operation *op) { 224 bool isSupportedOp = 225 op->getName().getStringRef() == "test.unregistered_side_effect_op"; 226 assert(isSupportedOp && "Unexpected dispatch"); 227 return isSupportedOp; 228 } 229 230 void 231 getEffects(Operation *op, 232 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 233 &effects) const { 234 testSideEffectOpGetEffect(op, effects); 235 } 236 }; 237 238 void TestDialect::initialize() { 239 registerAttributes(); 240 registerTypes(); 241 addOperations< 242 #define GET_OP_LIST 243 #include "TestOps.cpp.inc" 244 >(); 245 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 246 TestInlinerInterface, TestReductionPatternInterface>(); 247 allowUnknownOperations(); 248 249 // Instantiate our fallback op interface that we'll use on specific 250 // unregistered op. 251 fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback; 252 } 253 TestDialect::~TestDialect() { 254 delete static_cast<TestOpEffectInterfaceFallback *>( 255 fallbackEffectOpInterfaces); 256 } 257 258 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, 259 Type type, Location loc) { 260 return builder.create<TestOpConstant>(loc, type, value); 261 } 262 263 ::mlir::LogicalResult FormatInferType2Op::inferReturnTypes( 264 ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, 265 ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, 266 ::mlir::RegionRange regions, 267 ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { 268 inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)}); 269 return ::mlir::success(); 270 } 271 272 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, 273 OperationName opName) { 274 if (opName.getIdentifier() == "test.unregistered_side_effect_op" && 275 typeID == TypeID::get<TestEffectOpInterface>()) 276 return fallbackEffectOpInterfaces; 277 return nullptr; 278 } 279 280 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 281 NamedAttribute namedAttr) { 282 if (namedAttr.getName() == "test.invalid_attr") 283 return op->emitError() << "invalid to use 'test.invalid_attr'"; 284 return success(); 285 } 286 287 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 288 unsigned regionIndex, 289 unsigned argIndex, 290 NamedAttribute namedAttr) { 291 if (namedAttr.getName() == "test.invalid_attr") 292 return op->emitError() << "invalid to use 'test.invalid_attr'"; 293 return success(); 294 } 295 296 LogicalResult 297 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 298 unsigned resultIndex, 299 NamedAttribute namedAttr) { 300 if (namedAttr.getName() == "test.invalid_attr") 301 return op->emitError() << "invalid to use 'test.invalid_attr'"; 302 return success(); 303 } 304 305 Optional<Dialect::ParseOpHook> 306 TestDialect::getParseOperationHook(StringRef opName) const { 307 if (opName == "test.dialect_custom_printer") { 308 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 309 return parser.parseKeyword("custom_format"); 310 }}; 311 } 312 if (opName == "test.dialect_custom_format_fallback") { 313 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 314 return parser.parseKeyword("custom_format_fallback"); 315 }}; 316 } 317 return None; 318 } 319 320 llvm::unique_function<void(Operation *, OpAsmPrinter &)> 321 TestDialect::getOperationPrinter(Operation *op) const { 322 StringRef opName = op->getName().getStringRef(); 323 if (opName == "test.dialect_custom_printer") { 324 return [](Operation *op, OpAsmPrinter &printer) { 325 printer.getStream() << " custom_format"; 326 }; 327 } 328 if (opName == "test.dialect_custom_format_fallback") { 329 return [](Operation *op, OpAsmPrinter &printer) { 330 printer.getStream() << " custom_format_fallback"; 331 }; 332 } 333 return {}; 334 } 335 336 //===----------------------------------------------------------------------===// 337 // TestBranchOp 338 //===----------------------------------------------------------------------===// 339 340 SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) { 341 assert(index == 0 && "invalid successor index"); 342 return SuccessorOperands(getTargetOperandsMutable()); 343 } 344 345 //===----------------------------------------------------------------------===// 346 // TestProducingBranchOp 347 //===----------------------------------------------------------------------===// 348 349 SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) { 350 assert(index <= 1 && "invalid successor index"); 351 if (index == 1) 352 return SuccessorOperands(getFirstOperandsMutable()); 353 return SuccessorOperands(getSecondOperandsMutable()); 354 } 355 356 //===----------------------------------------------------------------------===// 357 // TestProducingBranchOp 358 //===----------------------------------------------------------------------===// 359 360 SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) { 361 assert(index <= 1 && "invalid successor index"); 362 if (index == 0) 363 return SuccessorOperands(0, getSuccessOperandsMutable()); 364 return SuccessorOperands(1, getErrorOperandsMutable()); 365 } 366 367 //===----------------------------------------------------------------------===// 368 // TestDialectCanonicalizerOp 369 //===----------------------------------------------------------------------===// 370 371 static LogicalResult 372 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op, 373 PatternRewriter &rewriter) { 374 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 375 op, rewriter.getI32IntegerAttr(42)); 376 return success(); 377 } 378 379 void TestDialect::getCanonicalizationPatterns( 380 RewritePatternSet &results) const { 381 results.add(&dialectCanonicalizationPattern); 382 } 383 384 //===----------------------------------------------------------------------===// 385 // TestCallOp 386 //===----------------------------------------------------------------------===// 387 388 LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 389 // Check that the callee attribute was specified. 390 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee"); 391 if (!fnAttr) 392 return emitOpError("requires a 'callee' symbol reference attribute"); 393 if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr)) 394 return emitOpError() << "'" << fnAttr.getValue() 395 << "' does not reference a valid function"; 396 return success(); 397 } 398 399 //===----------------------------------------------------------------------===// 400 // TestFoldToCallOp 401 //===----------------------------------------------------------------------===// 402 403 namespace { 404 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 405 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 406 407 LogicalResult matchAndRewrite(FoldToCallOp op, 408 PatternRewriter &rewriter) const override { 409 rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(), 410 op.getCalleeAttr(), ValueRange()); 411 return success(); 412 } 413 }; 414 } // namespace 415 416 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results, 417 MLIRContext *context) { 418 results.add<FoldToCallOpPattern>(context); 419 } 420 421 //===----------------------------------------------------------------------===// 422 // Test Format* operations 423 //===----------------------------------------------------------------------===// 424 425 //===----------------------------------------------------------------------===// 426 // Parsing 427 428 static ParseResult parseCustomOptionalOperand( 429 OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &optOperand) { 430 if (succeeded(parser.parseOptionalLParen())) { 431 optOperand.emplace(); 432 if (parser.parseOperand(*optOperand) || parser.parseRParen()) 433 return failure(); 434 } 435 return success(); 436 } 437 438 static ParseResult parseCustomDirectiveOperands( 439 OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, 440 Optional<OpAsmParser::UnresolvedOperand> &optOperand, 441 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands) { 442 if (parser.parseOperand(operand)) 443 return failure(); 444 if (succeeded(parser.parseOptionalComma())) { 445 optOperand.emplace(); 446 if (parser.parseOperand(*optOperand)) 447 return failure(); 448 } 449 if (parser.parseArrow() || parser.parseLParen() || 450 parser.parseOperandList(varOperands) || parser.parseRParen()) 451 return failure(); 452 return success(); 453 } 454 static ParseResult 455 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 456 Type &optOperandType, 457 SmallVectorImpl<Type> &varOperandTypes) { 458 if (parser.parseColon()) 459 return failure(); 460 461 if (parser.parseType(operandType)) 462 return failure(); 463 if (succeeded(parser.parseOptionalComma())) { 464 if (parser.parseType(optOperandType)) 465 return failure(); 466 } 467 if (parser.parseArrow() || parser.parseLParen() || 468 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 469 return failure(); 470 return success(); 471 } 472 static ParseResult 473 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 474 Type optOperandType, 475 const SmallVectorImpl<Type> &varOperandTypes) { 476 if (parser.parseKeyword("type_refs_capture")) 477 return failure(); 478 479 Type operandType2, optOperandType2; 480 SmallVector<Type, 1> varOperandTypes2; 481 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 482 varOperandTypes2)) 483 return failure(); 484 485 if (operandType != operandType2 || optOperandType != optOperandType2 || 486 varOperandTypes != varOperandTypes2) 487 return failure(); 488 489 return success(); 490 } 491 static ParseResult parseCustomDirectiveOperandsAndTypes( 492 OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, 493 Optional<OpAsmParser::UnresolvedOperand> &optOperand, 494 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands, 495 Type &operandType, Type &optOperandType, 496 SmallVectorImpl<Type> &varOperandTypes) { 497 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 498 parseCustomDirectiveResults(parser, operandType, optOperandType, 499 varOperandTypes)) 500 return failure(); 501 return success(); 502 } 503 static ParseResult parseCustomDirectiveRegions( 504 OpAsmParser &parser, Region ®ion, 505 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 506 if (parser.parseRegion(region)) 507 return failure(); 508 if (failed(parser.parseOptionalComma())) 509 return success(); 510 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 511 if (parser.parseRegion(*varRegion)) 512 return failure(); 513 varRegions.emplace_back(std::move(varRegion)); 514 return success(); 515 } 516 static ParseResult 517 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 518 SmallVectorImpl<Block *> &varSuccessors) { 519 if (parser.parseSuccessor(successor)) 520 return failure(); 521 if (failed(parser.parseOptionalComma())) 522 return success(); 523 Block *varSuccessor; 524 if (parser.parseSuccessor(varSuccessor)) 525 return failure(); 526 varSuccessors.append(2, varSuccessor); 527 return success(); 528 } 529 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 530 IntegerAttr &attr, 531 IntegerAttr &optAttr) { 532 if (parser.parseAttribute(attr)) 533 return failure(); 534 if (succeeded(parser.parseOptionalComma())) { 535 if (parser.parseAttribute(optAttr)) 536 return failure(); 537 } 538 return success(); 539 } 540 541 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, 542 NamedAttrList &attrs) { 543 return parser.parseOptionalAttrDict(attrs); 544 } 545 static ParseResult parseCustomDirectiveOptionalOperandRef( 546 OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &optOperand) { 547 int64_t operandCount = 0; 548 if (parser.parseInteger(operandCount)) 549 return failure(); 550 bool expectedOptionalOperand = operandCount == 0; 551 return success(expectedOptionalOperand != optOperand.hasValue()); 552 } 553 554 //===----------------------------------------------------------------------===// 555 // Printing 556 557 static void printCustomOptionalOperand(OpAsmPrinter &printer, Operation *, 558 Value optOperand) { 559 if (optOperand) 560 printer << "(" << optOperand << ") "; 561 } 562 563 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 564 Value operand, Value optOperand, 565 OperandRange varOperands) { 566 printer << operand; 567 if (optOperand) 568 printer << ", " << optOperand; 569 printer << " -> (" << varOperands << ")"; 570 } 571 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 572 Type operandType, Type optOperandType, 573 TypeRange varOperandTypes) { 574 printer << " : " << operandType; 575 if (optOperandType) 576 printer << ", " << optOperandType; 577 printer << " -> (" << varOperandTypes << ")"; 578 } 579 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 580 Operation *op, Type operandType, 581 Type optOperandType, 582 TypeRange varOperandTypes) { 583 printer << " type_refs_capture "; 584 printCustomDirectiveResults(printer, op, operandType, optOperandType, 585 varOperandTypes); 586 } 587 static void printCustomDirectiveOperandsAndTypes( 588 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 589 OperandRange varOperands, Type operandType, Type optOperandType, 590 TypeRange varOperandTypes) { 591 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 592 printCustomDirectiveResults(printer, op, operandType, optOperandType, 593 varOperandTypes); 594 } 595 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 596 Region ®ion, 597 MutableArrayRef<Region> varRegions) { 598 printer.printRegion(region); 599 if (!varRegions.empty()) { 600 printer << ", "; 601 for (Region ®ion : varRegions) 602 printer.printRegion(region); 603 } 604 } 605 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 606 Block *successor, 607 SuccessorRange varSuccessors) { 608 printer << successor; 609 if (!varSuccessors.empty()) 610 printer << ", " << varSuccessors.front(); 611 } 612 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 613 Attribute attribute, 614 Attribute optAttribute) { 615 printer << attribute; 616 if (optAttribute) 617 printer << ", " << optAttribute; 618 } 619 620 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 621 DictionaryAttr attrs) { 622 printer.printOptionalAttrDict(attrs.getValue()); 623 } 624 625 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, 626 Operation *op, 627 Value optOperand) { 628 printer << (optOperand ? "1" : "0"); 629 } 630 631 //===----------------------------------------------------------------------===// 632 // Test IsolatedRegionOp - parse passthrough region arguments. 633 //===----------------------------------------------------------------------===// 634 635 ParseResult IsolatedRegionOp::parse(OpAsmParser &parser, 636 OperationState &result) { 637 OpAsmParser::UnresolvedOperand argInfo; 638 Type argType = parser.getBuilder().getIndexType(); 639 640 // Parse the input operand. 641 if (parser.parseOperand(argInfo) || 642 parser.resolveOperand(argInfo, argType, result.operands)) 643 return failure(); 644 645 // Parse the body region, and reuse the operand info as the argument info. 646 Region *body = result.addRegion(); 647 return parser.parseRegion(*body, argInfo, argType, 648 /*enableNameShadowing=*/true); 649 } 650 651 void IsolatedRegionOp::print(OpAsmPrinter &p) { 652 p << "test.isolated_region "; 653 p.printOperand(getOperand()); 654 p.shadowRegionArgs(getRegion(), getOperand()); 655 p << ' '; 656 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); 657 } 658 659 //===----------------------------------------------------------------------===// 660 // Test SSACFGRegionOp 661 //===----------------------------------------------------------------------===// 662 663 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 664 return RegionKind::SSACFG; 665 } 666 667 //===----------------------------------------------------------------------===// 668 // Test GraphRegionOp 669 //===----------------------------------------------------------------------===// 670 671 ParseResult GraphRegionOp::parse(OpAsmParser &parser, OperationState &result) { 672 // Parse the body region, and reuse the operand info as the argument info. 673 Region *body = result.addRegion(); 674 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 675 } 676 677 void GraphRegionOp::print(OpAsmPrinter &p) { 678 p << "test.graph_region "; 679 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); 680 } 681 682 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 683 return RegionKind::Graph; 684 } 685 686 //===----------------------------------------------------------------------===// 687 // Test AffineScopeOp 688 //===----------------------------------------------------------------------===// 689 690 ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) { 691 // Parse the body region, and reuse the operand info as the argument info. 692 Region *body = result.addRegion(); 693 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 694 } 695 696 void AffineScopeOp::print(OpAsmPrinter &p) { 697 p << "test.affine_scope "; 698 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); 699 } 700 701 //===----------------------------------------------------------------------===// 702 // Test parser. 703 //===----------------------------------------------------------------------===// 704 705 ParseResult ParseIntegerLiteralOp::parse(OpAsmParser &parser, 706 OperationState &result) { 707 if (parser.parseOptionalColon()) 708 return success(); 709 uint64_t numResults; 710 if (parser.parseInteger(numResults)) 711 return failure(); 712 713 IndexType type = parser.getBuilder().getIndexType(); 714 for (unsigned i = 0; i < numResults; ++i) 715 result.addTypes(type); 716 return success(); 717 } 718 719 void ParseIntegerLiteralOp::print(OpAsmPrinter &p) { 720 if (unsigned numResults = getNumResults()) 721 p << " : " << numResults; 722 } 723 724 ParseResult ParseWrappedKeywordOp::parse(OpAsmParser &parser, 725 OperationState &result) { 726 StringRef keyword; 727 if (parser.parseKeyword(&keyword)) 728 return failure(); 729 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 730 return success(); 731 } 732 733 void ParseWrappedKeywordOp::print(OpAsmPrinter &p) { p << " " << getKeyword(); } 734 735 //===----------------------------------------------------------------------===// 736 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 737 738 ParseResult WrappingRegionOp::parse(OpAsmParser &parser, 739 OperationState &result) { 740 if (parser.parseKeyword("wraps")) 741 return failure(); 742 743 // Parse the wrapped op in a region 744 Region &body = *result.addRegion(); 745 body.push_back(new Block); 746 Block &block = body.back(); 747 Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin()); 748 if (!wrappedOp) 749 return failure(); 750 751 // Create a return terminator in the inner region, pass as operand to the 752 // terminator the returned values from the wrapped operation. 753 SmallVector<Value, 8> returnOperands(wrappedOp->getResults()); 754 OpBuilder builder(parser.getContext()); 755 builder.setInsertionPointToEnd(&block); 756 builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands); 757 758 // Get the results type for the wrapping op from the terminator operands. 759 Operation &returnOp = body.back().back(); 760 result.types.append(returnOp.operand_type_begin(), 761 returnOp.operand_type_end()); 762 763 // Use the location of the wrapped op for the "test.wrapping_region" op. 764 result.location = wrappedOp->getLoc(); 765 766 return success(); 767 } 768 769 void WrappingRegionOp::print(OpAsmPrinter &p) { 770 p << " wraps "; 771 p.printGenericOp(&getRegion().front().front()); 772 } 773 774 //===----------------------------------------------------------------------===// 775 // Test PrettyPrintedRegionOp - exercising the following parser APIs 776 // parseGenericOperationAfterOpName 777 // parseCustomOperationName 778 //===----------------------------------------------------------------------===// 779 780 ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser, 781 OperationState &result) { 782 783 SMLoc loc = parser.getCurrentLocation(); 784 Location currLocation = parser.getEncodedSourceLoc(loc); 785 786 // Parse the operands. 787 SmallVector<OpAsmParser::UnresolvedOperand, 2> operands; 788 if (parser.parseOperandList(operands)) 789 return failure(); 790 791 // Check if we are parsing the pretty-printed version 792 // test.pretty_printed_region start <inner-op> end : <functional-type> 793 // Else fallback to parsing the "non pretty-printed" version. 794 if (!succeeded(parser.parseOptionalKeyword("start"))) 795 return parser.parseGenericOperationAfterOpName( 796 result, llvm::makeArrayRef(operands)); 797 798 FailureOr<OperationName> parseOpNameInfo = parser.parseCustomOperationName(); 799 if (failed(parseOpNameInfo)) 800 return failure(); 801 802 StringAttr innerOpName = parseOpNameInfo->getIdentifier(); 803 804 FunctionType opFntype; 805 Optional<Location> explicitLoc; 806 if (parser.parseKeyword("end") || parser.parseColon() || 807 parser.parseType(opFntype) || 808 parser.parseOptionalLocationSpecifier(explicitLoc)) 809 return failure(); 810 811 // If location of the op is explicitly provided, then use it; Else use 812 // the parser's current location. 813 Location opLoc = explicitLoc.getValueOr(currLocation); 814 815 // Derive the SSA-values for op's operands. 816 if (parser.resolveOperands(operands, opFntype.getInputs(), loc, 817 result.operands)) 818 return failure(); 819 820 // Add a region for op. 821 Region ®ion = *result.addRegion(); 822 823 // Create a basic-block inside op's region. 824 Block &block = region.emplaceBlock(); 825 826 // Create and insert an "inner-op" operation in the block. 827 // Just for testing purposes, we can assume that inner op is a binary op with 828 // result and operand types all same as the test-op's first operand. 829 Type innerOpType = opFntype.getInput(0); 830 Value lhs = block.addArgument(innerOpType, opLoc); 831 Value rhs = block.addArgument(innerOpType, opLoc); 832 833 OpBuilder builder(parser.getBuilder().getContext()); 834 builder.setInsertionPointToStart(&block); 835 836 Operation *innerOp = 837 builder.create(opLoc, innerOpName, /*operands=*/{lhs, rhs}, innerOpType); 838 839 // Insert a return statement in the block returning the inner-op's result. 840 builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults()); 841 842 // Populate the op operation-state with result-type and location. 843 result.addTypes(opFntype.getResults()); 844 result.location = innerOp->getLoc(); 845 846 return success(); 847 } 848 849 void PrettyPrintedRegionOp::print(OpAsmPrinter &p) { 850 p << ' '; 851 p.printOperands(getOperands()); 852 853 Operation &innerOp = getRegion().front().front(); 854 // Assuming that region has a single non-terminator inner-op, if the inner-op 855 // meets some criteria (which in this case is a simple one based on the name 856 // of inner-op), then we can print the entire region in a succinct way. 857 // Here we assume that the prototype of "special.op" can be trivially derived 858 // while parsing it back. 859 if (innerOp.getName().getStringRef().equals("special.op")) { 860 p << " start special.op end"; 861 } else { 862 p << " ("; 863 p.printRegion(getRegion()); 864 p << ")"; 865 } 866 867 p << " : "; 868 p.printFunctionalType(*this); 869 } 870 871 //===----------------------------------------------------------------------===// 872 // Test PolyForOp - parse list of region arguments. 873 //===----------------------------------------------------------------------===// 874 875 ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) { 876 SmallVector<OpAsmParser::UnresolvedOperand, 4> ivsInfo; 877 // Parse list of region arguments without a delimiter. 878 if (parser.parseRegionArgumentList(ivsInfo)) 879 return failure(); 880 881 // Parse the body region. 882 Region *body = result.addRegion(); 883 auto &builder = parser.getBuilder(); 884 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 885 return parser.parseRegion(*body, ivsInfo, argTypes); 886 } 887 888 void PolyForOp::print(OpAsmPrinter &p) { p.printGenericOp(*this); } 889 890 void PolyForOp::getAsmBlockArgumentNames(Region ®ion, 891 OpAsmSetValueNameFn setNameFn) { 892 auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names"); 893 if (!arrayAttr) 894 return; 895 auto args = getRegion().front().getArguments(); 896 auto e = std::min(arrayAttr.size(), args.size()); 897 for (unsigned i = 0; i < e; ++i) { 898 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 899 setNameFn(args[i], strAttr.getValue()); 900 } 901 } 902 903 //===----------------------------------------------------------------------===// 904 // Test removing op with inner ops. 905 //===----------------------------------------------------------------------===// 906 907 namespace { 908 struct TestRemoveOpWithInnerOps 909 : public OpRewritePattern<TestOpWithRegionPattern> { 910 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 911 912 void initialize() { setDebugName("TestRemoveOpWithInnerOps"); } 913 914 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 915 PatternRewriter &rewriter) const override { 916 rewriter.eraseOp(op); 917 return success(); 918 } 919 }; 920 } // namespace 921 922 void TestOpWithRegionPattern::getCanonicalizationPatterns( 923 RewritePatternSet &results, MLIRContext *context) { 924 results.add<TestRemoveOpWithInnerOps>(context); 925 } 926 927 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 928 return getOperand(); 929 } 930 931 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) { 932 return getValue(); 933 } 934 935 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 936 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 937 for (Value input : this->getOperands()) { 938 results.push_back(input); 939 } 940 return success(); 941 } 942 943 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 944 assert(operands.size() == 1); 945 if (operands.front()) { 946 (*this)->setAttr("attr", operands.front()); 947 return getResult(); 948 } 949 return {}; 950 } 951 952 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) { 953 return getOperand(); 954 } 955 956 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 957 MLIRContext *, Optional<Location> location, ValueRange operands, 958 DictionaryAttr attributes, RegionRange regions, 959 SmallVectorImpl<Type> &inferredReturnTypes) { 960 if (operands[0].getType() != operands[1].getType()) { 961 return emitOptionalError(location, "operand type mismatch ", 962 operands[0].getType(), " vs ", 963 operands[1].getType()); 964 } 965 inferredReturnTypes.assign({operands[0].getType()}); 966 return success(); 967 } 968 969 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 970 MLIRContext *context, Optional<Location> location, ValueShapeRange operands, 971 DictionaryAttr attributes, RegionRange regions, 972 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 973 // Create return type consisting of the last element of the first operand. 974 auto operandType = operands.front().getType(); 975 auto sval = operandType.dyn_cast<ShapedType>(); 976 if (!sval) { 977 return emitOptionalError(location, "only shaped type operands allowed"); 978 } 979 int64_t dim = 980 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 981 auto type = IntegerType::get(context, 17); 982 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 983 return success(); 984 } 985 986 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 987 OpBuilder &builder, ValueRange operands, 988 llvm::SmallVectorImpl<Value> &shapes) { 989 shapes = SmallVector<Value, 1>{ 990 builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)}; 991 return success(); 992 } 993 994 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( 995 OpBuilder &builder, ValueRange operands, 996 llvm::SmallVectorImpl<Value> &shapes) { 997 Location loc = getLoc(); 998 shapes.reserve(operands.size()); 999 for (Value operand : llvm::reverse(operands)) { 1000 auto rank = operand.getType().cast<RankedTensorType>().getRank(); 1001 auto currShape = llvm::to_vector<4>( 1002 llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value { 1003 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 1004 })); 1005 shapes.push_back(builder.create<tensor::FromElementsOp>( 1006 getLoc(), RankedTensorType::get({rank}, builder.getIndexType()), 1007 currShape)); 1008 } 1009 return success(); 1010 } 1011 1012 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( 1013 OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { 1014 Location loc = getLoc(); 1015 shapes.reserve(getNumOperands()); 1016 for (Value operand : llvm::reverse(getOperands())) { 1017 auto currShape = llvm::to_vector<4>(llvm::map_range( 1018 llvm::seq<int64_t>( 1019 0, operand.getType().cast<RankedTensorType>().getRank()), 1020 [&](int64_t dim) -> Value { 1021 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 1022 })); 1023 shapes.emplace_back(std::move(currShape)); 1024 } 1025 return success(); 1026 } 1027 1028 //===----------------------------------------------------------------------===// 1029 // Test SideEffect interfaces 1030 //===----------------------------------------------------------------------===// 1031 1032 namespace { 1033 /// A test resource for side effects. 1034 struct TestResource : public SideEffects::Resource::Base<TestResource> { 1035 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource) 1036 1037 StringRef getName() final { return "<Test>"; } 1038 }; 1039 } // namespace 1040 1041 static void testSideEffectOpGetEffect( 1042 Operation *op, 1043 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 1044 &effects) { 1045 auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter"); 1046 if (!effectsAttr) 1047 return; 1048 1049 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); 1050 } 1051 1052 void SideEffectOp::getEffects( 1053 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1054 // Check for an effects attribute on the op instance. 1055 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects"); 1056 if (!effectsAttr) 1057 return; 1058 1059 // If there is one, it is an array of dictionary attributes that hold 1060 // information on the effects of this operation. 1061 for (Attribute element : effectsAttr) { 1062 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 1063 1064 // Get the specific memory effect. 1065 MemoryEffects::Effect *effect = 1066 StringSwitch<MemoryEffects::Effect *>( 1067 effectElement.get("effect").cast<StringAttr>().getValue()) 1068 .Case("allocate", MemoryEffects::Allocate::get()) 1069 .Case("free", MemoryEffects::Free::get()) 1070 .Case("read", MemoryEffects::Read::get()) 1071 .Case("write", MemoryEffects::Write::get()); 1072 1073 // Check for a non-default resource to use. 1074 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 1075 if (effectElement.get("test_resource")) 1076 resource = TestResource::get(); 1077 1078 // Check for a result to affect. 1079 if (effectElement.get("on_result")) 1080 effects.emplace_back(effect, getResult(), resource); 1081 else if (Attribute ref = effectElement.get("on_reference")) 1082 effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource); 1083 else 1084 effects.emplace_back(effect, resource); 1085 } 1086 } 1087 1088 void SideEffectOp::getEffects( 1089 SmallVectorImpl<TestEffects::EffectInstance> &effects) { 1090 testSideEffectOpGetEffect(getOperation(), effects); 1091 } 1092 1093 //===----------------------------------------------------------------------===// 1094 // StringAttrPrettyNameOp 1095 //===----------------------------------------------------------------------===// 1096 1097 // This op has fancy handling of its SSA result name. 1098 ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser, 1099 OperationState &result) { 1100 // Add the result types. 1101 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 1102 result.addTypes(parser.getBuilder().getIntegerType(32)); 1103 1104 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 1105 return failure(); 1106 1107 // If the attribute dictionary contains no 'names' attribute, infer it from 1108 // the SSA name (if specified). 1109 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 1110 return attr.getName() == "names"; 1111 }); 1112 1113 // If there was no name specified, check to see if there was a useful name 1114 // specified in the asm file. 1115 if (hadNames || parser.getNumResults() == 0) 1116 return success(); 1117 1118 SmallVector<StringRef, 4> names; 1119 auto *context = result.getContext(); 1120 1121 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 1122 auto resultName = parser.getResultName(i); 1123 StringRef nameStr; 1124 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 1125 nameStr = resultName.first; 1126 1127 names.push_back(nameStr); 1128 } 1129 1130 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 1131 result.attributes.push_back({StringAttr::get(context, "names"), namesAttr}); 1132 return success(); 1133 } 1134 1135 void StringAttrPrettyNameOp::print(OpAsmPrinter &p) { 1136 // Note that we only need to print the "name" attribute if the asmprinter 1137 // result name disagrees with it. This can happen in strange cases, e.g. 1138 // when there are conflicts. 1139 bool namesDisagree = getNames().size() != getNumResults(); 1140 1141 SmallString<32> resultNameStr; 1142 for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) { 1143 resultNameStr.clear(); 1144 llvm::raw_svector_ostream tmpStream(resultNameStr); 1145 p.printOperand(getResult(i), tmpStream); 1146 1147 auto expectedName = getNames()[i].dyn_cast<StringAttr>(); 1148 if (!expectedName || 1149 tmpStream.str().drop_front() != expectedName.getValue()) { 1150 namesDisagree = true; 1151 } 1152 } 1153 1154 if (namesDisagree) 1155 p.printOptionalAttrDictWithKeyword((*this)->getAttrs()); 1156 else 1157 p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"}); 1158 } 1159 1160 // We set the SSA name in the asm syntax to the contents of the name 1161 // attribute. 1162 void StringAttrPrettyNameOp::getAsmResultNames( 1163 function_ref<void(Value, StringRef)> setNameFn) { 1164 1165 auto value = getNames(); 1166 for (size_t i = 0, e = value.size(); i != e; ++i) 1167 if (auto str = value[i].dyn_cast<StringAttr>()) 1168 if (!str.getValue().empty()) 1169 setNameFn(getResult(i), str.getValue()); 1170 } 1171 1172 void CustomResultsNameOp::getAsmResultNames( 1173 function_ref<void(Value, StringRef)> setNameFn) { 1174 ArrayAttr value = getNames(); 1175 for (size_t i = 0, e = value.size(); i != e; ++i) 1176 if (auto str = value[i].dyn_cast<StringAttr>()) 1177 if (!str.getValue().empty()) 1178 setNameFn(getResult(i), str.getValue()); 1179 } 1180 1181 //===----------------------------------------------------------------------===// 1182 // ResultTypeWithTraitOp 1183 //===----------------------------------------------------------------------===// 1184 1185 LogicalResult ResultTypeWithTraitOp::verify() { 1186 if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>()) 1187 return success(); 1188 return emitError("result type should have trait 'TestTypeTrait'"); 1189 } 1190 1191 //===----------------------------------------------------------------------===// 1192 // AttrWithTraitOp 1193 //===----------------------------------------------------------------------===// 1194 1195 LogicalResult AttrWithTraitOp::verify() { 1196 if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>()) 1197 return success(); 1198 return emitError("'attr' attribute should have trait 'TestAttrTrait'"); 1199 } 1200 1201 //===----------------------------------------------------------------------===// 1202 // RegionIfOp 1203 //===----------------------------------------------------------------------===// 1204 1205 void RegionIfOp::print(OpAsmPrinter &p) { 1206 p << " "; 1207 p.printOperands(getOperands()); 1208 p << ": " << getOperandTypes(); 1209 p.printArrowTypeList(getResultTypes()); 1210 p << " then "; 1211 p.printRegion(getThenRegion(), 1212 /*printEntryBlockArgs=*/true, 1213 /*printBlockTerminators=*/true); 1214 p << " else "; 1215 p.printRegion(getElseRegion(), 1216 /*printEntryBlockArgs=*/true, 1217 /*printBlockTerminators=*/true); 1218 p << " join "; 1219 p.printRegion(getJoinRegion(), 1220 /*printEntryBlockArgs=*/true, 1221 /*printBlockTerminators=*/true); 1222 } 1223 1224 ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) { 1225 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos; 1226 SmallVector<Type, 2> operandTypes; 1227 1228 result.regions.reserve(3); 1229 Region *thenRegion = result.addRegion(); 1230 Region *elseRegion = result.addRegion(); 1231 Region *joinRegion = result.addRegion(); 1232 1233 // Parse operand, type and arrow type lists. 1234 if (parser.parseOperandList(operandInfos) || 1235 parser.parseColonTypeList(operandTypes) || 1236 parser.parseArrowTypeList(result.types)) 1237 return failure(); 1238 1239 // Parse all attached regions. 1240 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 1241 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 1242 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 1243 return failure(); 1244 1245 return parser.resolveOperands(operandInfos, operandTypes, 1246 parser.getCurrentLocation(), result.operands); 1247 } 1248 1249 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 1250 assert(index < 2 && "invalid region index"); 1251 return getOperands(); 1252 } 1253 1254 void RegionIfOp::getSuccessorRegions( 1255 Optional<unsigned> index, ArrayRef<Attribute> operands, 1256 SmallVectorImpl<RegionSuccessor> ®ions) { 1257 // We always branch to the join region. 1258 if (index.hasValue()) { 1259 if (index.getValue() < 2) 1260 regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs())); 1261 else 1262 regions.push_back(RegionSuccessor(getResults())); 1263 return; 1264 } 1265 1266 // The then and else regions are the entry regions of this op. 1267 regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs())); 1268 regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs())); 1269 } 1270 1271 void RegionIfOp::getRegionInvocationBounds( 1272 ArrayRef<Attribute> operands, 1273 SmallVectorImpl<InvocationBounds> &invocationBounds) { 1274 // Each region is invoked at most once. 1275 invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1}); 1276 } 1277 1278 //===----------------------------------------------------------------------===// 1279 // AnyCondOp 1280 //===----------------------------------------------------------------------===// 1281 1282 void AnyCondOp::getSuccessorRegions(Optional<unsigned> index, 1283 ArrayRef<Attribute> operands, 1284 SmallVectorImpl<RegionSuccessor> ®ions) { 1285 // The parent op branches into the only region, and the region branches back 1286 // to the parent op. 1287 if (index) 1288 regions.emplace_back(&getRegion()); 1289 else 1290 regions.emplace_back(getResults()); 1291 } 1292 1293 void AnyCondOp::getRegionInvocationBounds( 1294 ArrayRef<Attribute> operands, 1295 SmallVectorImpl<InvocationBounds> &invocationBounds) { 1296 invocationBounds.emplace_back(1, 1); 1297 } 1298 1299 //===----------------------------------------------------------------------===// 1300 // SingleNoTerminatorCustomAsmOp 1301 //===----------------------------------------------------------------------===// 1302 1303 ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser, 1304 OperationState &state) { 1305 Region *body = state.addRegion(); 1306 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) 1307 return failure(); 1308 return success(); 1309 } 1310 1311 void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) { 1312 printer.printRegion( 1313 getRegion(), /*printEntryBlockArgs=*/false, 1314 // This op has a single block without terminators. But explicitly mark 1315 // as not printing block terminators for testing. 1316 /*printBlockTerminators=*/false); 1317 } 1318 1319 //===----------------------------------------------------------------------===// 1320 // TestVerifiersOp 1321 //===----------------------------------------------------------------------===// 1322 1323 LogicalResult TestVerifiersOp::verify() { 1324 if (!getRegion().hasOneBlock()) 1325 return emitOpError("`hasOneBlock` trait hasn't been verified"); 1326 1327 Operation *definingOp = getInput().getDefiningOp(); 1328 if (definingOp && failed(mlir::verify(definingOp))) 1329 return emitOpError("operand hasn't been verified"); 1330 1331 emitRemark("success run of verifier"); 1332 1333 return success(); 1334 } 1335 1336 LogicalResult TestVerifiersOp::verifyRegions() { 1337 if (!getRegion().hasOneBlock()) 1338 return emitOpError("`hasOneBlock` trait hasn't been verified"); 1339 1340 for (Block &block : getRegion()) 1341 for (Operation &op : block) 1342 if (failed(mlir::verify(&op))) 1343 return emitOpError("nested op hasn't been verified"); 1344 1345 emitRemark("success run of region verifier"); 1346 1347 return success(); 1348 } 1349 1350 #include "TestOpEnums.cpp.inc" 1351 #include "TestOpInterfaces.cpp.inc" 1352 #include "TestOpStructs.cpp.inc" 1353 #include "TestTypeInterfaces.cpp.inc" 1354 1355 #define GET_OP_CLASSES 1356 #include "TestOps.cpp.inc" 1357