1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestAttributes.h" 11 #include "TestInterfaces.h" 12 #include "TestTypes.h" 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/DLTI/DLTI.h" 15 #include "mlir/Dialect/Func/IR/FuncOps.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/DialectImplementation.h" 19 #include "mlir/IR/ExtensibleDialect.h" 20 #include "mlir/IR/PatternMatch.h" 21 #include "mlir/IR/TypeUtilities.h" 22 #include "mlir/IR/Verifier.h" 23 #include "mlir/Reducer/ReductionPatternInterface.h" 24 #include "mlir/Transforms/FoldUtils.h" 25 #include "mlir/Transforms/InliningUtils.h" 26 #include "llvm/ADT/StringExtras.h" 27 #include "llvm/ADT/StringSwitch.h" 28 29 // Include this before the using namespace lines below to 30 // test that we don't have namespace dependencies. 31 #include "TestOpsDialect.cpp.inc" 32 33 using namespace mlir; 34 using namespace test; 35 36 void test::registerTestDialect(DialectRegistry ®istry) { 37 registry.insert<TestDialect>(); 38 } 39 40 //===----------------------------------------------------------------------===// 41 // TestDialect Interfaces 42 //===----------------------------------------------------------------------===// 43 44 namespace { 45 46 /// Testing the correctness of some traits. 47 static_assert( 48 llvm::is_detected<OpTrait::has_implicit_terminator_t, 49 SingleBlockImplicitTerminatorOp>::value, 50 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp"); 51 static_assert(OpTrait::hasSingleBlockImplicitTerminator< 52 SingleBlockImplicitTerminatorOp>::value, 53 "hasSingleBlockImplicitTerminator does not match " 54 "SingleBlockImplicitTerminatorOp"); 55 56 // Test support for interacting with the AsmPrinter. 57 struct TestOpAsmInterface : public OpAsmDialectInterface { 58 using OpAsmDialectInterface::OpAsmDialectInterface; 59 60 AliasResult getAlias(Attribute attr, raw_ostream &os) const final { 61 StringAttr strAttr = attr.dyn_cast<StringAttr>(); 62 if (!strAttr) 63 return AliasResult::NoAlias; 64 65 // Check the contents of the string attribute to see what the test alias 66 // should be named. 67 Optional<StringRef> aliasName = 68 StringSwitch<Optional<StringRef>>(strAttr.getValue()) 69 .Case("alias_test:dot_in_name", StringRef("test.alias")) 70 .Case("alias_test:trailing_digit", StringRef("test_alias0")) 71 .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) 72 .Case("alias_test:sanitize_conflict_a", 73 StringRef("test_alias_conflict0")) 74 .Case("alias_test:sanitize_conflict_b", 75 StringRef("test_alias_conflict0_")) 76 .Case("alias_test:tensor_encoding", StringRef("test_encoding")) 77 .Default(llvm::None); 78 if (!aliasName) 79 return AliasResult::NoAlias; 80 81 os << *aliasName; 82 return AliasResult::FinalAlias; 83 } 84 85 AliasResult getAlias(Type type, raw_ostream &os) const final { 86 if (auto tupleType = type.dyn_cast<TupleType>()) { 87 if (tupleType.size() > 0 && 88 llvm::all_of(tupleType.getTypes(), [](Type elemType) { 89 return elemType.isa<SimpleAType>(); 90 })) { 91 os << "test_tuple"; 92 return AliasResult::FinalAlias; 93 } 94 } 95 if (auto intType = type.dyn_cast<TestIntegerType>()) { 96 if (intType.getSignedness() == 97 TestIntegerType::SignednessSemantics::Unsigned && 98 intType.getWidth() == 8) { 99 os << "test_ui8"; 100 return AliasResult::FinalAlias; 101 } 102 } 103 return AliasResult::NoAlias; 104 } 105 }; 106 107 struct TestDialectFoldInterface : public DialectFoldInterface { 108 using DialectFoldInterface::DialectFoldInterface; 109 110 /// Registered hook to check if the given region, which is attached to an 111 /// operation that is *not* isolated from above, should be used when 112 /// materializing constants. 113 bool shouldMaterializeInto(Region *region) const final { 114 // If this is a one region operation, then insert into it. 115 return isa<OneRegionOp>(region->getParentOp()); 116 } 117 }; 118 119 /// This class defines the interface for handling inlining with standard 120 /// operations. 121 struct TestInlinerInterface : public DialectInlinerInterface { 122 using DialectInlinerInterface::DialectInlinerInterface; 123 124 //===--------------------------------------------------------------------===// 125 // Analysis Hooks 126 //===--------------------------------------------------------------------===// 127 128 bool isLegalToInline(Operation *call, Operation *callable, 129 bool wouldBeCloned) const final { 130 // Don't allow inlining calls that are marked `noinline`. 131 return !call->hasAttr("noinline"); 132 } 133 bool isLegalToInline(Region *, Region *, bool, 134 BlockAndValueMapping &) const final { 135 // Inlining into test dialect regions is legal. 136 return true; 137 } 138 bool isLegalToInline(Operation *, Region *, bool, 139 BlockAndValueMapping &) const final { 140 return true; 141 } 142 143 bool shouldAnalyzeRecursively(Operation *op) const final { 144 // Analyze recursively if this is not a functional region operation, it 145 // froms a separate functional scope. 146 return !isa<FunctionalRegionOp>(op); 147 } 148 149 //===--------------------------------------------------------------------===// 150 // Transformation Hooks 151 //===--------------------------------------------------------------------===// 152 153 /// Handle the given inlined terminator by replacing it with a new operation 154 /// as necessary. 155 void handleTerminator(Operation *op, 156 ArrayRef<Value> valuesToRepl) const final { 157 // Only handle "test.return" here. 158 auto returnOp = dyn_cast<TestReturnOp>(op); 159 if (!returnOp) 160 return; 161 162 // Replace the values directly with the return operands. 163 assert(returnOp.getNumOperands() == valuesToRepl.size()); 164 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 165 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 166 } 167 168 /// Attempt to materialize a conversion for a type mismatch between a call 169 /// from this dialect, and a callable region. This method should generate an 170 /// operation that takes 'input' as the only operand, and produces a single 171 /// result of 'resultType'. If a conversion can not be generated, nullptr 172 /// should be returned. 173 Operation *materializeCallConversion(OpBuilder &builder, Value input, 174 Type resultType, 175 Location conversionLoc) const final { 176 // Only allow conversion for i16/i32 types. 177 if (!(resultType.isSignlessInteger(16) || 178 resultType.isSignlessInteger(32)) || 179 !(input.getType().isSignlessInteger(16) || 180 input.getType().isSignlessInteger(32))) 181 return nullptr; 182 return builder.create<TestCastOp>(conversionLoc, resultType, input); 183 } 184 185 void processInlinedCallBlocks( 186 Operation *call, 187 iterator_range<Region::iterator> inlinedBlocks) const final { 188 if (!isa<ConversionCallOp>(call)) 189 return; 190 191 // Set attributed on all ops in the inlined blocks. 192 for (Block &block : inlinedBlocks) { 193 block.walk([&](Operation *op) { 194 op->setAttr("inlined_conversion", UnitAttr::get(call->getContext())); 195 }); 196 } 197 } 198 }; 199 200 struct TestReductionPatternInterface : public DialectReductionPatternInterface { 201 public: 202 TestReductionPatternInterface(Dialect *dialect) 203 : DialectReductionPatternInterface(dialect) {} 204 205 void populateReductionPatterns(RewritePatternSet &patterns) const final { 206 populateTestReductionPatterns(patterns); 207 } 208 }; 209 210 } // namespace 211 212 //===----------------------------------------------------------------------===// 213 // Dynamic operations 214 //===----------------------------------------------------------------------===// 215 216 std::unique_ptr<DynamicOpDefinition> getDynamicGenericOp(TestDialect *dialect) { 217 return DynamicOpDefinition::get( 218 "dynamic_generic", dialect, [](Operation *op) { return success(); }, 219 [](Operation *op) { return success(); }); 220 } 221 222 std::unique_ptr<DynamicOpDefinition> 223 getDynamicOneOperandTwoResultsOp(TestDialect *dialect) { 224 return DynamicOpDefinition::get( 225 "dynamic_one_operand_two_results", dialect, 226 [](Operation *op) { 227 if (op->getNumOperands() != 1) { 228 op->emitOpError() 229 << "expected 1 operand, but had " << op->getNumOperands(); 230 return failure(); 231 } 232 if (op->getNumResults() != 2) { 233 op->emitOpError() 234 << "expected 2 results, but had " << op->getNumResults(); 235 return failure(); 236 } 237 return success(); 238 }, 239 [](Operation *op) { return success(); }); 240 } 241 242 std::unique_ptr<DynamicOpDefinition> 243 getDynamicCustomParserPrinterOp(TestDialect *dialect) { 244 auto verifier = [](Operation *op) { 245 if (op->getNumOperands() == 0 && op->getNumResults() == 0) 246 return success(); 247 op->emitError() << "operation should have no operands and no results"; 248 return failure(); 249 }; 250 auto regionVerifier = [](Operation *op) { return success(); }; 251 252 auto parser = [](OpAsmParser &parser, OperationState &state) { 253 return parser.parseKeyword("custom_keyword"); 254 }; 255 256 auto printer = [](Operation *op, OpAsmPrinter &printer, llvm::StringRef) { 257 printer << op->getName() << " custom_keyword"; 258 }; 259 260 return DynamicOpDefinition::get("dynamic_custom_parser_printer", dialect, 261 verifier, regionVerifier, parser, printer); 262 } 263 264 //===----------------------------------------------------------------------===// 265 // TestDialect 266 //===----------------------------------------------------------------------===// 267 268 static void testSideEffectOpGetEffect( 269 Operation *op, 270 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects); 271 272 // This is the implementation of a dialect fallback for `TestEffectOpInterface`. 273 struct TestOpEffectInterfaceFallback 274 : public TestEffectOpInterface::FallbackModel< 275 TestOpEffectInterfaceFallback> { 276 static bool classof(Operation *op) { 277 bool isSupportedOp = 278 op->getName().getStringRef() == "test.unregistered_side_effect_op"; 279 assert(isSupportedOp && "Unexpected dispatch"); 280 return isSupportedOp; 281 } 282 283 void 284 getEffects(Operation *op, 285 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 286 &effects) const { 287 testSideEffectOpGetEffect(op, effects); 288 } 289 }; 290 291 void TestDialect::initialize() { 292 registerAttributes(); 293 registerTypes(); 294 addOperations< 295 #define GET_OP_LIST 296 #include "TestOps.cpp.inc" 297 >(); 298 registerDynamicOp(getDynamicGenericOp(this)); 299 registerDynamicOp(getDynamicOneOperandTwoResultsOp(this)); 300 registerDynamicOp(getDynamicCustomParserPrinterOp(this)); 301 302 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 303 TestInlinerInterface, TestReductionPatternInterface>(); 304 allowUnknownOperations(); 305 306 // Instantiate our fallback op interface that we'll use on specific 307 // unregistered op. 308 fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback; 309 } 310 TestDialect::~TestDialect() { 311 delete static_cast<TestOpEffectInterfaceFallback *>( 312 fallbackEffectOpInterfaces); 313 } 314 315 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, 316 Type type, Location loc) { 317 return builder.create<TestOpConstant>(loc, type, value); 318 } 319 320 ::mlir::LogicalResult FormatInferType2Op::inferReturnTypes( 321 ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, 322 ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, 323 ::mlir::RegionRange regions, 324 ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { 325 inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)}); 326 return ::mlir::success(); 327 } 328 329 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, 330 OperationName opName) { 331 if (opName.getIdentifier() == "test.unregistered_side_effect_op" && 332 typeID == TypeID::get<TestEffectOpInterface>()) 333 return fallbackEffectOpInterfaces; 334 return nullptr; 335 } 336 337 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 338 NamedAttribute namedAttr) { 339 if (namedAttr.getName() == "test.invalid_attr") 340 return op->emitError() << "invalid to use 'test.invalid_attr'"; 341 return success(); 342 } 343 344 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 345 unsigned regionIndex, 346 unsigned argIndex, 347 NamedAttribute namedAttr) { 348 if (namedAttr.getName() == "test.invalid_attr") 349 return op->emitError() << "invalid to use 'test.invalid_attr'"; 350 return success(); 351 } 352 353 LogicalResult 354 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 355 unsigned resultIndex, 356 NamedAttribute namedAttr) { 357 if (namedAttr.getName() == "test.invalid_attr") 358 return op->emitError() << "invalid to use 'test.invalid_attr'"; 359 return success(); 360 } 361 362 Optional<Dialect::ParseOpHook> 363 TestDialect::getParseOperationHook(StringRef opName) const { 364 if (opName == "test.dialect_custom_printer") { 365 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 366 return parser.parseKeyword("custom_format"); 367 }}; 368 } 369 if (opName == "test.dialect_custom_format_fallback") { 370 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 371 return parser.parseKeyword("custom_format_fallback"); 372 }}; 373 } 374 return None; 375 } 376 377 llvm::unique_function<void(Operation *, OpAsmPrinter &)> 378 TestDialect::getOperationPrinter(Operation *op) const { 379 StringRef opName = op->getName().getStringRef(); 380 if (opName == "test.dialect_custom_printer") { 381 return [](Operation *op, OpAsmPrinter &printer) { 382 printer.getStream() << " custom_format"; 383 }; 384 } 385 if (opName == "test.dialect_custom_format_fallback") { 386 return [](Operation *op, OpAsmPrinter &printer) { 387 printer.getStream() << " custom_format_fallback"; 388 }; 389 } 390 return {}; 391 } 392 393 //===----------------------------------------------------------------------===// 394 // TestBranchOp 395 //===----------------------------------------------------------------------===// 396 397 SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) { 398 assert(index == 0 && "invalid successor index"); 399 return SuccessorOperands(getTargetOperandsMutable()); 400 } 401 402 //===----------------------------------------------------------------------===// 403 // TestProducingBranchOp 404 //===----------------------------------------------------------------------===// 405 406 SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) { 407 assert(index <= 1 && "invalid successor index"); 408 if (index == 1) 409 return SuccessorOperands(getFirstOperandsMutable()); 410 return SuccessorOperands(getSecondOperandsMutable()); 411 } 412 413 //===----------------------------------------------------------------------===// 414 // TestProducingBranchOp 415 //===----------------------------------------------------------------------===// 416 417 SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) { 418 assert(index <= 1 && "invalid successor index"); 419 if (index == 0) 420 return SuccessorOperands(0, getSuccessOperandsMutable()); 421 return SuccessorOperands(1, getErrorOperandsMutable()); 422 } 423 424 //===----------------------------------------------------------------------===// 425 // TestDialectCanonicalizerOp 426 //===----------------------------------------------------------------------===// 427 428 static LogicalResult 429 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op, 430 PatternRewriter &rewriter) { 431 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 432 op, rewriter.getI32IntegerAttr(42)); 433 return success(); 434 } 435 436 void TestDialect::getCanonicalizationPatterns( 437 RewritePatternSet &results) const { 438 results.add(&dialectCanonicalizationPattern); 439 } 440 441 //===----------------------------------------------------------------------===// 442 // TestCallOp 443 //===----------------------------------------------------------------------===// 444 445 LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 446 // Check that the callee attribute was specified. 447 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee"); 448 if (!fnAttr) 449 return emitOpError("requires a 'callee' symbol reference attribute"); 450 if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr)) 451 return emitOpError() << "'" << fnAttr.getValue() 452 << "' does not reference a valid function"; 453 return success(); 454 } 455 456 //===----------------------------------------------------------------------===// 457 // TestFoldToCallOp 458 //===----------------------------------------------------------------------===// 459 460 namespace { 461 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 462 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 463 464 LogicalResult matchAndRewrite(FoldToCallOp op, 465 PatternRewriter &rewriter) const override { 466 rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(), 467 op.getCalleeAttr(), ValueRange()); 468 return success(); 469 } 470 }; 471 } // namespace 472 473 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results, 474 MLIRContext *context) { 475 results.add<FoldToCallOpPattern>(context); 476 } 477 478 //===----------------------------------------------------------------------===// 479 // Test Format* operations 480 //===----------------------------------------------------------------------===// 481 482 //===----------------------------------------------------------------------===// 483 // Parsing 484 485 static ParseResult parseCustomOptionalOperand( 486 OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &optOperand) { 487 if (succeeded(parser.parseOptionalLParen())) { 488 optOperand.emplace(); 489 if (parser.parseOperand(*optOperand) || parser.parseRParen()) 490 return failure(); 491 } 492 return success(); 493 } 494 495 static ParseResult parseCustomDirectiveOperands( 496 OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, 497 Optional<OpAsmParser::UnresolvedOperand> &optOperand, 498 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands) { 499 if (parser.parseOperand(operand)) 500 return failure(); 501 if (succeeded(parser.parseOptionalComma())) { 502 optOperand.emplace(); 503 if (parser.parseOperand(*optOperand)) 504 return failure(); 505 } 506 if (parser.parseArrow() || parser.parseLParen() || 507 parser.parseOperandList(varOperands) || parser.parseRParen()) 508 return failure(); 509 return success(); 510 } 511 static ParseResult 512 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 513 Type &optOperandType, 514 SmallVectorImpl<Type> &varOperandTypes) { 515 if (parser.parseColon()) 516 return failure(); 517 518 if (parser.parseType(operandType)) 519 return failure(); 520 if (succeeded(parser.parseOptionalComma())) { 521 if (parser.parseType(optOperandType)) 522 return failure(); 523 } 524 if (parser.parseArrow() || parser.parseLParen() || 525 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 526 return failure(); 527 return success(); 528 } 529 static ParseResult 530 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 531 Type optOperandType, 532 const SmallVectorImpl<Type> &varOperandTypes) { 533 if (parser.parseKeyword("type_refs_capture")) 534 return failure(); 535 536 Type operandType2, optOperandType2; 537 SmallVector<Type, 1> varOperandTypes2; 538 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 539 varOperandTypes2)) 540 return failure(); 541 542 if (operandType != operandType2 || optOperandType != optOperandType2 || 543 varOperandTypes != varOperandTypes2) 544 return failure(); 545 546 return success(); 547 } 548 static ParseResult parseCustomDirectiveOperandsAndTypes( 549 OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, 550 Optional<OpAsmParser::UnresolvedOperand> &optOperand, 551 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands, 552 Type &operandType, Type &optOperandType, 553 SmallVectorImpl<Type> &varOperandTypes) { 554 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 555 parseCustomDirectiveResults(parser, operandType, optOperandType, 556 varOperandTypes)) 557 return failure(); 558 return success(); 559 } 560 static ParseResult parseCustomDirectiveRegions( 561 OpAsmParser &parser, Region ®ion, 562 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 563 if (parser.parseRegion(region)) 564 return failure(); 565 if (failed(parser.parseOptionalComma())) 566 return success(); 567 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 568 if (parser.parseRegion(*varRegion)) 569 return failure(); 570 varRegions.emplace_back(std::move(varRegion)); 571 return success(); 572 } 573 static ParseResult 574 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 575 SmallVectorImpl<Block *> &varSuccessors) { 576 if (parser.parseSuccessor(successor)) 577 return failure(); 578 if (failed(parser.parseOptionalComma())) 579 return success(); 580 Block *varSuccessor; 581 if (parser.parseSuccessor(varSuccessor)) 582 return failure(); 583 varSuccessors.append(2, varSuccessor); 584 return success(); 585 } 586 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 587 IntegerAttr &attr, 588 IntegerAttr &optAttr) { 589 if (parser.parseAttribute(attr)) 590 return failure(); 591 if (succeeded(parser.parseOptionalComma())) { 592 if (parser.parseAttribute(optAttr)) 593 return failure(); 594 } 595 return success(); 596 } 597 598 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, 599 NamedAttrList &attrs) { 600 return parser.parseOptionalAttrDict(attrs); 601 } 602 static ParseResult parseCustomDirectiveOptionalOperandRef( 603 OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &optOperand) { 604 int64_t operandCount = 0; 605 if (parser.parseInteger(operandCount)) 606 return failure(); 607 bool expectedOptionalOperand = operandCount == 0; 608 return success(expectedOptionalOperand != optOperand.hasValue()); 609 } 610 611 //===----------------------------------------------------------------------===// 612 // Printing 613 614 static void printCustomOptionalOperand(OpAsmPrinter &printer, Operation *, 615 Value optOperand) { 616 if (optOperand) 617 printer << "(" << optOperand << ") "; 618 } 619 620 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 621 Value operand, Value optOperand, 622 OperandRange varOperands) { 623 printer << operand; 624 if (optOperand) 625 printer << ", " << optOperand; 626 printer << " -> (" << varOperands << ")"; 627 } 628 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 629 Type operandType, Type optOperandType, 630 TypeRange varOperandTypes) { 631 printer << " : " << operandType; 632 if (optOperandType) 633 printer << ", " << optOperandType; 634 printer << " -> (" << varOperandTypes << ")"; 635 } 636 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 637 Operation *op, Type operandType, 638 Type optOperandType, 639 TypeRange varOperandTypes) { 640 printer << " type_refs_capture "; 641 printCustomDirectiveResults(printer, op, operandType, optOperandType, 642 varOperandTypes); 643 } 644 static void printCustomDirectiveOperandsAndTypes( 645 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 646 OperandRange varOperands, Type operandType, Type optOperandType, 647 TypeRange varOperandTypes) { 648 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 649 printCustomDirectiveResults(printer, op, operandType, optOperandType, 650 varOperandTypes); 651 } 652 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 653 Region ®ion, 654 MutableArrayRef<Region> varRegions) { 655 printer.printRegion(region); 656 if (!varRegions.empty()) { 657 printer << ", "; 658 for (Region ®ion : varRegions) 659 printer.printRegion(region); 660 } 661 } 662 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 663 Block *successor, 664 SuccessorRange varSuccessors) { 665 printer << successor; 666 if (!varSuccessors.empty()) 667 printer << ", " << varSuccessors.front(); 668 } 669 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 670 Attribute attribute, 671 Attribute optAttribute) { 672 printer << attribute; 673 if (optAttribute) 674 printer << ", " << optAttribute; 675 } 676 677 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 678 DictionaryAttr attrs) { 679 printer.printOptionalAttrDict(attrs.getValue()); 680 } 681 682 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, 683 Operation *op, 684 Value optOperand) { 685 printer << (optOperand ? "1" : "0"); 686 } 687 688 //===----------------------------------------------------------------------===// 689 // Test IsolatedRegionOp - parse passthrough region arguments. 690 //===----------------------------------------------------------------------===// 691 692 ParseResult IsolatedRegionOp::parse(OpAsmParser &parser, 693 OperationState &result) { 694 OpAsmParser::UnresolvedOperand argInfo; 695 Type argType = parser.getBuilder().getIndexType(); 696 697 // Parse the input operand. 698 if (parser.parseOperand(argInfo) || 699 parser.resolveOperand(argInfo, argType, result.operands)) 700 return failure(); 701 702 // Parse the body region, and reuse the operand info as the argument info. 703 Region *body = result.addRegion(); 704 return parser.parseRegion(*body, argInfo, argType, 705 /*enableNameShadowing=*/true); 706 } 707 708 void IsolatedRegionOp::print(OpAsmPrinter &p) { 709 p << "test.isolated_region "; 710 p.printOperand(getOperand()); 711 p.shadowRegionArgs(getRegion(), getOperand()); 712 p << ' '; 713 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); 714 } 715 716 //===----------------------------------------------------------------------===// 717 // Test SSACFGRegionOp 718 //===----------------------------------------------------------------------===// 719 720 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 721 return RegionKind::SSACFG; 722 } 723 724 //===----------------------------------------------------------------------===// 725 // Test GraphRegionOp 726 //===----------------------------------------------------------------------===// 727 728 ParseResult GraphRegionOp::parse(OpAsmParser &parser, OperationState &result) { 729 // Parse the body region, and reuse the operand info as the argument info. 730 Region *body = result.addRegion(); 731 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 732 } 733 734 void GraphRegionOp::print(OpAsmPrinter &p) { 735 p << "test.graph_region "; 736 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); 737 } 738 739 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 740 return RegionKind::Graph; 741 } 742 743 //===----------------------------------------------------------------------===// 744 // Test AffineScopeOp 745 //===----------------------------------------------------------------------===// 746 747 ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) { 748 // Parse the body region, and reuse the operand info as the argument info. 749 Region *body = result.addRegion(); 750 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 751 } 752 753 void AffineScopeOp::print(OpAsmPrinter &p) { 754 p << "test.affine_scope "; 755 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); 756 } 757 758 //===----------------------------------------------------------------------===// 759 // Test parser. 760 //===----------------------------------------------------------------------===// 761 762 ParseResult ParseIntegerLiteralOp::parse(OpAsmParser &parser, 763 OperationState &result) { 764 if (parser.parseOptionalColon()) 765 return success(); 766 uint64_t numResults; 767 if (parser.parseInteger(numResults)) 768 return failure(); 769 770 IndexType type = parser.getBuilder().getIndexType(); 771 for (unsigned i = 0; i < numResults; ++i) 772 result.addTypes(type); 773 return success(); 774 } 775 776 void ParseIntegerLiteralOp::print(OpAsmPrinter &p) { 777 if (unsigned numResults = getNumResults()) 778 p << " : " << numResults; 779 } 780 781 ParseResult ParseWrappedKeywordOp::parse(OpAsmParser &parser, 782 OperationState &result) { 783 StringRef keyword; 784 if (parser.parseKeyword(&keyword)) 785 return failure(); 786 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 787 return success(); 788 } 789 790 void ParseWrappedKeywordOp::print(OpAsmPrinter &p) { p << " " << getKeyword(); } 791 792 //===----------------------------------------------------------------------===// 793 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 794 795 ParseResult WrappingRegionOp::parse(OpAsmParser &parser, 796 OperationState &result) { 797 if (parser.parseKeyword("wraps")) 798 return failure(); 799 800 // Parse the wrapped op in a region 801 Region &body = *result.addRegion(); 802 body.push_back(new Block); 803 Block &block = body.back(); 804 Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin()); 805 if (!wrappedOp) 806 return failure(); 807 808 // Create a return terminator in the inner region, pass as operand to the 809 // terminator the returned values from the wrapped operation. 810 SmallVector<Value, 8> returnOperands(wrappedOp->getResults()); 811 OpBuilder builder(parser.getContext()); 812 builder.setInsertionPointToEnd(&block); 813 builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands); 814 815 // Get the results type for the wrapping op from the terminator operands. 816 Operation &returnOp = body.back().back(); 817 result.types.append(returnOp.operand_type_begin(), 818 returnOp.operand_type_end()); 819 820 // Use the location of the wrapped op for the "test.wrapping_region" op. 821 result.location = wrappedOp->getLoc(); 822 823 return success(); 824 } 825 826 void WrappingRegionOp::print(OpAsmPrinter &p) { 827 p << " wraps "; 828 p.printGenericOp(&getRegion().front().front()); 829 } 830 831 //===----------------------------------------------------------------------===// 832 // Test PrettyPrintedRegionOp - exercising the following parser APIs 833 // parseGenericOperationAfterOpName 834 // parseCustomOperationName 835 //===----------------------------------------------------------------------===// 836 837 ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser, 838 OperationState &result) { 839 840 SMLoc loc = parser.getCurrentLocation(); 841 Location currLocation = parser.getEncodedSourceLoc(loc); 842 843 // Parse the operands. 844 SmallVector<OpAsmParser::UnresolvedOperand, 2> operands; 845 if (parser.parseOperandList(operands)) 846 return failure(); 847 848 // Check if we are parsing the pretty-printed version 849 // test.pretty_printed_region start <inner-op> end : <functional-type> 850 // Else fallback to parsing the "non pretty-printed" version. 851 if (!succeeded(parser.parseOptionalKeyword("start"))) 852 return parser.parseGenericOperationAfterOpName( 853 result, llvm::makeArrayRef(operands)); 854 855 FailureOr<OperationName> parseOpNameInfo = parser.parseCustomOperationName(); 856 if (failed(parseOpNameInfo)) 857 return failure(); 858 859 StringAttr innerOpName = parseOpNameInfo->getIdentifier(); 860 861 FunctionType opFntype; 862 Optional<Location> explicitLoc; 863 if (parser.parseKeyword("end") || parser.parseColon() || 864 parser.parseType(opFntype) || 865 parser.parseOptionalLocationSpecifier(explicitLoc)) 866 return failure(); 867 868 // If location of the op is explicitly provided, then use it; Else use 869 // the parser's current location. 870 Location opLoc = explicitLoc.getValueOr(currLocation); 871 872 // Derive the SSA-values for op's operands. 873 if (parser.resolveOperands(operands, opFntype.getInputs(), loc, 874 result.operands)) 875 return failure(); 876 877 // Add a region for op. 878 Region ®ion = *result.addRegion(); 879 880 // Create a basic-block inside op's region. 881 Block &block = region.emplaceBlock(); 882 883 // Create and insert an "inner-op" operation in the block. 884 // Just for testing purposes, we can assume that inner op is a binary op with 885 // result and operand types all same as the test-op's first operand. 886 Type innerOpType = opFntype.getInput(0); 887 Value lhs = block.addArgument(innerOpType, opLoc); 888 Value rhs = block.addArgument(innerOpType, opLoc); 889 890 OpBuilder builder(parser.getBuilder().getContext()); 891 builder.setInsertionPointToStart(&block); 892 893 Operation *innerOp = 894 builder.create(opLoc, innerOpName, /*operands=*/{lhs, rhs}, innerOpType); 895 896 // Insert a return statement in the block returning the inner-op's result. 897 builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults()); 898 899 // Populate the op operation-state with result-type and location. 900 result.addTypes(opFntype.getResults()); 901 result.location = innerOp->getLoc(); 902 903 return success(); 904 } 905 906 void PrettyPrintedRegionOp::print(OpAsmPrinter &p) { 907 p << ' '; 908 p.printOperands(getOperands()); 909 910 Operation &innerOp = getRegion().front().front(); 911 // Assuming that region has a single non-terminator inner-op, if the inner-op 912 // meets some criteria (which in this case is a simple one based on the name 913 // of inner-op), then we can print the entire region in a succinct way. 914 // Here we assume that the prototype of "special.op" can be trivially derived 915 // while parsing it back. 916 if (innerOp.getName().getStringRef().equals("special.op")) { 917 p << " start special.op end"; 918 } else { 919 p << " ("; 920 p.printRegion(getRegion()); 921 p << ")"; 922 } 923 924 p << " : "; 925 p.printFunctionalType(*this); 926 } 927 928 //===----------------------------------------------------------------------===// 929 // Test PolyForOp - parse list of region arguments. 930 //===----------------------------------------------------------------------===// 931 932 ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) { 933 SmallVector<OpAsmParser::UnresolvedOperand, 4> ivsInfo; 934 // Parse list of region arguments without a delimiter. 935 if (parser.parseOperandList(ivsInfo, OpAsmParser::Delimiter::None, 936 /*allowResultNumber=*/false)) 937 return failure(); 938 939 // Parse the body region. 940 Region *body = result.addRegion(); 941 auto &builder = parser.getBuilder(); 942 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 943 return parser.parseRegion(*body, ivsInfo, argTypes); 944 } 945 946 void PolyForOp::print(OpAsmPrinter &p) { p.printGenericOp(*this); } 947 948 void PolyForOp::getAsmBlockArgumentNames(Region ®ion, 949 OpAsmSetValueNameFn setNameFn) { 950 auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names"); 951 if (!arrayAttr) 952 return; 953 auto args = getRegion().front().getArguments(); 954 auto e = std::min(arrayAttr.size(), args.size()); 955 for (unsigned i = 0; i < e; ++i) { 956 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 957 setNameFn(args[i], strAttr.getValue()); 958 } 959 } 960 961 //===----------------------------------------------------------------------===// 962 // Test removing op with inner ops. 963 //===----------------------------------------------------------------------===// 964 965 namespace { 966 struct TestRemoveOpWithInnerOps 967 : public OpRewritePattern<TestOpWithRegionPattern> { 968 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 969 970 void initialize() { setDebugName("TestRemoveOpWithInnerOps"); } 971 972 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 973 PatternRewriter &rewriter) const override { 974 rewriter.eraseOp(op); 975 return success(); 976 } 977 }; 978 } // namespace 979 980 void TestOpWithRegionPattern::getCanonicalizationPatterns( 981 RewritePatternSet &results, MLIRContext *context) { 982 results.add<TestRemoveOpWithInnerOps>(context); 983 } 984 985 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 986 return getOperand(); 987 } 988 989 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) { 990 return getValue(); 991 } 992 993 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 994 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 995 for (Value input : this->getOperands()) { 996 results.push_back(input); 997 } 998 return success(); 999 } 1000 1001 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 1002 assert(operands.size() == 1); 1003 if (operands.front()) { 1004 (*this)->setAttr("attr", operands.front()); 1005 return getResult(); 1006 } 1007 return {}; 1008 } 1009 1010 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) { 1011 return getOperand(); 1012 } 1013 1014 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 1015 MLIRContext *, Optional<Location> location, ValueRange operands, 1016 DictionaryAttr attributes, RegionRange regions, 1017 SmallVectorImpl<Type> &inferredReturnTypes) { 1018 if (operands[0].getType() != operands[1].getType()) { 1019 return emitOptionalError(location, "operand type mismatch ", 1020 operands[0].getType(), " vs ", 1021 operands[1].getType()); 1022 } 1023 inferredReturnTypes.assign({operands[0].getType()}); 1024 return success(); 1025 } 1026 1027 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 1028 MLIRContext *context, Optional<Location> location, ValueShapeRange operands, 1029 DictionaryAttr attributes, RegionRange regions, 1030 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 1031 // Create return type consisting of the last element of the first operand. 1032 auto operandType = operands.front().getType(); 1033 auto sval = operandType.dyn_cast<ShapedType>(); 1034 if (!sval) { 1035 return emitOptionalError(location, "only shaped type operands allowed"); 1036 } 1037 int64_t dim = 1038 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 1039 auto type = IntegerType::get(context, 17); 1040 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 1041 return success(); 1042 } 1043 1044 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 1045 OpBuilder &builder, ValueRange operands, 1046 llvm::SmallVectorImpl<Value> &shapes) { 1047 shapes = SmallVector<Value, 1>{ 1048 builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)}; 1049 return success(); 1050 } 1051 1052 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( 1053 OpBuilder &builder, ValueRange operands, 1054 llvm::SmallVectorImpl<Value> &shapes) { 1055 Location loc = getLoc(); 1056 shapes.reserve(operands.size()); 1057 for (Value operand : llvm::reverse(operands)) { 1058 auto rank = operand.getType().cast<RankedTensorType>().getRank(); 1059 auto currShape = llvm::to_vector<4>( 1060 llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value { 1061 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 1062 })); 1063 shapes.push_back(builder.create<tensor::FromElementsOp>( 1064 getLoc(), RankedTensorType::get({rank}, builder.getIndexType()), 1065 currShape)); 1066 } 1067 return success(); 1068 } 1069 1070 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( 1071 OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { 1072 Location loc = getLoc(); 1073 shapes.reserve(getNumOperands()); 1074 for (Value operand : llvm::reverse(getOperands())) { 1075 auto currShape = llvm::to_vector<4>(llvm::map_range( 1076 llvm::seq<int64_t>( 1077 0, operand.getType().cast<RankedTensorType>().getRank()), 1078 [&](int64_t dim) -> Value { 1079 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 1080 })); 1081 shapes.emplace_back(std::move(currShape)); 1082 } 1083 return success(); 1084 } 1085 1086 //===----------------------------------------------------------------------===// 1087 // Test SideEffect interfaces 1088 //===----------------------------------------------------------------------===// 1089 1090 namespace { 1091 /// A test resource for side effects. 1092 struct TestResource : public SideEffects::Resource::Base<TestResource> { 1093 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource) 1094 1095 StringRef getName() final { return "<Test>"; } 1096 }; 1097 } // namespace 1098 1099 static void testSideEffectOpGetEffect( 1100 Operation *op, 1101 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 1102 &effects) { 1103 auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter"); 1104 if (!effectsAttr) 1105 return; 1106 1107 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); 1108 } 1109 1110 void SideEffectOp::getEffects( 1111 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1112 // Check for an effects attribute on the op instance. 1113 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects"); 1114 if (!effectsAttr) 1115 return; 1116 1117 // If there is one, it is an array of dictionary attributes that hold 1118 // information on the effects of this operation. 1119 for (Attribute element : effectsAttr) { 1120 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 1121 1122 // Get the specific memory effect. 1123 MemoryEffects::Effect *effect = 1124 StringSwitch<MemoryEffects::Effect *>( 1125 effectElement.get("effect").cast<StringAttr>().getValue()) 1126 .Case("allocate", MemoryEffects::Allocate::get()) 1127 .Case("free", MemoryEffects::Free::get()) 1128 .Case("read", MemoryEffects::Read::get()) 1129 .Case("write", MemoryEffects::Write::get()); 1130 1131 // Check for a non-default resource to use. 1132 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 1133 if (effectElement.get("test_resource")) 1134 resource = TestResource::get(); 1135 1136 // Check for a result to affect. 1137 if (effectElement.get("on_result")) 1138 effects.emplace_back(effect, getResult(), resource); 1139 else if (Attribute ref = effectElement.get("on_reference")) 1140 effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource); 1141 else 1142 effects.emplace_back(effect, resource); 1143 } 1144 } 1145 1146 void SideEffectOp::getEffects( 1147 SmallVectorImpl<TestEffects::EffectInstance> &effects) { 1148 testSideEffectOpGetEffect(getOperation(), effects); 1149 } 1150 1151 //===----------------------------------------------------------------------===// 1152 // StringAttrPrettyNameOp 1153 //===----------------------------------------------------------------------===// 1154 1155 // This op has fancy handling of its SSA result name. 1156 ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser, 1157 OperationState &result) { 1158 // Add the result types. 1159 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 1160 result.addTypes(parser.getBuilder().getIntegerType(32)); 1161 1162 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 1163 return failure(); 1164 1165 // If the attribute dictionary contains no 'names' attribute, infer it from 1166 // the SSA name (if specified). 1167 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 1168 return attr.getName() == "names"; 1169 }); 1170 1171 // If there was no name specified, check to see if there was a useful name 1172 // specified in the asm file. 1173 if (hadNames || parser.getNumResults() == 0) 1174 return success(); 1175 1176 SmallVector<StringRef, 4> names; 1177 auto *context = result.getContext(); 1178 1179 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 1180 auto resultName = parser.getResultName(i); 1181 StringRef nameStr; 1182 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 1183 nameStr = resultName.first; 1184 1185 names.push_back(nameStr); 1186 } 1187 1188 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 1189 result.attributes.push_back({StringAttr::get(context, "names"), namesAttr}); 1190 return success(); 1191 } 1192 1193 void StringAttrPrettyNameOp::print(OpAsmPrinter &p) { 1194 // Note that we only need to print the "name" attribute if the asmprinter 1195 // result name disagrees with it. This can happen in strange cases, e.g. 1196 // when there are conflicts. 1197 bool namesDisagree = getNames().size() != getNumResults(); 1198 1199 SmallString<32> resultNameStr; 1200 for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) { 1201 resultNameStr.clear(); 1202 llvm::raw_svector_ostream tmpStream(resultNameStr); 1203 p.printOperand(getResult(i), tmpStream); 1204 1205 auto expectedName = getNames()[i].dyn_cast<StringAttr>(); 1206 if (!expectedName || 1207 tmpStream.str().drop_front() != expectedName.getValue()) { 1208 namesDisagree = true; 1209 } 1210 } 1211 1212 if (namesDisagree) 1213 p.printOptionalAttrDictWithKeyword((*this)->getAttrs()); 1214 else 1215 p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"}); 1216 } 1217 1218 // We set the SSA name in the asm syntax to the contents of the name 1219 // attribute. 1220 void StringAttrPrettyNameOp::getAsmResultNames( 1221 function_ref<void(Value, StringRef)> setNameFn) { 1222 1223 auto value = getNames(); 1224 for (size_t i = 0, e = value.size(); i != e; ++i) 1225 if (auto str = value[i].dyn_cast<StringAttr>()) 1226 if (!str.getValue().empty()) 1227 setNameFn(getResult(i), str.getValue()); 1228 } 1229 1230 void CustomResultsNameOp::getAsmResultNames( 1231 function_ref<void(Value, StringRef)> setNameFn) { 1232 ArrayAttr value = getNames(); 1233 for (size_t i = 0, e = value.size(); i != e; ++i) 1234 if (auto str = value[i].dyn_cast<StringAttr>()) 1235 if (!str.getValue().empty()) 1236 setNameFn(getResult(i), str.getValue()); 1237 } 1238 1239 //===----------------------------------------------------------------------===// 1240 // ResultTypeWithTraitOp 1241 //===----------------------------------------------------------------------===// 1242 1243 LogicalResult ResultTypeWithTraitOp::verify() { 1244 if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>()) 1245 return success(); 1246 return emitError("result type should have trait 'TestTypeTrait'"); 1247 } 1248 1249 //===----------------------------------------------------------------------===// 1250 // AttrWithTraitOp 1251 //===----------------------------------------------------------------------===// 1252 1253 LogicalResult AttrWithTraitOp::verify() { 1254 if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>()) 1255 return success(); 1256 return emitError("'attr' attribute should have trait 'TestAttrTrait'"); 1257 } 1258 1259 //===----------------------------------------------------------------------===// 1260 // RegionIfOp 1261 //===----------------------------------------------------------------------===// 1262 1263 void RegionIfOp::print(OpAsmPrinter &p) { 1264 p << " "; 1265 p.printOperands(getOperands()); 1266 p << ": " << getOperandTypes(); 1267 p.printArrowTypeList(getResultTypes()); 1268 p << " then "; 1269 p.printRegion(getThenRegion(), 1270 /*printEntryBlockArgs=*/true, 1271 /*printBlockTerminators=*/true); 1272 p << " else "; 1273 p.printRegion(getElseRegion(), 1274 /*printEntryBlockArgs=*/true, 1275 /*printBlockTerminators=*/true); 1276 p << " join "; 1277 p.printRegion(getJoinRegion(), 1278 /*printEntryBlockArgs=*/true, 1279 /*printBlockTerminators=*/true); 1280 } 1281 1282 ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) { 1283 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos; 1284 SmallVector<Type, 2> operandTypes; 1285 1286 result.regions.reserve(3); 1287 Region *thenRegion = result.addRegion(); 1288 Region *elseRegion = result.addRegion(); 1289 Region *joinRegion = result.addRegion(); 1290 1291 // Parse operand, type and arrow type lists. 1292 if (parser.parseOperandList(operandInfos) || 1293 parser.parseColonTypeList(operandTypes) || 1294 parser.parseArrowTypeList(result.types)) 1295 return failure(); 1296 1297 // Parse all attached regions. 1298 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 1299 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 1300 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 1301 return failure(); 1302 1303 return parser.resolveOperands(operandInfos, operandTypes, 1304 parser.getCurrentLocation(), result.operands); 1305 } 1306 1307 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 1308 assert(index < 2 && "invalid region index"); 1309 return getOperands(); 1310 } 1311 1312 void RegionIfOp::getSuccessorRegions( 1313 Optional<unsigned> index, ArrayRef<Attribute> operands, 1314 SmallVectorImpl<RegionSuccessor> ®ions) { 1315 // We always branch to the join region. 1316 if (index.hasValue()) { 1317 if (index.getValue() < 2) 1318 regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs())); 1319 else 1320 regions.push_back(RegionSuccessor(getResults())); 1321 return; 1322 } 1323 1324 // The then and else regions are the entry regions of this op. 1325 regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs())); 1326 regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs())); 1327 } 1328 1329 void RegionIfOp::getRegionInvocationBounds( 1330 ArrayRef<Attribute> operands, 1331 SmallVectorImpl<InvocationBounds> &invocationBounds) { 1332 // Each region is invoked at most once. 1333 invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1}); 1334 } 1335 1336 //===----------------------------------------------------------------------===// 1337 // AnyCondOp 1338 //===----------------------------------------------------------------------===// 1339 1340 void AnyCondOp::getSuccessorRegions(Optional<unsigned> index, 1341 ArrayRef<Attribute> operands, 1342 SmallVectorImpl<RegionSuccessor> ®ions) { 1343 // The parent op branches into the only region, and the region branches back 1344 // to the parent op. 1345 if (index) 1346 regions.emplace_back(&getRegion()); 1347 else 1348 regions.emplace_back(getResults()); 1349 } 1350 1351 void AnyCondOp::getRegionInvocationBounds( 1352 ArrayRef<Attribute> operands, 1353 SmallVectorImpl<InvocationBounds> &invocationBounds) { 1354 invocationBounds.emplace_back(1, 1); 1355 } 1356 1357 //===----------------------------------------------------------------------===// 1358 // SingleNoTerminatorCustomAsmOp 1359 //===----------------------------------------------------------------------===// 1360 1361 ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser, 1362 OperationState &state) { 1363 Region *body = state.addRegion(); 1364 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) 1365 return failure(); 1366 return success(); 1367 } 1368 1369 void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) { 1370 printer.printRegion( 1371 getRegion(), /*printEntryBlockArgs=*/false, 1372 // This op has a single block without terminators. But explicitly mark 1373 // as not printing block terminators for testing. 1374 /*printBlockTerminators=*/false); 1375 } 1376 1377 //===----------------------------------------------------------------------===// 1378 // TestVerifiersOp 1379 //===----------------------------------------------------------------------===// 1380 1381 LogicalResult TestVerifiersOp::verify() { 1382 if (!getRegion().hasOneBlock()) 1383 return emitOpError("`hasOneBlock` trait hasn't been verified"); 1384 1385 Operation *definingOp = getInput().getDefiningOp(); 1386 if (definingOp && failed(mlir::verify(definingOp))) 1387 return emitOpError("operand hasn't been verified"); 1388 1389 emitRemark("success run of verifier"); 1390 1391 return success(); 1392 } 1393 1394 LogicalResult TestVerifiersOp::verifyRegions() { 1395 if (!getRegion().hasOneBlock()) 1396 return emitOpError("`hasOneBlock` trait hasn't been verified"); 1397 1398 for (Block &block : getRegion()) 1399 for (Operation &op : block) 1400 if (failed(mlir::verify(&op))) 1401 return emitOpError("nested op hasn't been verified"); 1402 1403 emitRemark("success run of region verifier"); 1404 1405 return success(); 1406 } 1407 1408 #include "TestOpEnums.cpp.inc" 1409 #include "TestOpInterfaces.cpp.inc" 1410 #include "TestOpStructs.cpp.inc" 1411 #include "TestTypeInterfaces.cpp.inc" 1412 1413 #define GET_OP_CLASSES 1414 #include "TestOps.cpp.inc" 1415