1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestAttributes.h" 11 #include "TestInterfaces.h" 12 #include "TestTypes.h" 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/DLTI/DLTI.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/DialectImplementation.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/IR/TypeUtilities.h" 21 #include "mlir/Reducer/ReductionPatternInterface.h" 22 #include "mlir/Transforms/FoldUtils.h" 23 #include "mlir/Transforms/InliningUtils.h" 24 #include "llvm/ADT/StringSwitch.h" 25 26 // Include this before the using namespace lines below to 27 // test that we don't have namespace dependencies. 28 #include "TestOpsDialect.cpp.inc" 29 30 using namespace mlir; 31 using namespace test; 32 33 void test::registerTestDialect(DialectRegistry ®istry) { 34 registry.insert<TestDialect>(); 35 } 36 37 //===----------------------------------------------------------------------===// 38 // TestDialect Interfaces 39 //===----------------------------------------------------------------------===// 40 41 namespace { 42 43 /// Testing the correctness of some traits. 44 static_assert( 45 llvm::is_detected<OpTrait::has_implicit_terminator_t, 46 SingleBlockImplicitTerminatorOp>::value, 47 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp"); 48 static_assert(OpTrait::hasSingleBlockImplicitTerminator< 49 SingleBlockImplicitTerminatorOp>::value, 50 "hasSingleBlockImplicitTerminator does not match " 51 "SingleBlockImplicitTerminatorOp"); 52 53 // Test support for interacting with the AsmPrinter. 54 struct TestOpAsmInterface : public OpAsmDialectInterface { 55 using OpAsmDialectInterface::OpAsmDialectInterface; 56 57 AliasResult getAlias(Attribute attr, raw_ostream &os) const final { 58 StringAttr strAttr = attr.dyn_cast<StringAttr>(); 59 if (!strAttr) 60 return AliasResult::NoAlias; 61 62 // Check the contents of the string attribute to see what the test alias 63 // should be named. 64 Optional<StringRef> aliasName = 65 StringSwitch<Optional<StringRef>>(strAttr.getValue()) 66 .Case("alias_test:dot_in_name", StringRef("test.alias")) 67 .Case("alias_test:trailing_digit", StringRef("test_alias0")) 68 .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) 69 .Case("alias_test:sanitize_conflict_a", 70 StringRef("test_alias_conflict0")) 71 .Case("alias_test:sanitize_conflict_b", 72 StringRef("test_alias_conflict0_")) 73 .Case("alias_test:tensor_encoding", StringRef("test_encoding")) 74 .Default(llvm::None); 75 if (!aliasName) 76 return AliasResult::NoAlias; 77 78 os << *aliasName; 79 return AliasResult::FinalAlias; 80 } 81 82 AliasResult getAlias(Type type, raw_ostream &os) const final { 83 if (auto tupleType = type.dyn_cast<TupleType>()) { 84 if (tupleType.size() > 0 && 85 llvm::all_of(tupleType.getTypes(), [](Type elemType) { 86 return elemType.isa<SimpleAType>(); 87 })) { 88 os << "test_tuple"; 89 return AliasResult::FinalAlias; 90 } 91 } 92 if (auto intType = type.dyn_cast<TestIntegerType>()) { 93 if (intType.getSignedness() == 94 TestIntegerType::SignednessSemantics::Unsigned && 95 intType.getWidth() == 8) { 96 os << "test_ui8"; 97 return AliasResult::FinalAlias; 98 } 99 } 100 return AliasResult::NoAlias; 101 } 102 103 void getAsmResultNames(Operation *op, 104 OpAsmSetValueNameFn setNameFn) const final { 105 if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) 106 setNameFn(asmOp, "result"); 107 } 108 109 void getAsmBlockArgumentNames(Block *block, 110 OpAsmSetValueNameFn setNameFn) const final { 111 auto op = block->getParentOp(); 112 auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names"); 113 if (!arrayAttr) 114 return; 115 auto args = block->getArguments(); 116 auto e = std::min(arrayAttr.size(), args.size()); 117 for (unsigned i = 0; i < e; ++i) { 118 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 119 setNameFn(args[i], strAttr.getValue()); 120 } 121 } 122 }; 123 124 struct TestDialectFoldInterface : public DialectFoldInterface { 125 using DialectFoldInterface::DialectFoldInterface; 126 127 /// Registered hook to check if the given region, which is attached to an 128 /// operation that is *not* isolated from above, should be used when 129 /// materializing constants. 130 bool shouldMaterializeInto(Region *region) const final { 131 // If this is a one region operation, then insert into it. 132 return isa<OneRegionOp>(region->getParentOp()); 133 } 134 }; 135 136 /// This class defines the interface for handling inlining with standard 137 /// operations. 138 struct TestInlinerInterface : public DialectInlinerInterface { 139 using DialectInlinerInterface::DialectInlinerInterface; 140 141 //===--------------------------------------------------------------------===// 142 // Analysis Hooks 143 //===--------------------------------------------------------------------===// 144 145 bool isLegalToInline(Operation *call, Operation *callable, 146 bool wouldBeCloned) const final { 147 // Don't allow inlining calls that are marked `noinline`. 148 return !call->hasAttr("noinline"); 149 } 150 bool isLegalToInline(Region *, Region *, bool, 151 BlockAndValueMapping &) const final { 152 // Inlining into test dialect regions is legal. 153 return true; 154 } 155 bool isLegalToInline(Operation *, Region *, bool, 156 BlockAndValueMapping &) const final { 157 return true; 158 } 159 160 bool shouldAnalyzeRecursively(Operation *op) const final { 161 // Analyze recursively if this is not a functional region operation, it 162 // froms a separate functional scope. 163 return !isa<FunctionalRegionOp>(op); 164 } 165 166 //===--------------------------------------------------------------------===// 167 // Transformation Hooks 168 //===--------------------------------------------------------------------===// 169 170 /// Handle the given inlined terminator by replacing it with a new operation 171 /// as necessary. 172 void handleTerminator(Operation *op, 173 ArrayRef<Value> valuesToRepl) const final { 174 // Only handle "test.return" here. 175 auto returnOp = dyn_cast<TestReturnOp>(op); 176 if (!returnOp) 177 return; 178 179 // Replace the values directly with the return operands. 180 assert(returnOp.getNumOperands() == valuesToRepl.size()); 181 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 182 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 183 } 184 185 /// Attempt to materialize a conversion for a type mismatch between a call 186 /// from this dialect, and a callable region. This method should generate an 187 /// operation that takes 'input' as the only operand, and produces a single 188 /// result of 'resultType'. If a conversion can not be generated, nullptr 189 /// should be returned. 190 Operation *materializeCallConversion(OpBuilder &builder, Value input, 191 Type resultType, 192 Location conversionLoc) const final { 193 // Only allow conversion for i16/i32 types. 194 if (!(resultType.isSignlessInteger(16) || 195 resultType.isSignlessInteger(32)) || 196 !(input.getType().isSignlessInteger(16) || 197 input.getType().isSignlessInteger(32))) 198 return nullptr; 199 return builder.create<TestCastOp>(conversionLoc, resultType, input); 200 } 201 202 void processInlinedCallBlocks( 203 Operation *call, 204 iterator_range<Region::iterator> inlinedBlocks) const final { 205 if (!isa<ConversionCallOp>(call)) 206 return; 207 208 // Set attributed on all ops in the inlined blocks. 209 for (Block &block : inlinedBlocks) { 210 block.walk([&](Operation *op) { 211 op->setAttr("inlined_conversion", UnitAttr::get(call->getContext())); 212 }); 213 } 214 } 215 }; 216 217 struct TestReductionPatternInterface : public DialectReductionPatternInterface { 218 public: 219 TestReductionPatternInterface(Dialect *dialect) 220 : DialectReductionPatternInterface(dialect) {} 221 222 void populateReductionPatterns(RewritePatternSet &patterns) const final { 223 populateTestReductionPatterns(patterns); 224 } 225 }; 226 227 } // namespace 228 229 //===----------------------------------------------------------------------===// 230 // TestDialect 231 //===----------------------------------------------------------------------===// 232 233 static void testSideEffectOpGetEffect( 234 Operation *op, 235 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects); 236 237 // This is the implementation of a dialect fallback for `TestEffectOpInterface`. 238 struct TestOpEffectInterfaceFallback 239 : public TestEffectOpInterface::FallbackModel< 240 TestOpEffectInterfaceFallback> { 241 static bool classof(Operation *op) { 242 bool isSupportedOp = 243 op->getName().getStringRef() == "test.unregistered_side_effect_op"; 244 assert(isSupportedOp && "Unexpected dispatch"); 245 return isSupportedOp; 246 } 247 248 void 249 getEffects(Operation *op, 250 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 251 &effects) const { 252 testSideEffectOpGetEffect(op, effects); 253 } 254 }; 255 256 void TestDialect::initialize() { 257 registerAttributes(); 258 registerTypes(); 259 addOperations< 260 #define GET_OP_LIST 261 #include "TestOps.cpp.inc" 262 >(); 263 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 264 TestInlinerInterface, TestReductionPatternInterface>(); 265 allowUnknownOperations(); 266 267 // Instantiate our fallback op interface that we'll use on specific 268 // unregistered op. 269 fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback; 270 } 271 TestDialect::~TestDialect() { 272 delete static_cast<TestOpEffectInterfaceFallback *>( 273 fallbackEffectOpInterfaces); 274 } 275 276 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, 277 Type type, Location loc) { 278 return builder.create<TestOpConstant>(loc, type, value); 279 } 280 281 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, 282 OperationName opName) { 283 if (opName.getIdentifier() == "test.unregistered_side_effect_op" && 284 typeID == TypeID::get<TestEffectOpInterface>()) 285 return fallbackEffectOpInterfaces; 286 return nullptr; 287 } 288 289 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 290 NamedAttribute namedAttr) { 291 if (namedAttr.getName() == "test.invalid_attr") 292 return op->emitError() << "invalid to use 'test.invalid_attr'"; 293 return success(); 294 } 295 296 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 297 unsigned regionIndex, 298 unsigned argIndex, 299 NamedAttribute namedAttr) { 300 if (namedAttr.getName() == "test.invalid_attr") 301 return op->emitError() << "invalid to use 'test.invalid_attr'"; 302 return success(); 303 } 304 305 LogicalResult 306 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 307 unsigned resultIndex, 308 NamedAttribute namedAttr) { 309 if (namedAttr.getName() == "test.invalid_attr") 310 return op->emitError() << "invalid to use 'test.invalid_attr'"; 311 return success(); 312 } 313 314 Optional<Dialect::ParseOpHook> 315 TestDialect::getParseOperationHook(StringRef opName) const { 316 if (opName == "test.dialect_custom_printer") { 317 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 318 return parser.parseKeyword("custom_format"); 319 }}; 320 } 321 if (opName == "test.dialect_custom_format_fallback") { 322 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 323 return parser.parseKeyword("custom_format_fallback"); 324 }}; 325 } 326 return None; 327 } 328 329 llvm::unique_function<void(Operation *, OpAsmPrinter &)> 330 TestDialect::getOperationPrinter(Operation *op) const { 331 StringRef opName = op->getName().getStringRef(); 332 if (opName == "test.dialect_custom_printer") { 333 return [](Operation *op, OpAsmPrinter &printer) { 334 printer.getStream() << " custom_format"; 335 }; 336 } 337 if (opName == "test.dialect_custom_format_fallback") { 338 return [](Operation *op, OpAsmPrinter &printer) { 339 printer.getStream() << " custom_format_fallback"; 340 }; 341 } 342 return {}; 343 } 344 345 //===----------------------------------------------------------------------===// 346 // TestBranchOp 347 //===----------------------------------------------------------------------===// 348 349 Optional<MutableOperandRange> 350 TestBranchOp::getMutableSuccessorOperands(unsigned index) { 351 assert(index == 0 && "invalid successor index"); 352 return getTargetOperandsMutable(); 353 } 354 355 //===----------------------------------------------------------------------===// 356 // TestDialectCanonicalizerOp 357 //===----------------------------------------------------------------------===// 358 359 static LogicalResult 360 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op, 361 PatternRewriter &rewriter) { 362 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 363 op, rewriter.getI32IntegerAttr(42)); 364 return success(); 365 } 366 367 void TestDialect::getCanonicalizationPatterns( 368 RewritePatternSet &results) const { 369 results.add(&dialectCanonicalizationPattern); 370 } 371 372 //===----------------------------------------------------------------------===// 373 // TestFoldToCallOp 374 //===----------------------------------------------------------------------===// 375 376 namespace { 377 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 378 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 379 380 LogicalResult matchAndRewrite(FoldToCallOp op, 381 PatternRewriter &rewriter) const override { 382 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.getCalleeAttr(), 383 ValueRange()); 384 return success(); 385 } 386 }; 387 } // namespace 388 389 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results, 390 MLIRContext *context) { 391 results.add<FoldToCallOpPattern>(context); 392 } 393 394 //===----------------------------------------------------------------------===// 395 // Test Format* operations 396 //===----------------------------------------------------------------------===// 397 398 //===----------------------------------------------------------------------===// 399 // Parsing 400 401 static ParseResult parseCustomDirectiveOperands( 402 OpAsmParser &parser, OpAsmParser::OperandType &operand, 403 Optional<OpAsmParser::OperandType> &optOperand, 404 SmallVectorImpl<OpAsmParser::OperandType> &varOperands) { 405 if (parser.parseOperand(operand)) 406 return failure(); 407 if (succeeded(parser.parseOptionalComma())) { 408 optOperand.emplace(); 409 if (parser.parseOperand(*optOperand)) 410 return failure(); 411 } 412 if (parser.parseArrow() || parser.parseLParen() || 413 parser.parseOperandList(varOperands) || parser.parseRParen()) 414 return failure(); 415 return success(); 416 } 417 static ParseResult 418 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 419 Type &optOperandType, 420 SmallVectorImpl<Type> &varOperandTypes) { 421 if (parser.parseColon()) 422 return failure(); 423 424 if (parser.parseType(operandType)) 425 return failure(); 426 if (succeeded(parser.parseOptionalComma())) { 427 if (parser.parseType(optOperandType)) 428 return failure(); 429 } 430 if (parser.parseArrow() || parser.parseLParen() || 431 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 432 return failure(); 433 return success(); 434 } 435 static ParseResult 436 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 437 Type optOperandType, 438 const SmallVectorImpl<Type> &varOperandTypes) { 439 if (parser.parseKeyword("type_refs_capture")) 440 return failure(); 441 442 Type operandType2, optOperandType2; 443 SmallVector<Type, 1> varOperandTypes2; 444 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 445 varOperandTypes2)) 446 return failure(); 447 448 if (operandType != operandType2 || optOperandType != optOperandType2 || 449 varOperandTypes != varOperandTypes2) 450 return failure(); 451 452 return success(); 453 } 454 static ParseResult parseCustomDirectiveOperandsAndTypes( 455 OpAsmParser &parser, OpAsmParser::OperandType &operand, 456 Optional<OpAsmParser::OperandType> &optOperand, 457 SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType, 458 Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) { 459 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 460 parseCustomDirectiveResults(parser, operandType, optOperandType, 461 varOperandTypes)) 462 return failure(); 463 return success(); 464 } 465 static ParseResult parseCustomDirectiveRegions( 466 OpAsmParser &parser, Region ®ion, 467 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 468 if (parser.parseRegion(region)) 469 return failure(); 470 if (failed(parser.parseOptionalComma())) 471 return success(); 472 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 473 if (parser.parseRegion(*varRegion)) 474 return failure(); 475 varRegions.emplace_back(std::move(varRegion)); 476 return success(); 477 } 478 static ParseResult 479 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 480 SmallVectorImpl<Block *> &varSuccessors) { 481 if (parser.parseSuccessor(successor)) 482 return failure(); 483 if (failed(parser.parseOptionalComma())) 484 return success(); 485 Block *varSuccessor; 486 if (parser.parseSuccessor(varSuccessor)) 487 return failure(); 488 varSuccessors.append(2, varSuccessor); 489 return success(); 490 } 491 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 492 IntegerAttr &attr, 493 IntegerAttr &optAttr) { 494 if (parser.parseAttribute(attr)) 495 return failure(); 496 if (succeeded(parser.parseOptionalComma())) { 497 if (parser.parseAttribute(optAttr)) 498 return failure(); 499 } 500 return success(); 501 } 502 503 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, 504 NamedAttrList &attrs) { 505 return parser.parseOptionalAttrDict(attrs); 506 } 507 static ParseResult parseCustomDirectiveOptionalOperandRef( 508 OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) { 509 int64_t operandCount = 0; 510 if (parser.parseInteger(operandCount)) 511 return failure(); 512 bool expectedOptionalOperand = operandCount == 0; 513 return success(expectedOptionalOperand != optOperand.hasValue()); 514 } 515 516 //===----------------------------------------------------------------------===// 517 // Printing 518 519 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 520 Value operand, Value optOperand, 521 OperandRange varOperands) { 522 printer << operand; 523 if (optOperand) 524 printer << ", " << optOperand; 525 printer << " -> (" << varOperands << ")"; 526 } 527 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 528 Type operandType, Type optOperandType, 529 TypeRange varOperandTypes) { 530 printer << " : " << operandType; 531 if (optOperandType) 532 printer << ", " << optOperandType; 533 printer << " -> (" << varOperandTypes << ")"; 534 } 535 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 536 Operation *op, Type operandType, 537 Type optOperandType, 538 TypeRange varOperandTypes) { 539 printer << " type_refs_capture "; 540 printCustomDirectiveResults(printer, op, operandType, optOperandType, 541 varOperandTypes); 542 } 543 static void printCustomDirectiveOperandsAndTypes( 544 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 545 OperandRange varOperands, Type operandType, Type optOperandType, 546 TypeRange varOperandTypes) { 547 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 548 printCustomDirectiveResults(printer, op, operandType, optOperandType, 549 varOperandTypes); 550 } 551 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 552 Region ®ion, 553 MutableArrayRef<Region> varRegions) { 554 printer.printRegion(region); 555 if (!varRegions.empty()) { 556 printer << ", "; 557 for (Region ®ion : varRegions) 558 printer.printRegion(region); 559 } 560 } 561 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 562 Block *successor, 563 SuccessorRange varSuccessors) { 564 printer << successor; 565 if (!varSuccessors.empty()) 566 printer << ", " << varSuccessors.front(); 567 } 568 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 569 Attribute attribute, 570 Attribute optAttribute) { 571 printer << attribute; 572 if (optAttribute) 573 printer << ", " << optAttribute; 574 } 575 576 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 577 DictionaryAttr attrs) { 578 printer.printOptionalAttrDict(attrs.getValue()); 579 } 580 581 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, 582 Operation *op, 583 Value optOperand) { 584 printer << (optOperand ? "1" : "0"); 585 } 586 587 //===----------------------------------------------------------------------===// 588 // Test IsolatedRegionOp - parse passthrough region arguments. 589 //===----------------------------------------------------------------------===// 590 591 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, 592 OperationState &result) { 593 OpAsmParser::OperandType argInfo; 594 Type argType = parser.getBuilder().getIndexType(); 595 596 // Parse the input operand. 597 if (parser.parseOperand(argInfo) || 598 parser.resolveOperand(argInfo, argType, result.operands)) 599 return failure(); 600 601 // Parse the body region, and reuse the operand info as the argument info. 602 Region *body = result.addRegion(); 603 return parser.parseRegion(*body, argInfo, argType, 604 /*enableNameShadowing=*/true); 605 } 606 607 static void print(OpAsmPrinter &p, IsolatedRegionOp op) { 608 p << "test.isolated_region "; 609 p.printOperand(op.getOperand()); 610 p.shadowRegionArgs(op.getRegion(), op.getOperand()); 611 p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false); 612 } 613 614 //===----------------------------------------------------------------------===// 615 // Test SSACFGRegionOp 616 //===----------------------------------------------------------------------===// 617 618 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 619 return RegionKind::SSACFG; 620 } 621 622 //===----------------------------------------------------------------------===// 623 // Test GraphRegionOp 624 //===----------------------------------------------------------------------===// 625 626 static ParseResult parseGraphRegionOp(OpAsmParser &parser, 627 OperationState &result) { 628 // Parse the body region, and reuse the operand info as the argument info. 629 Region *body = result.addRegion(); 630 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 631 } 632 633 static void print(OpAsmPrinter &p, GraphRegionOp op) { 634 p << "test.graph_region "; 635 p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false); 636 } 637 638 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 639 return RegionKind::Graph; 640 } 641 642 //===----------------------------------------------------------------------===// 643 // Test AffineScopeOp 644 //===----------------------------------------------------------------------===// 645 646 static ParseResult parseAffineScopeOp(OpAsmParser &parser, 647 OperationState &result) { 648 // Parse the body region, and reuse the operand info as the argument info. 649 Region *body = result.addRegion(); 650 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 651 } 652 653 static void print(OpAsmPrinter &p, AffineScopeOp op) { 654 p << "test.affine_scope "; 655 p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false); 656 } 657 658 //===----------------------------------------------------------------------===// 659 // Test parser. 660 //===----------------------------------------------------------------------===// 661 662 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser, 663 OperationState &result) { 664 if (parser.parseOptionalColon()) 665 return success(); 666 uint64_t numResults; 667 if (parser.parseInteger(numResults)) 668 return failure(); 669 670 IndexType type = parser.getBuilder().getIndexType(); 671 for (unsigned i = 0; i < numResults; ++i) 672 result.addTypes(type); 673 return success(); 674 } 675 676 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) { 677 if (unsigned numResults = op->getNumResults()) 678 p << " : " << numResults; 679 } 680 681 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser, 682 OperationState &result) { 683 StringRef keyword; 684 if (parser.parseKeyword(&keyword)) 685 return failure(); 686 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 687 return success(); 688 } 689 690 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) { 691 p << " " << op.getKeyword(); 692 } 693 694 //===----------------------------------------------------------------------===// 695 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 696 697 static ParseResult parseWrappingRegionOp(OpAsmParser &parser, 698 OperationState &result) { 699 if (parser.parseKeyword("wraps")) 700 return failure(); 701 702 // Parse the wrapped op in a region 703 Region &body = *result.addRegion(); 704 body.push_back(new Block); 705 Block &block = body.back(); 706 Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); 707 if (!wrapped_op) 708 return failure(); 709 710 // Create a return terminator in the inner region, pass as operand to the 711 // terminator the returned values from the wrapped operation. 712 SmallVector<Value, 8> return_operands(wrapped_op->getResults()); 713 OpBuilder builder(parser.getContext()); 714 builder.setInsertionPointToEnd(&block); 715 builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands); 716 717 // Get the results type for the wrapping op from the terminator operands. 718 Operation &return_op = body.back().back(); 719 result.types.append(return_op.operand_type_begin(), 720 return_op.operand_type_end()); 721 722 // Use the location of the wrapped op for the "test.wrapping_region" op. 723 result.location = wrapped_op->getLoc(); 724 725 return success(); 726 } 727 728 static void print(OpAsmPrinter &p, WrappingRegionOp op) { 729 p << " wraps "; 730 p.printGenericOp(&op.getRegion().front().front()); 731 } 732 733 //===----------------------------------------------------------------------===// 734 // Test PrettyPrintedRegionOp - exercising the following parser APIs 735 // parseGenericOperationAfterOpName 736 // parseCustomOperationName 737 //===----------------------------------------------------------------------===// 738 739 static ParseResult parsePrettyPrintedRegionOp(OpAsmParser &parser, 740 OperationState &result) { 741 742 llvm::SMLoc loc = parser.getCurrentLocation(); 743 Location currLocation = parser.getEncodedSourceLoc(loc); 744 745 // Parse the operands. 746 SmallVector<OpAsmParser::OperandType, 2> operands; 747 if (parser.parseOperandList(operands)) 748 return failure(); 749 750 // Check if we are parsing the pretty-printed version 751 // test.pretty_printed_region start <inner-op> end : <functional-type> 752 // Else fallback to parsing the "non pretty-printed" version. 753 if (!succeeded(parser.parseOptionalKeyword("start"))) 754 return parser.parseGenericOperationAfterOpName( 755 result, llvm::makeArrayRef(operands)); 756 757 FailureOr<OperationName> parseOpNameInfo = parser.parseCustomOperationName(); 758 if (failed(parseOpNameInfo)) 759 return failure(); 760 761 StringRef innerOpName = parseOpNameInfo->getStringRef(); 762 763 FunctionType opFntype; 764 Optional<Location> explicitLoc; 765 if (parser.parseKeyword("end") || parser.parseColon() || 766 parser.parseType(opFntype) || 767 parser.parseOptionalLocationSpecifier(explicitLoc)) 768 return failure(); 769 770 // If location of the op is explicitly provided, then use it; Else use 771 // the parser's current location. 772 Location opLoc = explicitLoc.getValueOr(currLocation); 773 774 // Derive the SSA-values for op's operands. 775 if (parser.resolveOperands(operands, opFntype.getInputs(), loc, 776 result.operands)) 777 return failure(); 778 779 // Add a region for op. 780 Region ®ion = *result.addRegion(); 781 782 // Create a basic-block inside op's region. 783 Block &block = region.emplaceBlock(); 784 785 // Create and insert an "inner-op" operation in the block. 786 // Just for testing purposes, we can assume that inner op is a binary op with 787 // result and operand types all same as the test-op's first operand. 788 Type innerOpType = opFntype.getInput(0); 789 Value lhs = block.addArgument(innerOpType, opLoc); 790 Value rhs = block.addArgument(innerOpType, opLoc); 791 792 OpBuilder builder(parser.getBuilder().getContext()); 793 builder.setInsertionPointToStart(&block); 794 795 OperationState innerOpState(opLoc, innerOpName); 796 innerOpState.operands.push_back(lhs); 797 innerOpState.operands.push_back(rhs); 798 innerOpState.addTypes(innerOpType); 799 800 Operation *innerOp = builder.createOperation(innerOpState); 801 802 // Insert a return statement in the block returning the inner-op's result. 803 builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults()); 804 805 // Populate the op operation-state with result-type and location. 806 result.addTypes(opFntype.getResults()); 807 result.location = innerOp->getLoc(); 808 809 return success(); 810 } 811 812 static void print(OpAsmPrinter &p, PrettyPrintedRegionOp op) { 813 p << ' '; 814 p.printOperands(op.getOperands()); 815 816 Operation &innerOp = op.getRegion().front().front(); 817 // Assuming that region has a single non-terminator inner-op, if the inner-op 818 // meets some criteria (which in this case is a simple one based on the name 819 // of inner-op), then we can print the entire region in a succinct way. 820 // Here we assume that the prototype of "special.op" can be trivially derived 821 // while parsing it back. 822 if (innerOp.getName().getStringRef().equals("special.op")) { 823 p << " start special.op end"; 824 } else { 825 p << " ("; 826 p.printRegion(op.getRegion()); 827 p << ")"; 828 } 829 830 p << " : "; 831 p.printFunctionalType(op); 832 } 833 834 //===----------------------------------------------------------------------===// 835 // Test PolyForOp - parse list of region arguments. 836 //===----------------------------------------------------------------------===// 837 838 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { 839 SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 840 // Parse list of region arguments without a delimiter. 841 if (parser.parseRegionArgumentList(ivsInfo)) 842 return failure(); 843 844 // Parse the body region. 845 Region *body = result.addRegion(); 846 auto &builder = parser.getBuilder(); 847 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 848 return parser.parseRegion(*body, ivsInfo, argTypes); 849 } 850 851 //===----------------------------------------------------------------------===// 852 // Test removing op with inner ops. 853 //===----------------------------------------------------------------------===// 854 855 namespace { 856 struct TestRemoveOpWithInnerOps 857 : public OpRewritePattern<TestOpWithRegionPattern> { 858 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 859 860 void initialize() { setDebugName("TestRemoveOpWithInnerOps"); } 861 862 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 863 PatternRewriter &rewriter) const override { 864 rewriter.eraseOp(op); 865 return success(); 866 } 867 }; 868 } // namespace 869 870 void TestOpWithRegionPattern::getCanonicalizationPatterns( 871 RewritePatternSet &results, MLIRContext *context) { 872 results.add<TestRemoveOpWithInnerOps>(context); 873 } 874 875 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 876 return getOperand(); 877 } 878 879 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) { 880 return getValue(); 881 } 882 883 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 884 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 885 for (Value input : this->getOperands()) { 886 results.push_back(input); 887 } 888 return success(); 889 } 890 891 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 892 assert(operands.size() == 1); 893 if (operands.front()) { 894 (*this)->setAttr("attr", operands.front()); 895 return getResult(); 896 } 897 return {}; 898 } 899 900 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) { 901 return getOperand(); 902 } 903 904 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 905 MLIRContext *, Optional<Location> location, ValueRange operands, 906 DictionaryAttr attributes, RegionRange regions, 907 SmallVectorImpl<Type> &inferredReturnTypes) { 908 if (operands[0].getType() != operands[1].getType()) { 909 return emitOptionalError(location, "operand type mismatch ", 910 operands[0].getType(), " vs ", 911 operands[1].getType()); 912 } 913 inferredReturnTypes.assign({operands[0].getType()}); 914 return success(); 915 } 916 917 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 918 MLIRContext *context, Optional<Location> location, ValueShapeRange operands, 919 DictionaryAttr attributes, RegionRange regions, 920 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 921 // Create return type consisting of the last element of the first operand. 922 auto operandType = operands.front().getType(); 923 auto sval = operandType.dyn_cast<ShapedType>(); 924 if (!sval) { 925 return emitOptionalError(location, "only shaped type operands allowed"); 926 } 927 int64_t dim = 928 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 929 auto type = IntegerType::get(context, 17); 930 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 931 return success(); 932 } 933 934 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 935 OpBuilder &builder, ValueRange operands, 936 llvm::SmallVectorImpl<Value> &shapes) { 937 shapes = SmallVector<Value, 1>{ 938 builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)}; 939 return success(); 940 } 941 942 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( 943 OpBuilder &builder, ValueRange operands, 944 llvm::SmallVectorImpl<Value> &shapes) { 945 Location loc = getLoc(); 946 shapes.reserve(operands.size()); 947 for (Value operand : llvm::reverse(operands)) { 948 auto currShape = llvm::to_vector<4>(llvm::map_range( 949 llvm::seq<int64_t>( 950 0, operand.getType().cast<RankedTensorType>().getRank()), 951 [&](int64_t dim) -> Value { 952 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 953 })); 954 shapes.push_back(builder.create<tensor::FromElementsOp>( 955 getLoc(), builder.getIndexType(), currShape)); 956 } 957 return success(); 958 } 959 960 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( 961 OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { 962 Location loc = getLoc(); 963 shapes.reserve(getNumOperands()); 964 for (Value operand : llvm::reverse(getOperands())) { 965 auto currShape = llvm::to_vector<4>(llvm::map_range( 966 llvm::seq<int64_t>( 967 0, operand.getType().cast<RankedTensorType>().getRank()), 968 [&](int64_t dim) -> Value { 969 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 970 })); 971 shapes.emplace_back(std::move(currShape)); 972 } 973 return success(); 974 } 975 976 //===----------------------------------------------------------------------===// 977 // Test SideEffect interfaces 978 //===----------------------------------------------------------------------===// 979 980 namespace { 981 /// A test resource for side effects. 982 struct TestResource : public SideEffects::Resource::Base<TestResource> { 983 StringRef getName() final { return "<Test>"; } 984 }; 985 } // namespace 986 987 static void testSideEffectOpGetEffect( 988 Operation *op, 989 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 990 &effects) { 991 auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter"); 992 if (!effectsAttr) 993 return; 994 995 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); 996 } 997 998 void SideEffectOp::getEffects( 999 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1000 // Check for an effects attribute on the op instance. 1001 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects"); 1002 if (!effectsAttr) 1003 return; 1004 1005 // If there is one, it is an array of dictionary attributes that hold 1006 // information on the effects of this operation. 1007 for (Attribute element : effectsAttr) { 1008 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 1009 1010 // Get the specific memory effect. 1011 MemoryEffects::Effect *effect = 1012 StringSwitch<MemoryEffects::Effect *>( 1013 effectElement.get("effect").cast<StringAttr>().getValue()) 1014 .Case("allocate", MemoryEffects::Allocate::get()) 1015 .Case("free", MemoryEffects::Free::get()) 1016 .Case("read", MemoryEffects::Read::get()) 1017 .Case("write", MemoryEffects::Write::get()); 1018 1019 // Check for a non-default resource to use. 1020 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 1021 if (effectElement.get("test_resource")) 1022 resource = TestResource::get(); 1023 1024 // Check for a result to affect. 1025 if (effectElement.get("on_result")) 1026 effects.emplace_back(effect, getResult(), resource); 1027 else if (Attribute ref = effectElement.get("on_reference")) 1028 effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource); 1029 else 1030 effects.emplace_back(effect, resource); 1031 } 1032 } 1033 1034 void SideEffectOp::getEffects( 1035 SmallVectorImpl<TestEffects::EffectInstance> &effects) { 1036 testSideEffectOpGetEffect(getOperation(), effects); 1037 } 1038 1039 //===----------------------------------------------------------------------===// 1040 // StringAttrPrettyNameOp 1041 //===----------------------------------------------------------------------===// 1042 1043 // This op has fancy handling of its SSA result name. 1044 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, 1045 OperationState &result) { 1046 // Add the result types. 1047 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 1048 result.addTypes(parser.getBuilder().getIntegerType(32)); 1049 1050 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 1051 return failure(); 1052 1053 // If the attribute dictionary contains no 'names' attribute, infer it from 1054 // the SSA name (if specified). 1055 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 1056 return attr.getName() == "names"; 1057 }); 1058 1059 // If there was no name specified, check to see if there was a useful name 1060 // specified in the asm file. 1061 if (hadNames || parser.getNumResults() == 0) 1062 return success(); 1063 1064 SmallVector<StringRef, 4> names; 1065 auto *context = result.getContext(); 1066 1067 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 1068 auto resultName = parser.getResultName(i); 1069 StringRef nameStr; 1070 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 1071 nameStr = resultName.first; 1072 1073 names.push_back(nameStr); 1074 } 1075 1076 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 1077 result.attributes.push_back({StringAttr::get(context, "names"), namesAttr}); 1078 return success(); 1079 } 1080 1081 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { 1082 // Note that we only need to print the "name" attribute if the asmprinter 1083 // result name disagrees with it. This can happen in strange cases, e.g. 1084 // when there are conflicts. 1085 bool namesDisagree = op.getNames().size() != op.getNumResults(); 1086 1087 SmallString<32> resultNameStr; 1088 for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { 1089 resultNameStr.clear(); 1090 llvm::raw_svector_ostream tmpStream(resultNameStr); 1091 p.printOperand(op.getResult(i), tmpStream); 1092 1093 auto expectedName = op.getNames()[i].dyn_cast<StringAttr>(); 1094 if (!expectedName || 1095 tmpStream.str().drop_front() != expectedName.getValue()) { 1096 namesDisagree = true; 1097 } 1098 } 1099 1100 if (namesDisagree) 1101 p.printOptionalAttrDictWithKeyword(op->getAttrs()); 1102 else 1103 p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"}); 1104 } 1105 1106 // We set the SSA name in the asm syntax to the contents of the name 1107 // attribute. 1108 void StringAttrPrettyNameOp::getAsmResultNames( 1109 function_ref<void(Value, StringRef)> setNameFn) { 1110 1111 auto value = getNames(); 1112 for (size_t i = 0, e = value.size(); i != e; ++i) 1113 if (auto str = value[i].dyn_cast<StringAttr>()) 1114 if (!str.getValue().empty()) 1115 setNameFn(getResult(i), str.getValue()); 1116 } 1117 1118 //===----------------------------------------------------------------------===// 1119 // RegionIfOp 1120 //===----------------------------------------------------------------------===// 1121 1122 static void print(OpAsmPrinter &p, RegionIfOp op) { 1123 p << " "; 1124 p.printOperands(op.getOperands()); 1125 p << ": " << op.getOperandTypes(); 1126 p.printArrowTypeList(op.getResultTypes()); 1127 p << " then"; 1128 p.printRegion(op.getThenRegion(), 1129 /*printEntryBlockArgs=*/true, 1130 /*printBlockTerminators=*/true); 1131 p << " else"; 1132 p.printRegion(op.getElseRegion(), 1133 /*printEntryBlockArgs=*/true, 1134 /*printBlockTerminators=*/true); 1135 p << " join"; 1136 p.printRegion(op.getJoinRegion(), 1137 /*printEntryBlockArgs=*/true, 1138 /*printBlockTerminators=*/true); 1139 } 1140 1141 static ParseResult parseRegionIfOp(OpAsmParser &parser, 1142 OperationState &result) { 1143 SmallVector<OpAsmParser::OperandType, 2> operandInfos; 1144 SmallVector<Type, 2> operandTypes; 1145 1146 result.regions.reserve(3); 1147 Region *thenRegion = result.addRegion(); 1148 Region *elseRegion = result.addRegion(); 1149 Region *joinRegion = result.addRegion(); 1150 1151 // Parse operand, type and arrow type lists. 1152 if (parser.parseOperandList(operandInfos) || 1153 parser.parseColonTypeList(operandTypes) || 1154 parser.parseArrowTypeList(result.types)) 1155 return failure(); 1156 1157 // Parse all attached regions. 1158 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 1159 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 1160 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 1161 return failure(); 1162 1163 return parser.resolveOperands(operandInfos, operandTypes, 1164 parser.getCurrentLocation(), result.operands); 1165 } 1166 1167 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 1168 assert(index < 2 && "invalid region index"); 1169 return getOperands(); 1170 } 1171 1172 void RegionIfOp::getSuccessorRegions( 1173 Optional<unsigned> index, ArrayRef<Attribute> operands, 1174 SmallVectorImpl<RegionSuccessor> ®ions) { 1175 // We always branch to the join region. 1176 if (index.hasValue()) { 1177 if (index.getValue() < 2) 1178 regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs())); 1179 else 1180 regions.push_back(RegionSuccessor(getResults())); 1181 return; 1182 } 1183 1184 // The then and else regions are the entry regions of this op. 1185 regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs())); 1186 regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs())); 1187 } 1188 1189 //===----------------------------------------------------------------------===// 1190 // SingleNoTerminatorCustomAsmOp 1191 //===----------------------------------------------------------------------===// 1192 1193 static ParseResult parseSingleNoTerminatorCustomAsmOp(OpAsmParser &parser, 1194 OperationState &state) { 1195 Region *body = state.addRegion(); 1196 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) 1197 return failure(); 1198 return success(); 1199 } 1200 1201 static void print(SingleNoTerminatorCustomAsmOp op, OpAsmPrinter &printer) { 1202 printer.printRegion( 1203 op.getRegion(), /*printEntryBlockArgs=*/false, 1204 // This op has a single block without terminators. But explicitly mark 1205 // as not printing block terminators for testing. 1206 /*printBlockTerminators=*/false); 1207 } 1208 1209 #include "TestOpEnums.cpp.inc" 1210 #include "TestOpInterfaces.cpp.inc" 1211 #include "TestOpStructs.cpp.inc" 1212 #include "TestTypeInterfaces.cpp.inc" 1213 1214 #define GET_OP_CLASSES 1215 #include "TestOps.cpp.inc" 1216