1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestAttributes.h" 11 #include "TestInterfaces.h" 12 #include "TestTypes.h" 13 #include "mlir/Dialect/DLTI/DLTI.h" 14 #include "mlir/Dialect/MemRef/IR/MemRef.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/DialectImplementation.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/IR/TypeUtilities.h" 21 #include "mlir/Reducer/ReductionPatternInterface.h" 22 #include "mlir/Transforms/FoldUtils.h" 23 #include "mlir/Transforms/InliningUtils.h" 24 #include "llvm/ADT/StringSwitch.h" 25 26 using namespace mlir; 27 using namespace mlir::test; 28 29 #include "TestOpsDialect.cpp.inc" 30 31 void mlir::test::registerTestDialect(DialectRegistry ®istry) { 32 registry.insert<TestDialect>(); 33 } 34 35 //===----------------------------------------------------------------------===// 36 // TestDialect Interfaces 37 //===----------------------------------------------------------------------===// 38 39 namespace { 40 41 /// Testing the correctness of some traits. 42 static_assert( 43 llvm::is_detected<OpTrait::has_implicit_terminator_t, 44 SingleBlockImplicitTerminatorOp>::value, 45 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp"); 46 static_assert(OpTrait::hasSingleBlockImplicitTerminator< 47 SingleBlockImplicitTerminatorOp>::value, 48 "hasSingleBlockImplicitTerminator does not match " 49 "SingleBlockImplicitTerminatorOp"); 50 51 // Test support for interacting with the AsmPrinter. 52 struct TestOpAsmInterface : public OpAsmDialectInterface { 53 using OpAsmDialectInterface::OpAsmDialectInterface; 54 55 LogicalResult getAlias(Attribute attr, raw_ostream &os) const final { 56 StringAttr strAttr = attr.dyn_cast<StringAttr>(); 57 if (!strAttr) 58 return failure(); 59 60 // Check the contents of the string attribute to see what the test alias 61 // should be named. 62 Optional<StringRef> aliasName = 63 StringSwitch<Optional<StringRef>>(strAttr.getValue()) 64 .Case("alias_test:dot_in_name", StringRef("test.alias")) 65 .Case("alias_test:trailing_digit", StringRef("test_alias0")) 66 .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) 67 .Case("alias_test:sanitize_conflict_a", 68 StringRef("test_alias_conflict0")) 69 .Case("alias_test:sanitize_conflict_b", 70 StringRef("test_alias_conflict0_")) 71 .Default(llvm::None); 72 if (!aliasName) 73 return failure(); 74 75 os << *aliasName; 76 return success(); 77 } 78 79 void getAsmResultNames(Operation *op, 80 OpAsmSetValueNameFn setNameFn) const final { 81 if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) 82 setNameFn(asmOp, "result"); 83 } 84 85 void getAsmBlockArgumentNames(Block *block, 86 OpAsmSetValueNameFn setNameFn) const final { 87 auto op = block->getParentOp(); 88 auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names"); 89 if (!arrayAttr) 90 return; 91 auto args = block->getArguments(); 92 auto e = std::min(arrayAttr.size(), args.size()); 93 for (unsigned i = 0; i < e; ++i) { 94 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 95 setNameFn(args[i], strAttr.getValue()); 96 } 97 } 98 }; 99 100 struct TestDialectFoldInterface : public DialectFoldInterface { 101 using DialectFoldInterface::DialectFoldInterface; 102 103 /// Registered hook to check if the given region, which is attached to an 104 /// operation that is *not* isolated from above, should be used when 105 /// materializing constants. 106 bool shouldMaterializeInto(Region *region) const final { 107 // If this is a one region operation, then insert into it. 108 return isa<OneRegionOp>(region->getParentOp()); 109 } 110 }; 111 112 /// This class defines the interface for handling inlining with standard 113 /// operations. 114 struct TestInlinerInterface : public DialectInlinerInterface { 115 using DialectInlinerInterface::DialectInlinerInterface; 116 117 //===--------------------------------------------------------------------===// 118 // Analysis Hooks 119 //===--------------------------------------------------------------------===// 120 121 bool isLegalToInline(Operation *call, Operation *callable, 122 bool wouldBeCloned) const final { 123 // Don't allow inlining calls that are marked `noinline`. 124 return !call->hasAttr("noinline"); 125 } 126 bool isLegalToInline(Region *, Region *, bool, 127 BlockAndValueMapping &) const final { 128 // Inlining into test dialect regions is legal. 129 return true; 130 } 131 bool isLegalToInline(Operation *, Region *, bool, 132 BlockAndValueMapping &) const final { 133 return true; 134 } 135 136 bool shouldAnalyzeRecursively(Operation *op) const final { 137 // Analyze recursively if this is not a functional region operation, it 138 // froms a separate functional scope. 139 return !isa<FunctionalRegionOp>(op); 140 } 141 142 //===--------------------------------------------------------------------===// 143 // Transformation Hooks 144 //===--------------------------------------------------------------------===// 145 146 /// Handle the given inlined terminator by replacing it with a new operation 147 /// as necessary. 148 void handleTerminator(Operation *op, 149 ArrayRef<Value> valuesToRepl) const final { 150 // Only handle "test.return" here. 151 auto returnOp = dyn_cast<TestReturnOp>(op); 152 if (!returnOp) 153 return; 154 155 // Replace the values directly with the return operands. 156 assert(returnOp.getNumOperands() == valuesToRepl.size()); 157 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 158 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 159 } 160 161 /// Attempt to materialize a conversion for a type mismatch between a call 162 /// from this dialect, and a callable region. This method should generate an 163 /// operation that takes 'input' as the only operand, and produces a single 164 /// result of 'resultType'. If a conversion can not be generated, nullptr 165 /// should be returned. 166 Operation *materializeCallConversion(OpBuilder &builder, Value input, 167 Type resultType, 168 Location conversionLoc) const final { 169 // Only allow conversion for i16/i32 types. 170 if (!(resultType.isSignlessInteger(16) || 171 resultType.isSignlessInteger(32)) || 172 !(input.getType().isSignlessInteger(16) || 173 input.getType().isSignlessInteger(32))) 174 return nullptr; 175 return builder.create<TestCastOp>(conversionLoc, resultType, input); 176 } 177 178 void processInlinedCallBlocks( 179 Operation *call, 180 iterator_range<Region::iterator> inlinedBlocks) const final { 181 if (!isa<ConversionCallOp>(call)) 182 return; 183 184 // Set attributed on all ops in the inlined blocks. 185 for (Block &block : inlinedBlocks) { 186 block.walk([&](Operation *op) { 187 op->setAttr("inlined_conversion", UnitAttr::get(call->getContext())); 188 }); 189 } 190 } 191 }; 192 193 struct TestReductionPatternInterface : public DialectReductionPatternInterface { 194 public: 195 TestReductionPatternInterface(Dialect *dialect) 196 : DialectReductionPatternInterface(dialect) {} 197 198 virtual void 199 populateReductionPatterns(RewritePatternSet &patterns) const final { 200 populateTestReductionPatterns(patterns); 201 } 202 }; 203 204 } // end anonymous namespace 205 206 //===----------------------------------------------------------------------===// 207 // TestDialect 208 //===----------------------------------------------------------------------===// 209 210 static void testSideEffectOpGetEffect( 211 Operation *op, 212 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects); 213 214 // This is the implementation of a dialect fallback for `TestEffectOpInterface`. 215 struct TestOpEffectInterfaceFallback 216 : public TestEffectOpInterface::FallbackModel< 217 TestOpEffectInterfaceFallback> { 218 static bool classof(Operation *op) { 219 bool isSupportedOp = 220 op->getName().getStringRef() == "test.unregistered_side_effect_op"; 221 assert(isSupportedOp && "Unexpected dispatch"); 222 return isSupportedOp; 223 } 224 225 void 226 getEffects(Operation *op, 227 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 228 &effects) const { 229 testSideEffectOpGetEffect(op, effects); 230 } 231 }; 232 233 void TestDialect::initialize() { 234 registerAttributes(); 235 registerTypes(); 236 addOperations< 237 #define GET_OP_LIST 238 #include "TestOps.cpp.inc" 239 >(); 240 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 241 TestInlinerInterface, TestReductionPatternInterface>(); 242 allowUnknownOperations(); 243 244 // Instantiate our fallback op interface that we'll use on specific 245 // unregistered op. 246 fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback; 247 } 248 TestDialect::~TestDialect() { 249 delete static_cast<TestOpEffectInterfaceFallback *>( 250 fallbackEffectOpInterfaces); 251 } 252 253 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, 254 Type type, Location loc) { 255 return builder.create<TestOpConstant>(loc, type, value); 256 } 257 258 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, 259 OperationName opName) { 260 if (opName.getIdentifier() == "test.unregistered_side_effect_op" && 261 typeID == TypeID::get<TestEffectOpInterface>()) 262 return fallbackEffectOpInterfaces; 263 return nullptr; 264 } 265 266 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 267 NamedAttribute namedAttr) { 268 if (namedAttr.first == "test.invalid_attr") 269 return op->emitError() << "invalid to use 'test.invalid_attr'"; 270 return success(); 271 } 272 273 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 274 unsigned regionIndex, 275 unsigned argIndex, 276 NamedAttribute namedAttr) { 277 if (namedAttr.first == "test.invalid_attr") 278 return op->emitError() << "invalid to use 'test.invalid_attr'"; 279 return success(); 280 } 281 282 LogicalResult 283 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 284 unsigned resultIndex, 285 NamedAttribute namedAttr) { 286 if (namedAttr.first == "test.invalid_attr") 287 return op->emitError() << "invalid to use 'test.invalid_attr'"; 288 return success(); 289 } 290 291 Optional<Dialect::ParseOpHook> 292 TestDialect::getParseOperationHook(StringRef opName) const { 293 if (opName == "test.dialect_custom_printer") { 294 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 295 return parser.parseKeyword("custom_format"); 296 }}; 297 } 298 return None; 299 } 300 301 LogicalResult TestDialect::printOperation(Operation *op, 302 OpAsmPrinter &printer) const { 303 StringRef opName = op->getName().getStringRef(); 304 if (opName == "test.dialect_custom_printer") { 305 printer.getStream() << opName << " custom_format"; 306 return success(); 307 } 308 return failure(); 309 } 310 311 //===----------------------------------------------------------------------===// 312 // TestBranchOp 313 //===----------------------------------------------------------------------===// 314 315 Optional<MutableOperandRange> 316 TestBranchOp::getMutableSuccessorOperands(unsigned index) { 317 assert(index == 0 && "invalid successor index"); 318 return targetOperandsMutable(); 319 } 320 321 //===----------------------------------------------------------------------===// 322 // TestDialectCanonicalizerOp 323 //===----------------------------------------------------------------------===// 324 325 static LogicalResult 326 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op, 327 PatternRewriter &rewriter) { 328 rewriter.replaceOpWithNewOp<ConstantOp>(op, rewriter.getI32Type(), 329 rewriter.getI32IntegerAttr(42)); 330 return success(); 331 } 332 333 void TestDialect::getCanonicalizationPatterns( 334 RewritePatternSet &results) const { 335 results.add(&dialectCanonicalizationPattern); 336 } 337 338 //===----------------------------------------------------------------------===// 339 // TestFoldToCallOp 340 //===----------------------------------------------------------------------===// 341 342 namespace { 343 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 344 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 345 346 LogicalResult matchAndRewrite(FoldToCallOp op, 347 PatternRewriter &rewriter) const override { 348 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(), 349 ValueRange()); 350 return success(); 351 } 352 }; 353 } // end anonymous namespace 354 355 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results, 356 MLIRContext *context) { 357 results.add<FoldToCallOpPattern>(context); 358 } 359 360 //===----------------------------------------------------------------------===// 361 // Test Format* operations 362 //===----------------------------------------------------------------------===// 363 364 //===----------------------------------------------------------------------===// 365 // Parsing 366 367 static ParseResult parseCustomDirectiveOperands( 368 OpAsmParser &parser, OpAsmParser::OperandType &operand, 369 Optional<OpAsmParser::OperandType> &optOperand, 370 SmallVectorImpl<OpAsmParser::OperandType> &varOperands) { 371 if (parser.parseOperand(operand)) 372 return failure(); 373 if (succeeded(parser.parseOptionalComma())) { 374 optOperand.emplace(); 375 if (parser.parseOperand(*optOperand)) 376 return failure(); 377 } 378 if (parser.parseArrow() || parser.parseLParen() || 379 parser.parseOperandList(varOperands) || parser.parseRParen()) 380 return failure(); 381 return success(); 382 } 383 static ParseResult 384 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 385 Type &optOperandType, 386 SmallVectorImpl<Type> &varOperandTypes) { 387 if (parser.parseColon()) 388 return failure(); 389 390 if (parser.parseType(operandType)) 391 return failure(); 392 if (succeeded(parser.parseOptionalComma())) { 393 if (parser.parseType(optOperandType)) 394 return failure(); 395 } 396 if (parser.parseArrow() || parser.parseLParen() || 397 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 398 return failure(); 399 return success(); 400 } 401 static ParseResult 402 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 403 Type optOperandType, 404 const SmallVectorImpl<Type> &varOperandTypes) { 405 if (parser.parseKeyword("type_refs_capture")) 406 return failure(); 407 408 Type operandType2, optOperandType2; 409 SmallVector<Type, 1> varOperandTypes2; 410 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 411 varOperandTypes2)) 412 return failure(); 413 414 if (operandType != operandType2 || optOperandType != optOperandType2 || 415 varOperandTypes != varOperandTypes2) 416 return failure(); 417 418 return success(); 419 } 420 static ParseResult parseCustomDirectiveOperandsAndTypes( 421 OpAsmParser &parser, OpAsmParser::OperandType &operand, 422 Optional<OpAsmParser::OperandType> &optOperand, 423 SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType, 424 Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) { 425 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 426 parseCustomDirectiveResults(parser, operandType, optOperandType, 427 varOperandTypes)) 428 return failure(); 429 return success(); 430 } 431 static ParseResult parseCustomDirectiveRegions( 432 OpAsmParser &parser, Region ®ion, 433 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 434 if (parser.parseRegion(region)) 435 return failure(); 436 if (failed(parser.parseOptionalComma())) 437 return success(); 438 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 439 if (parser.parseRegion(*varRegion)) 440 return failure(); 441 varRegions.emplace_back(std::move(varRegion)); 442 return success(); 443 } 444 static ParseResult 445 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 446 SmallVectorImpl<Block *> &varSuccessors) { 447 if (parser.parseSuccessor(successor)) 448 return failure(); 449 if (failed(parser.parseOptionalComma())) 450 return success(); 451 Block *varSuccessor; 452 if (parser.parseSuccessor(varSuccessor)) 453 return failure(); 454 varSuccessors.append(2, varSuccessor); 455 return success(); 456 } 457 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 458 IntegerAttr &attr, 459 IntegerAttr &optAttr) { 460 if (parser.parseAttribute(attr)) 461 return failure(); 462 if (succeeded(parser.parseOptionalComma())) { 463 if (parser.parseAttribute(optAttr)) 464 return failure(); 465 } 466 return success(); 467 } 468 469 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, 470 NamedAttrList &attrs) { 471 return parser.parseOptionalAttrDict(attrs); 472 } 473 static ParseResult parseCustomDirectiveOptionalOperandRef( 474 OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) { 475 int64_t operandCount = 0; 476 if (parser.parseInteger(operandCount)) 477 return failure(); 478 bool expectedOptionalOperand = operandCount == 0; 479 return success(expectedOptionalOperand != optOperand.hasValue()); 480 } 481 482 //===----------------------------------------------------------------------===// 483 // Printing 484 485 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 486 Value operand, Value optOperand, 487 OperandRange varOperands) { 488 printer << operand; 489 if (optOperand) 490 printer << ", " << optOperand; 491 printer << " -> (" << varOperands << ")"; 492 } 493 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 494 Type operandType, Type optOperandType, 495 TypeRange varOperandTypes) { 496 printer << " : " << operandType; 497 if (optOperandType) 498 printer << ", " << optOperandType; 499 printer << " -> (" << varOperandTypes << ")"; 500 } 501 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 502 Operation *op, Type operandType, 503 Type optOperandType, 504 TypeRange varOperandTypes) { 505 printer << " type_refs_capture "; 506 printCustomDirectiveResults(printer, op, operandType, optOperandType, 507 varOperandTypes); 508 } 509 static void printCustomDirectiveOperandsAndTypes( 510 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 511 OperandRange varOperands, Type operandType, Type optOperandType, 512 TypeRange varOperandTypes) { 513 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 514 printCustomDirectiveResults(printer, op, operandType, optOperandType, 515 varOperandTypes); 516 } 517 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 518 Region ®ion, 519 MutableArrayRef<Region> varRegions) { 520 printer.printRegion(region); 521 if (!varRegions.empty()) { 522 printer << ", "; 523 for (Region ®ion : varRegions) 524 printer.printRegion(region); 525 } 526 } 527 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 528 Block *successor, 529 SuccessorRange varSuccessors) { 530 printer << successor; 531 if (!varSuccessors.empty()) 532 printer << ", " << varSuccessors.front(); 533 } 534 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 535 Attribute attribute, 536 Attribute optAttribute) { 537 printer << attribute; 538 if (optAttribute) 539 printer << ", " << optAttribute; 540 } 541 542 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 543 DictionaryAttr attrs) { 544 printer.printOptionalAttrDict(attrs.getValue()); 545 } 546 547 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, 548 Operation *op, 549 Value optOperand) { 550 printer << (optOperand ? "1" : "0"); 551 } 552 553 //===----------------------------------------------------------------------===// 554 // Test IsolatedRegionOp - parse passthrough region arguments. 555 //===----------------------------------------------------------------------===// 556 557 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, 558 OperationState &result) { 559 OpAsmParser::OperandType argInfo; 560 Type argType = parser.getBuilder().getIndexType(); 561 562 // Parse the input operand. 563 if (parser.parseOperand(argInfo) || 564 parser.resolveOperand(argInfo, argType, result.operands)) 565 return failure(); 566 567 // Parse the body region, and reuse the operand info as the argument info. 568 Region *body = result.addRegion(); 569 return parser.parseRegion(*body, argInfo, argType, 570 /*enableNameShadowing=*/true); 571 } 572 573 static void print(OpAsmPrinter &p, IsolatedRegionOp op) { 574 p << "test.isolated_region "; 575 p.printOperand(op.getOperand()); 576 p.shadowRegionArgs(op.region(), op.getOperand()); 577 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 578 } 579 580 //===----------------------------------------------------------------------===// 581 // Test SSACFGRegionOp 582 //===----------------------------------------------------------------------===// 583 584 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 585 return RegionKind::SSACFG; 586 } 587 588 //===----------------------------------------------------------------------===// 589 // Test GraphRegionOp 590 //===----------------------------------------------------------------------===// 591 592 static ParseResult parseGraphRegionOp(OpAsmParser &parser, 593 OperationState &result) { 594 // Parse the body region, and reuse the operand info as the argument info. 595 Region *body = result.addRegion(); 596 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 597 } 598 599 static void print(OpAsmPrinter &p, GraphRegionOp op) { 600 p << "test.graph_region "; 601 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 602 } 603 604 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 605 return RegionKind::Graph; 606 } 607 608 //===----------------------------------------------------------------------===// 609 // Test AffineScopeOp 610 //===----------------------------------------------------------------------===// 611 612 static ParseResult parseAffineScopeOp(OpAsmParser &parser, 613 OperationState &result) { 614 // Parse the body region, and reuse the operand info as the argument info. 615 Region *body = result.addRegion(); 616 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 617 } 618 619 static void print(OpAsmPrinter &p, AffineScopeOp op) { 620 p << "test.affine_scope "; 621 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 622 } 623 624 //===----------------------------------------------------------------------===// 625 // Test parser. 626 //===----------------------------------------------------------------------===// 627 628 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser, 629 OperationState &result) { 630 if (parser.parseOptionalColon()) 631 return success(); 632 uint64_t numResults; 633 if (parser.parseInteger(numResults)) 634 return failure(); 635 636 IndexType type = parser.getBuilder().getIndexType(); 637 for (unsigned i = 0; i < numResults; ++i) 638 result.addTypes(type); 639 return success(); 640 } 641 642 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) { 643 p << ParseIntegerLiteralOp::getOperationName(); 644 if (unsigned numResults = op->getNumResults()) 645 p << " : " << numResults; 646 } 647 648 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser, 649 OperationState &result) { 650 StringRef keyword; 651 if (parser.parseKeyword(&keyword)) 652 return failure(); 653 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 654 return success(); 655 } 656 657 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) { 658 p << ParseWrappedKeywordOp::getOperationName() << " " << op.keyword(); 659 } 660 661 //===----------------------------------------------------------------------===// 662 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 663 664 static ParseResult parseWrappingRegionOp(OpAsmParser &parser, 665 OperationState &result) { 666 if (parser.parseKeyword("wraps")) 667 return failure(); 668 669 // Parse the wrapped op in a region 670 Region &body = *result.addRegion(); 671 body.push_back(new Block); 672 Block &block = body.back(); 673 Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); 674 if (!wrapped_op) 675 return failure(); 676 677 // Create a return terminator in the inner region, pass as operand to the 678 // terminator the returned values from the wrapped operation. 679 SmallVector<Value, 8> return_operands(wrapped_op->getResults()); 680 OpBuilder builder(parser.getBuilder().getContext()); 681 builder.setInsertionPointToEnd(&block); 682 builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands); 683 684 // Get the results type for the wrapping op from the terminator operands. 685 Operation &return_op = body.back().back(); 686 result.types.append(return_op.operand_type_begin(), 687 return_op.operand_type_end()); 688 689 // Use the location of the wrapped op for the "test.wrapping_region" op. 690 result.location = wrapped_op->getLoc(); 691 692 return success(); 693 } 694 695 static void print(OpAsmPrinter &p, WrappingRegionOp op) { 696 p << op.getOperationName() << " wraps "; 697 p.printGenericOp(&op.region().front().front()); 698 } 699 700 //===----------------------------------------------------------------------===// 701 // Test PolyForOp - parse list of region arguments. 702 //===----------------------------------------------------------------------===// 703 704 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { 705 SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 706 // Parse list of region arguments without a delimiter. 707 if (parser.parseRegionArgumentList(ivsInfo)) 708 return failure(); 709 710 // Parse the body region. 711 Region *body = result.addRegion(); 712 auto &builder = parser.getBuilder(); 713 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 714 return parser.parseRegion(*body, ivsInfo, argTypes); 715 } 716 717 //===----------------------------------------------------------------------===// 718 // Test removing op with inner ops. 719 //===----------------------------------------------------------------------===// 720 721 namespace { 722 struct TestRemoveOpWithInnerOps 723 : public OpRewritePattern<TestOpWithRegionPattern> { 724 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 725 726 void initialize() { setDebugName("TestRemoveOpWithInnerOps"); } 727 728 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 729 PatternRewriter &rewriter) const override { 730 rewriter.eraseOp(op); 731 return success(); 732 } 733 }; 734 } // end anonymous namespace 735 736 void TestOpWithRegionPattern::getCanonicalizationPatterns( 737 RewritePatternSet &results, MLIRContext *context) { 738 results.add<TestRemoveOpWithInnerOps>(context); 739 } 740 741 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 742 return operand(); 743 } 744 745 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) { 746 return getValue(); 747 } 748 749 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 750 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 751 for (Value input : this->operands()) { 752 results.push_back(input); 753 } 754 return success(); 755 } 756 757 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 758 assert(operands.size() == 1); 759 if (operands.front()) { 760 (*this)->setAttr("attr", operands.front()); 761 return getResult(); 762 } 763 return {}; 764 } 765 766 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) { 767 return getOperand(); 768 } 769 770 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 771 MLIRContext *, Optional<Location> location, ValueRange operands, 772 DictionaryAttr attributes, RegionRange regions, 773 SmallVectorImpl<Type> &inferredReturnTypes) { 774 if (operands[0].getType() != operands[1].getType()) { 775 return emitOptionalError(location, "operand type mismatch ", 776 operands[0].getType(), " vs ", 777 operands[1].getType()); 778 } 779 inferredReturnTypes.assign({operands[0].getType()}); 780 return success(); 781 } 782 783 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 784 MLIRContext *context, Optional<Location> location, ValueRange operands, 785 DictionaryAttr attributes, RegionRange regions, 786 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 787 // Create return type consisting of the last element of the first operand. 788 auto operandType = *operands.getTypes().begin(); 789 auto sval = operandType.dyn_cast<ShapedType>(); 790 if (!sval) { 791 return emitOptionalError(location, "only shaped type operands allowed"); 792 } 793 int64_t dim = 794 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 795 auto type = IntegerType::get(context, 17); 796 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 797 return success(); 798 } 799 800 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 801 OpBuilder &builder, ValueRange operands, 802 llvm::SmallVectorImpl<Value> &shapes) { 803 shapes = SmallVector<Value, 1>{ 804 builder.createOrFold<memref::DimOp>(getLoc(), operands.front(), 0)}; 805 return success(); 806 } 807 808 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( 809 OpBuilder &builder, ValueRange operands, 810 llvm::SmallVectorImpl<Value> &shapes) { 811 Location loc = getLoc(); 812 shapes.reserve(operands.size()); 813 for (Value operand : llvm::reverse(operands)) { 814 auto currShape = llvm::to_vector<4>(llvm::map_range( 815 llvm::seq<int64_t>( 816 0, operand.getType().cast<RankedTensorType>().getRank()), 817 [&](int64_t dim) -> Value { 818 return builder.createOrFold<memref::DimOp>(loc, operand, dim); 819 })); 820 shapes.push_back(builder.create<tensor::FromElementsOp>( 821 getLoc(), builder.getIndexType(), currShape)); 822 } 823 return success(); 824 } 825 826 LogicalResult 827 OpWithResultShapePerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim( 828 OpBuilder &builder, 829 llvm::SmallVectorImpl<llvm::SmallVector<Value>> &shapes) { 830 Location loc = getLoc(); 831 shapes.reserve(getNumOperands()); 832 for (Value operand : llvm::reverse(getOperands())) { 833 auto currShape = llvm::to_vector<4>(llvm::map_range( 834 llvm::seq<int64_t>( 835 0, operand.getType().cast<RankedTensorType>().getRank()), 836 [&](int64_t dim) -> Value { 837 return builder.createOrFold<memref::DimOp>(loc, operand, dim); 838 })); 839 shapes.emplace_back(std::move(currShape)); 840 } 841 return success(); 842 } 843 844 LogicalResult OpWithResultShapeAndPerDimInterfaceOp::reifyReturnTypeShapes( 845 OpBuilder &builder, ValueRange operands, 846 llvm::SmallVectorImpl<Value> &shapes) { 847 Location loc = getLoc(); 848 shapes.reserve(operands.size()); 849 for (Value operand : llvm::reverse(operands)) { 850 auto currShape = llvm::to_vector<4>(llvm::map_range( 851 llvm::seq<int64_t>( 852 0, operand.getType().cast<RankedTensorType>().getRank()), 853 [&](int64_t dim) -> Value { 854 return builder.createOrFold<memref::DimOp>(loc, operand, dim); 855 })); 856 shapes.push_back(builder.create<tensor::FromElementsOp>( 857 getLoc(), builder.getIndexType(), currShape)); 858 } 859 return success(); 860 } 861 862 LogicalResult 863 OpWithResultShapeAndPerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim( 864 OpBuilder &builder, 865 llvm::SmallVectorImpl<llvm::SmallVector<Value>> &shapes) { 866 Location loc = getLoc(); 867 shapes.reserve(getNumOperands()); 868 for (Value operand : llvm::reverse(getOperands())) { 869 auto currShape = llvm::to_vector<4>(llvm::map_range( 870 llvm::seq<int64_t>( 871 0, operand.getType().cast<RankedTensorType>().getRank()), 872 [&](int64_t dim) -> Value { 873 return builder.createOrFold<memref::DimOp>(loc, operand, dim); 874 })); 875 shapes.emplace_back(std::move(currShape)); 876 } 877 return success(); 878 } 879 880 //===----------------------------------------------------------------------===// 881 // Test SideEffect interfaces 882 //===----------------------------------------------------------------------===// 883 884 namespace { 885 /// A test resource for side effects. 886 struct TestResource : public SideEffects::Resource::Base<TestResource> { 887 StringRef getName() final { return "<Test>"; } 888 }; 889 } // end anonymous namespace 890 891 static void testSideEffectOpGetEffect( 892 Operation *op, 893 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 894 &effects) { 895 auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter"); 896 if (!effectsAttr) 897 return; 898 899 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); 900 } 901 902 void SideEffectOp::getEffects( 903 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 904 // Check for an effects attribute on the op instance. 905 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects"); 906 if (!effectsAttr) 907 return; 908 909 // If there is one, it is an array of dictionary attributes that hold 910 // information on the effects of this operation. 911 for (Attribute element : effectsAttr) { 912 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 913 914 // Get the specific memory effect. 915 MemoryEffects::Effect *effect = 916 StringSwitch<MemoryEffects::Effect *>( 917 effectElement.get("effect").cast<StringAttr>().getValue()) 918 .Case("allocate", MemoryEffects::Allocate::get()) 919 .Case("free", MemoryEffects::Free::get()) 920 .Case("read", MemoryEffects::Read::get()) 921 .Case("write", MemoryEffects::Write::get()); 922 923 // Check for a non-default resource to use. 924 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 925 if (effectElement.get("test_resource")) 926 resource = TestResource::get(); 927 928 // Check for a result to affect. 929 if (effectElement.get("on_result")) 930 effects.emplace_back(effect, getResult(), resource); 931 else if (Attribute ref = effectElement.get("on_reference")) 932 effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource); 933 else 934 effects.emplace_back(effect, resource); 935 } 936 } 937 938 void SideEffectOp::getEffects( 939 SmallVectorImpl<TestEffects::EffectInstance> &effects) { 940 testSideEffectOpGetEffect(getOperation(), effects); 941 } 942 943 //===----------------------------------------------------------------------===// 944 // StringAttrPrettyNameOp 945 //===----------------------------------------------------------------------===// 946 947 // This op has fancy handling of its SSA result name. 948 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, 949 OperationState &result) { 950 // Add the result types. 951 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 952 result.addTypes(parser.getBuilder().getIntegerType(32)); 953 954 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 955 return failure(); 956 957 // If the attribute dictionary contains no 'names' attribute, infer it from 958 // the SSA name (if specified). 959 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 960 return attr.first == "names"; 961 }); 962 963 // If there was no name specified, check to see if there was a useful name 964 // specified in the asm file. 965 if (hadNames || parser.getNumResults() == 0) 966 return success(); 967 968 SmallVector<StringRef, 4> names; 969 auto *context = result.getContext(); 970 971 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 972 auto resultName = parser.getResultName(i); 973 StringRef nameStr; 974 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 975 nameStr = resultName.first; 976 977 names.push_back(nameStr); 978 } 979 980 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 981 result.attributes.push_back({Identifier::get("names", context), namesAttr}); 982 return success(); 983 } 984 985 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { 986 p << "test.string_attr_pretty_name"; 987 988 // Note that we only need to print the "name" attribute if the asmprinter 989 // result name disagrees with it. This can happen in strange cases, e.g. 990 // when there are conflicts. 991 bool namesDisagree = op.names().size() != op.getNumResults(); 992 993 SmallString<32> resultNameStr; 994 for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { 995 resultNameStr.clear(); 996 llvm::raw_svector_ostream tmpStream(resultNameStr); 997 p.printOperand(op.getResult(i), tmpStream); 998 999 auto expectedName = op.names()[i].dyn_cast<StringAttr>(); 1000 if (!expectedName || 1001 tmpStream.str().drop_front() != expectedName.getValue()) { 1002 namesDisagree = true; 1003 } 1004 } 1005 1006 if (namesDisagree) 1007 p.printOptionalAttrDictWithKeyword(op->getAttrs()); 1008 else 1009 p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"}); 1010 } 1011 1012 // We set the SSA name in the asm syntax to the contents of the name 1013 // attribute. 1014 void StringAttrPrettyNameOp::getAsmResultNames( 1015 function_ref<void(Value, StringRef)> setNameFn) { 1016 1017 auto value = names(); 1018 for (size_t i = 0, e = value.size(); i != e; ++i) 1019 if (auto str = value[i].dyn_cast<StringAttr>()) 1020 if (!str.getValue().empty()) 1021 setNameFn(getResult(i), str.getValue()); 1022 } 1023 1024 //===----------------------------------------------------------------------===// 1025 // RegionIfOp 1026 //===----------------------------------------------------------------------===// 1027 1028 static void print(OpAsmPrinter &p, RegionIfOp op) { 1029 p << RegionIfOp::getOperationName() << " "; 1030 p.printOperands(op.getOperands()); 1031 p << ": " << op.getOperandTypes(); 1032 p.printArrowTypeList(op.getResultTypes()); 1033 p << " then"; 1034 p.printRegion(op.thenRegion(), 1035 /*printEntryBlockArgs=*/true, 1036 /*printBlockTerminators=*/true); 1037 p << " else"; 1038 p.printRegion(op.elseRegion(), 1039 /*printEntryBlockArgs=*/true, 1040 /*printBlockTerminators=*/true); 1041 p << " join"; 1042 p.printRegion(op.joinRegion(), 1043 /*printEntryBlockArgs=*/true, 1044 /*printBlockTerminators=*/true); 1045 } 1046 1047 static ParseResult parseRegionIfOp(OpAsmParser &parser, 1048 OperationState &result) { 1049 SmallVector<OpAsmParser::OperandType, 2> operandInfos; 1050 SmallVector<Type, 2> operandTypes; 1051 1052 result.regions.reserve(3); 1053 Region *thenRegion = result.addRegion(); 1054 Region *elseRegion = result.addRegion(); 1055 Region *joinRegion = result.addRegion(); 1056 1057 // Parse operand, type and arrow type lists. 1058 if (parser.parseOperandList(operandInfos) || 1059 parser.parseColonTypeList(operandTypes) || 1060 parser.parseArrowTypeList(result.types)) 1061 return failure(); 1062 1063 // Parse all attached regions. 1064 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 1065 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 1066 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 1067 return failure(); 1068 1069 return parser.resolveOperands(operandInfos, operandTypes, 1070 parser.getCurrentLocation(), result.operands); 1071 } 1072 1073 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 1074 assert(index < 2 && "invalid region index"); 1075 return getOperands(); 1076 } 1077 1078 void RegionIfOp::getSuccessorRegions( 1079 Optional<unsigned> index, ArrayRef<Attribute> operands, 1080 SmallVectorImpl<RegionSuccessor> ®ions) { 1081 // We always branch to the join region. 1082 if (index.hasValue()) { 1083 if (index.getValue() < 2) 1084 regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs())); 1085 else 1086 regions.push_back(RegionSuccessor(getResults())); 1087 return; 1088 } 1089 1090 // The then and else regions are the entry regions of this op. 1091 regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs())); 1092 regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs())); 1093 } 1094 1095 //===----------------------------------------------------------------------===// 1096 // SingleNoTerminatorCustomAsmOp 1097 //===----------------------------------------------------------------------===// 1098 1099 static ParseResult parseSingleNoTerminatorCustomAsmOp(OpAsmParser &parser, 1100 OperationState &state) { 1101 Region *body = state.addRegion(); 1102 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) 1103 return failure(); 1104 return success(); 1105 } 1106 1107 static void print(SingleNoTerminatorCustomAsmOp op, OpAsmPrinter &printer) { 1108 printer << op.getOperationName(); 1109 printer.printRegion( 1110 op.getRegion(), /*printEntryBlockArgs=*/false, 1111 // This op has a single block without terminators. But explicitly mark 1112 // as not printing block terminators for testing. 1113 /*printBlockTerminators=*/false); 1114 } 1115 1116 #include "TestOpEnums.cpp.inc" 1117 #include "TestOpInterfaces.cpp.inc" 1118 #include "TestOpStructs.cpp.inc" 1119 #include "TestTypeInterfaces.cpp.inc" 1120 1121 #define GET_OP_CLASSES 1122 #include "TestOps.cpp.inc" 1123