1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestAttributes.h" 11 #include "TestInterfaces.h" 12 #include "TestTypes.h" 13 #include "mlir/Dialect/DLTI/DLTI.h" 14 #include "mlir/Dialect/MemRef/IR/MemRef.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/DialectImplementation.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Reducer/ReductionPatternInterface.h" 21 #include "mlir/Transforms/FoldUtils.h" 22 #include "mlir/Transforms/InliningUtils.h" 23 #include "llvm/ADT/StringSwitch.h" 24 25 using namespace mlir; 26 using namespace mlir::test; 27 28 void mlir::test::registerTestDialect(DialectRegistry ®istry) { 29 registry.insert<TestDialect>(); 30 } 31 32 //===----------------------------------------------------------------------===// 33 // TestDialect Interfaces 34 //===----------------------------------------------------------------------===// 35 36 namespace { 37 38 /// Testing the correctness of some traits. 39 static_assert( 40 llvm::is_detected<OpTrait::has_implicit_terminator_t, 41 SingleBlockImplicitTerminatorOp>::value, 42 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp"); 43 static_assert(OpTrait::hasSingleBlockImplicitTerminator< 44 SingleBlockImplicitTerminatorOp>::value, 45 "hasSingleBlockImplicitTerminator does not match " 46 "SingleBlockImplicitTerminatorOp"); 47 48 // Test support for interacting with the AsmPrinter. 49 struct TestOpAsmInterface : public OpAsmDialectInterface { 50 using OpAsmDialectInterface::OpAsmDialectInterface; 51 52 LogicalResult getAlias(Attribute attr, raw_ostream &os) const final { 53 StringAttr strAttr = attr.dyn_cast<StringAttr>(); 54 if (!strAttr) 55 return failure(); 56 57 // Check the contents of the string attribute to see what the test alias 58 // should be named. 59 Optional<StringRef> aliasName = 60 StringSwitch<Optional<StringRef>>(strAttr.getValue()) 61 .Case("alias_test:dot_in_name", StringRef("test.alias")) 62 .Case("alias_test:trailing_digit", StringRef("test_alias0")) 63 .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) 64 .Case("alias_test:sanitize_conflict_a", 65 StringRef("test_alias_conflict0")) 66 .Case("alias_test:sanitize_conflict_b", 67 StringRef("test_alias_conflict0_")) 68 .Default(llvm::None); 69 if (!aliasName) 70 return failure(); 71 72 os << *aliasName; 73 return success(); 74 } 75 76 void getAsmResultNames(Operation *op, 77 OpAsmSetValueNameFn setNameFn) const final { 78 if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) 79 setNameFn(asmOp, "result"); 80 } 81 82 void getAsmBlockArgumentNames(Block *block, 83 OpAsmSetValueNameFn setNameFn) const final { 84 auto op = block->getParentOp(); 85 auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names"); 86 if (!arrayAttr) 87 return; 88 auto args = block->getArguments(); 89 auto e = std::min(arrayAttr.size(), args.size()); 90 for (unsigned i = 0; i < e; ++i) { 91 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 92 setNameFn(args[i], strAttr.getValue()); 93 } 94 } 95 }; 96 97 struct TestDialectFoldInterface : public DialectFoldInterface { 98 using DialectFoldInterface::DialectFoldInterface; 99 100 /// Registered hook to check if the given region, which is attached to an 101 /// operation that is *not* isolated from above, should be used when 102 /// materializing constants. 103 bool shouldMaterializeInto(Region *region) const final { 104 // If this is a one region operation, then insert into it. 105 return isa<OneRegionOp>(region->getParentOp()); 106 } 107 }; 108 109 /// This class defines the interface for handling inlining with standard 110 /// operations. 111 struct TestInlinerInterface : public DialectInlinerInterface { 112 using DialectInlinerInterface::DialectInlinerInterface; 113 114 //===--------------------------------------------------------------------===// 115 // Analysis Hooks 116 //===--------------------------------------------------------------------===// 117 118 bool isLegalToInline(Operation *call, Operation *callable, 119 bool wouldBeCloned) const final { 120 // Don't allow inlining calls that are marked `noinline`. 121 return !call->hasAttr("noinline"); 122 } 123 bool isLegalToInline(Region *, Region *, bool, 124 BlockAndValueMapping &) const final { 125 // Inlining into test dialect regions is legal. 126 return true; 127 } 128 bool isLegalToInline(Operation *, Region *, bool, 129 BlockAndValueMapping &) const final { 130 return true; 131 } 132 133 bool shouldAnalyzeRecursively(Operation *op) const final { 134 // Analyze recursively if this is not a functional region operation, it 135 // froms a separate functional scope. 136 return !isa<FunctionalRegionOp>(op); 137 } 138 139 //===--------------------------------------------------------------------===// 140 // Transformation Hooks 141 //===--------------------------------------------------------------------===// 142 143 /// Handle the given inlined terminator by replacing it with a new operation 144 /// as necessary. 145 void handleTerminator(Operation *op, 146 ArrayRef<Value> valuesToRepl) const final { 147 // Only handle "test.return" here. 148 auto returnOp = dyn_cast<TestReturnOp>(op); 149 if (!returnOp) 150 return; 151 152 // Replace the values directly with the return operands. 153 assert(returnOp.getNumOperands() == valuesToRepl.size()); 154 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 155 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 156 } 157 158 /// Attempt to materialize a conversion for a type mismatch between a call 159 /// from this dialect, and a callable region. This method should generate an 160 /// operation that takes 'input' as the only operand, and produces a single 161 /// result of 'resultType'. If a conversion can not be generated, nullptr 162 /// should be returned. 163 Operation *materializeCallConversion(OpBuilder &builder, Value input, 164 Type resultType, 165 Location conversionLoc) const final { 166 // Only allow conversion for i16/i32 types. 167 if (!(resultType.isSignlessInteger(16) || 168 resultType.isSignlessInteger(32)) || 169 !(input.getType().isSignlessInteger(16) || 170 input.getType().isSignlessInteger(32))) 171 return nullptr; 172 return builder.create<TestCastOp>(conversionLoc, resultType, input); 173 } 174 }; 175 176 struct TestReductionPatternInterface : public DialectReductionPatternInterface { 177 public: 178 TestReductionPatternInterface(Dialect *dialect) 179 : DialectReductionPatternInterface(dialect) {} 180 181 virtual void 182 populateReductionPatterns(RewritePatternSet &patterns) const final { 183 populateTestReductionPatterns(patterns); 184 } 185 }; 186 187 } // end anonymous namespace 188 189 //===----------------------------------------------------------------------===// 190 // TestDialect 191 //===----------------------------------------------------------------------===// 192 193 static void testSideEffectOpGetEffect( 194 Operation *op, 195 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects); 196 197 // This is the implementation of a dialect fallback for `TestEffectOpInterface`. 198 struct TestOpEffectInterfaceFallback 199 : public TestEffectOpInterface::FallbackModel< 200 TestOpEffectInterfaceFallback> { 201 static bool classof(Operation *op) { 202 bool isSupportedOp = 203 op->getName().getStringRef() == "test.unregistered_side_effect_op"; 204 assert(isSupportedOp && "Unexpected dispatch"); 205 return isSupportedOp; 206 } 207 208 void 209 getEffects(Operation *op, 210 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 211 &effects) const { 212 testSideEffectOpGetEffect(op, effects); 213 } 214 }; 215 216 void TestDialect::initialize() { 217 registerAttributes(); 218 registerTypes(); 219 addOperations< 220 #define GET_OP_LIST 221 #include "TestOps.cpp.inc" 222 >(); 223 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 224 TestInlinerInterface, TestReductionPatternInterface>(); 225 allowUnknownOperations(); 226 227 // Instantiate our fallback op interface that we'll use on specific 228 // unregistered op. 229 fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback; 230 } 231 TestDialect::~TestDialect() { 232 delete static_cast<TestOpEffectInterfaceFallback *>( 233 fallbackEffectOpInterfaces); 234 } 235 236 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, 237 Type type, Location loc) { 238 return builder.create<TestOpConstant>(loc, type, value); 239 } 240 241 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, 242 OperationName opName) { 243 if (opName.getIdentifier() == "test.unregistered_side_effect_op" && 244 typeID == TypeID::get<TestEffectOpInterface>()) 245 return fallbackEffectOpInterfaces; 246 return nullptr; 247 } 248 249 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 250 NamedAttribute namedAttr) { 251 if (namedAttr.first == "test.invalid_attr") 252 return op->emitError() << "invalid to use 'test.invalid_attr'"; 253 return success(); 254 } 255 256 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 257 unsigned regionIndex, 258 unsigned argIndex, 259 NamedAttribute namedAttr) { 260 if (namedAttr.first == "test.invalid_attr") 261 return op->emitError() << "invalid to use 'test.invalid_attr'"; 262 return success(); 263 } 264 265 LogicalResult 266 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 267 unsigned resultIndex, 268 NamedAttribute namedAttr) { 269 if (namedAttr.first == "test.invalid_attr") 270 return op->emitError() << "invalid to use 'test.invalid_attr'"; 271 return success(); 272 } 273 274 Optional<Dialect::ParseOpHook> 275 TestDialect::getParseOperationHook(StringRef opName) const { 276 if (opName == "test.dialect_custom_printer") { 277 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 278 return parser.parseKeyword("custom_format"); 279 }}; 280 } 281 return None; 282 } 283 284 LogicalResult TestDialect::printOperation(Operation *op, 285 OpAsmPrinter &printer) const { 286 StringRef opName = op->getName().getStringRef(); 287 if (opName == "test.dialect_custom_printer") { 288 printer.getStream() << opName << " custom_format"; 289 return success(); 290 } 291 return failure(); 292 } 293 294 //===----------------------------------------------------------------------===// 295 // TestBranchOp 296 //===----------------------------------------------------------------------===// 297 298 Optional<MutableOperandRange> 299 TestBranchOp::getMutableSuccessorOperands(unsigned index) { 300 assert(index == 0 && "invalid successor index"); 301 return targetOperandsMutable(); 302 } 303 304 //===----------------------------------------------------------------------===// 305 // TestDialectCanonicalizerOp 306 //===----------------------------------------------------------------------===// 307 308 static LogicalResult 309 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op, 310 PatternRewriter &rewriter) { 311 rewriter.replaceOpWithNewOp<ConstantOp>(op, rewriter.getI32Type(), 312 rewriter.getI32IntegerAttr(42)); 313 return success(); 314 } 315 316 void TestDialect::getCanonicalizationPatterns( 317 RewritePatternSet &results) const { 318 results.add(&dialectCanonicalizationPattern); 319 } 320 321 //===----------------------------------------------------------------------===// 322 // TestFoldToCallOp 323 //===----------------------------------------------------------------------===// 324 325 namespace { 326 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 327 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 328 329 LogicalResult matchAndRewrite(FoldToCallOp op, 330 PatternRewriter &rewriter) const override { 331 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(), 332 ValueRange()); 333 return success(); 334 } 335 }; 336 } // end anonymous namespace 337 338 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results, 339 MLIRContext *context) { 340 results.add<FoldToCallOpPattern>(context); 341 } 342 343 //===----------------------------------------------------------------------===// 344 // Test Format* operations 345 //===----------------------------------------------------------------------===// 346 347 //===----------------------------------------------------------------------===// 348 // Parsing 349 350 static ParseResult parseCustomDirectiveOperands( 351 OpAsmParser &parser, OpAsmParser::OperandType &operand, 352 Optional<OpAsmParser::OperandType> &optOperand, 353 SmallVectorImpl<OpAsmParser::OperandType> &varOperands) { 354 if (parser.parseOperand(operand)) 355 return failure(); 356 if (succeeded(parser.parseOptionalComma())) { 357 optOperand.emplace(); 358 if (parser.parseOperand(*optOperand)) 359 return failure(); 360 } 361 if (parser.parseArrow() || parser.parseLParen() || 362 parser.parseOperandList(varOperands) || parser.parseRParen()) 363 return failure(); 364 return success(); 365 } 366 static ParseResult 367 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 368 Type &optOperandType, 369 SmallVectorImpl<Type> &varOperandTypes) { 370 if (parser.parseColon()) 371 return failure(); 372 373 if (parser.parseType(operandType)) 374 return failure(); 375 if (succeeded(parser.parseOptionalComma())) { 376 if (parser.parseType(optOperandType)) 377 return failure(); 378 } 379 if (parser.parseArrow() || parser.parseLParen() || 380 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 381 return failure(); 382 return success(); 383 } 384 static ParseResult 385 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 386 Type optOperandType, 387 const SmallVectorImpl<Type> &varOperandTypes) { 388 if (parser.parseKeyword("type_refs_capture")) 389 return failure(); 390 391 Type operandType2, optOperandType2; 392 SmallVector<Type, 1> varOperandTypes2; 393 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 394 varOperandTypes2)) 395 return failure(); 396 397 if (operandType != operandType2 || optOperandType != optOperandType2 || 398 varOperandTypes != varOperandTypes2) 399 return failure(); 400 401 return success(); 402 } 403 static ParseResult parseCustomDirectiveOperandsAndTypes( 404 OpAsmParser &parser, OpAsmParser::OperandType &operand, 405 Optional<OpAsmParser::OperandType> &optOperand, 406 SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType, 407 Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) { 408 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 409 parseCustomDirectiveResults(parser, operandType, optOperandType, 410 varOperandTypes)) 411 return failure(); 412 return success(); 413 } 414 static ParseResult parseCustomDirectiveRegions( 415 OpAsmParser &parser, Region ®ion, 416 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 417 if (parser.parseRegion(region)) 418 return failure(); 419 if (failed(parser.parseOptionalComma())) 420 return success(); 421 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 422 if (parser.parseRegion(*varRegion)) 423 return failure(); 424 varRegions.emplace_back(std::move(varRegion)); 425 return success(); 426 } 427 static ParseResult 428 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 429 SmallVectorImpl<Block *> &varSuccessors) { 430 if (parser.parseSuccessor(successor)) 431 return failure(); 432 if (failed(parser.parseOptionalComma())) 433 return success(); 434 Block *varSuccessor; 435 if (parser.parseSuccessor(varSuccessor)) 436 return failure(); 437 varSuccessors.append(2, varSuccessor); 438 return success(); 439 } 440 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 441 IntegerAttr &attr, 442 IntegerAttr &optAttr) { 443 if (parser.parseAttribute(attr)) 444 return failure(); 445 if (succeeded(parser.parseOptionalComma())) { 446 if (parser.parseAttribute(optAttr)) 447 return failure(); 448 } 449 return success(); 450 } 451 452 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, 453 NamedAttrList &attrs) { 454 return parser.parseOptionalAttrDict(attrs); 455 } 456 static ParseResult parseCustomDirectiveOptionalOperandRef( 457 OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) { 458 int64_t operandCount = 0; 459 if (parser.parseInteger(operandCount)) 460 return failure(); 461 bool expectedOptionalOperand = operandCount == 0; 462 return success(expectedOptionalOperand != optOperand.hasValue()); 463 } 464 465 //===----------------------------------------------------------------------===// 466 // Printing 467 468 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 469 Value operand, Value optOperand, 470 OperandRange varOperands) { 471 printer << operand; 472 if (optOperand) 473 printer << ", " << optOperand; 474 printer << " -> (" << varOperands << ")"; 475 } 476 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 477 Type operandType, Type optOperandType, 478 TypeRange varOperandTypes) { 479 printer << " : " << operandType; 480 if (optOperandType) 481 printer << ", " << optOperandType; 482 printer << " -> (" << varOperandTypes << ")"; 483 } 484 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 485 Operation *op, Type operandType, 486 Type optOperandType, 487 TypeRange varOperandTypes) { 488 printer << " type_refs_capture "; 489 printCustomDirectiveResults(printer, op, operandType, optOperandType, 490 varOperandTypes); 491 } 492 static void printCustomDirectiveOperandsAndTypes( 493 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 494 OperandRange varOperands, Type operandType, Type optOperandType, 495 TypeRange varOperandTypes) { 496 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 497 printCustomDirectiveResults(printer, op, operandType, optOperandType, 498 varOperandTypes); 499 } 500 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 501 Region ®ion, 502 MutableArrayRef<Region> varRegions) { 503 printer.printRegion(region); 504 if (!varRegions.empty()) { 505 printer << ", "; 506 for (Region ®ion : varRegions) 507 printer.printRegion(region); 508 } 509 } 510 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 511 Block *successor, 512 SuccessorRange varSuccessors) { 513 printer << successor; 514 if (!varSuccessors.empty()) 515 printer << ", " << varSuccessors.front(); 516 } 517 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 518 Attribute attribute, 519 Attribute optAttribute) { 520 printer << attribute; 521 if (optAttribute) 522 printer << ", " << optAttribute; 523 } 524 525 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 526 DictionaryAttr attrs) { 527 printer.printOptionalAttrDict(attrs.getValue()); 528 } 529 530 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, 531 Operation *op, 532 Value optOperand) { 533 printer << (optOperand ? "1" : "0"); 534 } 535 536 //===----------------------------------------------------------------------===// 537 // Test IsolatedRegionOp - parse passthrough region arguments. 538 //===----------------------------------------------------------------------===// 539 540 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, 541 OperationState &result) { 542 OpAsmParser::OperandType argInfo; 543 Type argType = parser.getBuilder().getIndexType(); 544 545 // Parse the input operand. 546 if (parser.parseOperand(argInfo) || 547 parser.resolveOperand(argInfo, argType, result.operands)) 548 return failure(); 549 550 // Parse the body region, and reuse the operand info as the argument info. 551 Region *body = result.addRegion(); 552 return parser.parseRegion(*body, argInfo, argType, 553 /*enableNameShadowing=*/true); 554 } 555 556 static void print(OpAsmPrinter &p, IsolatedRegionOp op) { 557 p << "test.isolated_region "; 558 p.printOperand(op.getOperand()); 559 p.shadowRegionArgs(op.region(), op.getOperand()); 560 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 561 } 562 563 //===----------------------------------------------------------------------===// 564 // Test SSACFGRegionOp 565 //===----------------------------------------------------------------------===// 566 567 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 568 return RegionKind::SSACFG; 569 } 570 571 //===----------------------------------------------------------------------===// 572 // Test GraphRegionOp 573 //===----------------------------------------------------------------------===// 574 575 static ParseResult parseGraphRegionOp(OpAsmParser &parser, 576 OperationState &result) { 577 // Parse the body region, and reuse the operand info as the argument info. 578 Region *body = result.addRegion(); 579 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 580 } 581 582 static void print(OpAsmPrinter &p, GraphRegionOp op) { 583 p << "test.graph_region "; 584 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 585 } 586 587 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 588 return RegionKind::Graph; 589 } 590 591 //===----------------------------------------------------------------------===// 592 // Test AffineScopeOp 593 //===----------------------------------------------------------------------===// 594 595 static ParseResult parseAffineScopeOp(OpAsmParser &parser, 596 OperationState &result) { 597 // Parse the body region, and reuse the operand info as the argument info. 598 Region *body = result.addRegion(); 599 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 600 } 601 602 static void print(OpAsmPrinter &p, AffineScopeOp op) { 603 p << "test.affine_scope "; 604 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 605 } 606 607 //===----------------------------------------------------------------------===// 608 // Test parser. 609 //===----------------------------------------------------------------------===// 610 611 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser, 612 OperationState &result) { 613 if (parser.parseOptionalColon()) 614 return success(); 615 uint64_t numResults; 616 if (parser.parseInteger(numResults)) 617 return failure(); 618 619 IndexType type = parser.getBuilder().getIndexType(); 620 for (unsigned i = 0; i < numResults; ++i) 621 result.addTypes(type); 622 return success(); 623 } 624 625 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) { 626 p << ParseIntegerLiteralOp::getOperationName(); 627 if (unsigned numResults = op->getNumResults()) 628 p << " : " << numResults; 629 } 630 631 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser, 632 OperationState &result) { 633 StringRef keyword; 634 if (parser.parseKeyword(&keyword)) 635 return failure(); 636 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 637 return success(); 638 } 639 640 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) { 641 p << ParseWrappedKeywordOp::getOperationName() << " " << op.keyword(); 642 } 643 644 //===----------------------------------------------------------------------===// 645 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 646 647 static ParseResult parseWrappingRegionOp(OpAsmParser &parser, 648 OperationState &result) { 649 if (parser.parseKeyword("wraps")) 650 return failure(); 651 652 // Parse the wrapped op in a region 653 Region &body = *result.addRegion(); 654 body.push_back(new Block); 655 Block &block = body.back(); 656 Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); 657 if (!wrapped_op) 658 return failure(); 659 660 // Create a return terminator in the inner region, pass as operand to the 661 // terminator the returned values from the wrapped operation. 662 SmallVector<Value, 8> return_operands(wrapped_op->getResults()); 663 OpBuilder builder(parser.getBuilder().getContext()); 664 builder.setInsertionPointToEnd(&block); 665 builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands); 666 667 // Get the results type for the wrapping op from the terminator operands. 668 Operation &return_op = body.back().back(); 669 result.types.append(return_op.operand_type_begin(), 670 return_op.operand_type_end()); 671 672 // Use the location of the wrapped op for the "test.wrapping_region" op. 673 result.location = wrapped_op->getLoc(); 674 675 return success(); 676 } 677 678 static void print(OpAsmPrinter &p, WrappingRegionOp op) { 679 p << op.getOperationName() << " wraps "; 680 p.printGenericOp(&op.region().front().front()); 681 } 682 683 //===----------------------------------------------------------------------===// 684 // Test PolyForOp - parse list of region arguments. 685 //===----------------------------------------------------------------------===// 686 687 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { 688 SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 689 // Parse list of region arguments without a delimiter. 690 if (parser.parseRegionArgumentList(ivsInfo)) 691 return failure(); 692 693 // Parse the body region. 694 Region *body = result.addRegion(); 695 auto &builder = parser.getBuilder(); 696 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 697 return parser.parseRegion(*body, ivsInfo, argTypes); 698 } 699 700 //===----------------------------------------------------------------------===// 701 // Test removing op with inner ops. 702 //===----------------------------------------------------------------------===// 703 704 namespace { 705 struct TestRemoveOpWithInnerOps 706 : public OpRewritePattern<TestOpWithRegionPattern> { 707 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 708 709 void initialize() { setDebugName("TestRemoveOpWithInnerOps"); } 710 711 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 712 PatternRewriter &rewriter) const override { 713 rewriter.eraseOp(op); 714 return success(); 715 } 716 }; 717 } // end anonymous namespace 718 719 void TestOpWithRegionPattern::getCanonicalizationPatterns( 720 RewritePatternSet &results, MLIRContext *context) { 721 results.add<TestRemoveOpWithInnerOps>(context); 722 } 723 724 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 725 return operand(); 726 } 727 728 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) { 729 return getValue(); 730 } 731 732 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 733 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 734 for (Value input : this->operands()) { 735 results.push_back(input); 736 } 737 return success(); 738 } 739 740 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 741 assert(operands.size() == 1); 742 if (operands.front()) { 743 (*this)->setAttr("attr", operands.front()); 744 return getResult(); 745 } 746 return {}; 747 } 748 749 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) { 750 return getOperand(); 751 } 752 753 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 754 MLIRContext *, Optional<Location> location, ValueRange operands, 755 DictionaryAttr attributes, RegionRange regions, 756 SmallVectorImpl<Type> &inferredReturnTypes) { 757 if (operands[0].getType() != operands[1].getType()) { 758 return emitOptionalError(location, "operand type mismatch ", 759 operands[0].getType(), " vs ", 760 operands[1].getType()); 761 } 762 inferredReturnTypes.assign({operands[0].getType()}); 763 return success(); 764 } 765 766 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 767 MLIRContext *context, Optional<Location> location, ValueRange operands, 768 DictionaryAttr attributes, RegionRange regions, 769 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 770 // Create return type consisting of the last element of the first operand. 771 auto operandType = *operands.getTypes().begin(); 772 auto sval = operandType.dyn_cast<ShapedType>(); 773 if (!sval) { 774 return emitOptionalError(location, "only shaped type operands allowed"); 775 } 776 int64_t dim = 777 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 778 auto type = IntegerType::get(context, 17); 779 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 780 return success(); 781 } 782 783 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 784 OpBuilder &builder, ValueRange operands, 785 llvm::SmallVectorImpl<Value> &shapes) { 786 shapes = SmallVector<Value, 1>{ 787 builder.createOrFold<memref::DimOp>(getLoc(), operands.front(), 0)}; 788 return success(); 789 } 790 791 LogicalResult 792 OpWithResultShapePerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim( 793 OpBuilder &builder, 794 llvm::SmallVectorImpl<llvm::SmallVector<Value>> &shapes) { 795 SmallVector<Value> operand1Shape, operand2Shape; 796 Location loc = getLoc(); 797 for (auto i : 798 llvm::seq<int>(0, operand1().getType().cast<ShapedType>().getRank())) { 799 operand1Shape.push_back(builder.create<memref::DimOp>(loc, operand1(), i)); 800 } 801 for (auto i : 802 llvm::seq<int>(0, operand2().getType().cast<ShapedType>().getRank())) { 803 operand2Shape.push_back(builder.create<memref::DimOp>(loc, operand2(), i)); 804 } 805 shapes.emplace_back(std::move(operand2Shape)); 806 shapes.emplace_back(std::move(operand1Shape)); 807 return success(); 808 } 809 810 //===----------------------------------------------------------------------===// 811 // Test SideEffect interfaces 812 //===----------------------------------------------------------------------===// 813 814 namespace { 815 /// A test resource for side effects. 816 struct TestResource : public SideEffects::Resource::Base<TestResource> { 817 StringRef getName() final { return "<Test>"; } 818 }; 819 } // end anonymous namespace 820 821 static void testSideEffectOpGetEffect( 822 Operation *op, 823 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 824 &effects) { 825 auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter"); 826 if (!effectsAttr) 827 return; 828 829 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); 830 } 831 832 void SideEffectOp::getEffects( 833 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 834 // Check for an effects attribute on the op instance. 835 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects"); 836 if (!effectsAttr) 837 return; 838 839 // If there is one, it is an array of dictionary attributes that hold 840 // information on the effects of this operation. 841 for (Attribute element : effectsAttr) { 842 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 843 844 // Get the specific memory effect. 845 MemoryEffects::Effect *effect = 846 StringSwitch<MemoryEffects::Effect *>( 847 effectElement.get("effect").cast<StringAttr>().getValue()) 848 .Case("allocate", MemoryEffects::Allocate::get()) 849 .Case("free", MemoryEffects::Free::get()) 850 .Case("read", MemoryEffects::Read::get()) 851 .Case("write", MemoryEffects::Write::get()); 852 853 // Check for a non-default resource to use. 854 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 855 if (effectElement.get("test_resource")) 856 resource = TestResource::get(); 857 858 // Check for a result to affect. 859 if (effectElement.get("on_result")) 860 effects.emplace_back(effect, getResult(), resource); 861 else if (Attribute ref = effectElement.get("on_reference")) 862 effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource); 863 else 864 effects.emplace_back(effect, resource); 865 } 866 } 867 868 void SideEffectOp::getEffects( 869 SmallVectorImpl<TestEffects::EffectInstance> &effects) { 870 testSideEffectOpGetEffect(getOperation(), effects); 871 } 872 873 //===----------------------------------------------------------------------===// 874 // StringAttrPrettyNameOp 875 //===----------------------------------------------------------------------===// 876 877 // This op has fancy handling of its SSA result name. 878 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, 879 OperationState &result) { 880 // Add the result types. 881 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 882 result.addTypes(parser.getBuilder().getIntegerType(32)); 883 884 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 885 return failure(); 886 887 // If the attribute dictionary contains no 'names' attribute, infer it from 888 // the SSA name (if specified). 889 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 890 return attr.first == "names"; 891 }); 892 893 // If there was no name specified, check to see if there was a useful name 894 // specified in the asm file. 895 if (hadNames || parser.getNumResults() == 0) 896 return success(); 897 898 SmallVector<StringRef, 4> names; 899 auto *context = result.getContext(); 900 901 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 902 auto resultName = parser.getResultName(i); 903 StringRef nameStr; 904 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 905 nameStr = resultName.first; 906 907 names.push_back(nameStr); 908 } 909 910 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 911 result.attributes.push_back({Identifier::get("names", context), namesAttr}); 912 return success(); 913 } 914 915 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { 916 p << "test.string_attr_pretty_name"; 917 918 // Note that we only need to print the "name" attribute if the asmprinter 919 // result name disagrees with it. This can happen in strange cases, e.g. 920 // when there are conflicts. 921 bool namesDisagree = op.names().size() != op.getNumResults(); 922 923 SmallString<32> resultNameStr; 924 for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { 925 resultNameStr.clear(); 926 llvm::raw_svector_ostream tmpStream(resultNameStr); 927 p.printOperand(op.getResult(i), tmpStream); 928 929 auto expectedName = op.names()[i].dyn_cast<StringAttr>(); 930 if (!expectedName || 931 tmpStream.str().drop_front() != expectedName.getValue()) { 932 namesDisagree = true; 933 } 934 } 935 936 if (namesDisagree) 937 p.printOptionalAttrDictWithKeyword(op->getAttrs()); 938 else 939 p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"}); 940 } 941 942 // We set the SSA name in the asm syntax to the contents of the name 943 // attribute. 944 void StringAttrPrettyNameOp::getAsmResultNames( 945 function_ref<void(Value, StringRef)> setNameFn) { 946 947 auto value = names(); 948 for (size_t i = 0, e = value.size(); i != e; ++i) 949 if (auto str = value[i].dyn_cast<StringAttr>()) 950 if (!str.getValue().empty()) 951 setNameFn(getResult(i), str.getValue()); 952 } 953 954 //===----------------------------------------------------------------------===// 955 // RegionIfOp 956 //===----------------------------------------------------------------------===// 957 958 static void print(OpAsmPrinter &p, RegionIfOp op) { 959 p << RegionIfOp::getOperationName() << " "; 960 p.printOperands(op.getOperands()); 961 p << ": " << op.getOperandTypes(); 962 p.printArrowTypeList(op.getResultTypes()); 963 p << " then"; 964 p.printRegion(op.thenRegion(), 965 /*printEntryBlockArgs=*/true, 966 /*printBlockTerminators=*/true); 967 p << " else"; 968 p.printRegion(op.elseRegion(), 969 /*printEntryBlockArgs=*/true, 970 /*printBlockTerminators=*/true); 971 p << " join"; 972 p.printRegion(op.joinRegion(), 973 /*printEntryBlockArgs=*/true, 974 /*printBlockTerminators=*/true); 975 } 976 977 static ParseResult parseRegionIfOp(OpAsmParser &parser, 978 OperationState &result) { 979 SmallVector<OpAsmParser::OperandType, 2> operandInfos; 980 SmallVector<Type, 2> operandTypes; 981 982 result.regions.reserve(3); 983 Region *thenRegion = result.addRegion(); 984 Region *elseRegion = result.addRegion(); 985 Region *joinRegion = result.addRegion(); 986 987 // Parse operand, type and arrow type lists. 988 if (parser.parseOperandList(operandInfos) || 989 parser.parseColonTypeList(operandTypes) || 990 parser.parseArrowTypeList(result.types)) 991 return failure(); 992 993 // Parse all attached regions. 994 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 995 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 996 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 997 return failure(); 998 999 return parser.resolveOperands(operandInfos, operandTypes, 1000 parser.getCurrentLocation(), result.operands); 1001 } 1002 1003 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 1004 assert(index < 2 && "invalid region index"); 1005 return getOperands(); 1006 } 1007 1008 void RegionIfOp::getSuccessorRegions( 1009 Optional<unsigned> index, ArrayRef<Attribute> operands, 1010 SmallVectorImpl<RegionSuccessor> ®ions) { 1011 // We always branch to the join region. 1012 if (index.hasValue()) { 1013 if (index.getValue() < 2) 1014 regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs())); 1015 else 1016 regions.push_back(RegionSuccessor(getResults())); 1017 return; 1018 } 1019 1020 // The then and else regions are the entry regions of this op. 1021 regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs())); 1022 regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs())); 1023 } 1024 1025 //===----------------------------------------------------------------------===// 1026 // SingleNoTerminatorCustomAsmOp 1027 //===----------------------------------------------------------------------===// 1028 1029 static ParseResult parseSingleNoTerminatorCustomAsmOp(OpAsmParser &parser, 1030 OperationState &state) { 1031 Region *body = state.addRegion(); 1032 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) 1033 return failure(); 1034 return success(); 1035 } 1036 1037 static void print(SingleNoTerminatorCustomAsmOp op, OpAsmPrinter &printer) { 1038 printer << op.getOperationName(); 1039 printer.printRegion( 1040 op.getRegion(), /*printEntryBlockArgs=*/false, 1041 // This op has a single block without terminators. But explicitly mark 1042 // as not printing block terminators for testing. 1043 /*printBlockTerminators=*/false); 1044 } 1045 1046 #include "TestOpEnums.cpp.inc" 1047 #include "TestOpInterfaces.cpp.inc" 1048 #include "TestOpStructs.cpp.inc" 1049 #include "TestTypeInterfaces.cpp.inc" 1050 1051 #define GET_OP_CLASSES 1052 #include "TestOps.cpp.inc" 1053