1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestAttributes.h" 11 #include "TestInterfaces.h" 12 #include "TestTypes.h" 13 #include "mlir/Dialect/DLTI/DLTI.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/DialectImplementation.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Reducer/ReductionPatternInterface.h" 21 #include "mlir/Transforms/FoldUtils.h" 22 #include "mlir/Transforms/InliningUtils.h" 23 #include "llvm/ADT/StringSwitch.h" 24 25 using namespace mlir; 26 using namespace mlir::test; 27 28 #include "TestOpsDialect.cpp.inc" 29 30 void mlir::test::registerTestDialect(DialectRegistry ®istry) { 31 registry.insert<TestDialect>(); 32 } 33 34 //===----------------------------------------------------------------------===// 35 // TestDialect Interfaces 36 //===----------------------------------------------------------------------===// 37 38 namespace { 39 40 /// Testing the correctness of some traits. 41 static_assert( 42 llvm::is_detected<OpTrait::has_implicit_terminator_t, 43 SingleBlockImplicitTerminatorOp>::value, 44 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp"); 45 static_assert(OpTrait::hasSingleBlockImplicitTerminator< 46 SingleBlockImplicitTerminatorOp>::value, 47 "hasSingleBlockImplicitTerminator does not match " 48 "SingleBlockImplicitTerminatorOp"); 49 50 // Test support for interacting with the AsmPrinter. 51 struct TestOpAsmInterface : public OpAsmDialectInterface { 52 using OpAsmDialectInterface::OpAsmDialectInterface; 53 54 LogicalResult getAlias(Attribute attr, raw_ostream &os) const final { 55 StringAttr strAttr = attr.dyn_cast<StringAttr>(); 56 if (!strAttr) 57 return failure(); 58 59 // Check the contents of the string attribute to see what the test alias 60 // should be named. 61 Optional<StringRef> aliasName = 62 StringSwitch<Optional<StringRef>>(strAttr.getValue()) 63 .Case("alias_test:dot_in_name", StringRef("test.alias")) 64 .Case("alias_test:trailing_digit", StringRef("test_alias0")) 65 .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) 66 .Case("alias_test:sanitize_conflict_a", 67 StringRef("test_alias_conflict0")) 68 .Case("alias_test:sanitize_conflict_b", 69 StringRef("test_alias_conflict0_")) 70 .Default(llvm::None); 71 if (!aliasName) 72 return failure(); 73 74 os << *aliasName; 75 return success(); 76 } 77 78 void getAsmResultNames(Operation *op, 79 OpAsmSetValueNameFn setNameFn) const final { 80 if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) 81 setNameFn(asmOp, "result"); 82 } 83 84 void getAsmBlockArgumentNames(Block *block, 85 OpAsmSetValueNameFn setNameFn) const final { 86 auto op = block->getParentOp(); 87 auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names"); 88 if (!arrayAttr) 89 return; 90 auto args = block->getArguments(); 91 auto e = std::min(arrayAttr.size(), args.size()); 92 for (unsigned i = 0; i < e; ++i) { 93 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 94 setNameFn(args[i], strAttr.getValue()); 95 } 96 } 97 }; 98 99 struct TestDialectFoldInterface : public DialectFoldInterface { 100 using DialectFoldInterface::DialectFoldInterface; 101 102 /// Registered hook to check if the given region, which is attached to an 103 /// operation that is *not* isolated from above, should be used when 104 /// materializing constants. 105 bool shouldMaterializeInto(Region *region) const final { 106 // If this is a one region operation, then insert into it. 107 return isa<OneRegionOp>(region->getParentOp()); 108 } 109 }; 110 111 /// This class defines the interface for handling inlining with standard 112 /// operations. 113 struct TestInlinerInterface : public DialectInlinerInterface { 114 using DialectInlinerInterface::DialectInlinerInterface; 115 116 //===--------------------------------------------------------------------===// 117 // Analysis Hooks 118 //===--------------------------------------------------------------------===// 119 120 bool isLegalToInline(Operation *call, Operation *callable, 121 bool wouldBeCloned) const final { 122 // Don't allow inlining calls that are marked `noinline`. 123 return !call->hasAttr("noinline"); 124 } 125 bool isLegalToInline(Region *, Region *, bool, 126 BlockAndValueMapping &) const final { 127 // Inlining into test dialect regions is legal. 128 return true; 129 } 130 bool isLegalToInline(Operation *, Region *, bool, 131 BlockAndValueMapping &) const final { 132 return true; 133 } 134 135 bool shouldAnalyzeRecursively(Operation *op) const final { 136 // Analyze recursively if this is not a functional region operation, it 137 // froms a separate functional scope. 138 return !isa<FunctionalRegionOp>(op); 139 } 140 141 //===--------------------------------------------------------------------===// 142 // Transformation Hooks 143 //===--------------------------------------------------------------------===// 144 145 /// Handle the given inlined terminator by replacing it with a new operation 146 /// as necessary. 147 void handleTerminator(Operation *op, 148 ArrayRef<Value> valuesToRepl) const final { 149 // Only handle "test.return" here. 150 auto returnOp = dyn_cast<TestReturnOp>(op); 151 if (!returnOp) 152 return; 153 154 // Replace the values directly with the return operands. 155 assert(returnOp.getNumOperands() == valuesToRepl.size()); 156 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 157 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 158 } 159 160 /// Attempt to materialize a conversion for a type mismatch between a call 161 /// from this dialect, and a callable region. This method should generate an 162 /// operation that takes 'input' as the only operand, and produces a single 163 /// result of 'resultType'. If a conversion can not be generated, nullptr 164 /// should be returned. 165 Operation *materializeCallConversion(OpBuilder &builder, Value input, 166 Type resultType, 167 Location conversionLoc) const final { 168 // Only allow conversion for i16/i32 types. 169 if (!(resultType.isSignlessInteger(16) || 170 resultType.isSignlessInteger(32)) || 171 !(input.getType().isSignlessInteger(16) || 172 input.getType().isSignlessInteger(32))) 173 return nullptr; 174 return builder.create<TestCastOp>(conversionLoc, resultType, input); 175 } 176 177 void processInlinedCallBlocks( 178 Operation *call, 179 iterator_range<Region::iterator> inlinedBlocks) const final { 180 if (!isa<ConversionCallOp>(call)) 181 return; 182 183 // Set attributed on all ops in the inlined blocks. 184 for (Block &block : inlinedBlocks) { 185 block.walk([&](Operation *op) { 186 op->setAttr("inlined_conversion", UnitAttr::get(call->getContext())); 187 }); 188 } 189 } 190 }; 191 192 struct TestReductionPatternInterface : public DialectReductionPatternInterface { 193 public: 194 TestReductionPatternInterface(Dialect *dialect) 195 : DialectReductionPatternInterface(dialect) {} 196 197 virtual void 198 populateReductionPatterns(RewritePatternSet &patterns) const final { 199 populateTestReductionPatterns(patterns); 200 } 201 }; 202 203 } // end anonymous namespace 204 205 //===----------------------------------------------------------------------===// 206 // TestDialect 207 //===----------------------------------------------------------------------===// 208 209 static void testSideEffectOpGetEffect( 210 Operation *op, 211 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects); 212 213 // This is the implementation of a dialect fallback for `TestEffectOpInterface`. 214 struct TestOpEffectInterfaceFallback 215 : public TestEffectOpInterface::FallbackModel< 216 TestOpEffectInterfaceFallback> { 217 static bool classof(Operation *op) { 218 bool isSupportedOp = 219 op->getName().getStringRef() == "test.unregistered_side_effect_op"; 220 assert(isSupportedOp && "Unexpected dispatch"); 221 return isSupportedOp; 222 } 223 224 void 225 getEffects(Operation *op, 226 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 227 &effects) const { 228 testSideEffectOpGetEffect(op, effects); 229 } 230 }; 231 232 void TestDialect::initialize() { 233 registerAttributes(); 234 registerTypes(); 235 addOperations< 236 #define GET_OP_LIST 237 #include "TestOps.cpp.inc" 238 >(); 239 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 240 TestInlinerInterface, TestReductionPatternInterface>(); 241 allowUnknownOperations(); 242 243 // Instantiate our fallback op interface that we'll use on specific 244 // unregistered op. 245 fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback; 246 } 247 TestDialect::~TestDialect() { 248 delete static_cast<TestOpEffectInterfaceFallback *>( 249 fallbackEffectOpInterfaces); 250 } 251 252 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, 253 Type type, Location loc) { 254 return builder.create<TestOpConstant>(loc, type, value); 255 } 256 257 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, 258 OperationName opName) { 259 if (opName.getIdentifier() == "test.unregistered_side_effect_op" && 260 typeID == TypeID::get<TestEffectOpInterface>()) 261 return fallbackEffectOpInterfaces; 262 return nullptr; 263 } 264 265 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 266 NamedAttribute namedAttr) { 267 if (namedAttr.first == "test.invalid_attr") 268 return op->emitError() << "invalid to use 'test.invalid_attr'"; 269 return success(); 270 } 271 272 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 273 unsigned regionIndex, 274 unsigned argIndex, 275 NamedAttribute namedAttr) { 276 if (namedAttr.first == "test.invalid_attr") 277 return op->emitError() << "invalid to use 'test.invalid_attr'"; 278 return success(); 279 } 280 281 LogicalResult 282 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 283 unsigned resultIndex, 284 NamedAttribute namedAttr) { 285 if (namedAttr.first == "test.invalid_attr") 286 return op->emitError() << "invalid to use 'test.invalid_attr'"; 287 return success(); 288 } 289 290 Optional<Dialect::ParseOpHook> 291 TestDialect::getParseOperationHook(StringRef opName) const { 292 if (opName == "test.dialect_custom_printer") { 293 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 294 return parser.parseKeyword("custom_format"); 295 }}; 296 } 297 return None; 298 } 299 300 LogicalResult TestDialect::printOperation(Operation *op, 301 OpAsmPrinter &printer) const { 302 StringRef opName = op->getName().getStringRef(); 303 if (opName == "test.dialect_custom_printer") { 304 printer.getStream() << opName << " custom_format"; 305 return success(); 306 } 307 return failure(); 308 } 309 310 //===----------------------------------------------------------------------===// 311 // TestBranchOp 312 //===----------------------------------------------------------------------===// 313 314 Optional<MutableOperandRange> 315 TestBranchOp::getMutableSuccessorOperands(unsigned index) { 316 assert(index == 0 && "invalid successor index"); 317 return targetOperandsMutable(); 318 } 319 320 //===----------------------------------------------------------------------===// 321 // TestDialectCanonicalizerOp 322 //===----------------------------------------------------------------------===// 323 324 static LogicalResult 325 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op, 326 PatternRewriter &rewriter) { 327 rewriter.replaceOpWithNewOp<ConstantOp>(op, rewriter.getI32Type(), 328 rewriter.getI32IntegerAttr(42)); 329 return success(); 330 } 331 332 void TestDialect::getCanonicalizationPatterns( 333 RewritePatternSet &results) const { 334 results.add(&dialectCanonicalizationPattern); 335 } 336 337 //===----------------------------------------------------------------------===// 338 // TestFoldToCallOp 339 //===----------------------------------------------------------------------===// 340 341 namespace { 342 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 343 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 344 345 LogicalResult matchAndRewrite(FoldToCallOp op, 346 PatternRewriter &rewriter) const override { 347 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(), 348 ValueRange()); 349 return success(); 350 } 351 }; 352 } // end anonymous namespace 353 354 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results, 355 MLIRContext *context) { 356 results.add<FoldToCallOpPattern>(context); 357 } 358 359 //===----------------------------------------------------------------------===// 360 // Test Format* operations 361 //===----------------------------------------------------------------------===// 362 363 //===----------------------------------------------------------------------===// 364 // Parsing 365 366 static ParseResult parseCustomDirectiveOperands( 367 OpAsmParser &parser, OpAsmParser::OperandType &operand, 368 Optional<OpAsmParser::OperandType> &optOperand, 369 SmallVectorImpl<OpAsmParser::OperandType> &varOperands) { 370 if (parser.parseOperand(operand)) 371 return failure(); 372 if (succeeded(parser.parseOptionalComma())) { 373 optOperand.emplace(); 374 if (parser.parseOperand(*optOperand)) 375 return failure(); 376 } 377 if (parser.parseArrow() || parser.parseLParen() || 378 parser.parseOperandList(varOperands) || parser.parseRParen()) 379 return failure(); 380 return success(); 381 } 382 static ParseResult 383 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 384 Type &optOperandType, 385 SmallVectorImpl<Type> &varOperandTypes) { 386 if (parser.parseColon()) 387 return failure(); 388 389 if (parser.parseType(operandType)) 390 return failure(); 391 if (succeeded(parser.parseOptionalComma())) { 392 if (parser.parseType(optOperandType)) 393 return failure(); 394 } 395 if (parser.parseArrow() || parser.parseLParen() || 396 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 397 return failure(); 398 return success(); 399 } 400 static ParseResult 401 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 402 Type optOperandType, 403 const SmallVectorImpl<Type> &varOperandTypes) { 404 if (parser.parseKeyword("type_refs_capture")) 405 return failure(); 406 407 Type operandType2, optOperandType2; 408 SmallVector<Type, 1> varOperandTypes2; 409 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 410 varOperandTypes2)) 411 return failure(); 412 413 if (operandType != operandType2 || optOperandType != optOperandType2 || 414 varOperandTypes != varOperandTypes2) 415 return failure(); 416 417 return success(); 418 } 419 static ParseResult parseCustomDirectiveOperandsAndTypes( 420 OpAsmParser &parser, OpAsmParser::OperandType &operand, 421 Optional<OpAsmParser::OperandType> &optOperand, 422 SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType, 423 Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) { 424 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 425 parseCustomDirectiveResults(parser, operandType, optOperandType, 426 varOperandTypes)) 427 return failure(); 428 return success(); 429 } 430 static ParseResult parseCustomDirectiveRegions( 431 OpAsmParser &parser, Region ®ion, 432 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 433 if (parser.parseRegion(region)) 434 return failure(); 435 if (failed(parser.parseOptionalComma())) 436 return success(); 437 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 438 if (parser.parseRegion(*varRegion)) 439 return failure(); 440 varRegions.emplace_back(std::move(varRegion)); 441 return success(); 442 } 443 static ParseResult 444 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 445 SmallVectorImpl<Block *> &varSuccessors) { 446 if (parser.parseSuccessor(successor)) 447 return failure(); 448 if (failed(parser.parseOptionalComma())) 449 return success(); 450 Block *varSuccessor; 451 if (parser.parseSuccessor(varSuccessor)) 452 return failure(); 453 varSuccessors.append(2, varSuccessor); 454 return success(); 455 } 456 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 457 IntegerAttr &attr, 458 IntegerAttr &optAttr) { 459 if (parser.parseAttribute(attr)) 460 return failure(); 461 if (succeeded(parser.parseOptionalComma())) { 462 if (parser.parseAttribute(optAttr)) 463 return failure(); 464 } 465 return success(); 466 } 467 468 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, 469 NamedAttrList &attrs) { 470 return parser.parseOptionalAttrDict(attrs); 471 } 472 static ParseResult parseCustomDirectiveOptionalOperandRef( 473 OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) { 474 int64_t operandCount = 0; 475 if (parser.parseInteger(operandCount)) 476 return failure(); 477 bool expectedOptionalOperand = operandCount == 0; 478 return success(expectedOptionalOperand != optOperand.hasValue()); 479 } 480 481 //===----------------------------------------------------------------------===// 482 // Printing 483 484 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 485 Value operand, Value optOperand, 486 OperandRange varOperands) { 487 printer << operand; 488 if (optOperand) 489 printer << ", " << optOperand; 490 printer << " -> (" << varOperands << ")"; 491 } 492 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 493 Type operandType, Type optOperandType, 494 TypeRange varOperandTypes) { 495 printer << " : " << operandType; 496 if (optOperandType) 497 printer << ", " << optOperandType; 498 printer << " -> (" << varOperandTypes << ")"; 499 } 500 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 501 Operation *op, Type operandType, 502 Type optOperandType, 503 TypeRange varOperandTypes) { 504 printer << " type_refs_capture "; 505 printCustomDirectiveResults(printer, op, operandType, optOperandType, 506 varOperandTypes); 507 } 508 static void printCustomDirectiveOperandsAndTypes( 509 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 510 OperandRange varOperands, Type operandType, Type optOperandType, 511 TypeRange varOperandTypes) { 512 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 513 printCustomDirectiveResults(printer, op, operandType, optOperandType, 514 varOperandTypes); 515 } 516 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 517 Region ®ion, 518 MutableArrayRef<Region> varRegions) { 519 printer.printRegion(region); 520 if (!varRegions.empty()) { 521 printer << ", "; 522 for (Region ®ion : varRegions) 523 printer.printRegion(region); 524 } 525 } 526 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 527 Block *successor, 528 SuccessorRange varSuccessors) { 529 printer << successor; 530 if (!varSuccessors.empty()) 531 printer << ", " << varSuccessors.front(); 532 } 533 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 534 Attribute attribute, 535 Attribute optAttribute) { 536 printer << attribute; 537 if (optAttribute) 538 printer << ", " << optAttribute; 539 } 540 541 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 542 DictionaryAttr attrs) { 543 printer.printOptionalAttrDict(attrs.getValue()); 544 } 545 546 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, 547 Operation *op, 548 Value optOperand) { 549 printer << (optOperand ? "1" : "0"); 550 } 551 552 //===----------------------------------------------------------------------===// 553 // Test IsolatedRegionOp - parse passthrough region arguments. 554 //===----------------------------------------------------------------------===// 555 556 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, 557 OperationState &result) { 558 OpAsmParser::OperandType argInfo; 559 Type argType = parser.getBuilder().getIndexType(); 560 561 // Parse the input operand. 562 if (parser.parseOperand(argInfo) || 563 parser.resolveOperand(argInfo, argType, result.operands)) 564 return failure(); 565 566 // Parse the body region, and reuse the operand info as the argument info. 567 Region *body = result.addRegion(); 568 return parser.parseRegion(*body, argInfo, argType, 569 /*enableNameShadowing=*/true); 570 } 571 572 static void print(OpAsmPrinter &p, IsolatedRegionOp op) { 573 p << "test.isolated_region "; 574 p.printOperand(op.getOperand()); 575 p.shadowRegionArgs(op.region(), op.getOperand()); 576 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 577 } 578 579 //===----------------------------------------------------------------------===// 580 // Test SSACFGRegionOp 581 //===----------------------------------------------------------------------===// 582 583 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 584 return RegionKind::SSACFG; 585 } 586 587 //===----------------------------------------------------------------------===// 588 // Test GraphRegionOp 589 //===----------------------------------------------------------------------===// 590 591 static ParseResult parseGraphRegionOp(OpAsmParser &parser, 592 OperationState &result) { 593 // Parse the body region, and reuse the operand info as the argument info. 594 Region *body = result.addRegion(); 595 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 596 } 597 598 static void print(OpAsmPrinter &p, GraphRegionOp op) { 599 p << "test.graph_region "; 600 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 601 } 602 603 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 604 return RegionKind::Graph; 605 } 606 607 //===----------------------------------------------------------------------===// 608 // Test AffineScopeOp 609 //===----------------------------------------------------------------------===// 610 611 static ParseResult parseAffineScopeOp(OpAsmParser &parser, 612 OperationState &result) { 613 // Parse the body region, and reuse the operand info as the argument info. 614 Region *body = result.addRegion(); 615 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 616 } 617 618 static void print(OpAsmPrinter &p, AffineScopeOp op) { 619 p << "test.affine_scope "; 620 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 621 } 622 623 //===----------------------------------------------------------------------===// 624 // Test parser. 625 //===----------------------------------------------------------------------===// 626 627 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser, 628 OperationState &result) { 629 if (parser.parseOptionalColon()) 630 return success(); 631 uint64_t numResults; 632 if (parser.parseInteger(numResults)) 633 return failure(); 634 635 IndexType type = parser.getBuilder().getIndexType(); 636 for (unsigned i = 0; i < numResults; ++i) 637 result.addTypes(type); 638 return success(); 639 } 640 641 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) { 642 p << ParseIntegerLiteralOp::getOperationName(); 643 if (unsigned numResults = op->getNumResults()) 644 p << " : " << numResults; 645 } 646 647 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser, 648 OperationState &result) { 649 StringRef keyword; 650 if (parser.parseKeyword(&keyword)) 651 return failure(); 652 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 653 return success(); 654 } 655 656 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) { 657 p << ParseWrappedKeywordOp::getOperationName() << " " << op.keyword(); 658 } 659 660 //===----------------------------------------------------------------------===// 661 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 662 663 static ParseResult parseWrappingRegionOp(OpAsmParser &parser, 664 OperationState &result) { 665 if (parser.parseKeyword("wraps")) 666 return failure(); 667 668 // Parse the wrapped op in a region 669 Region &body = *result.addRegion(); 670 body.push_back(new Block); 671 Block &block = body.back(); 672 Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); 673 if (!wrapped_op) 674 return failure(); 675 676 // Create a return terminator in the inner region, pass as operand to the 677 // terminator the returned values from the wrapped operation. 678 SmallVector<Value, 8> return_operands(wrapped_op->getResults()); 679 OpBuilder builder(parser.getBuilder().getContext()); 680 builder.setInsertionPointToEnd(&block); 681 builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands); 682 683 // Get the results type for the wrapping op from the terminator operands. 684 Operation &return_op = body.back().back(); 685 result.types.append(return_op.operand_type_begin(), 686 return_op.operand_type_end()); 687 688 // Use the location of the wrapped op for the "test.wrapping_region" op. 689 result.location = wrapped_op->getLoc(); 690 691 return success(); 692 } 693 694 static void print(OpAsmPrinter &p, WrappingRegionOp op) { 695 p << op.getOperationName() << " wraps "; 696 p.printGenericOp(&op.region().front().front()); 697 } 698 699 //===----------------------------------------------------------------------===// 700 // Test PolyForOp - parse list of region arguments. 701 //===----------------------------------------------------------------------===// 702 703 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { 704 SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 705 // Parse list of region arguments without a delimiter. 706 if (parser.parseRegionArgumentList(ivsInfo)) 707 return failure(); 708 709 // Parse the body region. 710 Region *body = result.addRegion(); 711 auto &builder = parser.getBuilder(); 712 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 713 return parser.parseRegion(*body, ivsInfo, argTypes); 714 } 715 716 //===----------------------------------------------------------------------===// 717 // Test removing op with inner ops. 718 //===----------------------------------------------------------------------===// 719 720 namespace { 721 struct TestRemoveOpWithInnerOps 722 : public OpRewritePattern<TestOpWithRegionPattern> { 723 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 724 725 void initialize() { setDebugName("TestRemoveOpWithInnerOps"); } 726 727 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 728 PatternRewriter &rewriter) const override { 729 rewriter.eraseOp(op); 730 return success(); 731 } 732 }; 733 } // end anonymous namespace 734 735 void TestOpWithRegionPattern::getCanonicalizationPatterns( 736 RewritePatternSet &results, MLIRContext *context) { 737 results.add<TestRemoveOpWithInnerOps>(context); 738 } 739 740 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 741 return operand(); 742 } 743 744 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) { 745 return getValue(); 746 } 747 748 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 749 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 750 for (Value input : this->operands()) { 751 results.push_back(input); 752 } 753 return success(); 754 } 755 756 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 757 assert(operands.size() == 1); 758 if (operands.front()) { 759 (*this)->setAttr("attr", operands.front()); 760 return getResult(); 761 } 762 return {}; 763 } 764 765 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) { 766 return getOperand(); 767 } 768 769 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 770 MLIRContext *, Optional<Location> location, ValueRange operands, 771 DictionaryAttr attributes, RegionRange regions, 772 SmallVectorImpl<Type> &inferredReturnTypes) { 773 if (operands[0].getType() != operands[1].getType()) { 774 return emitOptionalError(location, "operand type mismatch ", 775 operands[0].getType(), " vs ", 776 operands[1].getType()); 777 } 778 inferredReturnTypes.assign({operands[0].getType()}); 779 return success(); 780 } 781 782 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 783 MLIRContext *context, Optional<Location> location, ValueRange operands, 784 DictionaryAttr attributes, RegionRange regions, 785 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 786 // Create return type consisting of the last element of the first operand. 787 auto operandType = *operands.getTypes().begin(); 788 auto sval = operandType.dyn_cast<ShapedType>(); 789 if (!sval) { 790 return emitOptionalError(location, "only shaped type operands allowed"); 791 } 792 int64_t dim = 793 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 794 auto type = IntegerType::get(context, 17); 795 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 796 return success(); 797 } 798 799 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 800 OpBuilder &builder, ValueRange operands, 801 llvm::SmallVectorImpl<Value> &shapes) { 802 shapes = SmallVector<Value, 1>{ 803 builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)}; 804 return success(); 805 } 806 807 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( 808 OpBuilder &builder, ValueRange operands, 809 llvm::SmallVectorImpl<Value> &shapes) { 810 Location loc = getLoc(); 811 shapes.reserve(operands.size()); 812 for (Value operand : llvm::reverse(operands)) { 813 auto currShape = llvm::to_vector<4>(llvm::map_range( 814 llvm::seq<int64_t>( 815 0, operand.getType().cast<RankedTensorType>().getRank()), 816 [&](int64_t dim) -> Value { 817 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 818 })); 819 shapes.push_back(builder.create<tensor::FromElementsOp>( 820 getLoc(), builder.getIndexType(), currShape)); 821 } 822 return success(); 823 } 824 825 LogicalResult 826 OpWithResultShapePerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim( 827 OpBuilder &builder, 828 llvm::SmallVectorImpl<llvm::SmallVector<Value>> &shapes) { 829 Location loc = getLoc(); 830 shapes.reserve(getNumOperands()); 831 for (Value operand : llvm::reverse(getOperands())) { 832 auto currShape = llvm::to_vector<4>(llvm::map_range( 833 llvm::seq<int64_t>( 834 0, operand.getType().cast<RankedTensorType>().getRank()), 835 [&](int64_t dim) -> Value { 836 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 837 })); 838 shapes.emplace_back(std::move(currShape)); 839 } 840 return success(); 841 } 842 843 LogicalResult OpWithResultShapeAndPerDimInterfaceOp::reifyReturnTypeShapes( 844 OpBuilder &builder, ValueRange operands, 845 llvm::SmallVectorImpl<Value> &shapes) { 846 Location loc = getLoc(); 847 shapes.reserve(operands.size()); 848 for (Value operand : llvm::reverse(operands)) { 849 auto currShape = llvm::to_vector<4>(llvm::map_range( 850 llvm::seq<int64_t>( 851 0, operand.getType().cast<RankedTensorType>().getRank()), 852 [&](int64_t dim) -> Value { 853 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 854 })); 855 shapes.push_back(builder.create<tensor::FromElementsOp>( 856 getLoc(), builder.getIndexType(), currShape)); 857 } 858 return success(); 859 } 860 861 LogicalResult 862 OpWithResultShapeAndPerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim( 863 OpBuilder &builder, 864 llvm::SmallVectorImpl<llvm::SmallVector<Value>> &shapes) { 865 Location loc = getLoc(); 866 shapes.reserve(getNumOperands()); 867 for (Value operand : llvm::reverse(getOperands())) { 868 auto currShape = llvm::to_vector<4>(llvm::map_range( 869 llvm::seq<int64_t>( 870 0, operand.getType().cast<RankedTensorType>().getRank()), 871 [&](int64_t dim) -> Value { 872 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 873 })); 874 shapes.emplace_back(std::move(currShape)); 875 } 876 return success(); 877 } 878 879 //===----------------------------------------------------------------------===// 880 // Test SideEffect interfaces 881 //===----------------------------------------------------------------------===// 882 883 namespace { 884 /// A test resource for side effects. 885 struct TestResource : public SideEffects::Resource::Base<TestResource> { 886 StringRef getName() final { return "<Test>"; } 887 }; 888 } // end anonymous namespace 889 890 static void testSideEffectOpGetEffect( 891 Operation *op, 892 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 893 &effects) { 894 auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter"); 895 if (!effectsAttr) 896 return; 897 898 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); 899 } 900 901 void SideEffectOp::getEffects( 902 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 903 // Check for an effects attribute on the op instance. 904 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects"); 905 if (!effectsAttr) 906 return; 907 908 // If there is one, it is an array of dictionary attributes that hold 909 // information on the effects of this operation. 910 for (Attribute element : effectsAttr) { 911 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 912 913 // Get the specific memory effect. 914 MemoryEffects::Effect *effect = 915 StringSwitch<MemoryEffects::Effect *>( 916 effectElement.get("effect").cast<StringAttr>().getValue()) 917 .Case("allocate", MemoryEffects::Allocate::get()) 918 .Case("free", MemoryEffects::Free::get()) 919 .Case("read", MemoryEffects::Read::get()) 920 .Case("write", MemoryEffects::Write::get()); 921 922 // Check for a non-default resource to use. 923 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 924 if (effectElement.get("test_resource")) 925 resource = TestResource::get(); 926 927 // Check for a result to affect. 928 if (effectElement.get("on_result")) 929 effects.emplace_back(effect, getResult(), resource); 930 else if (Attribute ref = effectElement.get("on_reference")) 931 effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource); 932 else 933 effects.emplace_back(effect, resource); 934 } 935 } 936 937 void SideEffectOp::getEffects( 938 SmallVectorImpl<TestEffects::EffectInstance> &effects) { 939 testSideEffectOpGetEffect(getOperation(), effects); 940 } 941 942 //===----------------------------------------------------------------------===// 943 // StringAttrPrettyNameOp 944 //===----------------------------------------------------------------------===// 945 946 // This op has fancy handling of its SSA result name. 947 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, 948 OperationState &result) { 949 // Add the result types. 950 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 951 result.addTypes(parser.getBuilder().getIntegerType(32)); 952 953 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 954 return failure(); 955 956 // If the attribute dictionary contains no 'names' attribute, infer it from 957 // the SSA name (if specified). 958 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 959 return attr.first == "names"; 960 }); 961 962 // If there was no name specified, check to see if there was a useful name 963 // specified in the asm file. 964 if (hadNames || parser.getNumResults() == 0) 965 return success(); 966 967 SmallVector<StringRef, 4> names; 968 auto *context = result.getContext(); 969 970 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 971 auto resultName = parser.getResultName(i); 972 StringRef nameStr; 973 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 974 nameStr = resultName.first; 975 976 names.push_back(nameStr); 977 } 978 979 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 980 result.attributes.push_back({Identifier::get("names", context), namesAttr}); 981 return success(); 982 } 983 984 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { 985 p << "test.string_attr_pretty_name"; 986 987 // Note that we only need to print the "name" attribute if the asmprinter 988 // result name disagrees with it. This can happen in strange cases, e.g. 989 // when there are conflicts. 990 bool namesDisagree = op.names().size() != op.getNumResults(); 991 992 SmallString<32> resultNameStr; 993 for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { 994 resultNameStr.clear(); 995 llvm::raw_svector_ostream tmpStream(resultNameStr); 996 p.printOperand(op.getResult(i), tmpStream); 997 998 auto expectedName = op.names()[i].dyn_cast<StringAttr>(); 999 if (!expectedName || 1000 tmpStream.str().drop_front() != expectedName.getValue()) { 1001 namesDisagree = true; 1002 } 1003 } 1004 1005 if (namesDisagree) 1006 p.printOptionalAttrDictWithKeyword(op->getAttrs()); 1007 else 1008 p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"}); 1009 } 1010 1011 // We set the SSA name in the asm syntax to the contents of the name 1012 // attribute. 1013 void StringAttrPrettyNameOp::getAsmResultNames( 1014 function_ref<void(Value, StringRef)> setNameFn) { 1015 1016 auto value = names(); 1017 for (size_t i = 0, e = value.size(); i != e; ++i) 1018 if (auto str = value[i].dyn_cast<StringAttr>()) 1019 if (!str.getValue().empty()) 1020 setNameFn(getResult(i), str.getValue()); 1021 } 1022 1023 //===----------------------------------------------------------------------===// 1024 // RegionIfOp 1025 //===----------------------------------------------------------------------===// 1026 1027 static void print(OpAsmPrinter &p, RegionIfOp op) { 1028 p << RegionIfOp::getOperationName() << " "; 1029 p.printOperands(op.getOperands()); 1030 p << ": " << op.getOperandTypes(); 1031 p.printArrowTypeList(op.getResultTypes()); 1032 p << " then"; 1033 p.printRegion(op.thenRegion(), 1034 /*printEntryBlockArgs=*/true, 1035 /*printBlockTerminators=*/true); 1036 p << " else"; 1037 p.printRegion(op.elseRegion(), 1038 /*printEntryBlockArgs=*/true, 1039 /*printBlockTerminators=*/true); 1040 p << " join"; 1041 p.printRegion(op.joinRegion(), 1042 /*printEntryBlockArgs=*/true, 1043 /*printBlockTerminators=*/true); 1044 } 1045 1046 static ParseResult parseRegionIfOp(OpAsmParser &parser, 1047 OperationState &result) { 1048 SmallVector<OpAsmParser::OperandType, 2> operandInfos; 1049 SmallVector<Type, 2> operandTypes; 1050 1051 result.regions.reserve(3); 1052 Region *thenRegion = result.addRegion(); 1053 Region *elseRegion = result.addRegion(); 1054 Region *joinRegion = result.addRegion(); 1055 1056 // Parse operand, type and arrow type lists. 1057 if (parser.parseOperandList(operandInfos) || 1058 parser.parseColonTypeList(operandTypes) || 1059 parser.parseArrowTypeList(result.types)) 1060 return failure(); 1061 1062 // Parse all attached regions. 1063 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 1064 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 1065 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 1066 return failure(); 1067 1068 return parser.resolveOperands(operandInfos, operandTypes, 1069 parser.getCurrentLocation(), result.operands); 1070 } 1071 1072 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 1073 assert(index < 2 && "invalid region index"); 1074 return getOperands(); 1075 } 1076 1077 void RegionIfOp::getSuccessorRegions( 1078 Optional<unsigned> index, ArrayRef<Attribute> operands, 1079 SmallVectorImpl<RegionSuccessor> ®ions) { 1080 // We always branch to the join region. 1081 if (index.hasValue()) { 1082 if (index.getValue() < 2) 1083 regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs())); 1084 else 1085 regions.push_back(RegionSuccessor(getResults())); 1086 return; 1087 } 1088 1089 // The then and else regions are the entry regions of this op. 1090 regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs())); 1091 regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs())); 1092 } 1093 1094 //===----------------------------------------------------------------------===// 1095 // SingleNoTerminatorCustomAsmOp 1096 //===----------------------------------------------------------------------===// 1097 1098 static ParseResult parseSingleNoTerminatorCustomAsmOp(OpAsmParser &parser, 1099 OperationState &state) { 1100 Region *body = state.addRegion(); 1101 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) 1102 return failure(); 1103 return success(); 1104 } 1105 1106 static void print(SingleNoTerminatorCustomAsmOp op, OpAsmPrinter &printer) { 1107 printer << op.getOperationName(); 1108 printer.printRegion( 1109 op.getRegion(), /*printEntryBlockArgs=*/false, 1110 // This op has a single block without terminators. But explicitly mark 1111 // as not printing block terminators for testing. 1112 /*printBlockTerminators=*/false); 1113 } 1114 1115 #include "TestOpEnums.cpp.inc" 1116 #include "TestOpInterfaces.cpp.inc" 1117 #include "TestOpStructs.cpp.inc" 1118 #include "TestTypeInterfaces.cpp.inc" 1119 1120 #define GET_OP_CLASSES 1121 #include "TestOps.cpp.inc" 1122