1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestAttributes.h" 11 #include "TestInterfaces.h" 12 #include "TestTypes.h" 13 #include "mlir/Dialect/DLTI/DLTI.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/DialectImplementation.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Reducer/ReductionPatternInterface.h" 21 #include "mlir/Transforms/FoldUtils.h" 22 #include "mlir/Transforms/InliningUtils.h" 23 #include "llvm/ADT/StringSwitch.h" 24 25 // Include this before the using namespace lines below to 26 // test that we don't have namespace dependencies. 27 #include "TestOpsDialect.cpp.inc" 28 29 using namespace mlir; 30 using namespace test; 31 32 void test::registerTestDialect(DialectRegistry ®istry) { 33 registry.insert<TestDialect>(); 34 } 35 36 //===----------------------------------------------------------------------===// 37 // TestDialect Interfaces 38 //===----------------------------------------------------------------------===// 39 40 namespace { 41 42 /// Testing the correctness of some traits. 43 static_assert( 44 llvm::is_detected<OpTrait::has_implicit_terminator_t, 45 SingleBlockImplicitTerminatorOp>::value, 46 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp"); 47 static_assert(OpTrait::hasSingleBlockImplicitTerminator< 48 SingleBlockImplicitTerminatorOp>::value, 49 "hasSingleBlockImplicitTerminator does not match " 50 "SingleBlockImplicitTerminatorOp"); 51 52 // Test support for interacting with the AsmPrinter. 53 struct TestOpAsmInterface : public OpAsmDialectInterface { 54 using OpAsmDialectInterface::OpAsmDialectInterface; 55 56 AliasResult getAlias(Attribute attr, raw_ostream &os) const final { 57 StringAttr strAttr = attr.dyn_cast<StringAttr>(); 58 if (!strAttr) 59 return AliasResult::NoAlias; 60 61 // Check the contents of the string attribute to see what the test alias 62 // should be named. 63 Optional<StringRef> aliasName = 64 StringSwitch<Optional<StringRef>>(strAttr.getValue()) 65 .Case("alias_test:dot_in_name", StringRef("test.alias")) 66 .Case("alias_test:trailing_digit", StringRef("test_alias0")) 67 .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) 68 .Case("alias_test:sanitize_conflict_a", 69 StringRef("test_alias_conflict0")) 70 .Case("alias_test:sanitize_conflict_b", 71 StringRef("test_alias_conflict0_")) 72 .Case("alias_test:tensor_encoding", StringRef("test_encoding")) 73 .Default(llvm::None); 74 if (!aliasName) 75 return AliasResult::NoAlias; 76 77 os << *aliasName; 78 return AliasResult::FinalAlias; 79 } 80 81 AliasResult getAlias(Type type, raw_ostream &os) const final { 82 if (auto tupleType = type.dyn_cast<TupleType>()) { 83 if (tupleType.size() > 0 && 84 llvm::all_of(tupleType.getTypes(), [](Type elemType) { 85 return elemType.isa<SimpleAType>(); 86 })) { 87 os << "test_tuple"; 88 return AliasResult::FinalAlias; 89 } 90 } 91 if (auto intType = type.dyn_cast<TestIntegerType>()) { 92 if (intType.getSignedness() == 93 TestIntegerType::SignednessSemantics::Unsigned && 94 intType.getWidth() == 8) { 95 os << "test_ui8"; 96 return AliasResult::FinalAlias; 97 } 98 } 99 return AliasResult::NoAlias; 100 } 101 102 void getAsmResultNames(Operation *op, 103 OpAsmSetValueNameFn setNameFn) const final { 104 if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) 105 setNameFn(asmOp, "result"); 106 } 107 108 void getAsmBlockArgumentNames(Block *block, 109 OpAsmSetValueNameFn setNameFn) const final { 110 auto op = block->getParentOp(); 111 auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names"); 112 if (!arrayAttr) 113 return; 114 auto args = block->getArguments(); 115 auto e = std::min(arrayAttr.size(), args.size()); 116 for (unsigned i = 0; i < e; ++i) { 117 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 118 setNameFn(args[i], strAttr.getValue()); 119 } 120 } 121 }; 122 123 struct TestDialectFoldInterface : public DialectFoldInterface { 124 using DialectFoldInterface::DialectFoldInterface; 125 126 /// Registered hook to check if the given region, which is attached to an 127 /// operation that is *not* isolated from above, should be used when 128 /// materializing constants. 129 bool shouldMaterializeInto(Region *region) const final { 130 // If this is a one region operation, then insert into it. 131 return isa<OneRegionOp>(region->getParentOp()); 132 } 133 }; 134 135 /// This class defines the interface for handling inlining with standard 136 /// operations. 137 struct TestInlinerInterface : public DialectInlinerInterface { 138 using DialectInlinerInterface::DialectInlinerInterface; 139 140 //===--------------------------------------------------------------------===// 141 // Analysis Hooks 142 //===--------------------------------------------------------------------===// 143 144 bool isLegalToInline(Operation *call, Operation *callable, 145 bool wouldBeCloned) const final { 146 // Don't allow inlining calls that are marked `noinline`. 147 return !call->hasAttr("noinline"); 148 } 149 bool isLegalToInline(Region *, Region *, bool, 150 BlockAndValueMapping &) const final { 151 // Inlining into test dialect regions is legal. 152 return true; 153 } 154 bool isLegalToInline(Operation *, Region *, bool, 155 BlockAndValueMapping &) const final { 156 return true; 157 } 158 159 bool shouldAnalyzeRecursively(Operation *op) const final { 160 // Analyze recursively if this is not a functional region operation, it 161 // froms a separate functional scope. 162 return !isa<FunctionalRegionOp>(op); 163 } 164 165 //===--------------------------------------------------------------------===// 166 // Transformation Hooks 167 //===--------------------------------------------------------------------===// 168 169 /// Handle the given inlined terminator by replacing it with a new operation 170 /// as necessary. 171 void handleTerminator(Operation *op, 172 ArrayRef<Value> valuesToRepl) const final { 173 // Only handle "test.return" here. 174 auto returnOp = dyn_cast<TestReturnOp>(op); 175 if (!returnOp) 176 return; 177 178 // Replace the values directly with the return operands. 179 assert(returnOp.getNumOperands() == valuesToRepl.size()); 180 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 181 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 182 } 183 184 /// Attempt to materialize a conversion for a type mismatch between a call 185 /// from this dialect, and a callable region. This method should generate an 186 /// operation that takes 'input' as the only operand, and produces a single 187 /// result of 'resultType'. If a conversion can not be generated, nullptr 188 /// should be returned. 189 Operation *materializeCallConversion(OpBuilder &builder, Value input, 190 Type resultType, 191 Location conversionLoc) const final { 192 // Only allow conversion for i16/i32 types. 193 if (!(resultType.isSignlessInteger(16) || 194 resultType.isSignlessInteger(32)) || 195 !(input.getType().isSignlessInteger(16) || 196 input.getType().isSignlessInteger(32))) 197 return nullptr; 198 return builder.create<TestCastOp>(conversionLoc, resultType, input); 199 } 200 201 void processInlinedCallBlocks( 202 Operation *call, 203 iterator_range<Region::iterator> inlinedBlocks) const final { 204 if (!isa<ConversionCallOp>(call)) 205 return; 206 207 // Set attributed on all ops in the inlined blocks. 208 for (Block &block : inlinedBlocks) { 209 block.walk([&](Operation *op) { 210 op->setAttr("inlined_conversion", UnitAttr::get(call->getContext())); 211 }); 212 } 213 } 214 }; 215 216 struct TestReductionPatternInterface : public DialectReductionPatternInterface { 217 public: 218 TestReductionPatternInterface(Dialect *dialect) 219 : DialectReductionPatternInterface(dialect) {} 220 221 void populateReductionPatterns(RewritePatternSet &patterns) const final { 222 populateTestReductionPatterns(patterns); 223 } 224 }; 225 226 } // end anonymous namespace 227 228 //===----------------------------------------------------------------------===// 229 // TestDialect 230 //===----------------------------------------------------------------------===// 231 232 static void testSideEffectOpGetEffect( 233 Operation *op, 234 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects); 235 236 // This is the implementation of a dialect fallback for `TestEffectOpInterface`. 237 struct TestOpEffectInterfaceFallback 238 : public TestEffectOpInterface::FallbackModel< 239 TestOpEffectInterfaceFallback> { 240 static bool classof(Operation *op) { 241 bool isSupportedOp = 242 op->getName().getStringRef() == "test.unregistered_side_effect_op"; 243 assert(isSupportedOp && "Unexpected dispatch"); 244 return isSupportedOp; 245 } 246 247 void 248 getEffects(Operation *op, 249 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 250 &effects) const { 251 testSideEffectOpGetEffect(op, effects); 252 } 253 }; 254 255 void TestDialect::initialize() { 256 registerAttributes(); 257 registerTypes(); 258 addOperations< 259 #define GET_OP_LIST 260 #include "TestOps.cpp.inc" 261 >(); 262 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 263 TestInlinerInterface, TestReductionPatternInterface>(); 264 allowUnknownOperations(); 265 266 // Instantiate our fallback op interface that we'll use on specific 267 // unregistered op. 268 fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback; 269 } 270 TestDialect::~TestDialect() { 271 delete static_cast<TestOpEffectInterfaceFallback *>( 272 fallbackEffectOpInterfaces); 273 } 274 275 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, 276 Type type, Location loc) { 277 return builder.create<TestOpConstant>(loc, type, value); 278 } 279 280 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, 281 OperationName opName) { 282 if (opName.getIdentifier() == "test.unregistered_side_effect_op" && 283 typeID == TypeID::get<TestEffectOpInterface>()) 284 return fallbackEffectOpInterfaces; 285 return nullptr; 286 } 287 288 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 289 NamedAttribute namedAttr) { 290 if (namedAttr.first == "test.invalid_attr") 291 return op->emitError() << "invalid to use 'test.invalid_attr'"; 292 return success(); 293 } 294 295 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 296 unsigned regionIndex, 297 unsigned argIndex, 298 NamedAttribute namedAttr) { 299 if (namedAttr.first == "test.invalid_attr") 300 return op->emitError() << "invalid to use 'test.invalid_attr'"; 301 return success(); 302 } 303 304 LogicalResult 305 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 306 unsigned resultIndex, 307 NamedAttribute namedAttr) { 308 if (namedAttr.first == "test.invalid_attr") 309 return op->emitError() << "invalid to use 'test.invalid_attr'"; 310 return success(); 311 } 312 313 Optional<Dialect::ParseOpHook> 314 TestDialect::getParseOperationHook(StringRef opName) const { 315 if (opName == "test.dialect_custom_printer") { 316 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 317 return parser.parseKeyword("custom_format"); 318 }}; 319 } 320 return None; 321 } 322 323 llvm::unique_function<void(Operation *, OpAsmPrinter &)> 324 TestDialect::getOperationPrinter(Operation *op) const { 325 StringRef opName = op->getName().getStringRef(); 326 if (opName == "test.dialect_custom_printer") { 327 return [](Operation *op, OpAsmPrinter &printer) { 328 printer.getStream() << " custom_format"; 329 }; 330 } 331 return {}; 332 } 333 334 //===----------------------------------------------------------------------===// 335 // TestBranchOp 336 //===----------------------------------------------------------------------===// 337 338 Optional<MutableOperandRange> 339 TestBranchOp::getMutableSuccessorOperands(unsigned index) { 340 assert(index == 0 && "invalid successor index"); 341 return targetOperandsMutable(); 342 } 343 344 //===----------------------------------------------------------------------===// 345 // TestDialectCanonicalizerOp 346 //===----------------------------------------------------------------------===// 347 348 static LogicalResult 349 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op, 350 PatternRewriter &rewriter) { 351 rewriter.replaceOpWithNewOp<ConstantOp>(op, rewriter.getI32Type(), 352 rewriter.getI32IntegerAttr(42)); 353 return success(); 354 } 355 356 void TestDialect::getCanonicalizationPatterns( 357 RewritePatternSet &results) const { 358 results.add(&dialectCanonicalizationPattern); 359 } 360 361 //===----------------------------------------------------------------------===// 362 // TestFoldToCallOp 363 //===----------------------------------------------------------------------===// 364 365 namespace { 366 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 367 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 368 369 LogicalResult matchAndRewrite(FoldToCallOp op, 370 PatternRewriter &rewriter) const override { 371 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(), 372 ValueRange()); 373 return success(); 374 } 375 }; 376 } // end anonymous namespace 377 378 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results, 379 MLIRContext *context) { 380 results.add<FoldToCallOpPattern>(context); 381 } 382 383 //===----------------------------------------------------------------------===// 384 // Test Format* operations 385 //===----------------------------------------------------------------------===// 386 387 //===----------------------------------------------------------------------===// 388 // Parsing 389 390 static ParseResult parseCustomDirectiveOperands( 391 OpAsmParser &parser, OpAsmParser::OperandType &operand, 392 Optional<OpAsmParser::OperandType> &optOperand, 393 SmallVectorImpl<OpAsmParser::OperandType> &varOperands) { 394 if (parser.parseOperand(operand)) 395 return failure(); 396 if (succeeded(parser.parseOptionalComma())) { 397 optOperand.emplace(); 398 if (parser.parseOperand(*optOperand)) 399 return failure(); 400 } 401 if (parser.parseArrow() || parser.parseLParen() || 402 parser.parseOperandList(varOperands) || parser.parseRParen()) 403 return failure(); 404 return success(); 405 } 406 static ParseResult 407 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 408 Type &optOperandType, 409 SmallVectorImpl<Type> &varOperandTypes) { 410 if (parser.parseColon()) 411 return failure(); 412 413 if (parser.parseType(operandType)) 414 return failure(); 415 if (succeeded(parser.parseOptionalComma())) { 416 if (parser.parseType(optOperandType)) 417 return failure(); 418 } 419 if (parser.parseArrow() || parser.parseLParen() || 420 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 421 return failure(); 422 return success(); 423 } 424 static ParseResult 425 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 426 Type optOperandType, 427 const SmallVectorImpl<Type> &varOperandTypes) { 428 if (parser.parseKeyword("type_refs_capture")) 429 return failure(); 430 431 Type operandType2, optOperandType2; 432 SmallVector<Type, 1> varOperandTypes2; 433 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 434 varOperandTypes2)) 435 return failure(); 436 437 if (operandType != operandType2 || optOperandType != optOperandType2 || 438 varOperandTypes != varOperandTypes2) 439 return failure(); 440 441 return success(); 442 } 443 static ParseResult parseCustomDirectiveOperandsAndTypes( 444 OpAsmParser &parser, OpAsmParser::OperandType &operand, 445 Optional<OpAsmParser::OperandType> &optOperand, 446 SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType, 447 Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) { 448 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 449 parseCustomDirectiveResults(parser, operandType, optOperandType, 450 varOperandTypes)) 451 return failure(); 452 return success(); 453 } 454 static ParseResult parseCustomDirectiveRegions( 455 OpAsmParser &parser, Region ®ion, 456 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 457 if (parser.parseRegion(region)) 458 return failure(); 459 if (failed(parser.parseOptionalComma())) 460 return success(); 461 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 462 if (parser.parseRegion(*varRegion)) 463 return failure(); 464 varRegions.emplace_back(std::move(varRegion)); 465 return success(); 466 } 467 static ParseResult 468 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 469 SmallVectorImpl<Block *> &varSuccessors) { 470 if (parser.parseSuccessor(successor)) 471 return failure(); 472 if (failed(parser.parseOptionalComma())) 473 return success(); 474 Block *varSuccessor; 475 if (parser.parseSuccessor(varSuccessor)) 476 return failure(); 477 varSuccessors.append(2, varSuccessor); 478 return success(); 479 } 480 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 481 IntegerAttr &attr, 482 IntegerAttr &optAttr) { 483 if (parser.parseAttribute(attr)) 484 return failure(); 485 if (succeeded(parser.parseOptionalComma())) { 486 if (parser.parseAttribute(optAttr)) 487 return failure(); 488 } 489 return success(); 490 } 491 492 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, 493 NamedAttrList &attrs) { 494 return parser.parseOptionalAttrDict(attrs); 495 } 496 static ParseResult parseCustomDirectiveOptionalOperandRef( 497 OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) { 498 int64_t operandCount = 0; 499 if (parser.parseInteger(operandCount)) 500 return failure(); 501 bool expectedOptionalOperand = operandCount == 0; 502 return success(expectedOptionalOperand != optOperand.hasValue()); 503 } 504 505 //===----------------------------------------------------------------------===// 506 // Printing 507 508 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 509 Value operand, Value optOperand, 510 OperandRange varOperands) { 511 printer << operand; 512 if (optOperand) 513 printer << ", " << optOperand; 514 printer << " -> (" << varOperands << ")"; 515 } 516 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 517 Type operandType, Type optOperandType, 518 TypeRange varOperandTypes) { 519 printer << " : " << operandType; 520 if (optOperandType) 521 printer << ", " << optOperandType; 522 printer << " -> (" << varOperandTypes << ")"; 523 } 524 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 525 Operation *op, Type operandType, 526 Type optOperandType, 527 TypeRange varOperandTypes) { 528 printer << " type_refs_capture "; 529 printCustomDirectiveResults(printer, op, operandType, optOperandType, 530 varOperandTypes); 531 } 532 static void printCustomDirectiveOperandsAndTypes( 533 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 534 OperandRange varOperands, Type operandType, Type optOperandType, 535 TypeRange varOperandTypes) { 536 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 537 printCustomDirectiveResults(printer, op, operandType, optOperandType, 538 varOperandTypes); 539 } 540 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 541 Region ®ion, 542 MutableArrayRef<Region> varRegions) { 543 printer.printRegion(region); 544 if (!varRegions.empty()) { 545 printer << ", "; 546 for (Region ®ion : varRegions) 547 printer.printRegion(region); 548 } 549 } 550 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 551 Block *successor, 552 SuccessorRange varSuccessors) { 553 printer << successor; 554 if (!varSuccessors.empty()) 555 printer << ", " << varSuccessors.front(); 556 } 557 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 558 Attribute attribute, 559 Attribute optAttribute) { 560 printer << attribute; 561 if (optAttribute) 562 printer << ", " << optAttribute; 563 } 564 565 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 566 DictionaryAttr attrs) { 567 printer.printOptionalAttrDict(attrs.getValue()); 568 } 569 570 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, 571 Operation *op, 572 Value optOperand) { 573 printer << (optOperand ? "1" : "0"); 574 } 575 576 //===----------------------------------------------------------------------===// 577 // Test IsolatedRegionOp - parse passthrough region arguments. 578 //===----------------------------------------------------------------------===// 579 580 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, 581 OperationState &result) { 582 OpAsmParser::OperandType argInfo; 583 Type argType = parser.getBuilder().getIndexType(); 584 585 // Parse the input operand. 586 if (parser.parseOperand(argInfo) || 587 parser.resolveOperand(argInfo, argType, result.operands)) 588 return failure(); 589 590 // Parse the body region, and reuse the operand info as the argument info. 591 Region *body = result.addRegion(); 592 return parser.parseRegion(*body, argInfo, argType, 593 /*enableNameShadowing=*/true); 594 } 595 596 static void print(OpAsmPrinter &p, IsolatedRegionOp op) { 597 p << "test.isolated_region "; 598 p.printOperand(op.getOperand()); 599 p.shadowRegionArgs(op.region(), op.getOperand()); 600 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 601 } 602 603 //===----------------------------------------------------------------------===// 604 // Test SSACFGRegionOp 605 //===----------------------------------------------------------------------===// 606 607 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 608 return RegionKind::SSACFG; 609 } 610 611 //===----------------------------------------------------------------------===// 612 // Test GraphRegionOp 613 //===----------------------------------------------------------------------===// 614 615 static ParseResult parseGraphRegionOp(OpAsmParser &parser, 616 OperationState &result) { 617 // Parse the body region, and reuse the operand info as the argument info. 618 Region *body = result.addRegion(); 619 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 620 } 621 622 static void print(OpAsmPrinter &p, GraphRegionOp op) { 623 p << "test.graph_region "; 624 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 625 } 626 627 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 628 return RegionKind::Graph; 629 } 630 631 //===----------------------------------------------------------------------===// 632 // Test AffineScopeOp 633 //===----------------------------------------------------------------------===// 634 635 static ParseResult parseAffineScopeOp(OpAsmParser &parser, 636 OperationState &result) { 637 // Parse the body region, and reuse the operand info as the argument info. 638 Region *body = result.addRegion(); 639 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 640 } 641 642 static void print(OpAsmPrinter &p, AffineScopeOp op) { 643 p << "test.affine_scope "; 644 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 645 } 646 647 //===----------------------------------------------------------------------===// 648 // Test parser. 649 //===----------------------------------------------------------------------===// 650 651 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser, 652 OperationState &result) { 653 if (parser.parseOptionalColon()) 654 return success(); 655 uint64_t numResults; 656 if (parser.parseInteger(numResults)) 657 return failure(); 658 659 IndexType type = parser.getBuilder().getIndexType(); 660 for (unsigned i = 0; i < numResults; ++i) 661 result.addTypes(type); 662 return success(); 663 } 664 665 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) { 666 if (unsigned numResults = op->getNumResults()) 667 p << " : " << numResults; 668 } 669 670 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser, 671 OperationState &result) { 672 StringRef keyword; 673 if (parser.parseKeyword(&keyword)) 674 return failure(); 675 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 676 return success(); 677 } 678 679 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) { 680 p << " " << op.keyword(); 681 } 682 683 //===----------------------------------------------------------------------===// 684 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 685 686 static ParseResult parseWrappingRegionOp(OpAsmParser &parser, 687 OperationState &result) { 688 if (parser.parseKeyword("wraps")) 689 return failure(); 690 691 // Parse the wrapped op in a region 692 Region &body = *result.addRegion(); 693 body.push_back(new Block); 694 Block &block = body.back(); 695 Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); 696 if (!wrapped_op) 697 return failure(); 698 699 // Create a return terminator in the inner region, pass as operand to the 700 // terminator the returned values from the wrapped operation. 701 SmallVector<Value, 8> return_operands(wrapped_op->getResults()); 702 OpBuilder builder(parser.getContext()); 703 builder.setInsertionPointToEnd(&block); 704 builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands); 705 706 // Get the results type for the wrapping op from the terminator operands. 707 Operation &return_op = body.back().back(); 708 result.types.append(return_op.operand_type_begin(), 709 return_op.operand_type_end()); 710 711 // Use the location of the wrapped op for the "test.wrapping_region" op. 712 result.location = wrapped_op->getLoc(); 713 714 return success(); 715 } 716 717 static void print(OpAsmPrinter &p, WrappingRegionOp op) { 718 p << " wraps "; 719 p.printGenericOp(&op.region().front().front()); 720 } 721 722 //===----------------------------------------------------------------------===// 723 // Test PolyForOp - parse list of region arguments. 724 //===----------------------------------------------------------------------===// 725 726 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { 727 SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 728 // Parse list of region arguments without a delimiter. 729 if (parser.parseRegionArgumentList(ivsInfo)) 730 return failure(); 731 732 // Parse the body region. 733 Region *body = result.addRegion(); 734 auto &builder = parser.getBuilder(); 735 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 736 return parser.parseRegion(*body, ivsInfo, argTypes); 737 } 738 739 //===----------------------------------------------------------------------===// 740 // Test removing op with inner ops. 741 //===----------------------------------------------------------------------===// 742 743 namespace { 744 struct TestRemoveOpWithInnerOps 745 : public OpRewritePattern<TestOpWithRegionPattern> { 746 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 747 748 void initialize() { setDebugName("TestRemoveOpWithInnerOps"); } 749 750 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 751 PatternRewriter &rewriter) const override { 752 rewriter.eraseOp(op); 753 return success(); 754 } 755 }; 756 } // end anonymous namespace 757 758 void TestOpWithRegionPattern::getCanonicalizationPatterns( 759 RewritePatternSet &results, MLIRContext *context) { 760 results.add<TestRemoveOpWithInnerOps>(context); 761 } 762 763 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 764 return operand(); 765 } 766 767 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) { 768 return getValue(); 769 } 770 771 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 772 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 773 for (Value input : this->operands()) { 774 results.push_back(input); 775 } 776 return success(); 777 } 778 779 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 780 assert(operands.size() == 1); 781 if (operands.front()) { 782 (*this)->setAttr("attr", operands.front()); 783 return getResult(); 784 } 785 return {}; 786 } 787 788 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) { 789 return getOperand(); 790 } 791 792 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 793 MLIRContext *, Optional<Location> location, ValueRange operands, 794 DictionaryAttr attributes, RegionRange regions, 795 SmallVectorImpl<Type> &inferredReturnTypes) { 796 if (operands[0].getType() != operands[1].getType()) { 797 return emitOptionalError(location, "operand type mismatch ", 798 operands[0].getType(), " vs ", 799 operands[1].getType()); 800 } 801 inferredReturnTypes.assign({operands[0].getType()}); 802 return success(); 803 } 804 805 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 806 MLIRContext *context, Optional<Location> location, ValueShapeRange operands, 807 DictionaryAttr attributes, RegionRange regions, 808 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 809 // Create return type consisting of the last element of the first operand. 810 auto operandType = operands.front().getType(); 811 auto sval = operandType.dyn_cast<ShapedType>(); 812 if (!sval) { 813 return emitOptionalError(location, "only shaped type operands allowed"); 814 } 815 int64_t dim = 816 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 817 auto type = IntegerType::get(context, 17); 818 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 819 return success(); 820 } 821 822 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 823 OpBuilder &builder, ValueRange operands, 824 llvm::SmallVectorImpl<Value> &shapes) { 825 shapes = SmallVector<Value, 1>{ 826 builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)}; 827 return success(); 828 } 829 830 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( 831 OpBuilder &builder, ValueRange operands, 832 llvm::SmallVectorImpl<Value> &shapes) { 833 Location loc = getLoc(); 834 shapes.reserve(operands.size()); 835 for (Value operand : llvm::reverse(operands)) { 836 auto currShape = llvm::to_vector<4>(llvm::map_range( 837 llvm::seq<int64_t>( 838 0, operand.getType().cast<RankedTensorType>().getRank()), 839 [&](int64_t dim) -> Value { 840 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 841 })); 842 shapes.push_back(builder.create<tensor::FromElementsOp>( 843 getLoc(), builder.getIndexType(), currShape)); 844 } 845 return success(); 846 } 847 848 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( 849 OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { 850 Location loc = getLoc(); 851 shapes.reserve(getNumOperands()); 852 for (Value operand : llvm::reverse(getOperands())) { 853 auto currShape = llvm::to_vector<4>(llvm::map_range( 854 llvm::seq<int64_t>( 855 0, operand.getType().cast<RankedTensorType>().getRank()), 856 [&](int64_t dim) -> Value { 857 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 858 })); 859 shapes.emplace_back(std::move(currShape)); 860 } 861 return success(); 862 } 863 864 //===----------------------------------------------------------------------===// 865 // Test SideEffect interfaces 866 //===----------------------------------------------------------------------===// 867 868 namespace { 869 /// A test resource for side effects. 870 struct TestResource : public SideEffects::Resource::Base<TestResource> { 871 StringRef getName() final { return "<Test>"; } 872 }; 873 } // end anonymous namespace 874 875 static void testSideEffectOpGetEffect( 876 Operation *op, 877 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 878 &effects) { 879 auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter"); 880 if (!effectsAttr) 881 return; 882 883 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); 884 } 885 886 void SideEffectOp::getEffects( 887 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 888 // Check for an effects attribute on the op instance. 889 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects"); 890 if (!effectsAttr) 891 return; 892 893 // If there is one, it is an array of dictionary attributes that hold 894 // information on the effects of this operation. 895 for (Attribute element : effectsAttr) { 896 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 897 898 // Get the specific memory effect. 899 MemoryEffects::Effect *effect = 900 StringSwitch<MemoryEffects::Effect *>( 901 effectElement.get("effect").cast<StringAttr>().getValue()) 902 .Case("allocate", MemoryEffects::Allocate::get()) 903 .Case("free", MemoryEffects::Free::get()) 904 .Case("read", MemoryEffects::Read::get()) 905 .Case("write", MemoryEffects::Write::get()); 906 907 // Check for a non-default resource to use. 908 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 909 if (effectElement.get("test_resource")) 910 resource = TestResource::get(); 911 912 // Check for a result to affect. 913 if (effectElement.get("on_result")) 914 effects.emplace_back(effect, getResult(), resource); 915 else if (Attribute ref = effectElement.get("on_reference")) 916 effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource); 917 else 918 effects.emplace_back(effect, resource); 919 } 920 } 921 922 void SideEffectOp::getEffects( 923 SmallVectorImpl<TestEffects::EffectInstance> &effects) { 924 testSideEffectOpGetEffect(getOperation(), effects); 925 } 926 927 //===----------------------------------------------------------------------===// 928 // StringAttrPrettyNameOp 929 //===----------------------------------------------------------------------===// 930 931 // This op has fancy handling of its SSA result name. 932 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, 933 OperationState &result) { 934 // Add the result types. 935 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 936 result.addTypes(parser.getBuilder().getIntegerType(32)); 937 938 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 939 return failure(); 940 941 // If the attribute dictionary contains no 'names' attribute, infer it from 942 // the SSA name (if specified). 943 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 944 return attr.first == "names"; 945 }); 946 947 // If there was no name specified, check to see if there was a useful name 948 // specified in the asm file. 949 if (hadNames || parser.getNumResults() == 0) 950 return success(); 951 952 SmallVector<StringRef, 4> names; 953 auto *context = result.getContext(); 954 955 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 956 auto resultName = parser.getResultName(i); 957 StringRef nameStr; 958 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 959 nameStr = resultName.first; 960 961 names.push_back(nameStr); 962 } 963 964 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 965 result.attributes.push_back({Identifier::get("names", context), namesAttr}); 966 return success(); 967 } 968 969 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { 970 // Note that we only need to print the "name" attribute if the asmprinter 971 // result name disagrees with it. This can happen in strange cases, e.g. 972 // when there are conflicts. 973 bool namesDisagree = op.names().size() != op.getNumResults(); 974 975 SmallString<32> resultNameStr; 976 for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { 977 resultNameStr.clear(); 978 llvm::raw_svector_ostream tmpStream(resultNameStr); 979 p.printOperand(op.getResult(i), tmpStream); 980 981 auto expectedName = op.names()[i].dyn_cast<StringAttr>(); 982 if (!expectedName || 983 tmpStream.str().drop_front() != expectedName.getValue()) { 984 namesDisagree = true; 985 } 986 } 987 988 if (namesDisagree) 989 p.printOptionalAttrDictWithKeyword(op->getAttrs()); 990 else 991 p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"}); 992 } 993 994 // We set the SSA name in the asm syntax to the contents of the name 995 // attribute. 996 void StringAttrPrettyNameOp::getAsmResultNames( 997 function_ref<void(Value, StringRef)> setNameFn) { 998 999 auto value = names(); 1000 for (size_t i = 0, e = value.size(); i != e; ++i) 1001 if (auto str = value[i].dyn_cast<StringAttr>()) 1002 if (!str.getValue().empty()) 1003 setNameFn(getResult(i), str.getValue()); 1004 } 1005 1006 //===----------------------------------------------------------------------===// 1007 // RegionIfOp 1008 //===----------------------------------------------------------------------===// 1009 1010 static void print(OpAsmPrinter &p, RegionIfOp op) { 1011 p << " "; 1012 p.printOperands(op.getOperands()); 1013 p << ": " << op.getOperandTypes(); 1014 p.printArrowTypeList(op.getResultTypes()); 1015 p << " then"; 1016 p.printRegion(op.thenRegion(), 1017 /*printEntryBlockArgs=*/true, 1018 /*printBlockTerminators=*/true); 1019 p << " else"; 1020 p.printRegion(op.elseRegion(), 1021 /*printEntryBlockArgs=*/true, 1022 /*printBlockTerminators=*/true); 1023 p << " join"; 1024 p.printRegion(op.joinRegion(), 1025 /*printEntryBlockArgs=*/true, 1026 /*printBlockTerminators=*/true); 1027 } 1028 1029 static ParseResult parseRegionIfOp(OpAsmParser &parser, 1030 OperationState &result) { 1031 SmallVector<OpAsmParser::OperandType, 2> operandInfos; 1032 SmallVector<Type, 2> operandTypes; 1033 1034 result.regions.reserve(3); 1035 Region *thenRegion = result.addRegion(); 1036 Region *elseRegion = result.addRegion(); 1037 Region *joinRegion = result.addRegion(); 1038 1039 // Parse operand, type and arrow type lists. 1040 if (parser.parseOperandList(operandInfos) || 1041 parser.parseColonTypeList(operandTypes) || 1042 parser.parseArrowTypeList(result.types)) 1043 return failure(); 1044 1045 // Parse all attached regions. 1046 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 1047 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 1048 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 1049 return failure(); 1050 1051 return parser.resolveOperands(operandInfos, operandTypes, 1052 parser.getCurrentLocation(), result.operands); 1053 } 1054 1055 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 1056 assert(index < 2 && "invalid region index"); 1057 return getOperands(); 1058 } 1059 1060 void RegionIfOp::getSuccessorRegions( 1061 Optional<unsigned> index, ArrayRef<Attribute> operands, 1062 SmallVectorImpl<RegionSuccessor> ®ions) { 1063 // We always branch to the join region. 1064 if (index.hasValue()) { 1065 if (index.getValue() < 2) 1066 regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs())); 1067 else 1068 regions.push_back(RegionSuccessor(getResults())); 1069 return; 1070 } 1071 1072 // The then and else regions are the entry regions of this op. 1073 regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs())); 1074 regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs())); 1075 } 1076 1077 //===----------------------------------------------------------------------===// 1078 // SingleNoTerminatorCustomAsmOp 1079 //===----------------------------------------------------------------------===// 1080 1081 static ParseResult parseSingleNoTerminatorCustomAsmOp(OpAsmParser &parser, 1082 OperationState &state) { 1083 Region *body = state.addRegion(); 1084 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) 1085 return failure(); 1086 return success(); 1087 } 1088 1089 static void print(SingleNoTerminatorCustomAsmOp op, OpAsmPrinter &printer) { 1090 printer.printRegion( 1091 op.getRegion(), /*printEntryBlockArgs=*/false, 1092 // This op has a single block without terminators. But explicitly mark 1093 // as not printing block terminators for testing. 1094 /*printBlockTerminators=*/false); 1095 } 1096 1097 #include "TestOpEnums.cpp.inc" 1098 #include "TestOpInterfaces.cpp.inc" 1099 #include "TestOpStructs.cpp.inc" 1100 #include "TestTypeInterfaces.cpp.inc" 1101 1102 #define GET_OP_CLASSES 1103 #include "TestOps.cpp.inc" 1104