1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestAttributes.h" 11 #include "TestInterfaces.h" 12 #include "TestTypes.h" 13 #include "mlir/Dialect/DLTI/DLTI.h" 14 #include "mlir/Dialect/MemRef/IR/MemRef.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/DialectImplementation.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/IR/TypeUtilities.h" 21 #include "mlir/Reducer/ReductionPatternInterface.h" 22 #include "mlir/Transforms/FoldUtils.h" 23 #include "mlir/Transforms/InliningUtils.h" 24 #include "llvm/ADT/StringSwitch.h" 25 26 using namespace mlir; 27 using namespace mlir::test; 28 29 void mlir::test::registerTestDialect(DialectRegistry ®istry) { 30 registry.insert<TestDialect>(); 31 } 32 33 //===----------------------------------------------------------------------===// 34 // TestDialect Interfaces 35 //===----------------------------------------------------------------------===// 36 37 namespace { 38 39 /// Testing the correctness of some traits. 40 static_assert( 41 llvm::is_detected<OpTrait::has_implicit_terminator_t, 42 SingleBlockImplicitTerminatorOp>::value, 43 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp"); 44 static_assert(OpTrait::hasSingleBlockImplicitTerminator< 45 SingleBlockImplicitTerminatorOp>::value, 46 "hasSingleBlockImplicitTerminator does not match " 47 "SingleBlockImplicitTerminatorOp"); 48 49 // Test support for interacting with the AsmPrinter. 50 struct TestOpAsmInterface : public OpAsmDialectInterface { 51 using OpAsmDialectInterface::OpAsmDialectInterface; 52 53 LogicalResult getAlias(Attribute attr, raw_ostream &os) const final { 54 StringAttr strAttr = attr.dyn_cast<StringAttr>(); 55 if (!strAttr) 56 return failure(); 57 58 // Check the contents of the string attribute to see what the test alias 59 // should be named. 60 Optional<StringRef> aliasName = 61 StringSwitch<Optional<StringRef>>(strAttr.getValue()) 62 .Case("alias_test:dot_in_name", StringRef("test.alias")) 63 .Case("alias_test:trailing_digit", StringRef("test_alias0")) 64 .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) 65 .Case("alias_test:sanitize_conflict_a", 66 StringRef("test_alias_conflict0")) 67 .Case("alias_test:sanitize_conflict_b", 68 StringRef("test_alias_conflict0_")) 69 .Default(llvm::None); 70 if (!aliasName) 71 return failure(); 72 73 os << *aliasName; 74 return success(); 75 } 76 77 void getAsmResultNames(Operation *op, 78 OpAsmSetValueNameFn setNameFn) const final { 79 if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) 80 setNameFn(asmOp, "result"); 81 } 82 83 void getAsmBlockArgumentNames(Block *block, 84 OpAsmSetValueNameFn setNameFn) const final { 85 auto op = block->getParentOp(); 86 auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names"); 87 if (!arrayAttr) 88 return; 89 auto args = block->getArguments(); 90 auto e = std::min(arrayAttr.size(), args.size()); 91 for (unsigned i = 0; i < e; ++i) { 92 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 93 setNameFn(args[i], strAttr.getValue()); 94 } 95 } 96 }; 97 98 struct TestDialectFoldInterface : public DialectFoldInterface { 99 using DialectFoldInterface::DialectFoldInterface; 100 101 /// Registered hook to check if the given region, which is attached to an 102 /// operation that is *not* isolated from above, should be used when 103 /// materializing constants. 104 bool shouldMaterializeInto(Region *region) const final { 105 // If this is a one region operation, then insert into it. 106 return isa<OneRegionOp>(region->getParentOp()); 107 } 108 }; 109 110 /// This class defines the interface for handling inlining with standard 111 /// operations. 112 struct TestInlinerInterface : public DialectInlinerInterface { 113 using DialectInlinerInterface::DialectInlinerInterface; 114 115 //===--------------------------------------------------------------------===// 116 // Analysis Hooks 117 //===--------------------------------------------------------------------===// 118 119 bool isLegalToInline(Operation *call, Operation *callable, 120 bool wouldBeCloned) const final { 121 // Don't allow inlining calls that are marked `noinline`. 122 return !call->hasAttr("noinline"); 123 } 124 bool isLegalToInline(Region *, Region *, bool, 125 BlockAndValueMapping &) const final { 126 // Inlining into test dialect regions is legal. 127 return true; 128 } 129 bool isLegalToInline(Operation *, Region *, bool, 130 BlockAndValueMapping &) const final { 131 return true; 132 } 133 134 bool shouldAnalyzeRecursively(Operation *op) const final { 135 // Analyze recursively if this is not a functional region operation, it 136 // froms a separate functional scope. 137 return !isa<FunctionalRegionOp>(op); 138 } 139 140 //===--------------------------------------------------------------------===// 141 // Transformation Hooks 142 //===--------------------------------------------------------------------===// 143 144 /// Handle the given inlined terminator by replacing it with a new operation 145 /// as necessary. 146 void handleTerminator(Operation *op, 147 ArrayRef<Value> valuesToRepl) const final { 148 // Only handle "test.return" here. 149 auto returnOp = dyn_cast<TestReturnOp>(op); 150 if (!returnOp) 151 return; 152 153 // Replace the values directly with the return operands. 154 assert(returnOp.getNumOperands() == valuesToRepl.size()); 155 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 156 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 157 } 158 159 /// Attempt to materialize a conversion for a type mismatch between a call 160 /// from this dialect, and a callable region. This method should generate an 161 /// operation that takes 'input' as the only operand, and produces a single 162 /// result of 'resultType'. If a conversion can not be generated, nullptr 163 /// should be returned. 164 Operation *materializeCallConversion(OpBuilder &builder, Value input, 165 Type resultType, 166 Location conversionLoc) const final { 167 // Only allow conversion for i16/i32 types. 168 if (!(resultType.isSignlessInteger(16) || 169 resultType.isSignlessInteger(32)) || 170 !(input.getType().isSignlessInteger(16) || 171 input.getType().isSignlessInteger(32))) 172 return nullptr; 173 return builder.create<TestCastOp>(conversionLoc, resultType, input); 174 } 175 176 void processInlinedCallBlocks( 177 Operation *call, 178 iterator_range<Region::iterator> inlinedBlocks) const final { 179 if (!isa<ConversionCallOp>(call)) 180 return; 181 182 // Set attributed on all ops in the inlined blocks. 183 for (Block &block : inlinedBlocks) { 184 block.walk([&](Operation *op) { 185 op->setAttr("inlined_conversion", UnitAttr::get(call->getContext())); 186 }); 187 } 188 } 189 }; 190 191 struct TestReductionPatternInterface : public DialectReductionPatternInterface { 192 public: 193 TestReductionPatternInterface(Dialect *dialect) 194 : DialectReductionPatternInterface(dialect) {} 195 196 virtual void 197 populateReductionPatterns(RewritePatternSet &patterns) const final { 198 populateTestReductionPatterns(patterns); 199 } 200 }; 201 202 } // end anonymous namespace 203 204 //===----------------------------------------------------------------------===// 205 // TestDialect 206 //===----------------------------------------------------------------------===// 207 208 static void testSideEffectOpGetEffect( 209 Operation *op, 210 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects); 211 212 // This is the implementation of a dialect fallback for `TestEffectOpInterface`. 213 struct TestOpEffectInterfaceFallback 214 : public TestEffectOpInterface::FallbackModel< 215 TestOpEffectInterfaceFallback> { 216 static bool classof(Operation *op) { 217 bool isSupportedOp = 218 op->getName().getStringRef() == "test.unregistered_side_effect_op"; 219 assert(isSupportedOp && "Unexpected dispatch"); 220 return isSupportedOp; 221 } 222 223 void 224 getEffects(Operation *op, 225 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 226 &effects) const { 227 testSideEffectOpGetEffect(op, effects); 228 } 229 }; 230 231 void TestDialect::initialize() { 232 registerAttributes(); 233 registerTypes(); 234 addOperations< 235 #define GET_OP_LIST 236 #include "TestOps.cpp.inc" 237 >(); 238 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 239 TestInlinerInterface, TestReductionPatternInterface>(); 240 allowUnknownOperations(); 241 242 // Instantiate our fallback op interface that we'll use on specific 243 // unregistered op. 244 fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback; 245 } 246 TestDialect::~TestDialect() { 247 delete static_cast<TestOpEffectInterfaceFallback *>( 248 fallbackEffectOpInterfaces); 249 } 250 251 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, 252 Type type, Location loc) { 253 return builder.create<TestOpConstant>(loc, type, value); 254 } 255 256 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, 257 OperationName opName) { 258 if (opName.getIdentifier() == "test.unregistered_side_effect_op" && 259 typeID == TypeID::get<TestEffectOpInterface>()) 260 return fallbackEffectOpInterfaces; 261 return nullptr; 262 } 263 264 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 265 NamedAttribute namedAttr) { 266 if (namedAttr.first == "test.invalid_attr") 267 return op->emitError() << "invalid to use 'test.invalid_attr'"; 268 return success(); 269 } 270 271 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 272 unsigned regionIndex, 273 unsigned argIndex, 274 NamedAttribute namedAttr) { 275 if (namedAttr.first == "test.invalid_attr") 276 return op->emitError() << "invalid to use 'test.invalid_attr'"; 277 return success(); 278 } 279 280 LogicalResult 281 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 282 unsigned resultIndex, 283 NamedAttribute namedAttr) { 284 if (namedAttr.first == "test.invalid_attr") 285 return op->emitError() << "invalid to use 'test.invalid_attr'"; 286 return success(); 287 } 288 289 Optional<Dialect::ParseOpHook> 290 TestDialect::getParseOperationHook(StringRef opName) const { 291 if (opName == "test.dialect_custom_printer") { 292 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 293 return parser.parseKeyword("custom_format"); 294 }}; 295 } 296 return None; 297 } 298 299 LogicalResult TestDialect::printOperation(Operation *op, 300 OpAsmPrinter &printer) const { 301 StringRef opName = op->getName().getStringRef(); 302 if (opName == "test.dialect_custom_printer") { 303 printer.getStream() << opName << " custom_format"; 304 return success(); 305 } 306 return failure(); 307 } 308 309 //===----------------------------------------------------------------------===// 310 // TestBranchOp 311 //===----------------------------------------------------------------------===// 312 313 Optional<MutableOperandRange> 314 TestBranchOp::getMutableSuccessorOperands(unsigned index) { 315 assert(index == 0 && "invalid successor index"); 316 return targetOperandsMutable(); 317 } 318 319 //===----------------------------------------------------------------------===// 320 // TestDialectCanonicalizerOp 321 //===----------------------------------------------------------------------===// 322 323 static LogicalResult 324 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op, 325 PatternRewriter &rewriter) { 326 rewriter.replaceOpWithNewOp<ConstantOp>(op, rewriter.getI32Type(), 327 rewriter.getI32IntegerAttr(42)); 328 return success(); 329 } 330 331 void TestDialect::getCanonicalizationPatterns( 332 RewritePatternSet &results) const { 333 results.add(&dialectCanonicalizationPattern); 334 } 335 336 //===----------------------------------------------------------------------===// 337 // TestFoldToCallOp 338 //===----------------------------------------------------------------------===// 339 340 namespace { 341 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 342 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 343 344 LogicalResult matchAndRewrite(FoldToCallOp op, 345 PatternRewriter &rewriter) const override { 346 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(), 347 ValueRange()); 348 return success(); 349 } 350 }; 351 } // end anonymous namespace 352 353 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results, 354 MLIRContext *context) { 355 results.add<FoldToCallOpPattern>(context); 356 } 357 358 //===----------------------------------------------------------------------===// 359 // Test Format* operations 360 //===----------------------------------------------------------------------===// 361 362 //===----------------------------------------------------------------------===// 363 // Parsing 364 365 static ParseResult parseCustomDirectiveOperands( 366 OpAsmParser &parser, OpAsmParser::OperandType &operand, 367 Optional<OpAsmParser::OperandType> &optOperand, 368 SmallVectorImpl<OpAsmParser::OperandType> &varOperands) { 369 if (parser.parseOperand(operand)) 370 return failure(); 371 if (succeeded(parser.parseOptionalComma())) { 372 optOperand.emplace(); 373 if (parser.parseOperand(*optOperand)) 374 return failure(); 375 } 376 if (parser.parseArrow() || parser.parseLParen() || 377 parser.parseOperandList(varOperands) || parser.parseRParen()) 378 return failure(); 379 return success(); 380 } 381 static ParseResult 382 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 383 Type &optOperandType, 384 SmallVectorImpl<Type> &varOperandTypes) { 385 if (parser.parseColon()) 386 return failure(); 387 388 if (parser.parseType(operandType)) 389 return failure(); 390 if (succeeded(parser.parseOptionalComma())) { 391 if (parser.parseType(optOperandType)) 392 return failure(); 393 } 394 if (parser.parseArrow() || parser.parseLParen() || 395 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 396 return failure(); 397 return success(); 398 } 399 static ParseResult 400 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 401 Type optOperandType, 402 const SmallVectorImpl<Type> &varOperandTypes) { 403 if (parser.parseKeyword("type_refs_capture")) 404 return failure(); 405 406 Type operandType2, optOperandType2; 407 SmallVector<Type, 1> varOperandTypes2; 408 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 409 varOperandTypes2)) 410 return failure(); 411 412 if (operandType != operandType2 || optOperandType != optOperandType2 || 413 varOperandTypes != varOperandTypes2) 414 return failure(); 415 416 return success(); 417 } 418 static ParseResult parseCustomDirectiveOperandsAndTypes( 419 OpAsmParser &parser, OpAsmParser::OperandType &operand, 420 Optional<OpAsmParser::OperandType> &optOperand, 421 SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType, 422 Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) { 423 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 424 parseCustomDirectiveResults(parser, operandType, optOperandType, 425 varOperandTypes)) 426 return failure(); 427 return success(); 428 } 429 static ParseResult parseCustomDirectiveRegions( 430 OpAsmParser &parser, Region ®ion, 431 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 432 if (parser.parseRegion(region)) 433 return failure(); 434 if (failed(parser.parseOptionalComma())) 435 return success(); 436 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 437 if (parser.parseRegion(*varRegion)) 438 return failure(); 439 varRegions.emplace_back(std::move(varRegion)); 440 return success(); 441 } 442 static ParseResult 443 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 444 SmallVectorImpl<Block *> &varSuccessors) { 445 if (parser.parseSuccessor(successor)) 446 return failure(); 447 if (failed(parser.parseOptionalComma())) 448 return success(); 449 Block *varSuccessor; 450 if (parser.parseSuccessor(varSuccessor)) 451 return failure(); 452 varSuccessors.append(2, varSuccessor); 453 return success(); 454 } 455 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 456 IntegerAttr &attr, 457 IntegerAttr &optAttr) { 458 if (parser.parseAttribute(attr)) 459 return failure(); 460 if (succeeded(parser.parseOptionalComma())) { 461 if (parser.parseAttribute(optAttr)) 462 return failure(); 463 } 464 return success(); 465 } 466 467 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, 468 NamedAttrList &attrs) { 469 return parser.parseOptionalAttrDict(attrs); 470 } 471 static ParseResult parseCustomDirectiveOptionalOperandRef( 472 OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) { 473 int64_t operandCount = 0; 474 if (parser.parseInteger(operandCount)) 475 return failure(); 476 bool expectedOptionalOperand = operandCount == 0; 477 return success(expectedOptionalOperand != optOperand.hasValue()); 478 } 479 480 //===----------------------------------------------------------------------===// 481 // Printing 482 483 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 484 Value operand, Value optOperand, 485 OperandRange varOperands) { 486 printer << operand; 487 if (optOperand) 488 printer << ", " << optOperand; 489 printer << " -> (" << varOperands << ")"; 490 } 491 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 492 Type operandType, Type optOperandType, 493 TypeRange varOperandTypes) { 494 printer << " : " << operandType; 495 if (optOperandType) 496 printer << ", " << optOperandType; 497 printer << " -> (" << varOperandTypes << ")"; 498 } 499 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 500 Operation *op, Type operandType, 501 Type optOperandType, 502 TypeRange varOperandTypes) { 503 printer << " type_refs_capture "; 504 printCustomDirectiveResults(printer, op, operandType, optOperandType, 505 varOperandTypes); 506 } 507 static void printCustomDirectiveOperandsAndTypes( 508 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 509 OperandRange varOperands, Type operandType, Type optOperandType, 510 TypeRange varOperandTypes) { 511 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 512 printCustomDirectiveResults(printer, op, operandType, optOperandType, 513 varOperandTypes); 514 } 515 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 516 Region ®ion, 517 MutableArrayRef<Region> varRegions) { 518 printer.printRegion(region); 519 if (!varRegions.empty()) { 520 printer << ", "; 521 for (Region ®ion : varRegions) 522 printer.printRegion(region); 523 } 524 } 525 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 526 Block *successor, 527 SuccessorRange varSuccessors) { 528 printer << successor; 529 if (!varSuccessors.empty()) 530 printer << ", " << varSuccessors.front(); 531 } 532 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 533 Attribute attribute, 534 Attribute optAttribute) { 535 printer << attribute; 536 if (optAttribute) 537 printer << ", " << optAttribute; 538 } 539 540 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 541 DictionaryAttr attrs) { 542 printer.printOptionalAttrDict(attrs.getValue()); 543 } 544 545 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, 546 Operation *op, 547 Value optOperand) { 548 printer << (optOperand ? "1" : "0"); 549 } 550 551 //===----------------------------------------------------------------------===// 552 // Test IsolatedRegionOp - parse passthrough region arguments. 553 //===----------------------------------------------------------------------===// 554 555 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, 556 OperationState &result) { 557 OpAsmParser::OperandType argInfo; 558 Type argType = parser.getBuilder().getIndexType(); 559 560 // Parse the input operand. 561 if (parser.parseOperand(argInfo) || 562 parser.resolveOperand(argInfo, argType, result.operands)) 563 return failure(); 564 565 // Parse the body region, and reuse the operand info as the argument info. 566 Region *body = result.addRegion(); 567 return parser.parseRegion(*body, argInfo, argType, 568 /*enableNameShadowing=*/true); 569 } 570 571 static void print(OpAsmPrinter &p, IsolatedRegionOp op) { 572 p << "test.isolated_region "; 573 p.printOperand(op.getOperand()); 574 p.shadowRegionArgs(op.region(), op.getOperand()); 575 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 576 } 577 578 //===----------------------------------------------------------------------===// 579 // Test SSACFGRegionOp 580 //===----------------------------------------------------------------------===// 581 582 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 583 return RegionKind::SSACFG; 584 } 585 586 //===----------------------------------------------------------------------===// 587 // Test GraphRegionOp 588 //===----------------------------------------------------------------------===// 589 590 static ParseResult parseGraphRegionOp(OpAsmParser &parser, 591 OperationState &result) { 592 // Parse the body region, and reuse the operand info as the argument info. 593 Region *body = result.addRegion(); 594 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 595 } 596 597 static void print(OpAsmPrinter &p, GraphRegionOp op) { 598 p << "test.graph_region "; 599 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 600 } 601 602 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 603 return RegionKind::Graph; 604 } 605 606 //===----------------------------------------------------------------------===// 607 // Test AffineScopeOp 608 //===----------------------------------------------------------------------===// 609 610 static ParseResult parseAffineScopeOp(OpAsmParser &parser, 611 OperationState &result) { 612 // Parse the body region, and reuse the operand info as the argument info. 613 Region *body = result.addRegion(); 614 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 615 } 616 617 static void print(OpAsmPrinter &p, AffineScopeOp op) { 618 p << "test.affine_scope "; 619 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 620 } 621 622 //===----------------------------------------------------------------------===// 623 // Test parser. 624 //===----------------------------------------------------------------------===// 625 626 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser, 627 OperationState &result) { 628 if (parser.parseOptionalColon()) 629 return success(); 630 uint64_t numResults; 631 if (parser.parseInteger(numResults)) 632 return failure(); 633 634 IndexType type = parser.getBuilder().getIndexType(); 635 for (unsigned i = 0; i < numResults; ++i) 636 result.addTypes(type); 637 return success(); 638 } 639 640 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) { 641 p << ParseIntegerLiteralOp::getOperationName(); 642 if (unsigned numResults = op->getNumResults()) 643 p << " : " << numResults; 644 } 645 646 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser, 647 OperationState &result) { 648 StringRef keyword; 649 if (parser.parseKeyword(&keyword)) 650 return failure(); 651 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 652 return success(); 653 } 654 655 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) { 656 p << ParseWrappedKeywordOp::getOperationName() << " " << op.keyword(); 657 } 658 659 //===----------------------------------------------------------------------===// 660 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 661 662 static ParseResult parseWrappingRegionOp(OpAsmParser &parser, 663 OperationState &result) { 664 if (parser.parseKeyword("wraps")) 665 return failure(); 666 667 // Parse the wrapped op in a region 668 Region &body = *result.addRegion(); 669 body.push_back(new Block); 670 Block &block = body.back(); 671 Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); 672 if (!wrapped_op) 673 return failure(); 674 675 // Create a return terminator in the inner region, pass as operand to the 676 // terminator the returned values from the wrapped operation. 677 SmallVector<Value, 8> return_operands(wrapped_op->getResults()); 678 OpBuilder builder(parser.getBuilder().getContext()); 679 builder.setInsertionPointToEnd(&block); 680 builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands); 681 682 // Get the results type for the wrapping op from the terminator operands. 683 Operation &return_op = body.back().back(); 684 result.types.append(return_op.operand_type_begin(), 685 return_op.operand_type_end()); 686 687 // Use the location of the wrapped op for the "test.wrapping_region" op. 688 result.location = wrapped_op->getLoc(); 689 690 return success(); 691 } 692 693 static void print(OpAsmPrinter &p, WrappingRegionOp op) { 694 p << op.getOperationName() << " wraps "; 695 p.printGenericOp(&op.region().front().front()); 696 } 697 698 //===----------------------------------------------------------------------===// 699 // Test PolyForOp - parse list of region arguments. 700 //===----------------------------------------------------------------------===// 701 702 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { 703 SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 704 // Parse list of region arguments without a delimiter. 705 if (parser.parseRegionArgumentList(ivsInfo)) 706 return failure(); 707 708 // Parse the body region. 709 Region *body = result.addRegion(); 710 auto &builder = parser.getBuilder(); 711 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 712 return parser.parseRegion(*body, ivsInfo, argTypes); 713 } 714 715 //===----------------------------------------------------------------------===// 716 // Test removing op with inner ops. 717 //===----------------------------------------------------------------------===// 718 719 namespace { 720 struct TestRemoveOpWithInnerOps 721 : public OpRewritePattern<TestOpWithRegionPattern> { 722 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 723 724 void initialize() { setDebugName("TestRemoveOpWithInnerOps"); } 725 726 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 727 PatternRewriter &rewriter) const override { 728 rewriter.eraseOp(op); 729 return success(); 730 } 731 }; 732 } // end anonymous namespace 733 734 void TestOpWithRegionPattern::getCanonicalizationPatterns( 735 RewritePatternSet &results, MLIRContext *context) { 736 results.add<TestRemoveOpWithInnerOps>(context); 737 } 738 739 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 740 return operand(); 741 } 742 743 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) { 744 return getValue(); 745 } 746 747 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 748 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 749 for (Value input : this->operands()) { 750 results.push_back(input); 751 } 752 return success(); 753 } 754 755 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 756 assert(operands.size() == 1); 757 if (operands.front()) { 758 (*this)->setAttr("attr", operands.front()); 759 return getResult(); 760 } 761 return {}; 762 } 763 764 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) { 765 return getOperand(); 766 } 767 768 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 769 MLIRContext *, Optional<Location> location, ValueRange operands, 770 DictionaryAttr attributes, RegionRange regions, 771 SmallVectorImpl<Type> &inferredReturnTypes) { 772 if (operands[0].getType() != operands[1].getType()) { 773 return emitOptionalError(location, "operand type mismatch ", 774 operands[0].getType(), " vs ", 775 operands[1].getType()); 776 } 777 inferredReturnTypes.assign({operands[0].getType()}); 778 return success(); 779 } 780 781 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 782 MLIRContext *context, Optional<Location> location, ValueRange operands, 783 DictionaryAttr attributes, RegionRange regions, 784 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 785 // Create return type consisting of the last element of the first operand. 786 auto operandType = *operands.getTypes().begin(); 787 auto sval = operandType.dyn_cast<ShapedType>(); 788 if (!sval) { 789 return emitOptionalError(location, "only shaped type operands allowed"); 790 } 791 int64_t dim = 792 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 793 auto type = IntegerType::get(context, 17); 794 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 795 return success(); 796 } 797 798 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 799 OpBuilder &builder, ValueRange operands, 800 llvm::SmallVectorImpl<Value> &shapes) { 801 shapes = SmallVector<Value, 1>{ 802 builder.createOrFold<memref::DimOp>(getLoc(), operands.front(), 0)}; 803 return success(); 804 } 805 806 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( 807 OpBuilder &builder, ValueRange operands, 808 llvm::SmallVectorImpl<Value> &shapes) { 809 Location loc = getLoc(); 810 shapes.reserve(operands.size()); 811 for (Value operand : llvm::reverse(operands)) { 812 auto currShape = llvm::to_vector<4>(llvm::map_range( 813 llvm::seq<int64_t>( 814 0, operand.getType().cast<RankedTensorType>().getRank()), 815 [&](int64_t dim) -> Value { 816 return builder.createOrFold<memref::DimOp>(loc, operand, dim); 817 })); 818 shapes.push_back(builder.create<tensor::FromElementsOp>( 819 getLoc(), builder.getIndexType(), currShape)); 820 } 821 return success(); 822 } 823 824 LogicalResult 825 OpWithResultShapePerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim( 826 OpBuilder &builder, 827 llvm::SmallVectorImpl<llvm::SmallVector<Value>> &shapes) { 828 Location loc = getLoc(); 829 shapes.reserve(getNumOperands()); 830 for (Value operand : llvm::reverse(getOperands())) { 831 auto currShape = llvm::to_vector<4>(llvm::map_range( 832 llvm::seq<int64_t>( 833 0, operand.getType().cast<RankedTensorType>().getRank()), 834 [&](int64_t dim) -> Value { 835 return builder.createOrFold<memref::DimOp>(loc, operand, dim); 836 })); 837 shapes.emplace_back(std::move(currShape)); 838 } 839 return success(); 840 } 841 842 LogicalResult OpWithResultShapeAndPerDimInterfaceOp::reifyReturnTypeShapes( 843 OpBuilder &builder, ValueRange operands, 844 llvm::SmallVectorImpl<Value> &shapes) { 845 Location loc = getLoc(); 846 shapes.reserve(operands.size()); 847 for (Value operand : llvm::reverse(operands)) { 848 auto currShape = llvm::to_vector<4>(llvm::map_range( 849 llvm::seq<int64_t>( 850 0, operand.getType().cast<RankedTensorType>().getRank()), 851 [&](int64_t dim) -> Value { 852 return builder.createOrFold<memref::DimOp>(loc, operand, dim); 853 })); 854 shapes.push_back(builder.create<tensor::FromElementsOp>( 855 getLoc(), builder.getIndexType(), currShape)); 856 } 857 return success(); 858 } 859 860 LogicalResult 861 OpWithResultShapeAndPerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim( 862 OpBuilder &builder, 863 llvm::SmallVectorImpl<llvm::SmallVector<Value>> &shapes) { 864 Location loc = getLoc(); 865 shapes.reserve(getNumOperands()); 866 for (Value operand : llvm::reverse(getOperands())) { 867 auto currShape = llvm::to_vector<4>(llvm::map_range( 868 llvm::seq<int64_t>( 869 0, operand.getType().cast<RankedTensorType>().getRank()), 870 [&](int64_t dim) -> Value { 871 return builder.createOrFold<memref::DimOp>(loc, operand, dim); 872 })); 873 shapes.emplace_back(std::move(currShape)); 874 } 875 return success(); 876 } 877 878 //===----------------------------------------------------------------------===// 879 // Test SideEffect interfaces 880 //===----------------------------------------------------------------------===// 881 882 namespace { 883 /// A test resource for side effects. 884 struct TestResource : public SideEffects::Resource::Base<TestResource> { 885 StringRef getName() final { return "<Test>"; } 886 }; 887 } // end anonymous namespace 888 889 static void testSideEffectOpGetEffect( 890 Operation *op, 891 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 892 &effects) { 893 auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter"); 894 if (!effectsAttr) 895 return; 896 897 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); 898 } 899 900 void SideEffectOp::getEffects( 901 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 902 // Check for an effects attribute on the op instance. 903 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects"); 904 if (!effectsAttr) 905 return; 906 907 // If there is one, it is an array of dictionary attributes that hold 908 // information on the effects of this operation. 909 for (Attribute element : effectsAttr) { 910 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 911 912 // Get the specific memory effect. 913 MemoryEffects::Effect *effect = 914 StringSwitch<MemoryEffects::Effect *>( 915 effectElement.get("effect").cast<StringAttr>().getValue()) 916 .Case("allocate", MemoryEffects::Allocate::get()) 917 .Case("free", MemoryEffects::Free::get()) 918 .Case("read", MemoryEffects::Read::get()) 919 .Case("write", MemoryEffects::Write::get()); 920 921 // Check for a non-default resource to use. 922 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 923 if (effectElement.get("test_resource")) 924 resource = TestResource::get(); 925 926 // Check for a result to affect. 927 if (effectElement.get("on_result")) 928 effects.emplace_back(effect, getResult(), resource); 929 else if (Attribute ref = effectElement.get("on_reference")) 930 effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource); 931 else 932 effects.emplace_back(effect, resource); 933 } 934 } 935 936 void SideEffectOp::getEffects( 937 SmallVectorImpl<TestEffects::EffectInstance> &effects) { 938 testSideEffectOpGetEffect(getOperation(), effects); 939 } 940 941 //===----------------------------------------------------------------------===// 942 // StringAttrPrettyNameOp 943 //===----------------------------------------------------------------------===// 944 945 // This op has fancy handling of its SSA result name. 946 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, 947 OperationState &result) { 948 // Add the result types. 949 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 950 result.addTypes(parser.getBuilder().getIntegerType(32)); 951 952 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 953 return failure(); 954 955 // If the attribute dictionary contains no 'names' attribute, infer it from 956 // the SSA name (if specified). 957 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 958 return attr.first == "names"; 959 }); 960 961 // If there was no name specified, check to see if there was a useful name 962 // specified in the asm file. 963 if (hadNames || parser.getNumResults() == 0) 964 return success(); 965 966 SmallVector<StringRef, 4> names; 967 auto *context = result.getContext(); 968 969 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 970 auto resultName = parser.getResultName(i); 971 StringRef nameStr; 972 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 973 nameStr = resultName.first; 974 975 names.push_back(nameStr); 976 } 977 978 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 979 result.attributes.push_back({Identifier::get("names", context), namesAttr}); 980 return success(); 981 } 982 983 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { 984 p << "test.string_attr_pretty_name"; 985 986 // Note that we only need to print the "name" attribute if the asmprinter 987 // result name disagrees with it. This can happen in strange cases, e.g. 988 // when there are conflicts. 989 bool namesDisagree = op.names().size() != op.getNumResults(); 990 991 SmallString<32> resultNameStr; 992 for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { 993 resultNameStr.clear(); 994 llvm::raw_svector_ostream tmpStream(resultNameStr); 995 p.printOperand(op.getResult(i), tmpStream); 996 997 auto expectedName = op.names()[i].dyn_cast<StringAttr>(); 998 if (!expectedName || 999 tmpStream.str().drop_front() != expectedName.getValue()) { 1000 namesDisagree = true; 1001 } 1002 } 1003 1004 if (namesDisagree) 1005 p.printOptionalAttrDictWithKeyword(op->getAttrs()); 1006 else 1007 p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"}); 1008 } 1009 1010 // We set the SSA name in the asm syntax to the contents of the name 1011 // attribute. 1012 void StringAttrPrettyNameOp::getAsmResultNames( 1013 function_ref<void(Value, StringRef)> setNameFn) { 1014 1015 auto value = names(); 1016 for (size_t i = 0, e = value.size(); i != e; ++i) 1017 if (auto str = value[i].dyn_cast<StringAttr>()) 1018 if (!str.getValue().empty()) 1019 setNameFn(getResult(i), str.getValue()); 1020 } 1021 1022 //===----------------------------------------------------------------------===// 1023 // RegionIfOp 1024 //===----------------------------------------------------------------------===// 1025 1026 static void print(OpAsmPrinter &p, RegionIfOp op) { 1027 p << RegionIfOp::getOperationName() << " "; 1028 p.printOperands(op.getOperands()); 1029 p << ": " << op.getOperandTypes(); 1030 p.printArrowTypeList(op.getResultTypes()); 1031 p << " then"; 1032 p.printRegion(op.thenRegion(), 1033 /*printEntryBlockArgs=*/true, 1034 /*printBlockTerminators=*/true); 1035 p << " else"; 1036 p.printRegion(op.elseRegion(), 1037 /*printEntryBlockArgs=*/true, 1038 /*printBlockTerminators=*/true); 1039 p << " join"; 1040 p.printRegion(op.joinRegion(), 1041 /*printEntryBlockArgs=*/true, 1042 /*printBlockTerminators=*/true); 1043 } 1044 1045 static ParseResult parseRegionIfOp(OpAsmParser &parser, 1046 OperationState &result) { 1047 SmallVector<OpAsmParser::OperandType, 2> operandInfos; 1048 SmallVector<Type, 2> operandTypes; 1049 1050 result.regions.reserve(3); 1051 Region *thenRegion = result.addRegion(); 1052 Region *elseRegion = result.addRegion(); 1053 Region *joinRegion = result.addRegion(); 1054 1055 // Parse operand, type and arrow type lists. 1056 if (parser.parseOperandList(operandInfos) || 1057 parser.parseColonTypeList(operandTypes) || 1058 parser.parseArrowTypeList(result.types)) 1059 return failure(); 1060 1061 // Parse all attached regions. 1062 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 1063 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 1064 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 1065 return failure(); 1066 1067 return parser.resolveOperands(operandInfos, operandTypes, 1068 parser.getCurrentLocation(), result.operands); 1069 } 1070 1071 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 1072 assert(index < 2 && "invalid region index"); 1073 return getOperands(); 1074 } 1075 1076 void RegionIfOp::getSuccessorRegions( 1077 Optional<unsigned> index, ArrayRef<Attribute> operands, 1078 SmallVectorImpl<RegionSuccessor> ®ions) { 1079 // We always branch to the join region. 1080 if (index.hasValue()) { 1081 if (index.getValue() < 2) 1082 regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs())); 1083 else 1084 regions.push_back(RegionSuccessor(getResults())); 1085 return; 1086 } 1087 1088 // The then and else regions are the entry regions of this op. 1089 regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs())); 1090 regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs())); 1091 } 1092 1093 //===----------------------------------------------------------------------===// 1094 // SingleNoTerminatorCustomAsmOp 1095 //===----------------------------------------------------------------------===// 1096 1097 static ParseResult parseSingleNoTerminatorCustomAsmOp(OpAsmParser &parser, 1098 OperationState &state) { 1099 Region *body = state.addRegion(); 1100 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) 1101 return failure(); 1102 return success(); 1103 } 1104 1105 static void print(SingleNoTerminatorCustomAsmOp op, OpAsmPrinter &printer) { 1106 printer << op.getOperationName(); 1107 printer.printRegion( 1108 op.getRegion(), /*printEntryBlockArgs=*/false, 1109 // This op has a single block without terminators. But explicitly mark 1110 // as not printing block terminators for testing. 1111 /*printBlockTerminators=*/false); 1112 } 1113 1114 #include "TestOpEnums.cpp.inc" 1115 #include "TestOpInterfaces.cpp.inc" 1116 #include "TestOpStructs.cpp.inc" 1117 #include "TestTypeInterfaces.cpp.inc" 1118 1119 #define GET_OP_CLASSES 1120 #include "TestOps.cpp.inc" 1121