1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestAttributes.h" 11 #include "TestInterfaces.h" 12 #include "TestTypes.h" 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/DLTI/DLTI.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/DialectImplementation.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/IR/TypeUtilities.h" 21 #include "mlir/Reducer/ReductionPatternInterface.h" 22 #include "mlir/Transforms/FoldUtils.h" 23 #include "mlir/Transforms/InliningUtils.h" 24 #include "llvm/ADT/StringSwitch.h" 25 26 // Include this before the using namespace lines below to 27 // test that we don't have namespace dependencies. 28 #include "TestOpsDialect.cpp.inc" 29 30 using namespace mlir; 31 using namespace test; 32 33 void test::registerTestDialect(DialectRegistry ®istry) { 34 registry.insert<TestDialect>(); 35 } 36 37 //===----------------------------------------------------------------------===// 38 // TestDialect Interfaces 39 //===----------------------------------------------------------------------===// 40 41 namespace { 42 43 /// Testing the correctness of some traits. 44 static_assert( 45 llvm::is_detected<OpTrait::has_implicit_terminator_t, 46 SingleBlockImplicitTerminatorOp>::value, 47 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp"); 48 static_assert(OpTrait::hasSingleBlockImplicitTerminator< 49 SingleBlockImplicitTerminatorOp>::value, 50 "hasSingleBlockImplicitTerminator does not match " 51 "SingleBlockImplicitTerminatorOp"); 52 53 // Test support for interacting with the AsmPrinter. 54 struct TestOpAsmInterface : public OpAsmDialectInterface { 55 using OpAsmDialectInterface::OpAsmDialectInterface; 56 57 AliasResult getAlias(Attribute attr, raw_ostream &os) const final { 58 StringAttr strAttr = attr.dyn_cast<StringAttr>(); 59 if (!strAttr) 60 return AliasResult::NoAlias; 61 62 // Check the contents of the string attribute to see what the test alias 63 // should be named. 64 Optional<StringRef> aliasName = 65 StringSwitch<Optional<StringRef>>(strAttr.getValue()) 66 .Case("alias_test:dot_in_name", StringRef("test.alias")) 67 .Case("alias_test:trailing_digit", StringRef("test_alias0")) 68 .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) 69 .Case("alias_test:sanitize_conflict_a", 70 StringRef("test_alias_conflict0")) 71 .Case("alias_test:sanitize_conflict_b", 72 StringRef("test_alias_conflict0_")) 73 .Case("alias_test:tensor_encoding", StringRef("test_encoding")) 74 .Default(llvm::None); 75 if (!aliasName) 76 return AliasResult::NoAlias; 77 78 os << *aliasName; 79 return AliasResult::FinalAlias; 80 } 81 82 AliasResult getAlias(Type type, raw_ostream &os) const final { 83 if (auto tupleType = type.dyn_cast<TupleType>()) { 84 if (tupleType.size() > 0 && 85 llvm::all_of(tupleType.getTypes(), [](Type elemType) { 86 return elemType.isa<SimpleAType>(); 87 })) { 88 os << "test_tuple"; 89 return AliasResult::FinalAlias; 90 } 91 } 92 if (auto intType = type.dyn_cast<TestIntegerType>()) { 93 if (intType.getSignedness() == 94 TestIntegerType::SignednessSemantics::Unsigned && 95 intType.getWidth() == 8) { 96 os << "test_ui8"; 97 return AliasResult::FinalAlias; 98 } 99 } 100 return AliasResult::NoAlias; 101 } 102 103 void getAsmResultNames(Operation *op, 104 OpAsmSetValueNameFn setNameFn) const final { 105 if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) 106 setNameFn(asmOp, "result"); 107 } 108 }; 109 110 struct TestDialectFoldInterface : public DialectFoldInterface { 111 using DialectFoldInterface::DialectFoldInterface; 112 113 /// Registered hook to check if the given region, which is attached to an 114 /// operation that is *not* isolated from above, should be used when 115 /// materializing constants. 116 bool shouldMaterializeInto(Region *region) const final { 117 // If this is a one region operation, then insert into it. 118 return isa<OneRegionOp>(region->getParentOp()); 119 } 120 }; 121 122 /// This class defines the interface for handling inlining with standard 123 /// operations. 124 struct TestInlinerInterface : public DialectInlinerInterface { 125 using DialectInlinerInterface::DialectInlinerInterface; 126 127 //===--------------------------------------------------------------------===// 128 // Analysis Hooks 129 //===--------------------------------------------------------------------===// 130 131 bool isLegalToInline(Operation *call, Operation *callable, 132 bool wouldBeCloned) const final { 133 // Don't allow inlining calls that are marked `noinline`. 134 return !call->hasAttr("noinline"); 135 } 136 bool isLegalToInline(Region *, Region *, bool, 137 BlockAndValueMapping &) const final { 138 // Inlining into test dialect regions is legal. 139 return true; 140 } 141 bool isLegalToInline(Operation *, Region *, bool, 142 BlockAndValueMapping &) const final { 143 return true; 144 } 145 146 bool shouldAnalyzeRecursively(Operation *op) const final { 147 // Analyze recursively if this is not a functional region operation, it 148 // froms a separate functional scope. 149 return !isa<FunctionalRegionOp>(op); 150 } 151 152 //===--------------------------------------------------------------------===// 153 // Transformation Hooks 154 //===--------------------------------------------------------------------===// 155 156 /// Handle the given inlined terminator by replacing it with a new operation 157 /// as necessary. 158 void handleTerminator(Operation *op, 159 ArrayRef<Value> valuesToRepl) const final { 160 // Only handle "test.return" here. 161 auto returnOp = dyn_cast<TestReturnOp>(op); 162 if (!returnOp) 163 return; 164 165 // Replace the values directly with the return operands. 166 assert(returnOp.getNumOperands() == valuesToRepl.size()); 167 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 168 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 169 } 170 171 /// Attempt to materialize a conversion for a type mismatch between a call 172 /// from this dialect, and a callable region. This method should generate an 173 /// operation that takes 'input' as the only operand, and produces a single 174 /// result of 'resultType'. If a conversion can not be generated, nullptr 175 /// should be returned. 176 Operation *materializeCallConversion(OpBuilder &builder, Value input, 177 Type resultType, 178 Location conversionLoc) const final { 179 // Only allow conversion for i16/i32 types. 180 if (!(resultType.isSignlessInteger(16) || 181 resultType.isSignlessInteger(32)) || 182 !(input.getType().isSignlessInteger(16) || 183 input.getType().isSignlessInteger(32))) 184 return nullptr; 185 return builder.create<TestCastOp>(conversionLoc, resultType, input); 186 } 187 188 void processInlinedCallBlocks( 189 Operation *call, 190 iterator_range<Region::iterator> inlinedBlocks) const final { 191 if (!isa<ConversionCallOp>(call)) 192 return; 193 194 // Set attributed on all ops in the inlined blocks. 195 for (Block &block : inlinedBlocks) { 196 block.walk([&](Operation *op) { 197 op->setAttr("inlined_conversion", UnitAttr::get(call->getContext())); 198 }); 199 } 200 } 201 }; 202 203 struct TestReductionPatternInterface : public DialectReductionPatternInterface { 204 public: 205 TestReductionPatternInterface(Dialect *dialect) 206 : DialectReductionPatternInterface(dialect) {} 207 208 void populateReductionPatterns(RewritePatternSet &patterns) const final { 209 populateTestReductionPatterns(patterns); 210 } 211 }; 212 213 } // namespace 214 215 //===----------------------------------------------------------------------===// 216 // TestDialect 217 //===----------------------------------------------------------------------===// 218 219 static void testSideEffectOpGetEffect( 220 Operation *op, 221 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects); 222 223 // This is the implementation of a dialect fallback for `TestEffectOpInterface`. 224 struct TestOpEffectInterfaceFallback 225 : public TestEffectOpInterface::FallbackModel< 226 TestOpEffectInterfaceFallback> { 227 static bool classof(Operation *op) { 228 bool isSupportedOp = 229 op->getName().getStringRef() == "test.unregistered_side_effect_op"; 230 assert(isSupportedOp && "Unexpected dispatch"); 231 return isSupportedOp; 232 } 233 234 void 235 getEffects(Operation *op, 236 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 237 &effects) const { 238 testSideEffectOpGetEffect(op, effects); 239 } 240 }; 241 242 void TestDialect::initialize() { 243 registerAttributes(); 244 registerTypes(); 245 addOperations< 246 #define GET_OP_LIST 247 #include "TestOps.cpp.inc" 248 >(); 249 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 250 TestInlinerInterface, TestReductionPatternInterface>(); 251 allowUnknownOperations(); 252 253 // Instantiate our fallback op interface that we'll use on specific 254 // unregistered op. 255 fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback; 256 } 257 TestDialect::~TestDialect() { 258 delete static_cast<TestOpEffectInterfaceFallback *>( 259 fallbackEffectOpInterfaces); 260 } 261 262 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, 263 Type type, Location loc) { 264 return builder.create<TestOpConstant>(loc, type, value); 265 } 266 267 ::mlir::LogicalResult FormatInferType2Op::inferReturnTypes( 268 ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, 269 ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, 270 ::mlir::RegionRange regions, 271 ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { 272 inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)}); 273 return ::mlir::success(); 274 } 275 276 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, 277 OperationName opName) { 278 if (opName.getIdentifier() == "test.unregistered_side_effect_op" && 279 typeID == TypeID::get<TestEffectOpInterface>()) 280 return fallbackEffectOpInterfaces; 281 return nullptr; 282 } 283 284 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 285 NamedAttribute namedAttr) { 286 if (namedAttr.getName() == "test.invalid_attr") 287 return op->emitError() << "invalid to use 'test.invalid_attr'"; 288 return success(); 289 } 290 291 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 292 unsigned regionIndex, 293 unsigned argIndex, 294 NamedAttribute namedAttr) { 295 if (namedAttr.getName() == "test.invalid_attr") 296 return op->emitError() << "invalid to use 'test.invalid_attr'"; 297 return success(); 298 } 299 300 LogicalResult 301 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 302 unsigned resultIndex, 303 NamedAttribute namedAttr) { 304 if (namedAttr.getName() == "test.invalid_attr") 305 return op->emitError() << "invalid to use 'test.invalid_attr'"; 306 return success(); 307 } 308 309 Optional<Dialect::ParseOpHook> 310 TestDialect::getParseOperationHook(StringRef opName) const { 311 if (opName == "test.dialect_custom_printer") { 312 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 313 return parser.parseKeyword("custom_format"); 314 }}; 315 } 316 if (opName == "test.dialect_custom_format_fallback") { 317 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 318 return parser.parseKeyword("custom_format_fallback"); 319 }}; 320 } 321 return None; 322 } 323 324 llvm::unique_function<void(Operation *, OpAsmPrinter &)> 325 TestDialect::getOperationPrinter(Operation *op) const { 326 StringRef opName = op->getName().getStringRef(); 327 if (opName == "test.dialect_custom_printer") { 328 return [](Operation *op, OpAsmPrinter &printer) { 329 printer.getStream() << " custom_format"; 330 }; 331 } 332 if (opName == "test.dialect_custom_format_fallback") { 333 return [](Operation *op, OpAsmPrinter &printer) { 334 printer.getStream() << " custom_format_fallback"; 335 }; 336 } 337 return {}; 338 } 339 340 //===----------------------------------------------------------------------===// 341 // TestBranchOp 342 //===----------------------------------------------------------------------===// 343 344 Optional<MutableOperandRange> 345 TestBranchOp::getMutableSuccessorOperands(unsigned index) { 346 assert(index == 0 && "invalid successor index"); 347 return getTargetOperandsMutable(); 348 } 349 350 //===----------------------------------------------------------------------===// 351 // TestDialectCanonicalizerOp 352 //===----------------------------------------------------------------------===// 353 354 static LogicalResult 355 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op, 356 PatternRewriter &rewriter) { 357 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 358 op, rewriter.getI32IntegerAttr(42)); 359 return success(); 360 } 361 362 void TestDialect::getCanonicalizationPatterns( 363 RewritePatternSet &results) const { 364 results.add(&dialectCanonicalizationPattern); 365 } 366 367 //===----------------------------------------------------------------------===// 368 // TestFoldToCallOp 369 //===----------------------------------------------------------------------===// 370 371 namespace { 372 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 373 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 374 375 LogicalResult matchAndRewrite(FoldToCallOp op, 376 PatternRewriter &rewriter) const override { 377 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.getCalleeAttr(), 378 ValueRange()); 379 return success(); 380 } 381 }; 382 } // namespace 383 384 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results, 385 MLIRContext *context) { 386 results.add<FoldToCallOpPattern>(context); 387 } 388 389 //===----------------------------------------------------------------------===// 390 // Test Format* operations 391 //===----------------------------------------------------------------------===// 392 393 //===----------------------------------------------------------------------===// 394 // Parsing 395 396 static ParseResult parseCustomDirectiveOperands( 397 OpAsmParser &parser, OpAsmParser::OperandType &operand, 398 Optional<OpAsmParser::OperandType> &optOperand, 399 SmallVectorImpl<OpAsmParser::OperandType> &varOperands) { 400 if (parser.parseOperand(operand)) 401 return failure(); 402 if (succeeded(parser.parseOptionalComma())) { 403 optOperand.emplace(); 404 if (parser.parseOperand(*optOperand)) 405 return failure(); 406 } 407 if (parser.parseArrow() || parser.parseLParen() || 408 parser.parseOperandList(varOperands) || parser.parseRParen()) 409 return failure(); 410 return success(); 411 } 412 static ParseResult 413 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 414 Type &optOperandType, 415 SmallVectorImpl<Type> &varOperandTypes) { 416 if (parser.parseColon()) 417 return failure(); 418 419 if (parser.parseType(operandType)) 420 return failure(); 421 if (succeeded(parser.parseOptionalComma())) { 422 if (parser.parseType(optOperandType)) 423 return failure(); 424 } 425 if (parser.parseArrow() || parser.parseLParen() || 426 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 427 return failure(); 428 return success(); 429 } 430 static ParseResult 431 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 432 Type optOperandType, 433 const SmallVectorImpl<Type> &varOperandTypes) { 434 if (parser.parseKeyword("type_refs_capture")) 435 return failure(); 436 437 Type operandType2, optOperandType2; 438 SmallVector<Type, 1> varOperandTypes2; 439 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 440 varOperandTypes2)) 441 return failure(); 442 443 if (operandType != operandType2 || optOperandType != optOperandType2 || 444 varOperandTypes != varOperandTypes2) 445 return failure(); 446 447 return success(); 448 } 449 static ParseResult parseCustomDirectiveOperandsAndTypes( 450 OpAsmParser &parser, OpAsmParser::OperandType &operand, 451 Optional<OpAsmParser::OperandType> &optOperand, 452 SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType, 453 Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) { 454 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 455 parseCustomDirectiveResults(parser, operandType, optOperandType, 456 varOperandTypes)) 457 return failure(); 458 return success(); 459 } 460 static ParseResult parseCustomDirectiveRegions( 461 OpAsmParser &parser, Region ®ion, 462 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 463 if (parser.parseRegion(region)) 464 return failure(); 465 if (failed(parser.parseOptionalComma())) 466 return success(); 467 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 468 if (parser.parseRegion(*varRegion)) 469 return failure(); 470 varRegions.emplace_back(std::move(varRegion)); 471 return success(); 472 } 473 static ParseResult 474 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 475 SmallVectorImpl<Block *> &varSuccessors) { 476 if (parser.parseSuccessor(successor)) 477 return failure(); 478 if (failed(parser.parseOptionalComma())) 479 return success(); 480 Block *varSuccessor; 481 if (parser.parseSuccessor(varSuccessor)) 482 return failure(); 483 varSuccessors.append(2, varSuccessor); 484 return success(); 485 } 486 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 487 IntegerAttr &attr, 488 IntegerAttr &optAttr) { 489 if (parser.parseAttribute(attr)) 490 return failure(); 491 if (succeeded(parser.parseOptionalComma())) { 492 if (parser.parseAttribute(optAttr)) 493 return failure(); 494 } 495 return success(); 496 } 497 498 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, 499 NamedAttrList &attrs) { 500 return parser.parseOptionalAttrDict(attrs); 501 } 502 static ParseResult parseCustomDirectiveOptionalOperandRef( 503 OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) { 504 int64_t operandCount = 0; 505 if (parser.parseInteger(operandCount)) 506 return failure(); 507 bool expectedOptionalOperand = operandCount == 0; 508 return success(expectedOptionalOperand != optOperand.hasValue()); 509 } 510 511 //===----------------------------------------------------------------------===// 512 // Printing 513 514 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 515 Value operand, Value optOperand, 516 OperandRange varOperands) { 517 printer << operand; 518 if (optOperand) 519 printer << ", " << optOperand; 520 printer << " -> (" << varOperands << ")"; 521 } 522 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 523 Type operandType, Type optOperandType, 524 TypeRange varOperandTypes) { 525 printer << " : " << operandType; 526 if (optOperandType) 527 printer << ", " << optOperandType; 528 printer << " -> (" << varOperandTypes << ")"; 529 } 530 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 531 Operation *op, Type operandType, 532 Type optOperandType, 533 TypeRange varOperandTypes) { 534 printer << " type_refs_capture "; 535 printCustomDirectiveResults(printer, op, operandType, optOperandType, 536 varOperandTypes); 537 } 538 static void printCustomDirectiveOperandsAndTypes( 539 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 540 OperandRange varOperands, Type operandType, Type optOperandType, 541 TypeRange varOperandTypes) { 542 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 543 printCustomDirectiveResults(printer, op, operandType, optOperandType, 544 varOperandTypes); 545 } 546 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 547 Region ®ion, 548 MutableArrayRef<Region> varRegions) { 549 printer.printRegion(region); 550 if (!varRegions.empty()) { 551 printer << ", "; 552 for (Region ®ion : varRegions) 553 printer.printRegion(region); 554 } 555 } 556 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 557 Block *successor, 558 SuccessorRange varSuccessors) { 559 printer << successor; 560 if (!varSuccessors.empty()) 561 printer << ", " << varSuccessors.front(); 562 } 563 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 564 Attribute attribute, 565 Attribute optAttribute) { 566 printer << attribute; 567 if (optAttribute) 568 printer << ", " << optAttribute; 569 } 570 571 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 572 DictionaryAttr attrs) { 573 printer.printOptionalAttrDict(attrs.getValue()); 574 } 575 576 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, 577 Operation *op, 578 Value optOperand) { 579 printer << (optOperand ? "1" : "0"); 580 } 581 582 //===----------------------------------------------------------------------===// 583 // Test IsolatedRegionOp - parse passthrough region arguments. 584 //===----------------------------------------------------------------------===// 585 586 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, 587 OperationState &result) { 588 OpAsmParser::OperandType argInfo; 589 Type argType = parser.getBuilder().getIndexType(); 590 591 // Parse the input operand. 592 if (parser.parseOperand(argInfo) || 593 parser.resolveOperand(argInfo, argType, result.operands)) 594 return failure(); 595 596 // Parse the body region, and reuse the operand info as the argument info. 597 Region *body = result.addRegion(); 598 return parser.parseRegion(*body, argInfo, argType, 599 /*enableNameShadowing=*/true); 600 } 601 602 static void print(OpAsmPrinter &p, IsolatedRegionOp op) { 603 p << "test.isolated_region "; 604 p.printOperand(op.getOperand()); 605 p.shadowRegionArgs(op.getRegion(), op.getOperand()); 606 p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false); 607 } 608 609 //===----------------------------------------------------------------------===// 610 // Test SSACFGRegionOp 611 //===----------------------------------------------------------------------===// 612 613 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 614 return RegionKind::SSACFG; 615 } 616 617 //===----------------------------------------------------------------------===// 618 // Test GraphRegionOp 619 //===----------------------------------------------------------------------===// 620 621 static ParseResult parseGraphRegionOp(OpAsmParser &parser, 622 OperationState &result) { 623 // Parse the body region, and reuse the operand info as the argument info. 624 Region *body = result.addRegion(); 625 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 626 } 627 628 static void print(OpAsmPrinter &p, GraphRegionOp op) { 629 p << "test.graph_region "; 630 p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false); 631 } 632 633 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 634 return RegionKind::Graph; 635 } 636 637 //===----------------------------------------------------------------------===// 638 // Test AffineScopeOp 639 //===----------------------------------------------------------------------===// 640 641 static ParseResult parseAffineScopeOp(OpAsmParser &parser, 642 OperationState &result) { 643 // Parse the body region, and reuse the operand info as the argument info. 644 Region *body = result.addRegion(); 645 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 646 } 647 648 static void print(OpAsmPrinter &p, AffineScopeOp op) { 649 p << "test.affine_scope "; 650 p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false); 651 } 652 653 //===----------------------------------------------------------------------===// 654 // Test parser. 655 //===----------------------------------------------------------------------===// 656 657 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser, 658 OperationState &result) { 659 if (parser.parseOptionalColon()) 660 return success(); 661 uint64_t numResults; 662 if (parser.parseInteger(numResults)) 663 return failure(); 664 665 IndexType type = parser.getBuilder().getIndexType(); 666 for (unsigned i = 0; i < numResults; ++i) 667 result.addTypes(type); 668 return success(); 669 } 670 671 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) { 672 if (unsigned numResults = op->getNumResults()) 673 p << " : " << numResults; 674 } 675 676 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser, 677 OperationState &result) { 678 StringRef keyword; 679 if (parser.parseKeyword(&keyword)) 680 return failure(); 681 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 682 return success(); 683 } 684 685 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) { 686 p << " " << op.getKeyword(); 687 } 688 689 //===----------------------------------------------------------------------===// 690 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 691 692 static ParseResult parseWrappingRegionOp(OpAsmParser &parser, 693 OperationState &result) { 694 if (parser.parseKeyword("wraps")) 695 return failure(); 696 697 // Parse the wrapped op in a region 698 Region &body = *result.addRegion(); 699 body.push_back(new Block); 700 Block &block = body.back(); 701 Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin()); 702 if (!wrappedOp) 703 return failure(); 704 705 // Create a return terminator in the inner region, pass as operand to the 706 // terminator the returned values from the wrapped operation. 707 SmallVector<Value, 8> returnOperands(wrappedOp->getResults()); 708 OpBuilder builder(parser.getContext()); 709 builder.setInsertionPointToEnd(&block); 710 builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands); 711 712 // Get the results type for the wrapping op from the terminator operands. 713 Operation &returnOp = body.back().back(); 714 result.types.append(returnOp.operand_type_begin(), 715 returnOp.operand_type_end()); 716 717 // Use the location of the wrapped op for the "test.wrapping_region" op. 718 result.location = wrappedOp->getLoc(); 719 720 return success(); 721 } 722 723 static void print(OpAsmPrinter &p, WrappingRegionOp op) { 724 p << " wraps "; 725 p.printGenericOp(&op.getRegion().front().front()); 726 } 727 728 //===----------------------------------------------------------------------===// 729 // Test PrettyPrintedRegionOp - exercising the following parser APIs 730 // parseGenericOperationAfterOpName 731 // parseCustomOperationName 732 //===----------------------------------------------------------------------===// 733 734 static ParseResult parsePrettyPrintedRegionOp(OpAsmParser &parser, 735 OperationState &result) { 736 737 llvm::SMLoc loc = parser.getCurrentLocation(); 738 Location currLocation = parser.getEncodedSourceLoc(loc); 739 740 // Parse the operands. 741 SmallVector<OpAsmParser::OperandType, 2> operands; 742 if (parser.parseOperandList(operands)) 743 return failure(); 744 745 // Check if we are parsing the pretty-printed version 746 // test.pretty_printed_region start <inner-op> end : <functional-type> 747 // Else fallback to parsing the "non pretty-printed" version. 748 if (!succeeded(parser.parseOptionalKeyword("start"))) 749 return parser.parseGenericOperationAfterOpName( 750 result, llvm::makeArrayRef(operands)); 751 752 FailureOr<OperationName> parseOpNameInfo = parser.parseCustomOperationName(); 753 if (failed(parseOpNameInfo)) 754 return failure(); 755 756 StringRef innerOpName = parseOpNameInfo->getStringRef(); 757 758 FunctionType opFntype; 759 Optional<Location> explicitLoc; 760 if (parser.parseKeyword("end") || parser.parseColon() || 761 parser.parseType(opFntype) || 762 parser.parseOptionalLocationSpecifier(explicitLoc)) 763 return failure(); 764 765 // If location of the op is explicitly provided, then use it; Else use 766 // the parser's current location. 767 Location opLoc = explicitLoc.getValueOr(currLocation); 768 769 // Derive the SSA-values for op's operands. 770 if (parser.resolveOperands(operands, opFntype.getInputs(), loc, 771 result.operands)) 772 return failure(); 773 774 // Add a region for op. 775 Region ®ion = *result.addRegion(); 776 777 // Create a basic-block inside op's region. 778 Block &block = region.emplaceBlock(); 779 780 // Create and insert an "inner-op" operation in the block. 781 // Just for testing purposes, we can assume that inner op is a binary op with 782 // result and operand types all same as the test-op's first operand. 783 Type innerOpType = opFntype.getInput(0); 784 Value lhs = block.addArgument(innerOpType, opLoc); 785 Value rhs = block.addArgument(innerOpType, opLoc); 786 787 OpBuilder builder(parser.getBuilder().getContext()); 788 builder.setInsertionPointToStart(&block); 789 790 OperationState innerOpState(opLoc, innerOpName); 791 innerOpState.operands.push_back(lhs); 792 innerOpState.operands.push_back(rhs); 793 innerOpState.addTypes(innerOpType); 794 795 Operation *innerOp = builder.createOperation(innerOpState); 796 797 // Insert a return statement in the block returning the inner-op's result. 798 builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults()); 799 800 // Populate the op operation-state with result-type and location. 801 result.addTypes(opFntype.getResults()); 802 result.location = innerOp->getLoc(); 803 804 return success(); 805 } 806 807 static void print(OpAsmPrinter &p, PrettyPrintedRegionOp op) { 808 p << ' '; 809 p.printOperands(op.getOperands()); 810 811 Operation &innerOp = op.getRegion().front().front(); 812 // Assuming that region has a single non-terminator inner-op, if the inner-op 813 // meets some criteria (which in this case is a simple one based on the name 814 // of inner-op), then we can print the entire region in a succinct way. 815 // Here we assume that the prototype of "special.op" can be trivially derived 816 // while parsing it back. 817 if (innerOp.getName().getStringRef().equals("special.op")) { 818 p << " start special.op end"; 819 } else { 820 p << " ("; 821 p.printRegion(op.getRegion()); 822 p << ")"; 823 } 824 825 p << " : "; 826 p.printFunctionalType(op); 827 } 828 829 //===----------------------------------------------------------------------===// 830 // Test PolyForOp - parse list of region arguments. 831 //===----------------------------------------------------------------------===// 832 833 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { 834 SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 835 // Parse list of region arguments without a delimiter. 836 if (parser.parseRegionArgumentList(ivsInfo)) 837 return failure(); 838 839 // Parse the body region. 840 Region *body = result.addRegion(); 841 auto &builder = parser.getBuilder(); 842 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 843 return parser.parseRegion(*body, ivsInfo, argTypes); 844 } 845 846 void PolyForOp::getAsmBlockArgumentNames(Region ®ion, 847 OpAsmSetValueNameFn setNameFn) { 848 auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names"); 849 if (!arrayAttr) 850 return; 851 auto args = getRegion().front().getArguments(); 852 auto e = std::min(arrayAttr.size(), args.size()); 853 for (unsigned i = 0; i < e; ++i) { 854 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 855 setNameFn(args[i], strAttr.getValue()); 856 } 857 } 858 859 //===----------------------------------------------------------------------===// 860 // Test removing op with inner ops. 861 //===----------------------------------------------------------------------===// 862 863 namespace { 864 struct TestRemoveOpWithInnerOps 865 : public OpRewritePattern<TestOpWithRegionPattern> { 866 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 867 868 void initialize() { setDebugName("TestRemoveOpWithInnerOps"); } 869 870 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 871 PatternRewriter &rewriter) const override { 872 rewriter.eraseOp(op); 873 return success(); 874 } 875 }; 876 } // namespace 877 878 void TestOpWithRegionPattern::getCanonicalizationPatterns( 879 RewritePatternSet &results, MLIRContext *context) { 880 results.add<TestRemoveOpWithInnerOps>(context); 881 } 882 883 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 884 return getOperand(); 885 } 886 887 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) { 888 return getValue(); 889 } 890 891 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 892 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 893 for (Value input : this->getOperands()) { 894 results.push_back(input); 895 } 896 return success(); 897 } 898 899 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 900 assert(operands.size() == 1); 901 if (operands.front()) { 902 (*this)->setAttr("attr", operands.front()); 903 return getResult(); 904 } 905 return {}; 906 } 907 908 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) { 909 return getOperand(); 910 } 911 912 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 913 MLIRContext *, Optional<Location> location, ValueRange operands, 914 DictionaryAttr attributes, RegionRange regions, 915 SmallVectorImpl<Type> &inferredReturnTypes) { 916 if (operands[0].getType() != operands[1].getType()) { 917 return emitOptionalError(location, "operand type mismatch ", 918 operands[0].getType(), " vs ", 919 operands[1].getType()); 920 } 921 inferredReturnTypes.assign({operands[0].getType()}); 922 return success(); 923 } 924 925 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 926 MLIRContext *context, Optional<Location> location, ValueShapeRange operands, 927 DictionaryAttr attributes, RegionRange regions, 928 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 929 // Create return type consisting of the last element of the first operand. 930 auto operandType = operands.front().getType(); 931 auto sval = operandType.dyn_cast<ShapedType>(); 932 if (!sval) { 933 return emitOptionalError(location, "only shaped type operands allowed"); 934 } 935 int64_t dim = 936 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 937 auto type = IntegerType::get(context, 17); 938 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 939 return success(); 940 } 941 942 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 943 OpBuilder &builder, ValueRange operands, 944 llvm::SmallVectorImpl<Value> &shapes) { 945 shapes = SmallVector<Value, 1>{ 946 builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)}; 947 return success(); 948 } 949 950 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( 951 OpBuilder &builder, ValueRange operands, 952 llvm::SmallVectorImpl<Value> &shapes) { 953 Location loc = getLoc(); 954 shapes.reserve(operands.size()); 955 for (Value operand : llvm::reverse(operands)) { 956 auto rank = operand.getType().cast<RankedTensorType>().getRank(); 957 auto currShape = llvm::to_vector<4>( 958 llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value { 959 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 960 })); 961 shapes.push_back(builder.create<tensor::FromElementsOp>( 962 getLoc(), RankedTensorType::get({rank}, builder.getIndexType()), 963 currShape)); 964 } 965 return success(); 966 } 967 968 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( 969 OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { 970 Location loc = getLoc(); 971 shapes.reserve(getNumOperands()); 972 for (Value operand : llvm::reverse(getOperands())) { 973 auto currShape = llvm::to_vector<4>(llvm::map_range( 974 llvm::seq<int64_t>( 975 0, operand.getType().cast<RankedTensorType>().getRank()), 976 [&](int64_t dim) -> Value { 977 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 978 })); 979 shapes.emplace_back(std::move(currShape)); 980 } 981 return success(); 982 } 983 984 //===----------------------------------------------------------------------===// 985 // Test SideEffect interfaces 986 //===----------------------------------------------------------------------===// 987 988 namespace { 989 /// A test resource for side effects. 990 struct TestResource : public SideEffects::Resource::Base<TestResource> { 991 StringRef getName() final { return "<Test>"; } 992 }; 993 } // namespace 994 995 static void testSideEffectOpGetEffect( 996 Operation *op, 997 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 998 &effects) { 999 auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter"); 1000 if (!effectsAttr) 1001 return; 1002 1003 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); 1004 } 1005 1006 void SideEffectOp::getEffects( 1007 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1008 // Check for an effects attribute on the op instance. 1009 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects"); 1010 if (!effectsAttr) 1011 return; 1012 1013 // If there is one, it is an array of dictionary attributes that hold 1014 // information on the effects of this operation. 1015 for (Attribute element : effectsAttr) { 1016 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 1017 1018 // Get the specific memory effect. 1019 MemoryEffects::Effect *effect = 1020 StringSwitch<MemoryEffects::Effect *>( 1021 effectElement.get("effect").cast<StringAttr>().getValue()) 1022 .Case("allocate", MemoryEffects::Allocate::get()) 1023 .Case("free", MemoryEffects::Free::get()) 1024 .Case("read", MemoryEffects::Read::get()) 1025 .Case("write", MemoryEffects::Write::get()); 1026 1027 // Check for a non-default resource to use. 1028 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 1029 if (effectElement.get("test_resource")) 1030 resource = TestResource::get(); 1031 1032 // Check for a result to affect. 1033 if (effectElement.get("on_result")) 1034 effects.emplace_back(effect, getResult(), resource); 1035 else if (Attribute ref = effectElement.get("on_reference")) 1036 effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource); 1037 else 1038 effects.emplace_back(effect, resource); 1039 } 1040 } 1041 1042 void SideEffectOp::getEffects( 1043 SmallVectorImpl<TestEffects::EffectInstance> &effects) { 1044 testSideEffectOpGetEffect(getOperation(), effects); 1045 } 1046 1047 //===----------------------------------------------------------------------===// 1048 // StringAttrPrettyNameOp 1049 //===----------------------------------------------------------------------===// 1050 1051 // This op has fancy handling of its SSA result name. 1052 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, 1053 OperationState &result) { 1054 // Add the result types. 1055 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 1056 result.addTypes(parser.getBuilder().getIntegerType(32)); 1057 1058 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 1059 return failure(); 1060 1061 // If the attribute dictionary contains no 'names' attribute, infer it from 1062 // the SSA name (if specified). 1063 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 1064 return attr.getName() == "names"; 1065 }); 1066 1067 // If there was no name specified, check to see if there was a useful name 1068 // specified in the asm file. 1069 if (hadNames || parser.getNumResults() == 0) 1070 return success(); 1071 1072 SmallVector<StringRef, 4> names; 1073 auto *context = result.getContext(); 1074 1075 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 1076 auto resultName = parser.getResultName(i); 1077 StringRef nameStr; 1078 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 1079 nameStr = resultName.first; 1080 1081 names.push_back(nameStr); 1082 } 1083 1084 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 1085 result.attributes.push_back({StringAttr::get(context, "names"), namesAttr}); 1086 return success(); 1087 } 1088 1089 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { 1090 // Note that we only need to print the "name" attribute if the asmprinter 1091 // result name disagrees with it. This can happen in strange cases, e.g. 1092 // when there are conflicts. 1093 bool namesDisagree = op.getNames().size() != op.getNumResults(); 1094 1095 SmallString<32> resultNameStr; 1096 for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { 1097 resultNameStr.clear(); 1098 llvm::raw_svector_ostream tmpStream(resultNameStr); 1099 p.printOperand(op.getResult(i), tmpStream); 1100 1101 auto expectedName = op.getNames()[i].dyn_cast<StringAttr>(); 1102 if (!expectedName || 1103 tmpStream.str().drop_front() != expectedName.getValue()) { 1104 namesDisagree = true; 1105 } 1106 } 1107 1108 if (namesDisagree) 1109 p.printOptionalAttrDictWithKeyword(op->getAttrs()); 1110 else 1111 p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"}); 1112 } 1113 1114 // We set the SSA name in the asm syntax to the contents of the name 1115 // attribute. 1116 void StringAttrPrettyNameOp::getAsmResultNames( 1117 function_ref<void(Value, StringRef)> setNameFn) { 1118 1119 auto value = getNames(); 1120 for (size_t i = 0, e = value.size(); i != e; ++i) 1121 if (auto str = value[i].dyn_cast<StringAttr>()) 1122 if (!str.getValue().empty()) 1123 setNameFn(getResult(i), str.getValue()); 1124 } 1125 1126 //===----------------------------------------------------------------------===// 1127 // RegionIfOp 1128 //===----------------------------------------------------------------------===// 1129 1130 static void print(OpAsmPrinter &p, RegionIfOp op) { 1131 p << " "; 1132 p.printOperands(op.getOperands()); 1133 p << ": " << op.getOperandTypes(); 1134 p.printArrowTypeList(op.getResultTypes()); 1135 p << " then"; 1136 p.printRegion(op.getThenRegion(), 1137 /*printEntryBlockArgs=*/true, 1138 /*printBlockTerminators=*/true); 1139 p << " else"; 1140 p.printRegion(op.getElseRegion(), 1141 /*printEntryBlockArgs=*/true, 1142 /*printBlockTerminators=*/true); 1143 p << " join"; 1144 p.printRegion(op.getJoinRegion(), 1145 /*printEntryBlockArgs=*/true, 1146 /*printBlockTerminators=*/true); 1147 } 1148 1149 static ParseResult parseRegionIfOp(OpAsmParser &parser, 1150 OperationState &result) { 1151 SmallVector<OpAsmParser::OperandType, 2> operandInfos; 1152 SmallVector<Type, 2> operandTypes; 1153 1154 result.regions.reserve(3); 1155 Region *thenRegion = result.addRegion(); 1156 Region *elseRegion = result.addRegion(); 1157 Region *joinRegion = result.addRegion(); 1158 1159 // Parse operand, type and arrow type lists. 1160 if (parser.parseOperandList(operandInfos) || 1161 parser.parseColonTypeList(operandTypes) || 1162 parser.parseArrowTypeList(result.types)) 1163 return failure(); 1164 1165 // Parse all attached regions. 1166 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 1167 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 1168 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 1169 return failure(); 1170 1171 return parser.resolveOperands(operandInfos, operandTypes, 1172 parser.getCurrentLocation(), result.operands); 1173 } 1174 1175 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 1176 assert(index < 2 && "invalid region index"); 1177 return getOperands(); 1178 } 1179 1180 void RegionIfOp::getSuccessorRegions( 1181 Optional<unsigned> index, ArrayRef<Attribute> operands, 1182 SmallVectorImpl<RegionSuccessor> ®ions) { 1183 // We always branch to the join region. 1184 if (index.hasValue()) { 1185 if (index.getValue() < 2) 1186 regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs())); 1187 else 1188 regions.push_back(RegionSuccessor(getResults())); 1189 return; 1190 } 1191 1192 // The then and else regions are the entry regions of this op. 1193 regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs())); 1194 regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs())); 1195 } 1196 1197 //===----------------------------------------------------------------------===// 1198 // SingleNoTerminatorCustomAsmOp 1199 //===----------------------------------------------------------------------===// 1200 1201 static ParseResult parseSingleNoTerminatorCustomAsmOp(OpAsmParser &parser, 1202 OperationState &state) { 1203 Region *body = state.addRegion(); 1204 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) 1205 return failure(); 1206 return success(); 1207 } 1208 1209 static void print(SingleNoTerminatorCustomAsmOp op, OpAsmPrinter &printer) { 1210 printer.printRegion( 1211 op.getRegion(), /*printEntryBlockArgs=*/false, 1212 // This op has a single block without terminators. But explicitly mark 1213 // as not printing block terminators for testing. 1214 /*printBlockTerminators=*/false); 1215 } 1216 1217 #include "TestOpEnums.cpp.inc" 1218 #include "TestOpInterfaces.cpp.inc" 1219 #include "TestOpStructs.cpp.inc" 1220 #include "TestTypeInterfaces.cpp.inc" 1221 1222 #define GET_OP_CLASSES 1223 #include "TestOps.cpp.inc" 1224