1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestAttributes.h" 11 #include "TestInterfaces.h" 12 #include "TestTypes.h" 13 #include "mlir/Dialect/DLTI/DLTI.h" 14 #include "mlir/Dialect/MemRef/IR/MemRef.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/DialectImplementation.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Reducer/ReductionPatternInterface.h" 21 #include "mlir/Transforms/FoldUtils.h" 22 #include "mlir/Transforms/InliningUtils.h" 23 #include "llvm/ADT/StringSwitch.h" 24 25 using namespace mlir; 26 using namespace mlir::test; 27 28 void mlir::test::registerTestDialect(DialectRegistry ®istry) { 29 registry.insert<TestDialect>(); 30 } 31 32 //===----------------------------------------------------------------------===// 33 // TestDialect Interfaces 34 //===----------------------------------------------------------------------===// 35 36 namespace { 37 38 /// Testing the correctness of some traits. 39 static_assert( 40 llvm::is_detected<OpTrait::has_implicit_terminator_t, 41 SingleBlockImplicitTerminatorOp>::value, 42 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp"); 43 static_assert(OpTrait::hasSingleBlockImplicitTerminator< 44 SingleBlockImplicitTerminatorOp>::value, 45 "hasSingleBlockImplicitTerminator does not match " 46 "SingleBlockImplicitTerminatorOp"); 47 48 // Test support for interacting with the AsmPrinter. 49 struct TestOpAsmInterface : public OpAsmDialectInterface { 50 using OpAsmDialectInterface::OpAsmDialectInterface; 51 52 LogicalResult getAlias(Attribute attr, raw_ostream &os) const final { 53 StringAttr strAttr = attr.dyn_cast<StringAttr>(); 54 if (!strAttr) 55 return failure(); 56 57 // Check the contents of the string attribute to see what the test alias 58 // should be named. 59 Optional<StringRef> aliasName = 60 StringSwitch<Optional<StringRef>>(strAttr.getValue()) 61 .Case("alias_test:dot_in_name", StringRef("test.alias")) 62 .Case("alias_test:trailing_digit", StringRef("test_alias0")) 63 .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) 64 .Case("alias_test:sanitize_conflict_a", 65 StringRef("test_alias_conflict0")) 66 .Case("alias_test:sanitize_conflict_b", 67 StringRef("test_alias_conflict0_")) 68 .Default(llvm::None); 69 if (!aliasName) 70 return failure(); 71 72 os << *aliasName; 73 return success(); 74 } 75 76 void getAsmResultNames(Operation *op, 77 OpAsmSetValueNameFn setNameFn) const final { 78 if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) 79 setNameFn(asmOp, "result"); 80 } 81 82 void getAsmBlockArgumentNames(Block *block, 83 OpAsmSetValueNameFn setNameFn) const final { 84 auto op = block->getParentOp(); 85 auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names"); 86 if (!arrayAttr) 87 return; 88 auto args = block->getArguments(); 89 auto e = std::min(arrayAttr.size(), args.size()); 90 for (unsigned i = 0; i < e; ++i) { 91 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 92 setNameFn(args[i], strAttr.getValue()); 93 } 94 } 95 }; 96 97 struct TestDialectFoldInterface : public DialectFoldInterface { 98 using DialectFoldInterface::DialectFoldInterface; 99 100 /// Registered hook to check if the given region, which is attached to an 101 /// operation that is *not* isolated from above, should be used when 102 /// materializing constants. 103 bool shouldMaterializeInto(Region *region) const final { 104 // If this is a one region operation, then insert into it. 105 return isa<OneRegionOp>(region->getParentOp()); 106 } 107 }; 108 109 /// This class defines the interface for handling inlining with standard 110 /// operations. 111 struct TestInlinerInterface : public DialectInlinerInterface { 112 using DialectInlinerInterface::DialectInlinerInterface; 113 114 //===--------------------------------------------------------------------===// 115 // Analysis Hooks 116 //===--------------------------------------------------------------------===// 117 118 bool isLegalToInline(Operation *call, Operation *callable, 119 bool wouldBeCloned) const final { 120 // Don't allow inlining calls that are marked `noinline`. 121 return !call->hasAttr("noinline"); 122 } 123 bool isLegalToInline(Region *, Region *, bool, 124 BlockAndValueMapping &) const final { 125 // Inlining into test dialect regions is legal. 126 return true; 127 } 128 bool isLegalToInline(Operation *, Region *, bool, 129 BlockAndValueMapping &) const final { 130 return true; 131 } 132 133 bool shouldAnalyzeRecursively(Operation *op) const final { 134 // Analyze recursively if this is not a functional region operation, it 135 // froms a separate functional scope. 136 return !isa<FunctionalRegionOp>(op); 137 } 138 139 //===--------------------------------------------------------------------===// 140 // Transformation Hooks 141 //===--------------------------------------------------------------------===// 142 143 /// Handle the given inlined terminator by replacing it with a new operation 144 /// as necessary. 145 void handleTerminator(Operation *op, 146 ArrayRef<Value> valuesToRepl) const final { 147 // Only handle "test.return" here. 148 auto returnOp = dyn_cast<TestReturnOp>(op); 149 if (!returnOp) 150 return; 151 152 // Replace the values directly with the return operands. 153 assert(returnOp.getNumOperands() == valuesToRepl.size()); 154 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 155 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 156 } 157 158 /// Attempt to materialize a conversion for a type mismatch between a call 159 /// from this dialect, and a callable region. This method should generate an 160 /// operation that takes 'input' as the only operand, and produces a single 161 /// result of 'resultType'. If a conversion can not be generated, nullptr 162 /// should be returned. 163 Operation *materializeCallConversion(OpBuilder &builder, Value input, 164 Type resultType, 165 Location conversionLoc) const final { 166 // Only allow conversion for i16/i32 types. 167 if (!(resultType.isSignlessInteger(16) || 168 resultType.isSignlessInteger(32)) || 169 !(input.getType().isSignlessInteger(16) || 170 input.getType().isSignlessInteger(32))) 171 return nullptr; 172 return builder.create<TestCastOp>(conversionLoc, resultType, input); 173 } 174 175 void processInlinedCallBlocks( 176 Operation *call, 177 iterator_range<Region::iterator> inlinedBlocks) const final { 178 if (!isa<ConversionCallOp>(call)) 179 return; 180 181 // Set attributed on all ops in the inlined blocks. 182 for (Block &block : inlinedBlocks) { 183 block.walk([&](Operation *op) { 184 op->setAttr("inlined_conversion", UnitAttr::get(call->getContext())); 185 }); 186 } 187 } 188 }; 189 190 struct TestReductionPatternInterface : public DialectReductionPatternInterface { 191 public: 192 TestReductionPatternInterface(Dialect *dialect) 193 : DialectReductionPatternInterface(dialect) {} 194 195 virtual void 196 populateReductionPatterns(RewritePatternSet &patterns) const final { 197 populateTestReductionPatterns(patterns); 198 } 199 }; 200 201 } // end anonymous namespace 202 203 //===----------------------------------------------------------------------===// 204 // TestDialect 205 //===----------------------------------------------------------------------===// 206 207 static void testSideEffectOpGetEffect( 208 Operation *op, 209 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects); 210 211 // This is the implementation of a dialect fallback for `TestEffectOpInterface`. 212 struct TestOpEffectInterfaceFallback 213 : public TestEffectOpInterface::FallbackModel< 214 TestOpEffectInterfaceFallback> { 215 static bool classof(Operation *op) { 216 bool isSupportedOp = 217 op->getName().getStringRef() == "test.unregistered_side_effect_op"; 218 assert(isSupportedOp && "Unexpected dispatch"); 219 return isSupportedOp; 220 } 221 222 void 223 getEffects(Operation *op, 224 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 225 &effects) const { 226 testSideEffectOpGetEffect(op, effects); 227 } 228 }; 229 230 void TestDialect::initialize() { 231 registerAttributes(); 232 registerTypes(); 233 addOperations< 234 #define GET_OP_LIST 235 #include "TestOps.cpp.inc" 236 >(); 237 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 238 TestInlinerInterface, TestReductionPatternInterface>(); 239 allowUnknownOperations(); 240 241 // Instantiate our fallback op interface that we'll use on specific 242 // unregistered op. 243 fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback; 244 } 245 TestDialect::~TestDialect() { 246 delete static_cast<TestOpEffectInterfaceFallback *>( 247 fallbackEffectOpInterfaces); 248 } 249 250 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, 251 Type type, Location loc) { 252 return builder.create<TestOpConstant>(loc, type, value); 253 } 254 255 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, 256 OperationName opName) { 257 if (opName.getIdentifier() == "test.unregistered_side_effect_op" && 258 typeID == TypeID::get<TestEffectOpInterface>()) 259 return fallbackEffectOpInterfaces; 260 return nullptr; 261 } 262 263 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 264 NamedAttribute namedAttr) { 265 if (namedAttr.first == "test.invalid_attr") 266 return op->emitError() << "invalid to use 'test.invalid_attr'"; 267 return success(); 268 } 269 270 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 271 unsigned regionIndex, 272 unsigned argIndex, 273 NamedAttribute namedAttr) { 274 if (namedAttr.first == "test.invalid_attr") 275 return op->emitError() << "invalid to use 'test.invalid_attr'"; 276 return success(); 277 } 278 279 LogicalResult 280 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 281 unsigned resultIndex, 282 NamedAttribute namedAttr) { 283 if (namedAttr.first == "test.invalid_attr") 284 return op->emitError() << "invalid to use 'test.invalid_attr'"; 285 return success(); 286 } 287 288 Optional<Dialect::ParseOpHook> 289 TestDialect::getParseOperationHook(StringRef opName) const { 290 if (opName == "test.dialect_custom_printer") { 291 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 292 return parser.parseKeyword("custom_format"); 293 }}; 294 } 295 return None; 296 } 297 298 LogicalResult TestDialect::printOperation(Operation *op, 299 OpAsmPrinter &printer) const { 300 StringRef opName = op->getName().getStringRef(); 301 if (opName == "test.dialect_custom_printer") { 302 printer.getStream() << opName << " custom_format"; 303 return success(); 304 } 305 return failure(); 306 } 307 308 //===----------------------------------------------------------------------===// 309 // TestBranchOp 310 //===----------------------------------------------------------------------===// 311 312 Optional<MutableOperandRange> 313 TestBranchOp::getMutableSuccessorOperands(unsigned index) { 314 assert(index == 0 && "invalid successor index"); 315 return targetOperandsMutable(); 316 } 317 318 //===----------------------------------------------------------------------===// 319 // TestDialectCanonicalizerOp 320 //===----------------------------------------------------------------------===// 321 322 static LogicalResult 323 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op, 324 PatternRewriter &rewriter) { 325 rewriter.replaceOpWithNewOp<ConstantOp>(op, rewriter.getI32Type(), 326 rewriter.getI32IntegerAttr(42)); 327 return success(); 328 } 329 330 void TestDialect::getCanonicalizationPatterns( 331 RewritePatternSet &results) const { 332 results.add(&dialectCanonicalizationPattern); 333 } 334 335 //===----------------------------------------------------------------------===// 336 // TestFoldToCallOp 337 //===----------------------------------------------------------------------===// 338 339 namespace { 340 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 341 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 342 343 LogicalResult matchAndRewrite(FoldToCallOp op, 344 PatternRewriter &rewriter) const override { 345 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(), 346 ValueRange()); 347 return success(); 348 } 349 }; 350 } // end anonymous namespace 351 352 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results, 353 MLIRContext *context) { 354 results.add<FoldToCallOpPattern>(context); 355 } 356 357 //===----------------------------------------------------------------------===// 358 // Test Format* operations 359 //===----------------------------------------------------------------------===// 360 361 //===----------------------------------------------------------------------===// 362 // Parsing 363 364 static ParseResult parseCustomDirectiveOperands( 365 OpAsmParser &parser, OpAsmParser::OperandType &operand, 366 Optional<OpAsmParser::OperandType> &optOperand, 367 SmallVectorImpl<OpAsmParser::OperandType> &varOperands) { 368 if (parser.parseOperand(operand)) 369 return failure(); 370 if (succeeded(parser.parseOptionalComma())) { 371 optOperand.emplace(); 372 if (parser.parseOperand(*optOperand)) 373 return failure(); 374 } 375 if (parser.parseArrow() || parser.parseLParen() || 376 parser.parseOperandList(varOperands) || parser.parseRParen()) 377 return failure(); 378 return success(); 379 } 380 static ParseResult 381 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 382 Type &optOperandType, 383 SmallVectorImpl<Type> &varOperandTypes) { 384 if (parser.parseColon()) 385 return failure(); 386 387 if (parser.parseType(operandType)) 388 return failure(); 389 if (succeeded(parser.parseOptionalComma())) { 390 if (parser.parseType(optOperandType)) 391 return failure(); 392 } 393 if (parser.parseArrow() || parser.parseLParen() || 394 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 395 return failure(); 396 return success(); 397 } 398 static ParseResult 399 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 400 Type optOperandType, 401 const SmallVectorImpl<Type> &varOperandTypes) { 402 if (parser.parseKeyword("type_refs_capture")) 403 return failure(); 404 405 Type operandType2, optOperandType2; 406 SmallVector<Type, 1> varOperandTypes2; 407 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 408 varOperandTypes2)) 409 return failure(); 410 411 if (operandType != operandType2 || optOperandType != optOperandType2 || 412 varOperandTypes != varOperandTypes2) 413 return failure(); 414 415 return success(); 416 } 417 static ParseResult parseCustomDirectiveOperandsAndTypes( 418 OpAsmParser &parser, OpAsmParser::OperandType &operand, 419 Optional<OpAsmParser::OperandType> &optOperand, 420 SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType, 421 Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) { 422 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 423 parseCustomDirectiveResults(parser, operandType, optOperandType, 424 varOperandTypes)) 425 return failure(); 426 return success(); 427 } 428 static ParseResult parseCustomDirectiveRegions( 429 OpAsmParser &parser, Region ®ion, 430 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 431 if (parser.parseRegion(region)) 432 return failure(); 433 if (failed(parser.parseOptionalComma())) 434 return success(); 435 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 436 if (parser.parseRegion(*varRegion)) 437 return failure(); 438 varRegions.emplace_back(std::move(varRegion)); 439 return success(); 440 } 441 static ParseResult 442 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 443 SmallVectorImpl<Block *> &varSuccessors) { 444 if (parser.parseSuccessor(successor)) 445 return failure(); 446 if (failed(parser.parseOptionalComma())) 447 return success(); 448 Block *varSuccessor; 449 if (parser.parseSuccessor(varSuccessor)) 450 return failure(); 451 varSuccessors.append(2, varSuccessor); 452 return success(); 453 } 454 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 455 IntegerAttr &attr, 456 IntegerAttr &optAttr) { 457 if (parser.parseAttribute(attr)) 458 return failure(); 459 if (succeeded(parser.parseOptionalComma())) { 460 if (parser.parseAttribute(optAttr)) 461 return failure(); 462 } 463 return success(); 464 } 465 466 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, 467 NamedAttrList &attrs) { 468 return parser.parseOptionalAttrDict(attrs); 469 } 470 static ParseResult parseCustomDirectiveOptionalOperandRef( 471 OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) { 472 int64_t operandCount = 0; 473 if (parser.parseInteger(operandCount)) 474 return failure(); 475 bool expectedOptionalOperand = operandCount == 0; 476 return success(expectedOptionalOperand != optOperand.hasValue()); 477 } 478 479 //===----------------------------------------------------------------------===// 480 // Printing 481 482 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 483 Value operand, Value optOperand, 484 OperandRange varOperands) { 485 printer << operand; 486 if (optOperand) 487 printer << ", " << optOperand; 488 printer << " -> (" << varOperands << ")"; 489 } 490 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 491 Type operandType, Type optOperandType, 492 TypeRange varOperandTypes) { 493 printer << " : " << operandType; 494 if (optOperandType) 495 printer << ", " << optOperandType; 496 printer << " -> (" << varOperandTypes << ")"; 497 } 498 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 499 Operation *op, Type operandType, 500 Type optOperandType, 501 TypeRange varOperandTypes) { 502 printer << " type_refs_capture "; 503 printCustomDirectiveResults(printer, op, operandType, optOperandType, 504 varOperandTypes); 505 } 506 static void printCustomDirectiveOperandsAndTypes( 507 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 508 OperandRange varOperands, Type operandType, Type optOperandType, 509 TypeRange varOperandTypes) { 510 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 511 printCustomDirectiveResults(printer, op, operandType, optOperandType, 512 varOperandTypes); 513 } 514 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 515 Region ®ion, 516 MutableArrayRef<Region> varRegions) { 517 printer.printRegion(region); 518 if (!varRegions.empty()) { 519 printer << ", "; 520 for (Region ®ion : varRegions) 521 printer.printRegion(region); 522 } 523 } 524 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 525 Block *successor, 526 SuccessorRange varSuccessors) { 527 printer << successor; 528 if (!varSuccessors.empty()) 529 printer << ", " << varSuccessors.front(); 530 } 531 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 532 Attribute attribute, 533 Attribute optAttribute) { 534 printer << attribute; 535 if (optAttribute) 536 printer << ", " << optAttribute; 537 } 538 539 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 540 DictionaryAttr attrs) { 541 printer.printOptionalAttrDict(attrs.getValue()); 542 } 543 544 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, 545 Operation *op, 546 Value optOperand) { 547 printer << (optOperand ? "1" : "0"); 548 } 549 550 //===----------------------------------------------------------------------===// 551 // Test IsolatedRegionOp - parse passthrough region arguments. 552 //===----------------------------------------------------------------------===// 553 554 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, 555 OperationState &result) { 556 OpAsmParser::OperandType argInfo; 557 Type argType = parser.getBuilder().getIndexType(); 558 559 // Parse the input operand. 560 if (parser.parseOperand(argInfo) || 561 parser.resolveOperand(argInfo, argType, result.operands)) 562 return failure(); 563 564 // Parse the body region, and reuse the operand info as the argument info. 565 Region *body = result.addRegion(); 566 return parser.parseRegion(*body, argInfo, argType, 567 /*enableNameShadowing=*/true); 568 } 569 570 static void print(OpAsmPrinter &p, IsolatedRegionOp op) { 571 p << "test.isolated_region "; 572 p.printOperand(op.getOperand()); 573 p.shadowRegionArgs(op.region(), op.getOperand()); 574 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 575 } 576 577 //===----------------------------------------------------------------------===// 578 // Test SSACFGRegionOp 579 //===----------------------------------------------------------------------===// 580 581 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 582 return RegionKind::SSACFG; 583 } 584 585 //===----------------------------------------------------------------------===// 586 // Test GraphRegionOp 587 //===----------------------------------------------------------------------===// 588 589 static ParseResult parseGraphRegionOp(OpAsmParser &parser, 590 OperationState &result) { 591 // Parse the body region, and reuse the operand info as the argument info. 592 Region *body = result.addRegion(); 593 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 594 } 595 596 static void print(OpAsmPrinter &p, GraphRegionOp op) { 597 p << "test.graph_region "; 598 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 599 } 600 601 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 602 return RegionKind::Graph; 603 } 604 605 //===----------------------------------------------------------------------===// 606 // Test AffineScopeOp 607 //===----------------------------------------------------------------------===// 608 609 static ParseResult parseAffineScopeOp(OpAsmParser &parser, 610 OperationState &result) { 611 // Parse the body region, and reuse the operand info as the argument info. 612 Region *body = result.addRegion(); 613 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 614 } 615 616 static void print(OpAsmPrinter &p, AffineScopeOp op) { 617 p << "test.affine_scope "; 618 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 619 } 620 621 //===----------------------------------------------------------------------===// 622 // Test parser. 623 //===----------------------------------------------------------------------===// 624 625 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser, 626 OperationState &result) { 627 if (parser.parseOptionalColon()) 628 return success(); 629 uint64_t numResults; 630 if (parser.parseInteger(numResults)) 631 return failure(); 632 633 IndexType type = parser.getBuilder().getIndexType(); 634 for (unsigned i = 0; i < numResults; ++i) 635 result.addTypes(type); 636 return success(); 637 } 638 639 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) { 640 p << ParseIntegerLiteralOp::getOperationName(); 641 if (unsigned numResults = op->getNumResults()) 642 p << " : " << numResults; 643 } 644 645 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser, 646 OperationState &result) { 647 StringRef keyword; 648 if (parser.parseKeyword(&keyword)) 649 return failure(); 650 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 651 return success(); 652 } 653 654 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) { 655 p << ParseWrappedKeywordOp::getOperationName() << " " << op.keyword(); 656 } 657 658 //===----------------------------------------------------------------------===// 659 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 660 661 static ParseResult parseWrappingRegionOp(OpAsmParser &parser, 662 OperationState &result) { 663 if (parser.parseKeyword("wraps")) 664 return failure(); 665 666 // Parse the wrapped op in a region 667 Region &body = *result.addRegion(); 668 body.push_back(new Block); 669 Block &block = body.back(); 670 Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); 671 if (!wrapped_op) 672 return failure(); 673 674 // Create a return terminator in the inner region, pass as operand to the 675 // terminator the returned values from the wrapped operation. 676 SmallVector<Value, 8> return_operands(wrapped_op->getResults()); 677 OpBuilder builder(parser.getBuilder().getContext()); 678 builder.setInsertionPointToEnd(&block); 679 builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands); 680 681 // Get the results type for the wrapping op from the terminator operands. 682 Operation &return_op = body.back().back(); 683 result.types.append(return_op.operand_type_begin(), 684 return_op.operand_type_end()); 685 686 // Use the location of the wrapped op for the "test.wrapping_region" op. 687 result.location = wrapped_op->getLoc(); 688 689 return success(); 690 } 691 692 static void print(OpAsmPrinter &p, WrappingRegionOp op) { 693 p << op.getOperationName() << " wraps "; 694 p.printGenericOp(&op.region().front().front()); 695 } 696 697 //===----------------------------------------------------------------------===// 698 // Test PolyForOp - parse list of region arguments. 699 //===----------------------------------------------------------------------===// 700 701 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { 702 SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 703 // Parse list of region arguments without a delimiter. 704 if (parser.parseRegionArgumentList(ivsInfo)) 705 return failure(); 706 707 // Parse the body region. 708 Region *body = result.addRegion(); 709 auto &builder = parser.getBuilder(); 710 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 711 return parser.parseRegion(*body, ivsInfo, argTypes); 712 } 713 714 //===----------------------------------------------------------------------===// 715 // Test removing op with inner ops. 716 //===----------------------------------------------------------------------===// 717 718 namespace { 719 struct TestRemoveOpWithInnerOps 720 : public OpRewritePattern<TestOpWithRegionPattern> { 721 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 722 723 void initialize() { setDebugName("TestRemoveOpWithInnerOps"); } 724 725 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 726 PatternRewriter &rewriter) const override { 727 rewriter.eraseOp(op); 728 return success(); 729 } 730 }; 731 } // end anonymous namespace 732 733 void TestOpWithRegionPattern::getCanonicalizationPatterns( 734 RewritePatternSet &results, MLIRContext *context) { 735 results.add<TestRemoveOpWithInnerOps>(context); 736 } 737 738 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 739 return operand(); 740 } 741 742 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) { 743 return getValue(); 744 } 745 746 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 747 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 748 for (Value input : this->operands()) { 749 results.push_back(input); 750 } 751 return success(); 752 } 753 754 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 755 assert(operands.size() == 1); 756 if (operands.front()) { 757 (*this)->setAttr("attr", operands.front()); 758 return getResult(); 759 } 760 return {}; 761 } 762 763 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) { 764 return getOperand(); 765 } 766 767 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 768 MLIRContext *, Optional<Location> location, ValueRange operands, 769 DictionaryAttr attributes, RegionRange regions, 770 SmallVectorImpl<Type> &inferredReturnTypes) { 771 if (operands[0].getType() != operands[1].getType()) { 772 return emitOptionalError(location, "operand type mismatch ", 773 operands[0].getType(), " vs ", 774 operands[1].getType()); 775 } 776 inferredReturnTypes.assign({operands[0].getType()}); 777 return success(); 778 } 779 780 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 781 MLIRContext *context, Optional<Location> location, ValueRange operands, 782 DictionaryAttr attributes, RegionRange regions, 783 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 784 // Create return type consisting of the last element of the first operand. 785 auto operandType = *operands.getTypes().begin(); 786 auto sval = operandType.dyn_cast<ShapedType>(); 787 if (!sval) { 788 return emitOptionalError(location, "only shaped type operands allowed"); 789 } 790 int64_t dim = 791 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 792 auto type = IntegerType::get(context, 17); 793 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 794 return success(); 795 } 796 797 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 798 OpBuilder &builder, ValueRange operands, 799 llvm::SmallVectorImpl<Value> &shapes) { 800 shapes = SmallVector<Value, 1>{ 801 builder.createOrFold<memref::DimOp>(getLoc(), operands.front(), 0)}; 802 return success(); 803 } 804 805 LogicalResult 806 OpWithResultShapePerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim( 807 OpBuilder &builder, 808 llvm::SmallVectorImpl<llvm::SmallVector<Value>> &shapes) { 809 SmallVector<Value> operand1Shape, operand2Shape; 810 Location loc = getLoc(); 811 for (auto i : 812 llvm::seq<int>(0, operand1().getType().cast<ShapedType>().getRank())) { 813 operand1Shape.push_back(builder.create<memref::DimOp>(loc, operand1(), i)); 814 } 815 for (auto i : 816 llvm::seq<int>(0, operand2().getType().cast<ShapedType>().getRank())) { 817 operand2Shape.push_back(builder.create<memref::DimOp>(loc, operand2(), i)); 818 } 819 shapes.emplace_back(std::move(operand2Shape)); 820 shapes.emplace_back(std::move(operand1Shape)); 821 return success(); 822 } 823 824 //===----------------------------------------------------------------------===// 825 // Test SideEffect interfaces 826 //===----------------------------------------------------------------------===// 827 828 namespace { 829 /// A test resource for side effects. 830 struct TestResource : public SideEffects::Resource::Base<TestResource> { 831 StringRef getName() final { return "<Test>"; } 832 }; 833 } // end anonymous namespace 834 835 static void testSideEffectOpGetEffect( 836 Operation *op, 837 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 838 &effects) { 839 auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter"); 840 if (!effectsAttr) 841 return; 842 843 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); 844 } 845 846 void SideEffectOp::getEffects( 847 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 848 // Check for an effects attribute on the op instance. 849 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects"); 850 if (!effectsAttr) 851 return; 852 853 // If there is one, it is an array of dictionary attributes that hold 854 // information on the effects of this operation. 855 for (Attribute element : effectsAttr) { 856 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 857 858 // Get the specific memory effect. 859 MemoryEffects::Effect *effect = 860 StringSwitch<MemoryEffects::Effect *>( 861 effectElement.get("effect").cast<StringAttr>().getValue()) 862 .Case("allocate", MemoryEffects::Allocate::get()) 863 .Case("free", MemoryEffects::Free::get()) 864 .Case("read", MemoryEffects::Read::get()) 865 .Case("write", MemoryEffects::Write::get()); 866 867 // Check for a non-default resource to use. 868 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 869 if (effectElement.get("test_resource")) 870 resource = TestResource::get(); 871 872 // Check for a result to affect. 873 if (effectElement.get("on_result")) 874 effects.emplace_back(effect, getResult(), resource); 875 else if (Attribute ref = effectElement.get("on_reference")) 876 effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource); 877 else 878 effects.emplace_back(effect, resource); 879 } 880 } 881 882 void SideEffectOp::getEffects( 883 SmallVectorImpl<TestEffects::EffectInstance> &effects) { 884 testSideEffectOpGetEffect(getOperation(), effects); 885 } 886 887 //===----------------------------------------------------------------------===// 888 // StringAttrPrettyNameOp 889 //===----------------------------------------------------------------------===// 890 891 // This op has fancy handling of its SSA result name. 892 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, 893 OperationState &result) { 894 // Add the result types. 895 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 896 result.addTypes(parser.getBuilder().getIntegerType(32)); 897 898 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 899 return failure(); 900 901 // If the attribute dictionary contains no 'names' attribute, infer it from 902 // the SSA name (if specified). 903 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 904 return attr.first == "names"; 905 }); 906 907 // If there was no name specified, check to see if there was a useful name 908 // specified in the asm file. 909 if (hadNames || parser.getNumResults() == 0) 910 return success(); 911 912 SmallVector<StringRef, 4> names; 913 auto *context = result.getContext(); 914 915 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 916 auto resultName = parser.getResultName(i); 917 StringRef nameStr; 918 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 919 nameStr = resultName.first; 920 921 names.push_back(nameStr); 922 } 923 924 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 925 result.attributes.push_back({Identifier::get("names", context), namesAttr}); 926 return success(); 927 } 928 929 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { 930 p << "test.string_attr_pretty_name"; 931 932 // Note that we only need to print the "name" attribute if the asmprinter 933 // result name disagrees with it. This can happen in strange cases, e.g. 934 // when there are conflicts. 935 bool namesDisagree = op.names().size() != op.getNumResults(); 936 937 SmallString<32> resultNameStr; 938 for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { 939 resultNameStr.clear(); 940 llvm::raw_svector_ostream tmpStream(resultNameStr); 941 p.printOperand(op.getResult(i), tmpStream); 942 943 auto expectedName = op.names()[i].dyn_cast<StringAttr>(); 944 if (!expectedName || 945 tmpStream.str().drop_front() != expectedName.getValue()) { 946 namesDisagree = true; 947 } 948 } 949 950 if (namesDisagree) 951 p.printOptionalAttrDictWithKeyword(op->getAttrs()); 952 else 953 p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"}); 954 } 955 956 // We set the SSA name in the asm syntax to the contents of the name 957 // attribute. 958 void StringAttrPrettyNameOp::getAsmResultNames( 959 function_ref<void(Value, StringRef)> setNameFn) { 960 961 auto value = names(); 962 for (size_t i = 0, e = value.size(); i != e; ++i) 963 if (auto str = value[i].dyn_cast<StringAttr>()) 964 if (!str.getValue().empty()) 965 setNameFn(getResult(i), str.getValue()); 966 } 967 968 //===----------------------------------------------------------------------===// 969 // RegionIfOp 970 //===----------------------------------------------------------------------===// 971 972 static void print(OpAsmPrinter &p, RegionIfOp op) { 973 p << RegionIfOp::getOperationName() << " "; 974 p.printOperands(op.getOperands()); 975 p << ": " << op.getOperandTypes(); 976 p.printArrowTypeList(op.getResultTypes()); 977 p << " then"; 978 p.printRegion(op.thenRegion(), 979 /*printEntryBlockArgs=*/true, 980 /*printBlockTerminators=*/true); 981 p << " else"; 982 p.printRegion(op.elseRegion(), 983 /*printEntryBlockArgs=*/true, 984 /*printBlockTerminators=*/true); 985 p << " join"; 986 p.printRegion(op.joinRegion(), 987 /*printEntryBlockArgs=*/true, 988 /*printBlockTerminators=*/true); 989 } 990 991 static ParseResult parseRegionIfOp(OpAsmParser &parser, 992 OperationState &result) { 993 SmallVector<OpAsmParser::OperandType, 2> operandInfos; 994 SmallVector<Type, 2> operandTypes; 995 996 result.regions.reserve(3); 997 Region *thenRegion = result.addRegion(); 998 Region *elseRegion = result.addRegion(); 999 Region *joinRegion = result.addRegion(); 1000 1001 // Parse operand, type and arrow type lists. 1002 if (parser.parseOperandList(operandInfos) || 1003 parser.parseColonTypeList(operandTypes) || 1004 parser.parseArrowTypeList(result.types)) 1005 return failure(); 1006 1007 // Parse all attached regions. 1008 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 1009 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 1010 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 1011 return failure(); 1012 1013 return parser.resolveOperands(operandInfos, operandTypes, 1014 parser.getCurrentLocation(), result.operands); 1015 } 1016 1017 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 1018 assert(index < 2 && "invalid region index"); 1019 return getOperands(); 1020 } 1021 1022 void RegionIfOp::getSuccessorRegions( 1023 Optional<unsigned> index, ArrayRef<Attribute> operands, 1024 SmallVectorImpl<RegionSuccessor> ®ions) { 1025 // We always branch to the join region. 1026 if (index.hasValue()) { 1027 if (index.getValue() < 2) 1028 regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs())); 1029 else 1030 regions.push_back(RegionSuccessor(getResults())); 1031 return; 1032 } 1033 1034 // The then and else regions are the entry regions of this op. 1035 regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs())); 1036 regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs())); 1037 } 1038 1039 //===----------------------------------------------------------------------===// 1040 // SingleNoTerminatorCustomAsmOp 1041 //===----------------------------------------------------------------------===// 1042 1043 static ParseResult parseSingleNoTerminatorCustomAsmOp(OpAsmParser &parser, 1044 OperationState &state) { 1045 Region *body = state.addRegion(); 1046 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) 1047 return failure(); 1048 return success(); 1049 } 1050 1051 static void print(SingleNoTerminatorCustomAsmOp op, OpAsmPrinter &printer) { 1052 printer << op.getOperationName(); 1053 printer.printRegion( 1054 op.getRegion(), /*printEntryBlockArgs=*/false, 1055 // This op has a single block without terminators. But explicitly mark 1056 // as not printing block terminators for testing. 1057 /*printBlockTerminators=*/false); 1058 } 1059 1060 #include "TestOpEnums.cpp.inc" 1061 #include "TestOpInterfaces.cpp.inc" 1062 #include "TestOpStructs.cpp.inc" 1063 #include "TestTypeInterfaces.cpp.inc" 1064 1065 #define GET_OP_CLASSES 1066 #include "TestOps.cpp.inc" 1067