1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestAttributes.h" 11 #include "TestTypes.h" 12 #include "mlir/Dialect/DLTI/DLTI.h" 13 #include "mlir/Dialect/MemRef/IR/MemRef.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/BuiltinOps.h" 16 #include "mlir/IR/DialectImplementation.h" 17 #include "mlir/IR/PatternMatch.h" 18 #include "mlir/IR/TypeUtilities.h" 19 #include "mlir/Transforms/FoldUtils.h" 20 #include "mlir/Transforms/InliningUtils.h" 21 #include "llvm/ADT/StringSwitch.h" 22 23 using namespace mlir; 24 using namespace mlir::test; 25 26 void mlir::test::registerTestDialect(DialectRegistry ®istry) { 27 registry.insert<TestDialect>(); 28 } 29 30 //===----------------------------------------------------------------------===// 31 // TestDialect Interfaces 32 //===----------------------------------------------------------------------===// 33 34 namespace { 35 36 // Test support for interacting with the AsmPrinter. 37 struct TestOpAsmInterface : public OpAsmDialectInterface { 38 using OpAsmDialectInterface::OpAsmDialectInterface; 39 40 LogicalResult getAlias(Attribute attr, raw_ostream &os) const final { 41 StringAttr strAttr = attr.dyn_cast<StringAttr>(); 42 if (!strAttr) 43 return failure(); 44 45 // Check the contents of the string attribute to see what the test alias 46 // should be named. 47 Optional<StringRef> aliasName = 48 StringSwitch<Optional<StringRef>>(strAttr.getValue()) 49 .Case("alias_test:dot_in_name", StringRef("test.alias")) 50 .Case("alias_test:trailing_digit", StringRef("test_alias0")) 51 .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) 52 .Case("alias_test:sanitize_conflict_a", 53 StringRef("test_alias_conflict0")) 54 .Case("alias_test:sanitize_conflict_b", 55 StringRef("test_alias_conflict0_")) 56 .Default(llvm::None); 57 if (!aliasName) 58 return failure(); 59 60 os << *aliasName; 61 return success(); 62 } 63 64 void getAsmResultNames(Operation *op, 65 OpAsmSetValueNameFn setNameFn) const final { 66 if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) 67 setNameFn(asmOp, "result"); 68 } 69 70 void getAsmBlockArgumentNames(Block *block, 71 OpAsmSetValueNameFn setNameFn) const final { 72 auto op = block->getParentOp(); 73 auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names"); 74 if (!arrayAttr) 75 return; 76 auto args = block->getArguments(); 77 auto e = std::min(arrayAttr.size(), args.size()); 78 for (unsigned i = 0; i < e; ++i) { 79 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 80 setNameFn(args[i], strAttr.getValue()); 81 } 82 } 83 }; 84 85 struct TestDialectFoldInterface : public DialectFoldInterface { 86 using DialectFoldInterface::DialectFoldInterface; 87 88 /// Registered hook to check if the given region, which is attached to an 89 /// operation that is *not* isolated from above, should be used when 90 /// materializing constants. 91 bool shouldMaterializeInto(Region *region) const final { 92 // If this is a one region operation, then insert into it. 93 return isa<OneRegionOp>(region->getParentOp()); 94 } 95 }; 96 97 /// This class defines the interface for handling inlining with standard 98 /// operations. 99 struct TestInlinerInterface : public DialectInlinerInterface { 100 using DialectInlinerInterface::DialectInlinerInterface; 101 102 //===--------------------------------------------------------------------===// 103 // Analysis Hooks 104 //===--------------------------------------------------------------------===// 105 106 bool isLegalToInline(Operation *call, Operation *callable, 107 bool wouldBeCloned) const final { 108 // Don't allow inlining calls that are marked `noinline`. 109 return !call->hasAttr("noinline"); 110 } 111 bool isLegalToInline(Region *, Region *, bool, 112 BlockAndValueMapping &) const final { 113 // Inlining into test dialect regions is legal. 114 return true; 115 } 116 bool isLegalToInline(Operation *, Region *, bool, 117 BlockAndValueMapping &) const final { 118 return true; 119 } 120 121 bool shouldAnalyzeRecursively(Operation *op) const final { 122 // Analyze recursively if this is not a functional region operation, it 123 // froms a separate functional scope. 124 return !isa<FunctionalRegionOp>(op); 125 } 126 127 //===--------------------------------------------------------------------===// 128 // Transformation Hooks 129 //===--------------------------------------------------------------------===// 130 131 /// Handle the given inlined terminator by replacing it with a new operation 132 /// as necessary. 133 void handleTerminator(Operation *op, 134 ArrayRef<Value> valuesToRepl) const final { 135 // Only handle "test.return" here. 136 auto returnOp = dyn_cast<TestReturnOp>(op); 137 if (!returnOp) 138 return; 139 140 // Replace the values directly with the return operands. 141 assert(returnOp.getNumOperands() == valuesToRepl.size()); 142 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 143 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 144 } 145 146 /// Attempt to materialize a conversion for a type mismatch between a call 147 /// from this dialect, and a callable region. This method should generate an 148 /// operation that takes 'input' as the only operand, and produces a single 149 /// result of 'resultType'. If a conversion can not be generated, nullptr 150 /// should be returned. 151 Operation *materializeCallConversion(OpBuilder &builder, Value input, 152 Type resultType, 153 Location conversionLoc) const final { 154 // Only allow conversion for i16/i32 types. 155 if (!(resultType.isSignlessInteger(16) || 156 resultType.isSignlessInteger(32)) || 157 !(input.getType().isSignlessInteger(16) || 158 input.getType().isSignlessInteger(32))) 159 return nullptr; 160 return builder.create<TestCastOp>(conversionLoc, resultType, input); 161 } 162 }; 163 } // end anonymous namespace 164 165 //===----------------------------------------------------------------------===// 166 // TestDialect 167 //===----------------------------------------------------------------------===// 168 169 static void testSideEffectOpGetEffect( 170 Operation *op, 171 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects); 172 173 // This is the implementation of a dialect fallback for `TestEffectOpInterface`. 174 struct TestOpEffectInterfaceFallback 175 : public TestEffectOpInterface::FallbackModel< 176 TestOpEffectInterfaceFallback> { 177 static bool classof(Operation *op) { 178 bool isSupportedOp = 179 op->getName().getStringRef() == "test.unregistered_side_effect_op"; 180 assert(isSupportedOp && "Unexpected dispatch"); 181 return isSupportedOp; 182 } 183 184 void 185 getEffects(Operation *op, 186 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 187 &effects) const { 188 testSideEffectOpGetEffect(op, effects); 189 } 190 }; 191 192 void TestDialect::initialize() { 193 registerAttributes(); 194 registerTypes(); 195 addOperations< 196 #define GET_OP_LIST 197 #include "TestOps.cpp.inc" 198 >(); 199 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 200 TestInlinerInterface>(); 201 allowUnknownOperations(); 202 203 // Instantiate our fallback op interface that we'll use on specific 204 // unregistered op. 205 fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback; 206 } 207 TestDialect::~TestDialect() { 208 delete static_cast<TestOpEffectInterfaceFallback *>( 209 fallbackEffectOpInterfaces); 210 } 211 212 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, 213 Type type, Location loc) { 214 return builder.create<TestOpConstant>(loc, type, value); 215 } 216 217 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, 218 OperationName opName) { 219 if (opName.getIdentifier() == "test.unregistered_side_effect_op" && 220 typeID == TypeID::get<TestEffectOpInterface>()) 221 return fallbackEffectOpInterfaces; 222 return nullptr; 223 } 224 225 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 226 NamedAttribute namedAttr) { 227 if (namedAttr.first == "test.invalid_attr") 228 return op->emitError() << "invalid to use 'test.invalid_attr'"; 229 return success(); 230 } 231 232 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 233 unsigned regionIndex, 234 unsigned argIndex, 235 NamedAttribute namedAttr) { 236 if (namedAttr.first == "test.invalid_attr") 237 return op->emitError() << "invalid to use 'test.invalid_attr'"; 238 return success(); 239 } 240 241 LogicalResult 242 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 243 unsigned resultIndex, 244 NamedAttribute namedAttr) { 245 if (namedAttr.first == "test.invalid_attr") 246 return op->emitError() << "invalid to use 'test.invalid_attr'"; 247 return success(); 248 } 249 250 Optional<Dialect::ParseOpHook> 251 TestDialect::getParseOperationHook(StringRef opName) const { 252 if (opName == "test.dialect_custom_printer") { 253 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 254 return parser.parseKeyword("custom_format"); 255 }}; 256 } 257 return None; 258 } 259 260 LogicalResult TestDialect::printOperation(Operation *op, 261 OpAsmPrinter &printer) const { 262 StringRef opName = op->getName().getStringRef(); 263 if (opName == "test.dialect_custom_printer") { 264 printer.getStream() << opName << " custom_format"; 265 return success(); 266 } 267 return failure(); 268 } 269 270 //===----------------------------------------------------------------------===// 271 // TestBranchOp 272 //===----------------------------------------------------------------------===// 273 274 Optional<MutableOperandRange> 275 TestBranchOp::getMutableSuccessorOperands(unsigned index) { 276 assert(index == 0 && "invalid successor index"); 277 return targetOperandsMutable(); 278 } 279 280 //===----------------------------------------------------------------------===// 281 // TestFoldToCallOp 282 //===----------------------------------------------------------------------===// 283 284 namespace { 285 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 286 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 287 288 LogicalResult matchAndRewrite(FoldToCallOp op, 289 PatternRewriter &rewriter) const override { 290 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(), 291 ValueRange()); 292 return success(); 293 } 294 }; 295 } // end anonymous namespace 296 297 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results, 298 MLIRContext *context) { 299 results.add<FoldToCallOpPattern>(context); 300 } 301 302 //===----------------------------------------------------------------------===// 303 // Test Format* operations 304 //===----------------------------------------------------------------------===// 305 306 //===----------------------------------------------------------------------===// 307 // Parsing 308 309 static ParseResult parseCustomDirectiveOperands( 310 OpAsmParser &parser, OpAsmParser::OperandType &operand, 311 Optional<OpAsmParser::OperandType> &optOperand, 312 SmallVectorImpl<OpAsmParser::OperandType> &varOperands) { 313 if (parser.parseOperand(operand)) 314 return failure(); 315 if (succeeded(parser.parseOptionalComma())) { 316 optOperand.emplace(); 317 if (parser.parseOperand(*optOperand)) 318 return failure(); 319 } 320 if (parser.parseArrow() || parser.parseLParen() || 321 parser.parseOperandList(varOperands) || parser.parseRParen()) 322 return failure(); 323 return success(); 324 } 325 static ParseResult 326 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 327 Type &optOperandType, 328 SmallVectorImpl<Type> &varOperandTypes) { 329 if (parser.parseColon()) 330 return failure(); 331 332 if (parser.parseType(operandType)) 333 return failure(); 334 if (succeeded(parser.parseOptionalComma())) { 335 if (parser.parseType(optOperandType)) 336 return failure(); 337 } 338 if (parser.parseArrow() || parser.parseLParen() || 339 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 340 return failure(); 341 return success(); 342 } 343 static ParseResult 344 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 345 Type optOperandType, 346 const SmallVectorImpl<Type> &varOperandTypes) { 347 if (parser.parseKeyword("type_refs_capture")) 348 return failure(); 349 350 Type operandType2, optOperandType2; 351 SmallVector<Type, 1> varOperandTypes2; 352 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 353 varOperandTypes2)) 354 return failure(); 355 356 if (operandType != operandType2 || optOperandType != optOperandType2 || 357 varOperandTypes != varOperandTypes2) 358 return failure(); 359 360 return success(); 361 } 362 static ParseResult parseCustomDirectiveOperandsAndTypes( 363 OpAsmParser &parser, OpAsmParser::OperandType &operand, 364 Optional<OpAsmParser::OperandType> &optOperand, 365 SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType, 366 Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) { 367 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 368 parseCustomDirectiveResults(parser, operandType, optOperandType, 369 varOperandTypes)) 370 return failure(); 371 return success(); 372 } 373 static ParseResult parseCustomDirectiveRegions( 374 OpAsmParser &parser, Region ®ion, 375 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 376 if (parser.parseRegion(region)) 377 return failure(); 378 if (failed(parser.parseOptionalComma())) 379 return success(); 380 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 381 if (parser.parseRegion(*varRegion)) 382 return failure(); 383 varRegions.emplace_back(std::move(varRegion)); 384 return success(); 385 } 386 static ParseResult 387 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 388 SmallVectorImpl<Block *> &varSuccessors) { 389 if (parser.parseSuccessor(successor)) 390 return failure(); 391 if (failed(parser.parseOptionalComma())) 392 return success(); 393 Block *varSuccessor; 394 if (parser.parseSuccessor(varSuccessor)) 395 return failure(); 396 varSuccessors.append(2, varSuccessor); 397 return success(); 398 } 399 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 400 IntegerAttr &attr, 401 IntegerAttr &optAttr) { 402 if (parser.parseAttribute(attr)) 403 return failure(); 404 if (succeeded(parser.parseOptionalComma())) { 405 if (parser.parseAttribute(optAttr)) 406 return failure(); 407 } 408 return success(); 409 } 410 411 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, 412 NamedAttrList &attrs) { 413 return parser.parseOptionalAttrDict(attrs); 414 } 415 static ParseResult parseCustomDirectiveOptionalOperandRef( 416 OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) { 417 int64_t operandCount = 0; 418 if (parser.parseInteger(operandCount)) 419 return failure(); 420 bool expectedOptionalOperand = operandCount == 0; 421 return success(expectedOptionalOperand != optOperand.hasValue()); 422 } 423 424 //===----------------------------------------------------------------------===// 425 // Printing 426 427 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 428 Value operand, Value optOperand, 429 OperandRange varOperands) { 430 printer << operand; 431 if (optOperand) 432 printer << ", " << optOperand; 433 printer << " -> (" << varOperands << ")"; 434 } 435 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 436 Type operandType, Type optOperandType, 437 TypeRange varOperandTypes) { 438 printer << " : " << operandType; 439 if (optOperandType) 440 printer << ", " << optOperandType; 441 printer << " -> (" << varOperandTypes << ")"; 442 } 443 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 444 Operation *op, Type operandType, 445 Type optOperandType, 446 TypeRange varOperandTypes) { 447 printer << " type_refs_capture "; 448 printCustomDirectiveResults(printer, op, operandType, optOperandType, 449 varOperandTypes); 450 } 451 static void printCustomDirectiveOperandsAndTypes( 452 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 453 OperandRange varOperands, Type operandType, Type optOperandType, 454 TypeRange varOperandTypes) { 455 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 456 printCustomDirectiveResults(printer, op, operandType, optOperandType, 457 varOperandTypes); 458 } 459 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 460 Region ®ion, 461 MutableArrayRef<Region> varRegions) { 462 printer.printRegion(region); 463 if (!varRegions.empty()) { 464 printer << ", "; 465 for (Region ®ion : varRegions) 466 printer.printRegion(region); 467 } 468 } 469 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 470 Block *successor, 471 SuccessorRange varSuccessors) { 472 printer << successor; 473 if (!varSuccessors.empty()) 474 printer << ", " << varSuccessors.front(); 475 } 476 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 477 Attribute attribute, 478 Attribute optAttribute) { 479 printer << attribute; 480 if (optAttribute) 481 printer << ", " << optAttribute; 482 } 483 484 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 485 DictionaryAttr attrs) { 486 printer.printOptionalAttrDict(attrs.getValue()); 487 } 488 489 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, 490 Operation *op, 491 Value optOperand) { 492 printer << (optOperand ? "1" : "0"); 493 } 494 495 //===----------------------------------------------------------------------===// 496 // Test IsolatedRegionOp - parse passthrough region arguments. 497 //===----------------------------------------------------------------------===// 498 499 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, 500 OperationState &result) { 501 OpAsmParser::OperandType argInfo; 502 Type argType = parser.getBuilder().getIndexType(); 503 504 // Parse the input operand. 505 if (parser.parseOperand(argInfo) || 506 parser.resolveOperand(argInfo, argType, result.operands)) 507 return failure(); 508 509 // Parse the body region, and reuse the operand info as the argument info. 510 Region *body = result.addRegion(); 511 return parser.parseRegion(*body, argInfo, argType, 512 /*enableNameShadowing=*/true); 513 } 514 515 static void print(OpAsmPrinter &p, IsolatedRegionOp op) { 516 p << "test.isolated_region "; 517 p.printOperand(op.getOperand()); 518 p.shadowRegionArgs(op.region(), op.getOperand()); 519 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 520 } 521 522 //===----------------------------------------------------------------------===// 523 // Test SSACFGRegionOp 524 //===----------------------------------------------------------------------===// 525 526 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 527 return RegionKind::SSACFG; 528 } 529 530 //===----------------------------------------------------------------------===// 531 // Test GraphRegionOp 532 //===----------------------------------------------------------------------===// 533 534 static ParseResult parseGraphRegionOp(OpAsmParser &parser, 535 OperationState &result) { 536 // Parse the body region, and reuse the operand info as the argument info. 537 Region *body = result.addRegion(); 538 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 539 } 540 541 static void print(OpAsmPrinter &p, GraphRegionOp op) { 542 p << "test.graph_region "; 543 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 544 } 545 546 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 547 return RegionKind::Graph; 548 } 549 550 //===----------------------------------------------------------------------===// 551 // Test AffineScopeOp 552 //===----------------------------------------------------------------------===// 553 554 static ParseResult parseAffineScopeOp(OpAsmParser &parser, 555 OperationState &result) { 556 // Parse the body region, and reuse the operand info as the argument info. 557 Region *body = result.addRegion(); 558 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 559 } 560 561 static void print(OpAsmPrinter &p, AffineScopeOp op) { 562 p << "test.affine_scope "; 563 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 564 } 565 566 //===----------------------------------------------------------------------===// 567 // Test parser. 568 //===----------------------------------------------------------------------===// 569 570 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser, 571 OperationState &result) { 572 if (parser.parseOptionalColon()) 573 return success(); 574 uint64_t numResults; 575 if (parser.parseInteger(numResults)) 576 return failure(); 577 578 IndexType type = parser.getBuilder().getIndexType(); 579 for (unsigned i = 0; i < numResults; ++i) 580 result.addTypes(type); 581 return success(); 582 } 583 584 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) { 585 p << ParseIntegerLiteralOp::getOperationName(); 586 if (unsigned numResults = op->getNumResults()) 587 p << " : " << numResults; 588 } 589 590 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser, 591 OperationState &result) { 592 StringRef keyword; 593 if (parser.parseKeyword(&keyword)) 594 return failure(); 595 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 596 return success(); 597 } 598 599 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) { 600 p << ParseWrappedKeywordOp::getOperationName() << " " << op.keyword(); 601 } 602 603 //===----------------------------------------------------------------------===// 604 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 605 606 static ParseResult parseWrappingRegionOp(OpAsmParser &parser, 607 OperationState &result) { 608 if (parser.parseKeyword("wraps")) 609 return failure(); 610 611 // Parse the wrapped op in a region 612 Region &body = *result.addRegion(); 613 body.push_back(new Block); 614 Block &block = body.back(); 615 Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); 616 if (!wrapped_op) 617 return failure(); 618 619 // Create a return terminator in the inner region, pass as operand to the 620 // terminator the returned values from the wrapped operation. 621 SmallVector<Value, 8> return_operands(wrapped_op->getResults()); 622 OpBuilder builder(parser.getBuilder().getContext()); 623 builder.setInsertionPointToEnd(&block); 624 builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands); 625 626 // Get the results type for the wrapping op from the terminator operands. 627 Operation &return_op = body.back().back(); 628 result.types.append(return_op.operand_type_begin(), 629 return_op.operand_type_end()); 630 631 // Use the location of the wrapped op for the "test.wrapping_region" op. 632 result.location = wrapped_op->getLoc(); 633 634 return success(); 635 } 636 637 static void print(OpAsmPrinter &p, WrappingRegionOp op) { 638 p << op.getOperationName() << " wraps "; 639 p.printGenericOp(&op.region().front().front()); 640 } 641 642 //===----------------------------------------------------------------------===// 643 // Test PolyForOp - parse list of region arguments. 644 //===----------------------------------------------------------------------===// 645 646 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { 647 SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 648 // Parse list of region arguments without a delimiter. 649 if (parser.parseRegionArgumentList(ivsInfo)) 650 return failure(); 651 652 // Parse the body region. 653 Region *body = result.addRegion(); 654 auto &builder = parser.getBuilder(); 655 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 656 return parser.parseRegion(*body, ivsInfo, argTypes); 657 } 658 659 //===----------------------------------------------------------------------===// 660 // Test removing op with inner ops. 661 //===----------------------------------------------------------------------===// 662 663 namespace { 664 struct TestRemoveOpWithInnerOps 665 : public OpRewritePattern<TestOpWithRegionPattern> { 666 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 667 668 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 669 PatternRewriter &rewriter) const override { 670 rewriter.eraseOp(op); 671 return success(); 672 } 673 }; 674 } // end anonymous namespace 675 676 void TestOpWithRegionPattern::getCanonicalizationPatterns( 677 RewritePatternSet &results, MLIRContext *context) { 678 results.add<TestRemoveOpWithInnerOps>(context); 679 } 680 681 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 682 return operand(); 683 } 684 685 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) { 686 return getValue(); 687 } 688 689 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 690 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 691 for (Value input : this->operands()) { 692 results.push_back(input); 693 } 694 return success(); 695 } 696 697 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 698 assert(operands.size() == 1); 699 if (operands.front()) { 700 (*this)->setAttr("attr", operands.front()); 701 return getResult(); 702 } 703 return {}; 704 } 705 706 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) { 707 return getOperand(); 708 } 709 710 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 711 MLIRContext *, Optional<Location> location, ValueRange operands, 712 DictionaryAttr attributes, RegionRange regions, 713 SmallVectorImpl<Type> &inferredReturnTypes) { 714 if (operands[0].getType() != operands[1].getType()) { 715 return emitOptionalError(location, "operand type mismatch ", 716 operands[0].getType(), " vs ", 717 operands[1].getType()); 718 } 719 inferredReturnTypes.assign({operands[0].getType()}); 720 return success(); 721 } 722 723 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 724 MLIRContext *context, Optional<Location> location, ValueRange operands, 725 DictionaryAttr attributes, RegionRange regions, 726 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 727 // Create return type consisting of the last element of the first operand. 728 auto operandType = *operands.getTypes().begin(); 729 auto sval = operandType.dyn_cast<ShapedType>(); 730 if (!sval) { 731 return emitOptionalError(location, "only shaped type operands allowed"); 732 } 733 int64_t dim = 734 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 735 auto type = IntegerType::get(context, 17); 736 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 737 return success(); 738 } 739 740 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 741 OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) { 742 shapes = SmallVector<Value, 1>{ 743 builder.createOrFold<memref::DimOp>(getLoc(), getOperand(0), 0)}; 744 return success(); 745 } 746 747 //===----------------------------------------------------------------------===// 748 // Test SideEffect interfaces 749 //===----------------------------------------------------------------------===// 750 751 namespace { 752 /// A test resource for side effects. 753 struct TestResource : public SideEffects::Resource::Base<TestResource> { 754 StringRef getName() final { return "<Test>"; } 755 }; 756 } // end anonymous namespace 757 758 static void testSideEffectOpGetEffect( 759 Operation *op, 760 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 761 &effects) { 762 auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter"); 763 if (!effectsAttr) 764 return; 765 766 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); 767 } 768 769 void SideEffectOp::getEffects( 770 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 771 // Check for an effects attribute on the op instance. 772 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects"); 773 if (!effectsAttr) 774 return; 775 776 // If there is one, it is an array of dictionary attributes that hold 777 // information on the effects of this operation. 778 for (Attribute element : effectsAttr) { 779 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 780 781 // Get the specific memory effect. 782 MemoryEffects::Effect *effect = 783 StringSwitch<MemoryEffects::Effect *>( 784 effectElement.get("effect").cast<StringAttr>().getValue()) 785 .Case("allocate", MemoryEffects::Allocate::get()) 786 .Case("free", MemoryEffects::Free::get()) 787 .Case("read", MemoryEffects::Read::get()) 788 .Case("write", MemoryEffects::Write::get()); 789 790 // Check for a non-default resource to use. 791 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 792 if (effectElement.get("test_resource")) 793 resource = TestResource::get(); 794 795 // Check for a result to affect. 796 if (effectElement.get("on_result")) 797 effects.emplace_back(effect, getResult(), resource); 798 else if (Attribute ref = effectElement.get("on_reference")) 799 effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource); 800 else 801 effects.emplace_back(effect, resource); 802 } 803 } 804 805 void SideEffectOp::getEffects( 806 SmallVectorImpl<TestEffects::EffectInstance> &effects) { 807 testSideEffectOpGetEffect(getOperation(), effects); 808 } 809 810 //===----------------------------------------------------------------------===// 811 // StringAttrPrettyNameOp 812 //===----------------------------------------------------------------------===// 813 814 // This op has fancy handling of its SSA result name. 815 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, 816 OperationState &result) { 817 // Add the result types. 818 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 819 result.addTypes(parser.getBuilder().getIntegerType(32)); 820 821 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 822 return failure(); 823 824 // If the attribute dictionary contains no 'names' attribute, infer it from 825 // the SSA name (if specified). 826 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 827 return attr.first == "names"; 828 }); 829 830 // If there was no name specified, check to see if there was a useful name 831 // specified in the asm file. 832 if (hadNames || parser.getNumResults() == 0) 833 return success(); 834 835 SmallVector<StringRef, 4> names; 836 auto *context = result.getContext(); 837 838 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 839 auto resultName = parser.getResultName(i); 840 StringRef nameStr; 841 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 842 nameStr = resultName.first; 843 844 names.push_back(nameStr); 845 } 846 847 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 848 result.attributes.push_back({Identifier::get("names", context), namesAttr}); 849 return success(); 850 } 851 852 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { 853 p << "test.string_attr_pretty_name"; 854 855 // Note that we only need to print the "name" attribute if the asmprinter 856 // result name disagrees with it. This can happen in strange cases, e.g. 857 // when there are conflicts. 858 bool namesDisagree = op.names().size() != op.getNumResults(); 859 860 SmallString<32> resultNameStr; 861 for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { 862 resultNameStr.clear(); 863 llvm::raw_svector_ostream tmpStream(resultNameStr); 864 p.printOperand(op.getResult(i), tmpStream); 865 866 auto expectedName = op.names()[i].dyn_cast<StringAttr>(); 867 if (!expectedName || 868 tmpStream.str().drop_front() != expectedName.getValue()) { 869 namesDisagree = true; 870 } 871 } 872 873 if (namesDisagree) 874 p.printOptionalAttrDictWithKeyword(op->getAttrs()); 875 else 876 p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"}); 877 } 878 879 // We set the SSA name in the asm syntax to the contents of the name 880 // attribute. 881 void StringAttrPrettyNameOp::getAsmResultNames( 882 function_ref<void(Value, StringRef)> setNameFn) { 883 884 auto value = names(); 885 for (size_t i = 0, e = value.size(); i != e; ++i) 886 if (auto str = value[i].dyn_cast<StringAttr>()) 887 if (!str.getValue().empty()) 888 setNameFn(getResult(i), str.getValue()); 889 } 890 891 //===----------------------------------------------------------------------===// 892 // RegionIfOp 893 //===----------------------------------------------------------------------===// 894 895 static void print(OpAsmPrinter &p, RegionIfOp op) { 896 p << RegionIfOp::getOperationName() << " "; 897 p.printOperands(op.getOperands()); 898 p << ": " << op.getOperandTypes(); 899 p.printArrowTypeList(op.getResultTypes()); 900 p << " then"; 901 p.printRegion(op.thenRegion(), 902 /*printEntryBlockArgs=*/true, 903 /*printBlockTerminators=*/true); 904 p << " else"; 905 p.printRegion(op.elseRegion(), 906 /*printEntryBlockArgs=*/true, 907 /*printBlockTerminators=*/true); 908 p << " join"; 909 p.printRegion(op.joinRegion(), 910 /*printEntryBlockArgs=*/true, 911 /*printBlockTerminators=*/true); 912 } 913 914 static ParseResult parseRegionIfOp(OpAsmParser &parser, 915 OperationState &result) { 916 SmallVector<OpAsmParser::OperandType, 2> operandInfos; 917 SmallVector<Type, 2> operandTypes; 918 919 result.regions.reserve(3); 920 Region *thenRegion = result.addRegion(); 921 Region *elseRegion = result.addRegion(); 922 Region *joinRegion = result.addRegion(); 923 924 // Parse operand, type and arrow type lists. 925 if (parser.parseOperandList(operandInfos) || 926 parser.parseColonTypeList(operandTypes) || 927 parser.parseArrowTypeList(result.types)) 928 return failure(); 929 930 // Parse all attached regions. 931 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 932 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 933 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 934 return failure(); 935 936 return parser.resolveOperands(operandInfos, operandTypes, 937 parser.getCurrentLocation(), result.operands); 938 } 939 940 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 941 assert(index < 2 && "invalid region index"); 942 return getOperands(); 943 } 944 945 void RegionIfOp::getSuccessorRegions( 946 Optional<unsigned> index, ArrayRef<Attribute> operands, 947 SmallVectorImpl<RegionSuccessor> ®ions) { 948 // We always branch to the join region. 949 if (index.hasValue()) { 950 if (index.getValue() < 2) 951 regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs())); 952 else 953 regions.push_back(RegionSuccessor(getResults())); 954 return; 955 } 956 957 // The then and else regions are the entry regions of this op. 958 regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs())); 959 regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs())); 960 } 961 962 #include "TestOpEnums.cpp.inc" 963 #include "TestOpInterfaces.cpp.inc" 964 #include "TestOpStructs.cpp.inc" 965 #include "TestTypeInterfaces.cpp.inc" 966 967 #define GET_OP_CLASSES 968 #include "TestOps.cpp.inc" 969