1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestAttributes.h" 11 #include "TestTypes.h" 12 #include "mlir/Dialect/DLTI/DLTI.h" 13 #include "mlir/Dialect/MemRef/IR/MemRef.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/BuiltinOps.h" 16 #include "mlir/IR/DialectImplementation.h" 17 #include "mlir/IR/PatternMatch.h" 18 #include "mlir/IR/TypeUtilities.h" 19 #include "mlir/Transforms/FoldUtils.h" 20 #include "mlir/Transforms/InliningUtils.h" 21 #include "llvm/ADT/StringSwitch.h" 22 23 using namespace mlir; 24 using namespace mlir::test; 25 26 void mlir::test::registerTestDialect(DialectRegistry ®istry) { 27 registry.insert<TestDialect>(); 28 } 29 30 //===----------------------------------------------------------------------===// 31 // TestDialect Interfaces 32 //===----------------------------------------------------------------------===// 33 34 namespace { 35 36 /// Testing the correctness of some traits. 37 static_assert( 38 llvm::is_detected<OpTrait::has_implicit_terminator_t, 39 SingleBlockImplicitTerminatorOp>::value, 40 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp"); 41 static_assert(OpTrait::hasSingleBlockImplicitTerminator< 42 SingleBlockImplicitTerminatorOp>::value, 43 "hasSingleBlockImplicitTerminator does not match " 44 "SingleBlockImplicitTerminatorOp"); 45 46 // Test support for interacting with the AsmPrinter. 47 struct TestOpAsmInterface : public OpAsmDialectInterface { 48 using OpAsmDialectInterface::OpAsmDialectInterface; 49 50 LogicalResult getAlias(Attribute attr, raw_ostream &os) const final { 51 StringAttr strAttr = attr.dyn_cast<StringAttr>(); 52 if (!strAttr) 53 return failure(); 54 55 // Check the contents of the string attribute to see what the test alias 56 // should be named. 57 Optional<StringRef> aliasName = 58 StringSwitch<Optional<StringRef>>(strAttr.getValue()) 59 .Case("alias_test:dot_in_name", StringRef("test.alias")) 60 .Case("alias_test:trailing_digit", StringRef("test_alias0")) 61 .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) 62 .Case("alias_test:sanitize_conflict_a", 63 StringRef("test_alias_conflict0")) 64 .Case("alias_test:sanitize_conflict_b", 65 StringRef("test_alias_conflict0_")) 66 .Default(llvm::None); 67 if (!aliasName) 68 return failure(); 69 70 os << *aliasName; 71 return success(); 72 } 73 74 void getAsmResultNames(Operation *op, 75 OpAsmSetValueNameFn setNameFn) const final { 76 if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) 77 setNameFn(asmOp, "result"); 78 } 79 80 void getAsmBlockArgumentNames(Block *block, 81 OpAsmSetValueNameFn setNameFn) const final { 82 auto op = block->getParentOp(); 83 auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names"); 84 if (!arrayAttr) 85 return; 86 auto args = block->getArguments(); 87 auto e = std::min(arrayAttr.size(), args.size()); 88 for (unsigned i = 0; i < e; ++i) { 89 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 90 setNameFn(args[i], strAttr.getValue()); 91 } 92 } 93 }; 94 95 struct TestDialectFoldInterface : public DialectFoldInterface { 96 using DialectFoldInterface::DialectFoldInterface; 97 98 /// Registered hook to check if the given region, which is attached to an 99 /// operation that is *not* isolated from above, should be used when 100 /// materializing constants. 101 bool shouldMaterializeInto(Region *region) const final { 102 // If this is a one region operation, then insert into it. 103 return isa<OneRegionOp>(region->getParentOp()); 104 } 105 }; 106 107 /// This class defines the interface for handling inlining with standard 108 /// operations. 109 struct TestInlinerInterface : public DialectInlinerInterface { 110 using DialectInlinerInterface::DialectInlinerInterface; 111 112 //===--------------------------------------------------------------------===// 113 // Analysis Hooks 114 //===--------------------------------------------------------------------===// 115 116 bool isLegalToInline(Operation *call, Operation *callable, 117 bool wouldBeCloned) const final { 118 // Don't allow inlining calls that are marked `noinline`. 119 return !call->hasAttr("noinline"); 120 } 121 bool isLegalToInline(Region *, Region *, bool, 122 BlockAndValueMapping &) const final { 123 // Inlining into test dialect regions is legal. 124 return true; 125 } 126 bool isLegalToInline(Operation *, Region *, bool, 127 BlockAndValueMapping &) const final { 128 return true; 129 } 130 131 bool shouldAnalyzeRecursively(Operation *op) const final { 132 // Analyze recursively if this is not a functional region operation, it 133 // froms a separate functional scope. 134 return !isa<FunctionalRegionOp>(op); 135 } 136 137 //===--------------------------------------------------------------------===// 138 // Transformation Hooks 139 //===--------------------------------------------------------------------===// 140 141 /// Handle the given inlined terminator by replacing it with a new operation 142 /// as necessary. 143 void handleTerminator(Operation *op, 144 ArrayRef<Value> valuesToRepl) const final { 145 // Only handle "test.return" here. 146 auto returnOp = dyn_cast<TestReturnOp>(op); 147 if (!returnOp) 148 return; 149 150 // Replace the values directly with the return operands. 151 assert(returnOp.getNumOperands() == valuesToRepl.size()); 152 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 153 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 154 } 155 156 /// Attempt to materialize a conversion for a type mismatch between a call 157 /// from this dialect, and a callable region. This method should generate an 158 /// operation that takes 'input' as the only operand, and produces a single 159 /// result of 'resultType'. If a conversion can not be generated, nullptr 160 /// should be returned. 161 Operation *materializeCallConversion(OpBuilder &builder, Value input, 162 Type resultType, 163 Location conversionLoc) const final { 164 // Only allow conversion for i16/i32 types. 165 if (!(resultType.isSignlessInteger(16) || 166 resultType.isSignlessInteger(32)) || 167 !(input.getType().isSignlessInteger(16) || 168 input.getType().isSignlessInteger(32))) 169 return nullptr; 170 return builder.create<TestCastOp>(conversionLoc, resultType, input); 171 } 172 }; 173 } // end anonymous namespace 174 175 //===----------------------------------------------------------------------===// 176 // TestDialect 177 //===----------------------------------------------------------------------===// 178 179 static void testSideEffectOpGetEffect( 180 Operation *op, 181 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects); 182 183 // This is the implementation of a dialect fallback for `TestEffectOpInterface`. 184 struct TestOpEffectInterfaceFallback 185 : public TestEffectOpInterface::FallbackModel< 186 TestOpEffectInterfaceFallback> { 187 static bool classof(Operation *op) { 188 bool isSupportedOp = 189 op->getName().getStringRef() == "test.unregistered_side_effect_op"; 190 assert(isSupportedOp && "Unexpected dispatch"); 191 return isSupportedOp; 192 } 193 194 void 195 getEffects(Operation *op, 196 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 197 &effects) const { 198 testSideEffectOpGetEffect(op, effects); 199 } 200 }; 201 202 void TestDialect::initialize() { 203 registerAttributes(); 204 registerTypes(); 205 addOperations< 206 #define GET_OP_LIST 207 #include "TestOps.cpp.inc" 208 >(); 209 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 210 TestInlinerInterface>(); 211 allowUnknownOperations(); 212 213 // Instantiate our fallback op interface that we'll use on specific 214 // unregistered op. 215 fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback; 216 } 217 TestDialect::~TestDialect() { 218 delete static_cast<TestOpEffectInterfaceFallback *>( 219 fallbackEffectOpInterfaces); 220 } 221 222 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, 223 Type type, Location loc) { 224 return builder.create<TestOpConstant>(loc, type, value); 225 } 226 227 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, 228 OperationName opName) { 229 if (opName.getIdentifier() == "test.unregistered_side_effect_op" && 230 typeID == TypeID::get<TestEffectOpInterface>()) 231 return fallbackEffectOpInterfaces; 232 return nullptr; 233 } 234 235 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 236 NamedAttribute namedAttr) { 237 if (namedAttr.first == "test.invalid_attr") 238 return op->emitError() << "invalid to use 'test.invalid_attr'"; 239 return success(); 240 } 241 242 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 243 unsigned regionIndex, 244 unsigned argIndex, 245 NamedAttribute namedAttr) { 246 if (namedAttr.first == "test.invalid_attr") 247 return op->emitError() << "invalid to use 'test.invalid_attr'"; 248 return success(); 249 } 250 251 LogicalResult 252 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 253 unsigned resultIndex, 254 NamedAttribute namedAttr) { 255 if (namedAttr.first == "test.invalid_attr") 256 return op->emitError() << "invalid to use 'test.invalid_attr'"; 257 return success(); 258 } 259 260 Optional<Dialect::ParseOpHook> 261 TestDialect::getParseOperationHook(StringRef opName) const { 262 if (opName == "test.dialect_custom_printer") { 263 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 264 return parser.parseKeyword("custom_format"); 265 }}; 266 } 267 return None; 268 } 269 270 LogicalResult TestDialect::printOperation(Operation *op, 271 OpAsmPrinter &printer) const { 272 StringRef opName = op->getName().getStringRef(); 273 if (opName == "test.dialect_custom_printer") { 274 printer.getStream() << opName << " custom_format"; 275 return success(); 276 } 277 return failure(); 278 } 279 280 //===----------------------------------------------------------------------===// 281 // TestBranchOp 282 //===----------------------------------------------------------------------===// 283 284 Optional<MutableOperandRange> 285 TestBranchOp::getMutableSuccessorOperands(unsigned index) { 286 assert(index == 0 && "invalid successor index"); 287 return targetOperandsMutable(); 288 } 289 290 //===----------------------------------------------------------------------===// 291 // TestDialectCanonicalizerOp 292 //===----------------------------------------------------------------------===// 293 294 static LogicalResult 295 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op, 296 PatternRewriter &rewriter) { 297 rewriter.replaceOpWithNewOp<ConstantOp>(op, rewriter.getI32Type(), 298 rewriter.getI32IntegerAttr(42)); 299 return success(); 300 } 301 302 void TestDialect::getCanonicalizationPatterns( 303 RewritePatternSet &results) const { 304 results.add(&dialectCanonicalizationPattern); 305 } 306 307 //===----------------------------------------------------------------------===// 308 // TestFoldToCallOp 309 //===----------------------------------------------------------------------===// 310 311 namespace { 312 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 313 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 314 315 LogicalResult matchAndRewrite(FoldToCallOp op, 316 PatternRewriter &rewriter) const override { 317 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(), 318 ValueRange()); 319 return success(); 320 } 321 }; 322 } // end anonymous namespace 323 324 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results, 325 MLIRContext *context) { 326 results.add<FoldToCallOpPattern>(context); 327 } 328 329 //===----------------------------------------------------------------------===// 330 // Test Format* operations 331 //===----------------------------------------------------------------------===// 332 333 //===----------------------------------------------------------------------===// 334 // Parsing 335 336 static ParseResult parseCustomDirectiveOperands( 337 OpAsmParser &parser, OpAsmParser::OperandType &operand, 338 Optional<OpAsmParser::OperandType> &optOperand, 339 SmallVectorImpl<OpAsmParser::OperandType> &varOperands) { 340 if (parser.parseOperand(operand)) 341 return failure(); 342 if (succeeded(parser.parseOptionalComma())) { 343 optOperand.emplace(); 344 if (parser.parseOperand(*optOperand)) 345 return failure(); 346 } 347 if (parser.parseArrow() || parser.parseLParen() || 348 parser.parseOperandList(varOperands) || parser.parseRParen()) 349 return failure(); 350 return success(); 351 } 352 static ParseResult 353 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 354 Type &optOperandType, 355 SmallVectorImpl<Type> &varOperandTypes) { 356 if (parser.parseColon()) 357 return failure(); 358 359 if (parser.parseType(operandType)) 360 return failure(); 361 if (succeeded(parser.parseOptionalComma())) { 362 if (parser.parseType(optOperandType)) 363 return failure(); 364 } 365 if (parser.parseArrow() || parser.parseLParen() || 366 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 367 return failure(); 368 return success(); 369 } 370 static ParseResult 371 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 372 Type optOperandType, 373 const SmallVectorImpl<Type> &varOperandTypes) { 374 if (parser.parseKeyword("type_refs_capture")) 375 return failure(); 376 377 Type operandType2, optOperandType2; 378 SmallVector<Type, 1> varOperandTypes2; 379 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 380 varOperandTypes2)) 381 return failure(); 382 383 if (operandType != operandType2 || optOperandType != optOperandType2 || 384 varOperandTypes != varOperandTypes2) 385 return failure(); 386 387 return success(); 388 } 389 static ParseResult parseCustomDirectiveOperandsAndTypes( 390 OpAsmParser &parser, OpAsmParser::OperandType &operand, 391 Optional<OpAsmParser::OperandType> &optOperand, 392 SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType, 393 Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) { 394 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 395 parseCustomDirectiveResults(parser, operandType, optOperandType, 396 varOperandTypes)) 397 return failure(); 398 return success(); 399 } 400 static ParseResult parseCustomDirectiveRegions( 401 OpAsmParser &parser, Region ®ion, 402 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 403 if (parser.parseRegion(region)) 404 return failure(); 405 if (failed(parser.parseOptionalComma())) 406 return success(); 407 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 408 if (parser.parseRegion(*varRegion)) 409 return failure(); 410 varRegions.emplace_back(std::move(varRegion)); 411 return success(); 412 } 413 static ParseResult 414 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 415 SmallVectorImpl<Block *> &varSuccessors) { 416 if (parser.parseSuccessor(successor)) 417 return failure(); 418 if (failed(parser.parseOptionalComma())) 419 return success(); 420 Block *varSuccessor; 421 if (parser.parseSuccessor(varSuccessor)) 422 return failure(); 423 varSuccessors.append(2, varSuccessor); 424 return success(); 425 } 426 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 427 IntegerAttr &attr, 428 IntegerAttr &optAttr) { 429 if (parser.parseAttribute(attr)) 430 return failure(); 431 if (succeeded(parser.parseOptionalComma())) { 432 if (parser.parseAttribute(optAttr)) 433 return failure(); 434 } 435 return success(); 436 } 437 438 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, 439 NamedAttrList &attrs) { 440 return parser.parseOptionalAttrDict(attrs); 441 } 442 static ParseResult parseCustomDirectiveOptionalOperandRef( 443 OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) { 444 int64_t operandCount = 0; 445 if (parser.parseInteger(operandCount)) 446 return failure(); 447 bool expectedOptionalOperand = operandCount == 0; 448 return success(expectedOptionalOperand != optOperand.hasValue()); 449 } 450 451 //===----------------------------------------------------------------------===// 452 // Printing 453 454 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 455 Value operand, Value optOperand, 456 OperandRange varOperands) { 457 printer << operand; 458 if (optOperand) 459 printer << ", " << optOperand; 460 printer << " -> (" << varOperands << ")"; 461 } 462 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 463 Type operandType, Type optOperandType, 464 TypeRange varOperandTypes) { 465 printer << " : " << operandType; 466 if (optOperandType) 467 printer << ", " << optOperandType; 468 printer << " -> (" << varOperandTypes << ")"; 469 } 470 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 471 Operation *op, Type operandType, 472 Type optOperandType, 473 TypeRange varOperandTypes) { 474 printer << " type_refs_capture "; 475 printCustomDirectiveResults(printer, op, operandType, optOperandType, 476 varOperandTypes); 477 } 478 static void printCustomDirectiveOperandsAndTypes( 479 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 480 OperandRange varOperands, Type operandType, Type optOperandType, 481 TypeRange varOperandTypes) { 482 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 483 printCustomDirectiveResults(printer, op, operandType, optOperandType, 484 varOperandTypes); 485 } 486 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 487 Region ®ion, 488 MutableArrayRef<Region> varRegions) { 489 printer.printRegion(region); 490 if (!varRegions.empty()) { 491 printer << ", "; 492 for (Region ®ion : varRegions) 493 printer.printRegion(region); 494 } 495 } 496 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 497 Block *successor, 498 SuccessorRange varSuccessors) { 499 printer << successor; 500 if (!varSuccessors.empty()) 501 printer << ", " << varSuccessors.front(); 502 } 503 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 504 Attribute attribute, 505 Attribute optAttribute) { 506 printer << attribute; 507 if (optAttribute) 508 printer << ", " << optAttribute; 509 } 510 511 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 512 DictionaryAttr attrs) { 513 printer.printOptionalAttrDict(attrs.getValue()); 514 } 515 516 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, 517 Operation *op, 518 Value optOperand) { 519 printer << (optOperand ? "1" : "0"); 520 } 521 522 //===----------------------------------------------------------------------===// 523 // Test IsolatedRegionOp - parse passthrough region arguments. 524 //===----------------------------------------------------------------------===// 525 526 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, 527 OperationState &result) { 528 OpAsmParser::OperandType argInfo; 529 Type argType = parser.getBuilder().getIndexType(); 530 531 // Parse the input operand. 532 if (parser.parseOperand(argInfo) || 533 parser.resolveOperand(argInfo, argType, result.operands)) 534 return failure(); 535 536 // Parse the body region, and reuse the operand info as the argument info. 537 Region *body = result.addRegion(); 538 return parser.parseRegion(*body, argInfo, argType, 539 /*enableNameShadowing=*/true); 540 } 541 542 static void print(OpAsmPrinter &p, IsolatedRegionOp op) { 543 p << "test.isolated_region "; 544 p.printOperand(op.getOperand()); 545 p.shadowRegionArgs(op.region(), op.getOperand()); 546 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 547 } 548 549 //===----------------------------------------------------------------------===// 550 // Test SSACFGRegionOp 551 //===----------------------------------------------------------------------===// 552 553 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 554 return RegionKind::SSACFG; 555 } 556 557 //===----------------------------------------------------------------------===// 558 // Test GraphRegionOp 559 //===----------------------------------------------------------------------===// 560 561 static ParseResult parseGraphRegionOp(OpAsmParser &parser, 562 OperationState &result) { 563 // Parse the body region, and reuse the operand info as the argument info. 564 Region *body = result.addRegion(); 565 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 566 } 567 568 static void print(OpAsmPrinter &p, GraphRegionOp op) { 569 p << "test.graph_region "; 570 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 571 } 572 573 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 574 return RegionKind::Graph; 575 } 576 577 //===----------------------------------------------------------------------===// 578 // Test AffineScopeOp 579 //===----------------------------------------------------------------------===// 580 581 static ParseResult parseAffineScopeOp(OpAsmParser &parser, 582 OperationState &result) { 583 // Parse the body region, and reuse the operand info as the argument info. 584 Region *body = result.addRegion(); 585 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 586 } 587 588 static void print(OpAsmPrinter &p, AffineScopeOp op) { 589 p << "test.affine_scope "; 590 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 591 } 592 593 //===----------------------------------------------------------------------===// 594 // Test parser. 595 //===----------------------------------------------------------------------===// 596 597 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser, 598 OperationState &result) { 599 if (parser.parseOptionalColon()) 600 return success(); 601 uint64_t numResults; 602 if (parser.parseInteger(numResults)) 603 return failure(); 604 605 IndexType type = parser.getBuilder().getIndexType(); 606 for (unsigned i = 0; i < numResults; ++i) 607 result.addTypes(type); 608 return success(); 609 } 610 611 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) { 612 p << ParseIntegerLiteralOp::getOperationName(); 613 if (unsigned numResults = op->getNumResults()) 614 p << " : " << numResults; 615 } 616 617 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser, 618 OperationState &result) { 619 StringRef keyword; 620 if (parser.parseKeyword(&keyword)) 621 return failure(); 622 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 623 return success(); 624 } 625 626 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) { 627 p << ParseWrappedKeywordOp::getOperationName() << " " << op.keyword(); 628 } 629 630 //===----------------------------------------------------------------------===// 631 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 632 633 static ParseResult parseWrappingRegionOp(OpAsmParser &parser, 634 OperationState &result) { 635 if (parser.parseKeyword("wraps")) 636 return failure(); 637 638 // Parse the wrapped op in a region 639 Region &body = *result.addRegion(); 640 body.push_back(new Block); 641 Block &block = body.back(); 642 Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); 643 if (!wrapped_op) 644 return failure(); 645 646 // Create a return terminator in the inner region, pass as operand to the 647 // terminator the returned values from the wrapped operation. 648 SmallVector<Value, 8> return_operands(wrapped_op->getResults()); 649 OpBuilder builder(parser.getBuilder().getContext()); 650 builder.setInsertionPointToEnd(&block); 651 builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands); 652 653 // Get the results type for the wrapping op from the terminator operands. 654 Operation &return_op = body.back().back(); 655 result.types.append(return_op.operand_type_begin(), 656 return_op.operand_type_end()); 657 658 // Use the location of the wrapped op for the "test.wrapping_region" op. 659 result.location = wrapped_op->getLoc(); 660 661 return success(); 662 } 663 664 static void print(OpAsmPrinter &p, WrappingRegionOp op) { 665 p << op.getOperationName() << " wraps "; 666 p.printGenericOp(&op.region().front().front()); 667 } 668 669 //===----------------------------------------------------------------------===// 670 // Test PolyForOp - parse list of region arguments. 671 //===----------------------------------------------------------------------===// 672 673 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { 674 SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 675 // Parse list of region arguments without a delimiter. 676 if (parser.parseRegionArgumentList(ivsInfo)) 677 return failure(); 678 679 // Parse the body region. 680 Region *body = result.addRegion(); 681 auto &builder = parser.getBuilder(); 682 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 683 return parser.parseRegion(*body, ivsInfo, argTypes); 684 } 685 686 //===----------------------------------------------------------------------===// 687 // Test removing op with inner ops. 688 //===----------------------------------------------------------------------===// 689 690 namespace { 691 struct TestRemoveOpWithInnerOps 692 : public OpRewritePattern<TestOpWithRegionPattern> { 693 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 694 695 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 696 PatternRewriter &rewriter) const override { 697 rewriter.eraseOp(op); 698 return success(); 699 } 700 }; 701 } // end anonymous namespace 702 703 void TestOpWithRegionPattern::getCanonicalizationPatterns( 704 RewritePatternSet &results, MLIRContext *context) { 705 results.add<TestRemoveOpWithInnerOps>(context); 706 } 707 708 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 709 return operand(); 710 } 711 712 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) { 713 return getValue(); 714 } 715 716 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 717 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 718 for (Value input : this->operands()) { 719 results.push_back(input); 720 } 721 return success(); 722 } 723 724 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 725 assert(operands.size() == 1); 726 if (operands.front()) { 727 (*this)->setAttr("attr", operands.front()); 728 return getResult(); 729 } 730 return {}; 731 } 732 733 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) { 734 return getOperand(); 735 } 736 737 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 738 MLIRContext *, Optional<Location> location, ValueRange operands, 739 DictionaryAttr attributes, RegionRange regions, 740 SmallVectorImpl<Type> &inferredReturnTypes) { 741 if (operands[0].getType() != operands[1].getType()) { 742 return emitOptionalError(location, "operand type mismatch ", 743 operands[0].getType(), " vs ", 744 operands[1].getType()); 745 } 746 inferredReturnTypes.assign({operands[0].getType()}); 747 return success(); 748 } 749 750 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 751 MLIRContext *context, Optional<Location> location, ValueRange operands, 752 DictionaryAttr attributes, RegionRange regions, 753 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 754 // Create return type consisting of the last element of the first operand. 755 auto operandType = *operands.getTypes().begin(); 756 auto sval = operandType.dyn_cast<ShapedType>(); 757 if (!sval) { 758 return emitOptionalError(location, "only shaped type operands allowed"); 759 } 760 int64_t dim = 761 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 762 auto type = IntegerType::get(context, 17); 763 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 764 return success(); 765 } 766 767 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 768 OpBuilder &builder, ValueRange operands, 769 llvm::SmallVectorImpl<Value> &shapes) { 770 shapes = SmallVector<Value, 1>{ 771 builder.createOrFold<memref::DimOp>(getLoc(), operands.front(), 0)}; 772 return success(); 773 } 774 775 LogicalResult 776 OpWithResultShapePerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim( 777 OpBuilder &builder, 778 llvm::SmallVectorImpl<llvm::SmallVector<Value>> &shapes) { 779 SmallVector<Value> operand1Shape, operand2Shape; 780 Location loc = getLoc(); 781 for (auto i : 782 llvm::seq<int>(0, operand1().getType().cast<ShapedType>().getRank())) { 783 operand1Shape.push_back(builder.create<memref::DimOp>(loc, operand1(), i)); 784 } 785 for (auto i : 786 llvm::seq<int>(0, operand2().getType().cast<ShapedType>().getRank())) { 787 operand2Shape.push_back(builder.create<memref::DimOp>(loc, operand2(), i)); 788 } 789 shapes.emplace_back(std::move(operand2Shape)); 790 shapes.emplace_back(std::move(operand1Shape)); 791 return success(); 792 } 793 794 //===----------------------------------------------------------------------===// 795 // Test SideEffect interfaces 796 //===----------------------------------------------------------------------===// 797 798 namespace { 799 /// A test resource for side effects. 800 struct TestResource : public SideEffects::Resource::Base<TestResource> { 801 StringRef getName() final { return "<Test>"; } 802 }; 803 } // end anonymous namespace 804 805 static void testSideEffectOpGetEffect( 806 Operation *op, 807 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 808 &effects) { 809 auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter"); 810 if (!effectsAttr) 811 return; 812 813 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); 814 } 815 816 void SideEffectOp::getEffects( 817 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 818 // Check for an effects attribute on the op instance. 819 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects"); 820 if (!effectsAttr) 821 return; 822 823 // If there is one, it is an array of dictionary attributes that hold 824 // information on the effects of this operation. 825 for (Attribute element : effectsAttr) { 826 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 827 828 // Get the specific memory effect. 829 MemoryEffects::Effect *effect = 830 StringSwitch<MemoryEffects::Effect *>( 831 effectElement.get("effect").cast<StringAttr>().getValue()) 832 .Case("allocate", MemoryEffects::Allocate::get()) 833 .Case("free", MemoryEffects::Free::get()) 834 .Case("read", MemoryEffects::Read::get()) 835 .Case("write", MemoryEffects::Write::get()); 836 837 // Check for a non-default resource to use. 838 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 839 if (effectElement.get("test_resource")) 840 resource = TestResource::get(); 841 842 // Check for a result to affect. 843 if (effectElement.get("on_result")) 844 effects.emplace_back(effect, getResult(), resource); 845 else if (Attribute ref = effectElement.get("on_reference")) 846 effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource); 847 else 848 effects.emplace_back(effect, resource); 849 } 850 } 851 852 void SideEffectOp::getEffects( 853 SmallVectorImpl<TestEffects::EffectInstance> &effects) { 854 testSideEffectOpGetEffect(getOperation(), effects); 855 } 856 857 //===----------------------------------------------------------------------===// 858 // StringAttrPrettyNameOp 859 //===----------------------------------------------------------------------===// 860 861 // This op has fancy handling of its SSA result name. 862 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, 863 OperationState &result) { 864 // Add the result types. 865 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 866 result.addTypes(parser.getBuilder().getIntegerType(32)); 867 868 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 869 return failure(); 870 871 // If the attribute dictionary contains no 'names' attribute, infer it from 872 // the SSA name (if specified). 873 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 874 return attr.first == "names"; 875 }); 876 877 // If there was no name specified, check to see if there was a useful name 878 // specified in the asm file. 879 if (hadNames || parser.getNumResults() == 0) 880 return success(); 881 882 SmallVector<StringRef, 4> names; 883 auto *context = result.getContext(); 884 885 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 886 auto resultName = parser.getResultName(i); 887 StringRef nameStr; 888 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 889 nameStr = resultName.first; 890 891 names.push_back(nameStr); 892 } 893 894 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 895 result.attributes.push_back({Identifier::get("names", context), namesAttr}); 896 return success(); 897 } 898 899 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { 900 p << "test.string_attr_pretty_name"; 901 902 // Note that we only need to print the "name" attribute if the asmprinter 903 // result name disagrees with it. This can happen in strange cases, e.g. 904 // when there are conflicts. 905 bool namesDisagree = op.names().size() != op.getNumResults(); 906 907 SmallString<32> resultNameStr; 908 for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { 909 resultNameStr.clear(); 910 llvm::raw_svector_ostream tmpStream(resultNameStr); 911 p.printOperand(op.getResult(i), tmpStream); 912 913 auto expectedName = op.names()[i].dyn_cast<StringAttr>(); 914 if (!expectedName || 915 tmpStream.str().drop_front() != expectedName.getValue()) { 916 namesDisagree = true; 917 } 918 } 919 920 if (namesDisagree) 921 p.printOptionalAttrDictWithKeyword(op->getAttrs()); 922 else 923 p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"}); 924 } 925 926 // We set the SSA name in the asm syntax to the contents of the name 927 // attribute. 928 void StringAttrPrettyNameOp::getAsmResultNames( 929 function_ref<void(Value, StringRef)> setNameFn) { 930 931 auto value = names(); 932 for (size_t i = 0, e = value.size(); i != e; ++i) 933 if (auto str = value[i].dyn_cast<StringAttr>()) 934 if (!str.getValue().empty()) 935 setNameFn(getResult(i), str.getValue()); 936 } 937 938 //===----------------------------------------------------------------------===// 939 // RegionIfOp 940 //===----------------------------------------------------------------------===// 941 942 static void print(OpAsmPrinter &p, RegionIfOp op) { 943 p << RegionIfOp::getOperationName() << " "; 944 p.printOperands(op.getOperands()); 945 p << ": " << op.getOperandTypes(); 946 p.printArrowTypeList(op.getResultTypes()); 947 p << " then"; 948 p.printRegion(op.thenRegion(), 949 /*printEntryBlockArgs=*/true, 950 /*printBlockTerminators=*/true); 951 p << " else"; 952 p.printRegion(op.elseRegion(), 953 /*printEntryBlockArgs=*/true, 954 /*printBlockTerminators=*/true); 955 p << " join"; 956 p.printRegion(op.joinRegion(), 957 /*printEntryBlockArgs=*/true, 958 /*printBlockTerminators=*/true); 959 } 960 961 static ParseResult parseRegionIfOp(OpAsmParser &parser, 962 OperationState &result) { 963 SmallVector<OpAsmParser::OperandType, 2> operandInfos; 964 SmallVector<Type, 2> operandTypes; 965 966 result.regions.reserve(3); 967 Region *thenRegion = result.addRegion(); 968 Region *elseRegion = result.addRegion(); 969 Region *joinRegion = result.addRegion(); 970 971 // Parse operand, type and arrow type lists. 972 if (parser.parseOperandList(operandInfos) || 973 parser.parseColonTypeList(operandTypes) || 974 parser.parseArrowTypeList(result.types)) 975 return failure(); 976 977 // Parse all attached regions. 978 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 979 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 980 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 981 return failure(); 982 983 return parser.resolveOperands(operandInfos, operandTypes, 984 parser.getCurrentLocation(), result.operands); 985 } 986 987 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 988 assert(index < 2 && "invalid region index"); 989 return getOperands(); 990 } 991 992 void RegionIfOp::getSuccessorRegions( 993 Optional<unsigned> index, ArrayRef<Attribute> operands, 994 SmallVectorImpl<RegionSuccessor> ®ions) { 995 // We always branch to the join region. 996 if (index.hasValue()) { 997 if (index.getValue() < 2) 998 regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs())); 999 else 1000 regions.push_back(RegionSuccessor(getResults())); 1001 return; 1002 } 1003 1004 // The then and else regions are the entry regions of this op. 1005 regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs())); 1006 regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs())); 1007 } 1008 1009 #include "TestOpEnums.cpp.inc" 1010 #include "TestOpInterfaces.cpp.inc" 1011 #include "TestOpStructs.cpp.inc" 1012 #include "TestTypeInterfaces.cpp.inc" 1013 1014 #define GET_OP_CLASSES 1015 #include "TestOps.cpp.inc" 1016