1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestAttributes.h" 11 #include "TestInterfaces.h" 12 #include "TestTypes.h" 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/DLTI/DLTI.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/DialectImplementation.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/IR/TypeUtilities.h" 21 #include "mlir/Reducer/ReductionPatternInterface.h" 22 #include "mlir/Transforms/FoldUtils.h" 23 #include "mlir/Transforms/InliningUtils.h" 24 #include "llvm/ADT/StringSwitch.h" 25 26 // Include this before the using namespace lines below to 27 // test that we don't have namespace dependencies. 28 #include "TestOpsDialect.cpp.inc" 29 30 using namespace mlir; 31 using namespace test; 32 33 void test::registerTestDialect(DialectRegistry ®istry) { 34 registry.insert<TestDialect>(); 35 } 36 37 //===----------------------------------------------------------------------===// 38 // TestDialect Interfaces 39 //===----------------------------------------------------------------------===// 40 41 namespace { 42 43 /// Testing the correctness of some traits. 44 static_assert( 45 llvm::is_detected<OpTrait::has_implicit_terminator_t, 46 SingleBlockImplicitTerminatorOp>::value, 47 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp"); 48 static_assert(OpTrait::hasSingleBlockImplicitTerminator< 49 SingleBlockImplicitTerminatorOp>::value, 50 "hasSingleBlockImplicitTerminator does not match " 51 "SingleBlockImplicitTerminatorOp"); 52 53 // Test support for interacting with the AsmPrinter. 54 struct TestOpAsmInterface : public OpAsmDialectInterface { 55 using OpAsmDialectInterface::OpAsmDialectInterface; 56 57 AliasResult getAlias(Attribute attr, raw_ostream &os) const final { 58 StringAttr strAttr = attr.dyn_cast<StringAttr>(); 59 if (!strAttr) 60 return AliasResult::NoAlias; 61 62 // Check the contents of the string attribute to see what the test alias 63 // should be named. 64 Optional<StringRef> aliasName = 65 StringSwitch<Optional<StringRef>>(strAttr.getValue()) 66 .Case("alias_test:dot_in_name", StringRef("test.alias")) 67 .Case("alias_test:trailing_digit", StringRef("test_alias0")) 68 .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) 69 .Case("alias_test:sanitize_conflict_a", 70 StringRef("test_alias_conflict0")) 71 .Case("alias_test:sanitize_conflict_b", 72 StringRef("test_alias_conflict0_")) 73 .Case("alias_test:tensor_encoding", StringRef("test_encoding")) 74 .Default(llvm::None); 75 if (!aliasName) 76 return AliasResult::NoAlias; 77 78 os << *aliasName; 79 return AliasResult::FinalAlias; 80 } 81 82 AliasResult getAlias(Type type, raw_ostream &os) const final { 83 if (auto tupleType = type.dyn_cast<TupleType>()) { 84 if (tupleType.size() > 0 && 85 llvm::all_of(tupleType.getTypes(), [](Type elemType) { 86 return elemType.isa<SimpleAType>(); 87 })) { 88 os << "test_tuple"; 89 return AliasResult::FinalAlias; 90 } 91 } 92 if (auto intType = type.dyn_cast<TestIntegerType>()) { 93 if (intType.getSignedness() == 94 TestIntegerType::SignednessSemantics::Unsigned && 95 intType.getWidth() == 8) { 96 os << "test_ui8"; 97 return AliasResult::FinalAlias; 98 } 99 } 100 return AliasResult::NoAlias; 101 } 102 103 void getAsmResultNames(Operation *op, 104 OpAsmSetValueNameFn setNameFn) const final { 105 if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) 106 setNameFn(asmOp, "result"); 107 } 108 }; 109 110 struct TestDialectFoldInterface : public DialectFoldInterface { 111 using DialectFoldInterface::DialectFoldInterface; 112 113 /// Registered hook to check if the given region, which is attached to an 114 /// operation that is *not* isolated from above, should be used when 115 /// materializing constants. 116 bool shouldMaterializeInto(Region *region) const final { 117 // If this is a one region operation, then insert into it. 118 return isa<OneRegionOp>(region->getParentOp()); 119 } 120 }; 121 122 /// This class defines the interface for handling inlining with standard 123 /// operations. 124 struct TestInlinerInterface : public DialectInlinerInterface { 125 using DialectInlinerInterface::DialectInlinerInterface; 126 127 //===--------------------------------------------------------------------===// 128 // Analysis Hooks 129 //===--------------------------------------------------------------------===// 130 131 bool isLegalToInline(Operation *call, Operation *callable, 132 bool wouldBeCloned) const final { 133 // Don't allow inlining calls that are marked `noinline`. 134 return !call->hasAttr("noinline"); 135 } 136 bool isLegalToInline(Region *, Region *, bool, 137 BlockAndValueMapping &) const final { 138 // Inlining into test dialect regions is legal. 139 return true; 140 } 141 bool isLegalToInline(Operation *, Region *, bool, 142 BlockAndValueMapping &) const final { 143 return true; 144 } 145 146 bool shouldAnalyzeRecursively(Operation *op) const final { 147 // Analyze recursively if this is not a functional region operation, it 148 // froms a separate functional scope. 149 return !isa<FunctionalRegionOp>(op); 150 } 151 152 //===--------------------------------------------------------------------===// 153 // Transformation Hooks 154 //===--------------------------------------------------------------------===// 155 156 /// Handle the given inlined terminator by replacing it with a new operation 157 /// as necessary. 158 void handleTerminator(Operation *op, 159 ArrayRef<Value> valuesToRepl) const final { 160 // Only handle "test.return" here. 161 auto returnOp = dyn_cast<TestReturnOp>(op); 162 if (!returnOp) 163 return; 164 165 // Replace the values directly with the return operands. 166 assert(returnOp.getNumOperands() == valuesToRepl.size()); 167 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 168 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 169 } 170 171 /// Attempt to materialize a conversion for a type mismatch between a call 172 /// from this dialect, and a callable region. This method should generate an 173 /// operation that takes 'input' as the only operand, and produces a single 174 /// result of 'resultType'. If a conversion can not be generated, nullptr 175 /// should be returned. 176 Operation *materializeCallConversion(OpBuilder &builder, Value input, 177 Type resultType, 178 Location conversionLoc) const final { 179 // Only allow conversion for i16/i32 types. 180 if (!(resultType.isSignlessInteger(16) || 181 resultType.isSignlessInteger(32)) || 182 !(input.getType().isSignlessInteger(16) || 183 input.getType().isSignlessInteger(32))) 184 return nullptr; 185 return builder.create<TestCastOp>(conversionLoc, resultType, input); 186 } 187 188 void processInlinedCallBlocks( 189 Operation *call, 190 iterator_range<Region::iterator> inlinedBlocks) const final { 191 if (!isa<ConversionCallOp>(call)) 192 return; 193 194 // Set attributed on all ops in the inlined blocks. 195 for (Block &block : inlinedBlocks) { 196 block.walk([&](Operation *op) { 197 op->setAttr("inlined_conversion", UnitAttr::get(call->getContext())); 198 }); 199 } 200 } 201 }; 202 203 struct TestReductionPatternInterface : public DialectReductionPatternInterface { 204 public: 205 TestReductionPatternInterface(Dialect *dialect) 206 : DialectReductionPatternInterface(dialect) {} 207 208 void populateReductionPatterns(RewritePatternSet &patterns) const final { 209 populateTestReductionPatterns(patterns); 210 } 211 }; 212 213 } // namespace 214 215 //===----------------------------------------------------------------------===// 216 // TestDialect 217 //===----------------------------------------------------------------------===// 218 219 static void testSideEffectOpGetEffect( 220 Operation *op, 221 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects); 222 223 // This is the implementation of a dialect fallback for `TestEffectOpInterface`. 224 struct TestOpEffectInterfaceFallback 225 : public TestEffectOpInterface::FallbackModel< 226 TestOpEffectInterfaceFallback> { 227 static bool classof(Operation *op) { 228 bool isSupportedOp = 229 op->getName().getStringRef() == "test.unregistered_side_effect_op"; 230 assert(isSupportedOp && "Unexpected dispatch"); 231 return isSupportedOp; 232 } 233 234 void 235 getEffects(Operation *op, 236 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 237 &effects) const { 238 testSideEffectOpGetEffect(op, effects); 239 } 240 }; 241 242 void TestDialect::initialize() { 243 registerAttributes(); 244 registerTypes(); 245 addOperations< 246 #define GET_OP_LIST 247 #include "TestOps.cpp.inc" 248 >(); 249 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 250 TestInlinerInterface, TestReductionPatternInterface>(); 251 allowUnknownOperations(); 252 253 // Instantiate our fallback op interface that we'll use on specific 254 // unregistered op. 255 fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback; 256 } 257 TestDialect::~TestDialect() { 258 delete static_cast<TestOpEffectInterfaceFallback *>( 259 fallbackEffectOpInterfaces); 260 } 261 262 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, 263 Type type, Location loc) { 264 return builder.create<TestOpConstant>(loc, type, value); 265 } 266 267 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, 268 OperationName opName) { 269 if (opName.getIdentifier() == "test.unregistered_side_effect_op" && 270 typeID == TypeID::get<TestEffectOpInterface>()) 271 return fallbackEffectOpInterfaces; 272 return nullptr; 273 } 274 275 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 276 NamedAttribute namedAttr) { 277 if (namedAttr.getName() == "test.invalid_attr") 278 return op->emitError() << "invalid to use 'test.invalid_attr'"; 279 return success(); 280 } 281 282 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 283 unsigned regionIndex, 284 unsigned argIndex, 285 NamedAttribute namedAttr) { 286 if (namedAttr.getName() == "test.invalid_attr") 287 return op->emitError() << "invalid to use 'test.invalid_attr'"; 288 return success(); 289 } 290 291 LogicalResult 292 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 293 unsigned resultIndex, 294 NamedAttribute namedAttr) { 295 if (namedAttr.getName() == "test.invalid_attr") 296 return op->emitError() << "invalid to use 'test.invalid_attr'"; 297 return success(); 298 } 299 300 Optional<Dialect::ParseOpHook> 301 TestDialect::getParseOperationHook(StringRef opName) const { 302 if (opName == "test.dialect_custom_printer") { 303 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 304 return parser.parseKeyword("custom_format"); 305 }}; 306 } 307 if (opName == "test.dialect_custom_format_fallback") { 308 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 309 return parser.parseKeyword("custom_format_fallback"); 310 }}; 311 } 312 return None; 313 } 314 315 llvm::unique_function<void(Operation *, OpAsmPrinter &)> 316 TestDialect::getOperationPrinter(Operation *op) const { 317 StringRef opName = op->getName().getStringRef(); 318 if (opName == "test.dialect_custom_printer") { 319 return [](Operation *op, OpAsmPrinter &printer) { 320 printer.getStream() << " custom_format"; 321 }; 322 } 323 if (opName == "test.dialect_custom_format_fallback") { 324 return [](Operation *op, OpAsmPrinter &printer) { 325 printer.getStream() << " custom_format_fallback"; 326 }; 327 } 328 return {}; 329 } 330 331 //===----------------------------------------------------------------------===// 332 // TestBranchOp 333 //===----------------------------------------------------------------------===// 334 335 Optional<MutableOperandRange> 336 TestBranchOp::getMutableSuccessorOperands(unsigned index) { 337 assert(index == 0 && "invalid successor index"); 338 return getTargetOperandsMutable(); 339 } 340 341 //===----------------------------------------------------------------------===// 342 // TestDialectCanonicalizerOp 343 //===----------------------------------------------------------------------===// 344 345 static LogicalResult 346 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op, 347 PatternRewriter &rewriter) { 348 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 349 op, rewriter.getI32IntegerAttr(42)); 350 return success(); 351 } 352 353 void TestDialect::getCanonicalizationPatterns( 354 RewritePatternSet &results) const { 355 results.add(&dialectCanonicalizationPattern); 356 } 357 358 //===----------------------------------------------------------------------===// 359 // TestFoldToCallOp 360 //===----------------------------------------------------------------------===// 361 362 namespace { 363 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 364 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 365 366 LogicalResult matchAndRewrite(FoldToCallOp op, 367 PatternRewriter &rewriter) const override { 368 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.getCalleeAttr(), 369 ValueRange()); 370 return success(); 371 } 372 }; 373 } // namespace 374 375 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results, 376 MLIRContext *context) { 377 results.add<FoldToCallOpPattern>(context); 378 } 379 380 //===----------------------------------------------------------------------===// 381 // Test Format* operations 382 //===----------------------------------------------------------------------===// 383 384 //===----------------------------------------------------------------------===// 385 // Parsing 386 387 static ParseResult parseCustomDirectiveOperands( 388 OpAsmParser &parser, OpAsmParser::OperandType &operand, 389 Optional<OpAsmParser::OperandType> &optOperand, 390 SmallVectorImpl<OpAsmParser::OperandType> &varOperands) { 391 if (parser.parseOperand(operand)) 392 return failure(); 393 if (succeeded(parser.parseOptionalComma())) { 394 optOperand.emplace(); 395 if (parser.parseOperand(*optOperand)) 396 return failure(); 397 } 398 if (parser.parseArrow() || parser.parseLParen() || 399 parser.parseOperandList(varOperands) || parser.parseRParen()) 400 return failure(); 401 return success(); 402 } 403 static ParseResult 404 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 405 Type &optOperandType, 406 SmallVectorImpl<Type> &varOperandTypes) { 407 if (parser.parseColon()) 408 return failure(); 409 410 if (parser.parseType(operandType)) 411 return failure(); 412 if (succeeded(parser.parseOptionalComma())) { 413 if (parser.parseType(optOperandType)) 414 return failure(); 415 } 416 if (parser.parseArrow() || parser.parseLParen() || 417 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 418 return failure(); 419 return success(); 420 } 421 static ParseResult 422 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 423 Type optOperandType, 424 const SmallVectorImpl<Type> &varOperandTypes) { 425 if (parser.parseKeyword("type_refs_capture")) 426 return failure(); 427 428 Type operandType2, optOperandType2; 429 SmallVector<Type, 1> varOperandTypes2; 430 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 431 varOperandTypes2)) 432 return failure(); 433 434 if (operandType != operandType2 || optOperandType != optOperandType2 || 435 varOperandTypes != varOperandTypes2) 436 return failure(); 437 438 return success(); 439 } 440 static ParseResult parseCustomDirectiveOperandsAndTypes( 441 OpAsmParser &parser, OpAsmParser::OperandType &operand, 442 Optional<OpAsmParser::OperandType> &optOperand, 443 SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType, 444 Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) { 445 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 446 parseCustomDirectiveResults(parser, operandType, optOperandType, 447 varOperandTypes)) 448 return failure(); 449 return success(); 450 } 451 static ParseResult parseCustomDirectiveRegions( 452 OpAsmParser &parser, Region ®ion, 453 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 454 if (parser.parseRegion(region)) 455 return failure(); 456 if (failed(parser.parseOptionalComma())) 457 return success(); 458 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 459 if (parser.parseRegion(*varRegion)) 460 return failure(); 461 varRegions.emplace_back(std::move(varRegion)); 462 return success(); 463 } 464 static ParseResult 465 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 466 SmallVectorImpl<Block *> &varSuccessors) { 467 if (parser.parseSuccessor(successor)) 468 return failure(); 469 if (failed(parser.parseOptionalComma())) 470 return success(); 471 Block *varSuccessor; 472 if (parser.parseSuccessor(varSuccessor)) 473 return failure(); 474 varSuccessors.append(2, varSuccessor); 475 return success(); 476 } 477 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 478 IntegerAttr &attr, 479 IntegerAttr &optAttr) { 480 if (parser.parseAttribute(attr)) 481 return failure(); 482 if (succeeded(parser.parseOptionalComma())) { 483 if (parser.parseAttribute(optAttr)) 484 return failure(); 485 } 486 return success(); 487 } 488 489 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, 490 NamedAttrList &attrs) { 491 return parser.parseOptionalAttrDict(attrs); 492 } 493 static ParseResult parseCustomDirectiveOptionalOperandRef( 494 OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) { 495 int64_t operandCount = 0; 496 if (parser.parseInteger(operandCount)) 497 return failure(); 498 bool expectedOptionalOperand = operandCount == 0; 499 return success(expectedOptionalOperand != optOperand.hasValue()); 500 } 501 502 //===----------------------------------------------------------------------===// 503 // Printing 504 505 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 506 Value operand, Value optOperand, 507 OperandRange varOperands) { 508 printer << operand; 509 if (optOperand) 510 printer << ", " << optOperand; 511 printer << " -> (" << varOperands << ")"; 512 } 513 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 514 Type operandType, Type optOperandType, 515 TypeRange varOperandTypes) { 516 printer << " : " << operandType; 517 if (optOperandType) 518 printer << ", " << optOperandType; 519 printer << " -> (" << varOperandTypes << ")"; 520 } 521 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 522 Operation *op, Type operandType, 523 Type optOperandType, 524 TypeRange varOperandTypes) { 525 printer << " type_refs_capture "; 526 printCustomDirectiveResults(printer, op, operandType, optOperandType, 527 varOperandTypes); 528 } 529 static void printCustomDirectiveOperandsAndTypes( 530 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 531 OperandRange varOperands, Type operandType, Type optOperandType, 532 TypeRange varOperandTypes) { 533 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 534 printCustomDirectiveResults(printer, op, operandType, optOperandType, 535 varOperandTypes); 536 } 537 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 538 Region ®ion, 539 MutableArrayRef<Region> varRegions) { 540 printer.printRegion(region); 541 if (!varRegions.empty()) { 542 printer << ", "; 543 for (Region ®ion : varRegions) 544 printer.printRegion(region); 545 } 546 } 547 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 548 Block *successor, 549 SuccessorRange varSuccessors) { 550 printer << successor; 551 if (!varSuccessors.empty()) 552 printer << ", " << varSuccessors.front(); 553 } 554 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 555 Attribute attribute, 556 Attribute optAttribute) { 557 printer << attribute; 558 if (optAttribute) 559 printer << ", " << optAttribute; 560 } 561 562 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 563 DictionaryAttr attrs) { 564 printer.printOptionalAttrDict(attrs.getValue()); 565 } 566 567 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, 568 Operation *op, 569 Value optOperand) { 570 printer << (optOperand ? "1" : "0"); 571 } 572 573 //===----------------------------------------------------------------------===// 574 // Test IsolatedRegionOp - parse passthrough region arguments. 575 //===----------------------------------------------------------------------===// 576 577 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, 578 OperationState &result) { 579 OpAsmParser::OperandType argInfo; 580 Type argType = parser.getBuilder().getIndexType(); 581 582 // Parse the input operand. 583 if (parser.parseOperand(argInfo) || 584 parser.resolveOperand(argInfo, argType, result.operands)) 585 return failure(); 586 587 // Parse the body region, and reuse the operand info as the argument info. 588 Region *body = result.addRegion(); 589 return parser.parseRegion(*body, argInfo, argType, 590 /*enableNameShadowing=*/true); 591 } 592 593 static void print(OpAsmPrinter &p, IsolatedRegionOp op) { 594 p << "test.isolated_region "; 595 p.printOperand(op.getOperand()); 596 p.shadowRegionArgs(op.getRegion(), op.getOperand()); 597 p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false); 598 } 599 600 //===----------------------------------------------------------------------===// 601 // Test SSACFGRegionOp 602 //===----------------------------------------------------------------------===// 603 604 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 605 return RegionKind::SSACFG; 606 } 607 608 //===----------------------------------------------------------------------===// 609 // Test GraphRegionOp 610 //===----------------------------------------------------------------------===// 611 612 static ParseResult parseGraphRegionOp(OpAsmParser &parser, 613 OperationState &result) { 614 // Parse the body region, and reuse the operand info as the argument info. 615 Region *body = result.addRegion(); 616 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 617 } 618 619 static void print(OpAsmPrinter &p, GraphRegionOp op) { 620 p << "test.graph_region "; 621 p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false); 622 } 623 624 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 625 return RegionKind::Graph; 626 } 627 628 //===----------------------------------------------------------------------===// 629 // Test AffineScopeOp 630 //===----------------------------------------------------------------------===// 631 632 static ParseResult parseAffineScopeOp(OpAsmParser &parser, 633 OperationState &result) { 634 // Parse the body region, and reuse the operand info as the argument info. 635 Region *body = result.addRegion(); 636 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 637 } 638 639 static void print(OpAsmPrinter &p, AffineScopeOp op) { 640 p << "test.affine_scope "; 641 p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false); 642 } 643 644 //===----------------------------------------------------------------------===// 645 // Test parser. 646 //===----------------------------------------------------------------------===// 647 648 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser, 649 OperationState &result) { 650 if (parser.parseOptionalColon()) 651 return success(); 652 uint64_t numResults; 653 if (parser.parseInteger(numResults)) 654 return failure(); 655 656 IndexType type = parser.getBuilder().getIndexType(); 657 for (unsigned i = 0; i < numResults; ++i) 658 result.addTypes(type); 659 return success(); 660 } 661 662 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) { 663 if (unsigned numResults = op->getNumResults()) 664 p << " : " << numResults; 665 } 666 667 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser, 668 OperationState &result) { 669 StringRef keyword; 670 if (parser.parseKeyword(&keyword)) 671 return failure(); 672 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 673 return success(); 674 } 675 676 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) { 677 p << " " << op.getKeyword(); 678 } 679 680 //===----------------------------------------------------------------------===// 681 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 682 683 static ParseResult parseWrappingRegionOp(OpAsmParser &parser, 684 OperationState &result) { 685 if (parser.parseKeyword("wraps")) 686 return failure(); 687 688 // Parse the wrapped op in a region 689 Region &body = *result.addRegion(); 690 body.push_back(new Block); 691 Block &block = body.back(); 692 Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin()); 693 if (!wrappedOp) 694 return failure(); 695 696 // Create a return terminator in the inner region, pass as operand to the 697 // terminator the returned values from the wrapped operation. 698 SmallVector<Value, 8> returnOperands(wrappedOp->getResults()); 699 OpBuilder builder(parser.getContext()); 700 builder.setInsertionPointToEnd(&block); 701 builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands); 702 703 // Get the results type for the wrapping op from the terminator operands. 704 Operation &returnOp = body.back().back(); 705 result.types.append(returnOp.operand_type_begin(), 706 returnOp.operand_type_end()); 707 708 // Use the location of the wrapped op for the "test.wrapping_region" op. 709 result.location = wrappedOp->getLoc(); 710 711 return success(); 712 } 713 714 static void print(OpAsmPrinter &p, WrappingRegionOp op) { 715 p << " wraps "; 716 p.printGenericOp(&op.getRegion().front().front()); 717 } 718 719 //===----------------------------------------------------------------------===// 720 // Test PrettyPrintedRegionOp - exercising the following parser APIs 721 // parseGenericOperationAfterOpName 722 // parseCustomOperationName 723 //===----------------------------------------------------------------------===// 724 725 static ParseResult parsePrettyPrintedRegionOp(OpAsmParser &parser, 726 OperationState &result) { 727 728 llvm::SMLoc loc = parser.getCurrentLocation(); 729 Location currLocation = parser.getEncodedSourceLoc(loc); 730 731 // Parse the operands. 732 SmallVector<OpAsmParser::OperandType, 2> operands; 733 if (parser.parseOperandList(operands)) 734 return failure(); 735 736 // Check if we are parsing the pretty-printed version 737 // test.pretty_printed_region start <inner-op> end : <functional-type> 738 // Else fallback to parsing the "non pretty-printed" version. 739 if (!succeeded(parser.parseOptionalKeyword("start"))) 740 return parser.parseGenericOperationAfterOpName( 741 result, llvm::makeArrayRef(operands)); 742 743 FailureOr<OperationName> parseOpNameInfo = parser.parseCustomOperationName(); 744 if (failed(parseOpNameInfo)) 745 return failure(); 746 747 StringRef innerOpName = parseOpNameInfo->getStringRef(); 748 749 FunctionType opFntype; 750 Optional<Location> explicitLoc; 751 if (parser.parseKeyword("end") || parser.parseColon() || 752 parser.parseType(opFntype) || 753 parser.parseOptionalLocationSpecifier(explicitLoc)) 754 return failure(); 755 756 // If location of the op is explicitly provided, then use it; Else use 757 // the parser's current location. 758 Location opLoc = explicitLoc.getValueOr(currLocation); 759 760 // Derive the SSA-values for op's operands. 761 if (parser.resolveOperands(operands, opFntype.getInputs(), loc, 762 result.operands)) 763 return failure(); 764 765 // Add a region for op. 766 Region ®ion = *result.addRegion(); 767 768 // Create a basic-block inside op's region. 769 Block &block = region.emplaceBlock(); 770 771 // Create and insert an "inner-op" operation in the block. 772 // Just for testing purposes, we can assume that inner op is a binary op with 773 // result and operand types all same as the test-op's first operand. 774 Type innerOpType = opFntype.getInput(0); 775 Value lhs = block.addArgument(innerOpType, opLoc); 776 Value rhs = block.addArgument(innerOpType, opLoc); 777 778 OpBuilder builder(parser.getBuilder().getContext()); 779 builder.setInsertionPointToStart(&block); 780 781 OperationState innerOpState(opLoc, innerOpName); 782 innerOpState.operands.push_back(lhs); 783 innerOpState.operands.push_back(rhs); 784 innerOpState.addTypes(innerOpType); 785 786 Operation *innerOp = builder.createOperation(innerOpState); 787 788 // Insert a return statement in the block returning the inner-op's result. 789 builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults()); 790 791 // Populate the op operation-state with result-type and location. 792 result.addTypes(opFntype.getResults()); 793 result.location = innerOp->getLoc(); 794 795 return success(); 796 } 797 798 static void print(OpAsmPrinter &p, PrettyPrintedRegionOp op) { 799 p << ' '; 800 p.printOperands(op.getOperands()); 801 802 Operation &innerOp = op.getRegion().front().front(); 803 // Assuming that region has a single non-terminator inner-op, if the inner-op 804 // meets some criteria (which in this case is a simple one based on the name 805 // of inner-op), then we can print the entire region in a succinct way. 806 // Here we assume that the prototype of "special.op" can be trivially derived 807 // while parsing it back. 808 if (innerOp.getName().getStringRef().equals("special.op")) { 809 p << " start special.op end"; 810 } else { 811 p << " ("; 812 p.printRegion(op.getRegion()); 813 p << ")"; 814 } 815 816 p << " : "; 817 p.printFunctionalType(op); 818 } 819 820 //===----------------------------------------------------------------------===// 821 // Test PolyForOp - parse list of region arguments. 822 //===----------------------------------------------------------------------===// 823 824 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { 825 SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 826 // Parse list of region arguments without a delimiter. 827 if (parser.parseRegionArgumentList(ivsInfo)) 828 return failure(); 829 830 // Parse the body region. 831 Region *body = result.addRegion(); 832 auto &builder = parser.getBuilder(); 833 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 834 return parser.parseRegion(*body, ivsInfo, argTypes); 835 } 836 837 void PolyForOp::getAsmBlockArgumentNames(Region ®ion, 838 OpAsmSetValueNameFn setNameFn) { 839 auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names"); 840 if (!arrayAttr) 841 return; 842 auto args = getRegion().front().getArguments(); 843 auto e = std::min(arrayAttr.size(), args.size()); 844 for (unsigned i = 0; i < e; ++i) { 845 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 846 setNameFn(args[i], strAttr.getValue()); 847 } 848 } 849 850 //===----------------------------------------------------------------------===// 851 // Test removing op with inner ops. 852 //===----------------------------------------------------------------------===// 853 854 namespace { 855 struct TestRemoveOpWithInnerOps 856 : public OpRewritePattern<TestOpWithRegionPattern> { 857 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 858 859 void initialize() { setDebugName("TestRemoveOpWithInnerOps"); } 860 861 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 862 PatternRewriter &rewriter) const override { 863 rewriter.eraseOp(op); 864 return success(); 865 } 866 }; 867 } // namespace 868 869 void TestOpWithRegionPattern::getCanonicalizationPatterns( 870 RewritePatternSet &results, MLIRContext *context) { 871 results.add<TestRemoveOpWithInnerOps>(context); 872 } 873 874 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 875 return getOperand(); 876 } 877 878 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) { 879 return getValue(); 880 } 881 882 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 883 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 884 for (Value input : this->getOperands()) { 885 results.push_back(input); 886 } 887 return success(); 888 } 889 890 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 891 assert(operands.size() == 1); 892 if (operands.front()) { 893 (*this)->setAttr("attr", operands.front()); 894 return getResult(); 895 } 896 return {}; 897 } 898 899 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) { 900 return getOperand(); 901 } 902 903 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 904 MLIRContext *, Optional<Location> location, ValueRange operands, 905 DictionaryAttr attributes, RegionRange regions, 906 SmallVectorImpl<Type> &inferredReturnTypes) { 907 if (operands[0].getType() != operands[1].getType()) { 908 return emitOptionalError(location, "operand type mismatch ", 909 operands[0].getType(), " vs ", 910 operands[1].getType()); 911 } 912 inferredReturnTypes.assign({operands[0].getType()}); 913 return success(); 914 } 915 916 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 917 MLIRContext *context, Optional<Location> location, ValueShapeRange operands, 918 DictionaryAttr attributes, RegionRange regions, 919 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 920 // Create return type consisting of the last element of the first operand. 921 auto operandType = operands.front().getType(); 922 auto sval = operandType.dyn_cast<ShapedType>(); 923 if (!sval) { 924 return emitOptionalError(location, "only shaped type operands allowed"); 925 } 926 int64_t dim = 927 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 928 auto type = IntegerType::get(context, 17); 929 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 930 return success(); 931 } 932 933 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 934 OpBuilder &builder, ValueRange operands, 935 llvm::SmallVectorImpl<Value> &shapes) { 936 shapes = SmallVector<Value, 1>{ 937 builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)}; 938 return success(); 939 } 940 941 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( 942 OpBuilder &builder, ValueRange operands, 943 llvm::SmallVectorImpl<Value> &shapes) { 944 Location loc = getLoc(); 945 shapes.reserve(operands.size()); 946 for (Value operand : llvm::reverse(operands)) { 947 auto rank = operand.getType().cast<RankedTensorType>().getRank(); 948 auto currShape = llvm::to_vector<4>( 949 llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value { 950 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 951 })); 952 shapes.push_back(builder.create<tensor::FromElementsOp>( 953 getLoc(), RankedTensorType::get({rank}, builder.getIndexType()), 954 currShape)); 955 } 956 return success(); 957 } 958 959 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( 960 OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { 961 Location loc = getLoc(); 962 shapes.reserve(getNumOperands()); 963 for (Value operand : llvm::reverse(getOperands())) { 964 auto currShape = llvm::to_vector<4>(llvm::map_range( 965 llvm::seq<int64_t>( 966 0, operand.getType().cast<RankedTensorType>().getRank()), 967 [&](int64_t dim) -> Value { 968 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 969 })); 970 shapes.emplace_back(std::move(currShape)); 971 } 972 return success(); 973 } 974 975 //===----------------------------------------------------------------------===// 976 // Test SideEffect interfaces 977 //===----------------------------------------------------------------------===// 978 979 namespace { 980 /// A test resource for side effects. 981 struct TestResource : public SideEffects::Resource::Base<TestResource> { 982 StringRef getName() final { return "<Test>"; } 983 }; 984 } // namespace 985 986 static void testSideEffectOpGetEffect( 987 Operation *op, 988 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 989 &effects) { 990 auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter"); 991 if (!effectsAttr) 992 return; 993 994 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); 995 } 996 997 void SideEffectOp::getEffects( 998 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 999 // Check for an effects attribute on the op instance. 1000 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects"); 1001 if (!effectsAttr) 1002 return; 1003 1004 // If there is one, it is an array of dictionary attributes that hold 1005 // information on the effects of this operation. 1006 for (Attribute element : effectsAttr) { 1007 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 1008 1009 // Get the specific memory effect. 1010 MemoryEffects::Effect *effect = 1011 StringSwitch<MemoryEffects::Effect *>( 1012 effectElement.get("effect").cast<StringAttr>().getValue()) 1013 .Case("allocate", MemoryEffects::Allocate::get()) 1014 .Case("free", MemoryEffects::Free::get()) 1015 .Case("read", MemoryEffects::Read::get()) 1016 .Case("write", MemoryEffects::Write::get()); 1017 1018 // Check for a non-default resource to use. 1019 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 1020 if (effectElement.get("test_resource")) 1021 resource = TestResource::get(); 1022 1023 // Check for a result to affect. 1024 if (effectElement.get("on_result")) 1025 effects.emplace_back(effect, getResult(), resource); 1026 else if (Attribute ref = effectElement.get("on_reference")) 1027 effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource); 1028 else 1029 effects.emplace_back(effect, resource); 1030 } 1031 } 1032 1033 void SideEffectOp::getEffects( 1034 SmallVectorImpl<TestEffects::EffectInstance> &effects) { 1035 testSideEffectOpGetEffect(getOperation(), effects); 1036 } 1037 1038 //===----------------------------------------------------------------------===// 1039 // StringAttrPrettyNameOp 1040 //===----------------------------------------------------------------------===// 1041 1042 // This op has fancy handling of its SSA result name. 1043 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, 1044 OperationState &result) { 1045 // Add the result types. 1046 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 1047 result.addTypes(parser.getBuilder().getIntegerType(32)); 1048 1049 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 1050 return failure(); 1051 1052 // If the attribute dictionary contains no 'names' attribute, infer it from 1053 // the SSA name (if specified). 1054 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 1055 return attr.getName() == "names"; 1056 }); 1057 1058 // If there was no name specified, check to see if there was a useful name 1059 // specified in the asm file. 1060 if (hadNames || parser.getNumResults() == 0) 1061 return success(); 1062 1063 SmallVector<StringRef, 4> names; 1064 auto *context = result.getContext(); 1065 1066 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 1067 auto resultName = parser.getResultName(i); 1068 StringRef nameStr; 1069 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 1070 nameStr = resultName.first; 1071 1072 names.push_back(nameStr); 1073 } 1074 1075 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 1076 result.attributes.push_back({StringAttr::get(context, "names"), namesAttr}); 1077 return success(); 1078 } 1079 1080 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { 1081 // Note that we only need to print the "name" attribute if the asmprinter 1082 // result name disagrees with it. This can happen in strange cases, e.g. 1083 // when there are conflicts. 1084 bool namesDisagree = op.getNames().size() != op.getNumResults(); 1085 1086 SmallString<32> resultNameStr; 1087 for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { 1088 resultNameStr.clear(); 1089 llvm::raw_svector_ostream tmpStream(resultNameStr); 1090 p.printOperand(op.getResult(i), tmpStream); 1091 1092 auto expectedName = op.getNames()[i].dyn_cast<StringAttr>(); 1093 if (!expectedName || 1094 tmpStream.str().drop_front() != expectedName.getValue()) { 1095 namesDisagree = true; 1096 } 1097 } 1098 1099 if (namesDisagree) 1100 p.printOptionalAttrDictWithKeyword(op->getAttrs()); 1101 else 1102 p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"}); 1103 } 1104 1105 // We set the SSA name in the asm syntax to the contents of the name 1106 // attribute. 1107 void StringAttrPrettyNameOp::getAsmResultNames( 1108 function_ref<void(Value, StringRef)> setNameFn) { 1109 1110 auto value = getNames(); 1111 for (size_t i = 0, e = value.size(); i != e; ++i) 1112 if (auto str = value[i].dyn_cast<StringAttr>()) 1113 if (!str.getValue().empty()) 1114 setNameFn(getResult(i), str.getValue()); 1115 } 1116 1117 //===----------------------------------------------------------------------===// 1118 // RegionIfOp 1119 //===----------------------------------------------------------------------===// 1120 1121 static void print(OpAsmPrinter &p, RegionIfOp op) { 1122 p << " "; 1123 p.printOperands(op.getOperands()); 1124 p << ": " << op.getOperandTypes(); 1125 p.printArrowTypeList(op.getResultTypes()); 1126 p << " then"; 1127 p.printRegion(op.getThenRegion(), 1128 /*printEntryBlockArgs=*/true, 1129 /*printBlockTerminators=*/true); 1130 p << " else"; 1131 p.printRegion(op.getElseRegion(), 1132 /*printEntryBlockArgs=*/true, 1133 /*printBlockTerminators=*/true); 1134 p << " join"; 1135 p.printRegion(op.getJoinRegion(), 1136 /*printEntryBlockArgs=*/true, 1137 /*printBlockTerminators=*/true); 1138 } 1139 1140 static ParseResult parseRegionIfOp(OpAsmParser &parser, 1141 OperationState &result) { 1142 SmallVector<OpAsmParser::OperandType, 2> operandInfos; 1143 SmallVector<Type, 2> operandTypes; 1144 1145 result.regions.reserve(3); 1146 Region *thenRegion = result.addRegion(); 1147 Region *elseRegion = result.addRegion(); 1148 Region *joinRegion = result.addRegion(); 1149 1150 // Parse operand, type and arrow type lists. 1151 if (parser.parseOperandList(operandInfos) || 1152 parser.parseColonTypeList(operandTypes) || 1153 parser.parseArrowTypeList(result.types)) 1154 return failure(); 1155 1156 // Parse all attached regions. 1157 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 1158 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 1159 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 1160 return failure(); 1161 1162 return parser.resolveOperands(operandInfos, operandTypes, 1163 parser.getCurrentLocation(), result.operands); 1164 } 1165 1166 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 1167 assert(index < 2 && "invalid region index"); 1168 return getOperands(); 1169 } 1170 1171 void RegionIfOp::getSuccessorRegions( 1172 Optional<unsigned> index, ArrayRef<Attribute> operands, 1173 SmallVectorImpl<RegionSuccessor> ®ions) { 1174 // We always branch to the join region. 1175 if (index.hasValue()) { 1176 if (index.getValue() < 2) 1177 regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs())); 1178 else 1179 regions.push_back(RegionSuccessor(getResults())); 1180 return; 1181 } 1182 1183 // The then and else regions are the entry regions of this op. 1184 regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs())); 1185 regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs())); 1186 } 1187 1188 //===----------------------------------------------------------------------===// 1189 // SingleNoTerminatorCustomAsmOp 1190 //===----------------------------------------------------------------------===// 1191 1192 static ParseResult parseSingleNoTerminatorCustomAsmOp(OpAsmParser &parser, 1193 OperationState &state) { 1194 Region *body = state.addRegion(); 1195 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) 1196 return failure(); 1197 return success(); 1198 } 1199 1200 static void print(SingleNoTerminatorCustomAsmOp op, OpAsmPrinter &printer) { 1201 printer.printRegion( 1202 op.getRegion(), /*printEntryBlockArgs=*/false, 1203 // This op has a single block without terminators. But explicitly mark 1204 // as not printing block terminators for testing. 1205 /*printBlockTerminators=*/false); 1206 } 1207 1208 #include "TestOpEnums.cpp.inc" 1209 #include "TestOpInterfaces.cpp.inc" 1210 #include "TestOpStructs.cpp.inc" 1211 #include "TestTypeInterfaces.cpp.inc" 1212 1213 #define GET_OP_CLASSES 1214 #include "TestOps.cpp.inc" 1215