1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestAttributes.h" 11 #include "TestInterfaces.h" 12 #include "TestTypes.h" 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/DLTI/DLTI.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/DialectImplementation.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/IR/TypeUtilities.h" 21 #include "mlir/Reducer/ReductionPatternInterface.h" 22 #include "mlir/Transforms/FoldUtils.h" 23 #include "mlir/Transforms/InliningUtils.h" 24 #include "llvm/ADT/StringSwitch.h" 25 26 // Include this before the using namespace lines below to 27 // test that we don't have namespace dependencies. 28 #include "TestOpsDialect.cpp.inc" 29 30 using namespace mlir; 31 using namespace test; 32 33 void test::registerTestDialect(DialectRegistry ®istry) { 34 registry.insert<TestDialect>(); 35 } 36 37 //===----------------------------------------------------------------------===// 38 // TestDialect Interfaces 39 //===----------------------------------------------------------------------===// 40 41 namespace { 42 43 /// Testing the correctness of some traits. 44 static_assert( 45 llvm::is_detected<OpTrait::has_implicit_terminator_t, 46 SingleBlockImplicitTerminatorOp>::value, 47 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp"); 48 static_assert(OpTrait::hasSingleBlockImplicitTerminator< 49 SingleBlockImplicitTerminatorOp>::value, 50 "hasSingleBlockImplicitTerminator does not match " 51 "SingleBlockImplicitTerminatorOp"); 52 53 // Test support for interacting with the AsmPrinter. 54 struct TestOpAsmInterface : public OpAsmDialectInterface { 55 using OpAsmDialectInterface::OpAsmDialectInterface; 56 57 AliasResult getAlias(Attribute attr, raw_ostream &os) const final { 58 StringAttr strAttr = attr.dyn_cast<StringAttr>(); 59 if (!strAttr) 60 return AliasResult::NoAlias; 61 62 // Check the contents of the string attribute to see what the test alias 63 // should be named. 64 Optional<StringRef> aliasName = 65 StringSwitch<Optional<StringRef>>(strAttr.getValue()) 66 .Case("alias_test:dot_in_name", StringRef("test.alias")) 67 .Case("alias_test:trailing_digit", StringRef("test_alias0")) 68 .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) 69 .Case("alias_test:sanitize_conflict_a", 70 StringRef("test_alias_conflict0")) 71 .Case("alias_test:sanitize_conflict_b", 72 StringRef("test_alias_conflict0_")) 73 .Case("alias_test:tensor_encoding", StringRef("test_encoding")) 74 .Default(llvm::None); 75 if (!aliasName) 76 return AliasResult::NoAlias; 77 78 os << *aliasName; 79 return AliasResult::FinalAlias; 80 } 81 82 AliasResult getAlias(Type type, raw_ostream &os) const final { 83 if (auto tupleType = type.dyn_cast<TupleType>()) { 84 if (tupleType.size() > 0 && 85 llvm::all_of(tupleType.getTypes(), [](Type elemType) { 86 return elemType.isa<SimpleAType>(); 87 })) { 88 os << "test_tuple"; 89 return AliasResult::FinalAlias; 90 } 91 } 92 if (auto intType = type.dyn_cast<TestIntegerType>()) { 93 if (intType.getSignedness() == 94 TestIntegerType::SignednessSemantics::Unsigned && 95 intType.getWidth() == 8) { 96 os << "test_ui8"; 97 return AliasResult::FinalAlias; 98 } 99 } 100 return AliasResult::NoAlias; 101 } 102 103 void getAsmResultNames(Operation *op, 104 OpAsmSetValueNameFn setNameFn) const final { 105 if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) 106 setNameFn(asmOp, "result"); 107 } 108 }; 109 110 struct TestDialectFoldInterface : public DialectFoldInterface { 111 using DialectFoldInterface::DialectFoldInterface; 112 113 /// Registered hook to check if the given region, which is attached to an 114 /// operation that is *not* isolated from above, should be used when 115 /// materializing constants. 116 bool shouldMaterializeInto(Region *region) const final { 117 // If this is a one region operation, then insert into it. 118 return isa<OneRegionOp>(region->getParentOp()); 119 } 120 }; 121 122 /// This class defines the interface for handling inlining with standard 123 /// operations. 124 struct TestInlinerInterface : public DialectInlinerInterface { 125 using DialectInlinerInterface::DialectInlinerInterface; 126 127 //===--------------------------------------------------------------------===// 128 // Analysis Hooks 129 //===--------------------------------------------------------------------===// 130 131 bool isLegalToInline(Operation *call, Operation *callable, 132 bool wouldBeCloned) const final { 133 // Don't allow inlining calls that are marked `noinline`. 134 return !call->hasAttr("noinline"); 135 } 136 bool isLegalToInline(Region *, Region *, bool, 137 BlockAndValueMapping &) const final { 138 // Inlining into test dialect regions is legal. 139 return true; 140 } 141 bool isLegalToInline(Operation *, Region *, bool, 142 BlockAndValueMapping &) const final { 143 return true; 144 } 145 146 bool shouldAnalyzeRecursively(Operation *op) const final { 147 // Analyze recursively if this is not a functional region operation, it 148 // froms a separate functional scope. 149 return !isa<FunctionalRegionOp>(op); 150 } 151 152 //===--------------------------------------------------------------------===// 153 // Transformation Hooks 154 //===--------------------------------------------------------------------===// 155 156 /// Handle the given inlined terminator by replacing it with a new operation 157 /// as necessary. 158 void handleTerminator(Operation *op, 159 ArrayRef<Value> valuesToRepl) const final { 160 // Only handle "test.return" here. 161 auto returnOp = dyn_cast<TestReturnOp>(op); 162 if (!returnOp) 163 return; 164 165 // Replace the values directly with the return operands. 166 assert(returnOp.getNumOperands() == valuesToRepl.size()); 167 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 168 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 169 } 170 171 /// Attempt to materialize a conversion for a type mismatch between a call 172 /// from this dialect, and a callable region. This method should generate an 173 /// operation that takes 'input' as the only operand, and produces a single 174 /// result of 'resultType'. If a conversion can not be generated, nullptr 175 /// should be returned. 176 Operation *materializeCallConversion(OpBuilder &builder, Value input, 177 Type resultType, 178 Location conversionLoc) const final { 179 // Only allow conversion for i16/i32 types. 180 if (!(resultType.isSignlessInteger(16) || 181 resultType.isSignlessInteger(32)) || 182 !(input.getType().isSignlessInteger(16) || 183 input.getType().isSignlessInteger(32))) 184 return nullptr; 185 return builder.create<TestCastOp>(conversionLoc, resultType, input); 186 } 187 188 void processInlinedCallBlocks( 189 Operation *call, 190 iterator_range<Region::iterator> inlinedBlocks) const final { 191 if (!isa<ConversionCallOp>(call)) 192 return; 193 194 // Set attributed on all ops in the inlined blocks. 195 for (Block &block : inlinedBlocks) { 196 block.walk([&](Operation *op) { 197 op->setAttr("inlined_conversion", UnitAttr::get(call->getContext())); 198 }); 199 } 200 } 201 }; 202 203 struct TestReductionPatternInterface : public DialectReductionPatternInterface { 204 public: 205 TestReductionPatternInterface(Dialect *dialect) 206 : DialectReductionPatternInterface(dialect) {} 207 208 void populateReductionPatterns(RewritePatternSet &patterns) const final { 209 populateTestReductionPatterns(patterns); 210 } 211 }; 212 213 } // namespace 214 215 //===----------------------------------------------------------------------===// 216 // TestDialect 217 //===----------------------------------------------------------------------===// 218 219 static void testSideEffectOpGetEffect( 220 Operation *op, 221 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects); 222 223 // This is the implementation of a dialect fallback for `TestEffectOpInterface`. 224 struct TestOpEffectInterfaceFallback 225 : public TestEffectOpInterface::FallbackModel< 226 TestOpEffectInterfaceFallback> { 227 static bool classof(Operation *op) { 228 bool isSupportedOp = 229 op->getName().getStringRef() == "test.unregistered_side_effect_op"; 230 assert(isSupportedOp && "Unexpected dispatch"); 231 return isSupportedOp; 232 } 233 234 void 235 getEffects(Operation *op, 236 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 237 &effects) const { 238 testSideEffectOpGetEffect(op, effects); 239 } 240 }; 241 242 void TestDialect::initialize() { 243 registerAttributes(); 244 registerTypes(); 245 addOperations< 246 #define GET_OP_LIST 247 #include "TestOps.cpp.inc" 248 >(); 249 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 250 TestInlinerInterface, TestReductionPatternInterface>(); 251 allowUnknownOperations(); 252 253 // Instantiate our fallback op interface that we'll use on specific 254 // unregistered op. 255 fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback; 256 } 257 TestDialect::~TestDialect() { 258 delete static_cast<TestOpEffectInterfaceFallback *>( 259 fallbackEffectOpInterfaces); 260 } 261 262 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, 263 Type type, Location loc) { 264 return builder.create<TestOpConstant>(loc, type, value); 265 } 266 267 ::mlir::LogicalResult FormatInferType2Op::inferReturnTypes( 268 ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, 269 ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, 270 ::mlir::RegionRange regions, 271 ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { 272 inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)}); 273 return ::mlir::success(); 274 } 275 276 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, 277 OperationName opName) { 278 if (opName.getIdentifier() == "test.unregistered_side_effect_op" && 279 typeID == TypeID::get<TestEffectOpInterface>()) 280 return fallbackEffectOpInterfaces; 281 return nullptr; 282 } 283 284 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 285 NamedAttribute namedAttr) { 286 if (namedAttr.getName() == "test.invalid_attr") 287 return op->emitError() << "invalid to use 'test.invalid_attr'"; 288 return success(); 289 } 290 291 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 292 unsigned regionIndex, 293 unsigned argIndex, 294 NamedAttribute namedAttr) { 295 if (namedAttr.getName() == "test.invalid_attr") 296 return op->emitError() << "invalid to use 'test.invalid_attr'"; 297 return success(); 298 } 299 300 LogicalResult 301 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 302 unsigned resultIndex, 303 NamedAttribute namedAttr) { 304 if (namedAttr.getName() == "test.invalid_attr") 305 return op->emitError() << "invalid to use 'test.invalid_attr'"; 306 return success(); 307 } 308 309 Optional<Dialect::ParseOpHook> 310 TestDialect::getParseOperationHook(StringRef opName) const { 311 if (opName == "test.dialect_custom_printer") { 312 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 313 return parser.parseKeyword("custom_format"); 314 }}; 315 } 316 if (opName == "test.dialect_custom_format_fallback") { 317 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 318 return parser.parseKeyword("custom_format_fallback"); 319 }}; 320 } 321 return None; 322 } 323 324 llvm::unique_function<void(Operation *, OpAsmPrinter &)> 325 TestDialect::getOperationPrinter(Operation *op) const { 326 StringRef opName = op->getName().getStringRef(); 327 if (opName == "test.dialect_custom_printer") { 328 return [](Operation *op, OpAsmPrinter &printer) { 329 printer.getStream() << " custom_format"; 330 }; 331 } 332 if (opName == "test.dialect_custom_format_fallback") { 333 return [](Operation *op, OpAsmPrinter &printer) { 334 printer.getStream() << " custom_format_fallback"; 335 }; 336 } 337 return {}; 338 } 339 340 //===----------------------------------------------------------------------===// 341 // TestBranchOp 342 //===----------------------------------------------------------------------===// 343 344 Optional<MutableOperandRange> 345 TestBranchOp::getMutableSuccessorOperands(unsigned index) { 346 assert(index == 0 && "invalid successor index"); 347 return getTargetOperandsMutable(); 348 } 349 350 //===----------------------------------------------------------------------===// 351 // TestDialectCanonicalizerOp 352 //===----------------------------------------------------------------------===// 353 354 static LogicalResult 355 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op, 356 PatternRewriter &rewriter) { 357 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 358 op, rewriter.getI32IntegerAttr(42)); 359 return success(); 360 } 361 362 void TestDialect::getCanonicalizationPatterns( 363 RewritePatternSet &results) const { 364 results.add(&dialectCanonicalizationPattern); 365 } 366 367 //===----------------------------------------------------------------------===// 368 // TestFoldToCallOp 369 //===----------------------------------------------------------------------===// 370 371 namespace { 372 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 373 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 374 375 LogicalResult matchAndRewrite(FoldToCallOp op, 376 PatternRewriter &rewriter) const override { 377 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.getCalleeAttr(), 378 ValueRange()); 379 return success(); 380 } 381 }; 382 } // namespace 383 384 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results, 385 MLIRContext *context) { 386 results.add<FoldToCallOpPattern>(context); 387 } 388 389 //===----------------------------------------------------------------------===// 390 // Test Format* operations 391 //===----------------------------------------------------------------------===// 392 393 //===----------------------------------------------------------------------===// 394 // Parsing 395 396 static ParseResult parseCustomDirectiveOperands( 397 OpAsmParser &parser, OpAsmParser::OperandType &operand, 398 Optional<OpAsmParser::OperandType> &optOperand, 399 SmallVectorImpl<OpAsmParser::OperandType> &varOperands) { 400 if (parser.parseOperand(operand)) 401 return failure(); 402 if (succeeded(parser.parseOptionalComma())) { 403 optOperand.emplace(); 404 if (parser.parseOperand(*optOperand)) 405 return failure(); 406 } 407 if (parser.parseArrow() || parser.parseLParen() || 408 parser.parseOperandList(varOperands) || parser.parseRParen()) 409 return failure(); 410 return success(); 411 } 412 static ParseResult 413 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 414 Type &optOperandType, 415 SmallVectorImpl<Type> &varOperandTypes) { 416 if (parser.parseColon()) 417 return failure(); 418 419 if (parser.parseType(operandType)) 420 return failure(); 421 if (succeeded(parser.parseOptionalComma())) { 422 if (parser.parseType(optOperandType)) 423 return failure(); 424 } 425 if (parser.parseArrow() || parser.parseLParen() || 426 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 427 return failure(); 428 return success(); 429 } 430 static ParseResult 431 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 432 Type optOperandType, 433 const SmallVectorImpl<Type> &varOperandTypes) { 434 if (parser.parseKeyword("type_refs_capture")) 435 return failure(); 436 437 Type operandType2, optOperandType2; 438 SmallVector<Type, 1> varOperandTypes2; 439 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 440 varOperandTypes2)) 441 return failure(); 442 443 if (operandType != operandType2 || optOperandType != optOperandType2 || 444 varOperandTypes != varOperandTypes2) 445 return failure(); 446 447 return success(); 448 } 449 static ParseResult parseCustomDirectiveOperandsAndTypes( 450 OpAsmParser &parser, OpAsmParser::OperandType &operand, 451 Optional<OpAsmParser::OperandType> &optOperand, 452 SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType, 453 Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) { 454 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 455 parseCustomDirectiveResults(parser, operandType, optOperandType, 456 varOperandTypes)) 457 return failure(); 458 return success(); 459 } 460 static ParseResult parseCustomDirectiveRegions( 461 OpAsmParser &parser, Region ®ion, 462 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 463 if (parser.parseRegion(region)) 464 return failure(); 465 if (failed(parser.parseOptionalComma())) 466 return success(); 467 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 468 if (parser.parseRegion(*varRegion)) 469 return failure(); 470 varRegions.emplace_back(std::move(varRegion)); 471 return success(); 472 } 473 static ParseResult 474 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 475 SmallVectorImpl<Block *> &varSuccessors) { 476 if (parser.parseSuccessor(successor)) 477 return failure(); 478 if (failed(parser.parseOptionalComma())) 479 return success(); 480 Block *varSuccessor; 481 if (parser.parseSuccessor(varSuccessor)) 482 return failure(); 483 varSuccessors.append(2, varSuccessor); 484 return success(); 485 } 486 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 487 IntegerAttr &attr, 488 IntegerAttr &optAttr) { 489 if (parser.parseAttribute(attr)) 490 return failure(); 491 if (succeeded(parser.parseOptionalComma())) { 492 if (parser.parseAttribute(optAttr)) 493 return failure(); 494 } 495 return success(); 496 } 497 498 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, 499 NamedAttrList &attrs) { 500 return parser.parseOptionalAttrDict(attrs); 501 } 502 static ParseResult parseCustomDirectiveOptionalOperandRef( 503 OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) { 504 int64_t operandCount = 0; 505 if (parser.parseInteger(operandCount)) 506 return failure(); 507 bool expectedOptionalOperand = operandCount == 0; 508 return success(expectedOptionalOperand != optOperand.hasValue()); 509 } 510 511 //===----------------------------------------------------------------------===// 512 // Printing 513 514 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 515 Value operand, Value optOperand, 516 OperandRange varOperands) { 517 printer << operand; 518 if (optOperand) 519 printer << ", " << optOperand; 520 printer << " -> (" << varOperands << ")"; 521 } 522 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 523 Type operandType, Type optOperandType, 524 TypeRange varOperandTypes) { 525 printer << " : " << operandType; 526 if (optOperandType) 527 printer << ", " << optOperandType; 528 printer << " -> (" << varOperandTypes << ")"; 529 } 530 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 531 Operation *op, Type operandType, 532 Type optOperandType, 533 TypeRange varOperandTypes) { 534 printer << " type_refs_capture "; 535 printCustomDirectiveResults(printer, op, operandType, optOperandType, 536 varOperandTypes); 537 } 538 static void printCustomDirectiveOperandsAndTypes( 539 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 540 OperandRange varOperands, Type operandType, Type optOperandType, 541 TypeRange varOperandTypes) { 542 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 543 printCustomDirectiveResults(printer, op, operandType, optOperandType, 544 varOperandTypes); 545 } 546 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 547 Region ®ion, 548 MutableArrayRef<Region> varRegions) { 549 printer.printRegion(region); 550 if (!varRegions.empty()) { 551 printer << ", "; 552 for (Region ®ion : varRegions) 553 printer.printRegion(region); 554 } 555 } 556 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 557 Block *successor, 558 SuccessorRange varSuccessors) { 559 printer << successor; 560 if (!varSuccessors.empty()) 561 printer << ", " << varSuccessors.front(); 562 } 563 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 564 Attribute attribute, 565 Attribute optAttribute) { 566 printer << attribute; 567 if (optAttribute) 568 printer << ", " << optAttribute; 569 } 570 571 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 572 DictionaryAttr attrs) { 573 printer.printOptionalAttrDict(attrs.getValue()); 574 } 575 576 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, 577 Operation *op, 578 Value optOperand) { 579 printer << (optOperand ? "1" : "0"); 580 } 581 582 //===----------------------------------------------------------------------===// 583 // Test IsolatedRegionOp - parse passthrough region arguments. 584 //===----------------------------------------------------------------------===// 585 586 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, 587 OperationState &result) { 588 OpAsmParser::OperandType argInfo; 589 Type argType = parser.getBuilder().getIndexType(); 590 591 // Parse the input operand. 592 if (parser.parseOperand(argInfo) || 593 parser.resolveOperand(argInfo, argType, result.operands)) 594 return failure(); 595 596 // Parse the body region, and reuse the operand info as the argument info. 597 Region *body = result.addRegion(); 598 return parser.parseRegion(*body, argInfo, argType, /*argLocations=*/{}, 599 /*enableNameShadowing=*/true); 600 } 601 602 static void print(OpAsmPrinter &p, IsolatedRegionOp op) { 603 p << "test.isolated_region "; 604 p.printOperand(op.getOperand()); 605 p.shadowRegionArgs(op.getRegion(), op.getOperand()); 606 p << ' '; 607 p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false); 608 } 609 610 //===----------------------------------------------------------------------===// 611 // Test SSACFGRegionOp 612 //===----------------------------------------------------------------------===// 613 614 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 615 return RegionKind::SSACFG; 616 } 617 618 //===----------------------------------------------------------------------===// 619 // Test GraphRegionOp 620 //===----------------------------------------------------------------------===// 621 622 static ParseResult parseGraphRegionOp(OpAsmParser &parser, 623 OperationState &result) { 624 // Parse the body region, and reuse the operand info as the argument info. 625 Region *body = result.addRegion(); 626 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 627 } 628 629 static void print(OpAsmPrinter &p, GraphRegionOp op) { 630 p << "test.graph_region "; 631 p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false); 632 } 633 634 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 635 return RegionKind::Graph; 636 } 637 638 //===----------------------------------------------------------------------===// 639 // Test AffineScopeOp 640 //===----------------------------------------------------------------------===// 641 642 static ParseResult parseAffineScopeOp(OpAsmParser &parser, 643 OperationState &result) { 644 // Parse the body region, and reuse the operand info as the argument info. 645 Region *body = result.addRegion(); 646 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 647 } 648 649 static void print(OpAsmPrinter &p, AffineScopeOp op) { 650 p << "test.affine_scope "; 651 p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false); 652 } 653 654 //===----------------------------------------------------------------------===// 655 // Test parser. 656 //===----------------------------------------------------------------------===// 657 658 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser, 659 OperationState &result) { 660 if (parser.parseOptionalColon()) 661 return success(); 662 uint64_t numResults; 663 if (parser.parseInteger(numResults)) 664 return failure(); 665 666 IndexType type = parser.getBuilder().getIndexType(); 667 for (unsigned i = 0; i < numResults; ++i) 668 result.addTypes(type); 669 return success(); 670 } 671 672 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) { 673 if (unsigned numResults = op->getNumResults()) 674 p << " : " << numResults; 675 } 676 677 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser, 678 OperationState &result) { 679 StringRef keyword; 680 if (parser.parseKeyword(&keyword)) 681 return failure(); 682 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 683 return success(); 684 } 685 686 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) { 687 p << " " << op.getKeyword(); 688 } 689 690 //===----------------------------------------------------------------------===// 691 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 692 693 static ParseResult parseWrappingRegionOp(OpAsmParser &parser, 694 OperationState &result) { 695 if (parser.parseKeyword("wraps")) 696 return failure(); 697 698 // Parse the wrapped op in a region 699 Region &body = *result.addRegion(); 700 body.push_back(new Block); 701 Block &block = body.back(); 702 Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin()); 703 if (!wrappedOp) 704 return failure(); 705 706 // Create a return terminator in the inner region, pass as operand to the 707 // terminator the returned values from the wrapped operation. 708 SmallVector<Value, 8> returnOperands(wrappedOp->getResults()); 709 OpBuilder builder(parser.getContext()); 710 builder.setInsertionPointToEnd(&block); 711 builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands); 712 713 // Get the results type for the wrapping op from the terminator operands. 714 Operation &returnOp = body.back().back(); 715 result.types.append(returnOp.operand_type_begin(), 716 returnOp.operand_type_end()); 717 718 // Use the location of the wrapped op for the "test.wrapping_region" op. 719 result.location = wrappedOp->getLoc(); 720 721 return success(); 722 } 723 724 static void print(OpAsmPrinter &p, WrappingRegionOp op) { 725 p << " wraps "; 726 p.printGenericOp(&op.getRegion().front().front()); 727 } 728 729 //===----------------------------------------------------------------------===// 730 // Test PrettyPrintedRegionOp - exercising the following parser APIs 731 // parseGenericOperationAfterOpName 732 // parseCustomOperationName 733 //===----------------------------------------------------------------------===// 734 735 static ParseResult parsePrettyPrintedRegionOp(OpAsmParser &parser, 736 OperationState &result) { 737 738 llvm::SMLoc loc = parser.getCurrentLocation(); 739 Location currLocation = parser.getEncodedSourceLoc(loc); 740 741 // Parse the operands. 742 SmallVector<OpAsmParser::OperandType, 2> operands; 743 if (parser.parseOperandList(operands)) 744 return failure(); 745 746 // Check if we are parsing the pretty-printed version 747 // test.pretty_printed_region start <inner-op> end : <functional-type> 748 // Else fallback to parsing the "non pretty-printed" version. 749 if (!succeeded(parser.parseOptionalKeyword("start"))) 750 return parser.parseGenericOperationAfterOpName( 751 result, llvm::makeArrayRef(operands)); 752 753 FailureOr<OperationName> parseOpNameInfo = parser.parseCustomOperationName(); 754 if (failed(parseOpNameInfo)) 755 return failure(); 756 757 StringRef innerOpName = parseOpNameInfo->getStringRef(); 758 759 FunctionType opFntype; 760 Optional<Location> explicitLoc; 761 if (parser.parseKeyword("end") || parser.parseColon() || 762 parser.parseType(opFntype) || 763 parser.parseOptionalLocationSpecifier(explicitLoc)) 764 return failure(); 765 766 // If location of the op is explicitly provided, then use it; Else use 767 // the parser's current location. 768 Location opLoc = explicitLoc.getValueOr(currLocation); 769 770 // Derive the SSA-values for op's operands. 771 if (parser.resolveOperands(operands, opFntype.getInputs(), loc, 772 result.operands)) 773 return failure(); 774 775 // Add a region for op. 776 Region ®ion = *result.addRegion(); 777 778 // Create a basic-block inside op's region. 779 Block &block = region.emplaceBlock(); 780 781 // Create and insert an "inner-op" operation in the block. 782 // Just for testing purposes, we can assume that inner op is a binary op with 783 // result and operand types all same as the test-op's first operand. 784 Type innerOpType = opFntype.getInput(0); 785 Value lhs = block.addArgument(innerOpType, opLoc); 786 Value rhs = block.addArgument(innerOpType, opLoc); 787 788 OpBuilder builder(parser.getBuilder().getContext()); 789 builder.setInsertionPointToStart(&block); 790 791 OperationState innerOpState(opLoc, innerOpName); 792 innerOpState.operands.push_back(lhs); 793 innerOpState.operands.push_back(rhs); 794 innerOpState.addTypes(innerOpType); 795 796 Operation *innerOp = builder.createOperation(innerOpState); 797 798 // Insert a return statement in the block returning the inner-op's result. 799 builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults()); 800 801 // Populate the op operation-state with result-type and location. 802 result.addTypes(opFntype.getResults()); 803 result.location = innerOp->getLoc(); 804 805 return success(); 806 } 807 808 static void print(OpAsmPrinter &p, PrettyPrintedRegionOp op) { 809 p << ' '; 810 p.printOperands(op.getOperands()); 811 812 Operation &innerOp = op.getRegion().front().front(); 813 // Assuming that region has a single non-terminator inner-op, if the inner-op 814 // meets some criteria (which in this case is a simple one based on the name 815 // of inner-op), then we can print the entire region in a succinct way. 816 // Here we assume that the prototype of "special.op" can be trivially derived 817 // while parsing it back. 818 if (innerOp.getName().getStringRef().equals("special.op")) { 819 p << " start special.op end"; 820 } else { 821 p << " ("; 822 p.printRegion(op.getRegion()); 823 p << ")"; 824 } 825 826 p << " : "; 827 p.printFunctionalType(op); 828 } 829 830 //===----------------------------------------------------------------------===// 831 // Test PolyForOp - parse list of region arguments. 832 //===----------------------------------------------------------------------===// 833 834 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { 835 SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 836 // Parse list of region arguments without a delimiter. 837 if (parser.parseRegionArgumentList(ivsInfo)) 838 return failure(); 839 840 // Parse the body region. 841 Region *body = result.addRegion(); 842 auto &builder = parser.getBuilder(); 843 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 844 return parser.parseRegion(*body, ivsInfo, argTypes); 845 } 846 847 void PolyForOp::getAsmBlockArgumentNames(Region ®ion, 848 OpAsmSetValueNameFn setNameFn) { 849 auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names"); 850 if (!arrayAttr) 851 return; 852 auto args = getRegion().front().getArguments(); 853 auto e = std::min(arrayAttr.size(), args.size()); 854 for (unsigned i = 0; i < e; ++i) { 855 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 856 setNameFn(args[i], strAttr.getValue()); 857 } 858 } 859 860 //===----------------------------------------------------------------------===// 861 // Test removing op with inner ops. 862 //===----------------------------------------------------------------------===// 863 864 namespace { 865 struct TestRemoveOpWithInnerOps 866 : public OpRewritePattern<TestOpWithRegionPattern> { 867 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 868 869 void initialize() { setDebugName("TestRemoveOpWithInnerOps"); } 870 871 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 872 PatternRewriter &rewriter) const override { 873 rewriter.eraseOp(op); 874 return success(); 875 } 876 }; 877 } // namespace 878 879 void TestOpWithRegionPattern::getCanonicalizationPatterns( 880 RewritePatternSet &results, MLIRContext *context) { 881 results.add<TestRemoveOpWithInnerOps>(context); 882 } 883 884 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 885 return getOperand(); 886 } 887 888 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) { 889 return getValue(); 890 } 891 892 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 893 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 894 for (Value input : this->getOperands()) { 895 results.push_back(input); 896 } 897 return success(); 898 } 899 900 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 901 assert(operands.size() == 1); 902 if (operands.front()) { 903 (*this)->setAttr("attr", operands.front()); 904 return getResult(); 905 } 906 return {}; 907 } 908 909 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) { 910 return getOperand(); 911 } 912 913 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 914 MLIRContext *, Optional<Location> location, ValueRange operands, 915 DictionaryAttr attributes, RegionRange regions, 916 SmallVectorImpl<Type> &inferredReturnTypes) { 917 if (operands[0].getType() != operands[1].getType()) { 918 return emitOptionalError(location, "operand type mismatch ", 919 operands[0].getType(), " vs ", 920 operands[1].getType()); 921 } 922 inferredReturnTypes.assign({operands[0].getType()}); 923 return success(); 924 } 925 926 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 927 MLIRContext *context, Optional<Location> location, ValueShapeRange operands, 928 DictionaryAttr attributes, RegionRange regions, 929 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 930 // Create return type consisting of the last element of the first operand. 931 auto operandType = operands.front().getType(); 932 auto sval = operandType.dyn_cast<ShapedType>(); 933 if (!sval) { 934 return emitOptionalError(location, "only shaped type operands allowed"); 935 } 936 int64_t dim = 937 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 938 auto type = IntegerType::get(context, 17); 939 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 940 return success(); 941 } 942 943 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 944 OpBuilder &builder, ValueRange operands, 945 llvm::SmallVectorImpl<Value> &shapes) { 946 shapes = SmallVector<Value, 1>{ 947 builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)}; 948 return success(); 949 } 950 951 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( 952 OpBuilder &builder, ValueRange operands, 953 llvm::SmallVectorImpl<Value> &shapes) { 954 Location loc = getLoc(); 955 shapes.reserve(operands.size()); 956 for (Value operand : llvm::reverse(operands)) { 957 auto rank = operand.getType().cast<RankedTensorType>().getRank(); 958 auto currShape = llvm::to_vector<4>( 959 llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value { 960 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 961 })); 962 shapes.push_back(builder.create<tensor::FromElementsOp>( 963 getLoc(), RankedTensorType::get({rank}, builder.getIndexType()), 964 currShape)); 965 } 966 return success(); 967 } 968 969 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( 970 OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { 971 Location loc = getLoc(); 972 shapes.reserve(getNumOperands()); 973 for (Value operand : llvm::reverse(getOperands())) { 974 auto currShape = llvm::to_vector<4>(llvm::map_range( 975 llvm::seq<int64_t>( 976 0, operand.getType().cast<RankedTensorType>().getRank()), 977 [&](int64_t dim) -> Value { 978 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 979 })); 980 shapes.emplace_back(std::move(currShape)); 981 } 982 return success(); 983 } 984 985 //===----------------------------------------------------------------------===// 986 // Test SideEffect interfaces 987 //===----------------------------------------------------------------------===// 988 989 namespace { 990 /// A test resource for side effects. 991 struct TestResource : public SideEffects::Resource::Base<TestResource> { 992 StringRef getName() final { return "<Test>"; } 993 }; 994 } // namespace 995 996 static void testSideEffectOpGetEffect( 997 Operation *op, 998 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 999 &effects) { 1000 auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter"); 1001 if (!effectsAttr) 1002 return; 1003 1004 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); 1005 } 1006 1007 void SideEffectOp::getEffects( 1008 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1009 // Check for an effects attribute on the op instance. 1010 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects"); 1011 if (!effectsAttr) 1012 return; 1013 1014 // If there is one, it is an array of dictionary attributes that hold 1015 // information on the effects of this operation. 1016 for (Attribute element : effectsAttr) { 1017 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 1018 1019 // Get the specific memory effect. 1020 MemoryEffects::Effect *effect = 1021 StringSwitch<MemoryEffects::Effect *>( 1022 effectElement.get("effect").cast<StringAttr>().getValue()) 1023 .Case("allocate", MemoryEffects::Allocate::get()) 1024 .Case("free", MemoryEffects::Free::get()) 1025 .Case("read", MemoryEffects::Read::get()) 1026 .Case("write", MemoryEffects::Write::get()); 1027 1028 // Check for a non-default resource to use. 1029 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 1030 if (effectElement.get("test_resource")) 1031 resource = TestResource::get(); 1032 1033 // Check for a result to affect. 1034 if (effectElement.get("on_result")) 1035 effects.emplace_back(effect, getResult(), resource); 1036 else if (Attribute ref = effectElement.get("on_reference")) 1037 effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource); 1038 else 1039 effects.emplace_back(effect, resource); 1040 } 1041 } 1042 1043 void SideEffectOp::getEffects( 1044 SmallVectorImpl<TestEffects::EffectInstance> &effects) { 1045 testSideEffectOpGetEffect(getOperation(), effects); 1046 } 1047 1048 //===----------------------------------------------------------------------===// 1049 // StringAttrPrettyNameOp 1050 //===----------------------------------------------------------------------===// 1051 1052 // This op has fancy handling of its SSA result name. 1053 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, 1054 OperationState &result) { 1055 // Add the result types. 1056 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 1057 result.addTypes(parser.getBuilder().getIntegerType(32)); 1058 1059 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 1060 return failure(); 1061 1062 // If the attribute dictionary contains no 'names' attribute, infer it from 1063 // the SSA name (if specified). 1064 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 1065 return attr.getName() == "names"; 1066 }); 1067 1068 // If there was no name specified, check to see if there was a useful name 1069 // specified in the asm file. 1070 if (hadNames || parser.getNumResults() == 0) 1071 return success(); 1072 1073 SmallVector<StringRef, 4> names; 1074 auto *context = result.getContext(); 1075 1076 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 1077 auto resultName = parser.getResultName(i); 1078 StringRef nameStr; 1079 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 1080 nameStr = resultName.first; 1081 1082 names.push_back(nameStr); 1083 } 1084 1085 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 1086 result.attributes.push_back({StringAttr::get(context, "names"), namesAttr}); 1087 return success(); 1088 } 1089 1090 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { 1091 // Note that we only need to print the "name" attribute if the asmprinter 1092 // result name disagrees with it. This can happen in strange cases, e.g. 1093 // when there are conflicts. 1094 bool namesDisagree = op.getNames().size() != op.getNumResults(); 1095 1096 SmallString<32> resultNameStr; 1097 for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { 1098 resultNameStr.clear(); 1099 llvm::raw_svector_ostream tmpStream(resultNameStr); 1100 p.printOperand(op.getResult(i), tmpStream); 1101 1102 auto expectedName = op.getNames()[i].dyn_cast<StringAttr>(); 1103 if (!expectedName || 1104 tmpStream.str().drop_front() != expectedName.getValue()) { 1105 namesDisagree = true; 1106 } 1107 } 1108 1109 if (namesDisagree) 1110 p.printOptionalAttrDictWithKeyword(op->getAttrs()); 1111 else 1112 p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"}); 1113 } 1114 1115 // We set the SSA name in the asm syntax to the contents of the name 1116 // attribute. 1117 void StringAttrPrettyNameOp::getAsmResultNames( 1118 function_ref<void(Value, StringRef)> setNameFn) { 1119 1120 auto value = getNames(); 1121 for (size_t i = 0, e = value.size(); i != e; ++i) 1122 if (auto str = value[i].dyn_cast<StringAttr>()) 1123 if (!str.getValue().empty()) 1124 setNameFn(getResult(i), str.getValue()); 1125 } 1126 1127 //===----------------------------------------------------------------------===// 1128 // RegionIfOp 1129 //===----------------------------------------------------------------------===// 1130 1131 static void print(OpAsmPrinter &p, RegionIfOp op) { 1132 p << " "; 1133 p.printOperands(op.getOperands()); 1134 p << ": " << op.getOperandTypes(); 1135 p.printArrowTypeList(op.getResultTypes()); 1136 p << " then"; 1137 p.printRegion(op.getThenRegion(), 1138 /*printEntryBlockArgs=*/true, 1139 /*printBlockTerminators=*/true); 1140 p << " else"; 1141 p.printRegion(op.getElseRegion(), 1142 /*printEntryBlockArgs=*/true, 1143 /*printBlockTerminators=*/true); 1144 p << " join"; 1145 p.printRegion(op.getJoinRegion(), 1146 /*printEntryBlockArgs=*/true, 1147 /*printBlockTerminators=*/true); 1148 } 1149 1150 static ParseResult parseRegionIfOp(OpAsmParser &parser, 1151 OperationState &result) { 1152 SmallVector<OpAsmParser::OperandType, 2> operandInfos; 1153 SmallVector<Type, 2> operandTypes; 1154 1155 result.regions.reserve(3); 1156 Region *thenRegion = result.addRegion(); 1157 Region *elseRegion = result.addRegion(); 1158 Region *joinRegion = result.addRegion(); 1159 1160 // Parse operand, type and arrow type lists. 1161 if (parser.parseOperandList(operandInfos) || 1162 parser.parseColonTypeList(operandTypes) || 1163 parser.parseArrowTypeList(result.types)) 1164 return failure(); 1165 1166 // Parse all attached regions. 1167 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 1168 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 1169 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 1170 return failure(); 1171 1172 return parser.resolveOperands(operandInfos, operandTypes, 1173 parser.getCurrentLocation(), result.operands); 1174 } 1175 1176 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 1177 assert(index < 2 && "invalid region index"); 1178 return getOperands(); 1179 } 1180 1181 void RegionIfOp::getSuccessorRegions( 1182 Optional<unsigned> index, ArrayRef<Attribute> operands, 1183 SmallVectorImpl<RegionSuccessor> ®ions) { 1184 // We always branch to the join region. 1185 if (index.hasValue()) { 1186 if (index.getValue() < 2) 1187 regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs())); 1188 else 1189 regions.push_back(RegionSuccessor(getResults())); 1190 return; 1191 } 1192 1193 // The then and else regions are the entry regions of this op. 1194 regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs())); 1195 regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs())); 1196 } 1197 1198 //===----------------------------------------------------------------------===// 1199 // SingleNoTerminatorCustomAsmOp 1200 //===----------------------------------------------------------------------===// 1201 1202 static ParseResult parseSingleNoTerminatorCustomAsmOp(OpAsmParser &parser, 1203 OperationState &state) { 1204 Region *body = state.addRegion(); 1205 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) 1206 return failure(); 1207 return success(); 1208 } 1209 1210 static void print(SingleNoTerminatorCustomAsmOp op, OpAsmPrinter &printer) { 1211 printer.printRegion( 1212 op.getRegion(), /*printEntryBlockArgs=*/false, 1213 // This op has a single block without terminators. But explicitly mark 1214 // as not printing block terminators for testing. 1215 /*printBlockTerminators=*/false); 1216 } 1217 1218 #include "TestOpEnums.cpp.inc" 1219 #include "TestOpInterfaces.cpp.inc" 1220 #include "TestOpStructs.cpp.inc" 1221 #include "TestTypeInterfaces.cpp.inc" 1222 1223 #define GET_OP_CLASSES 1224 #include "TestOps.cpp.inc" 1225