1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestAttributes.h" 11 #include "TestInterfaces.h" 12 #include "TestTypes.h" 13 #include "mlir/Dialect/DLTI/DLTI.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/DialectImplementation.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Reducer/ReductionPatternInterface.h" 21 #include "mlir/Transforms/FoldUtils.h" 22 #include "mlir/Transforms/InliningUtils.h" 23 #include "llvm/ADT/StringSwitch.h" 24 25 // Include this before the using namespace lines below to 26 // test that we don't have namespace dependencies. 27 #include "TestOpsDialect.cpp.inc" 28 29 using namespace mlir; 30 using namespace test; 31 32 void test::registerTestDialect(DialectRegistry ®istry) { 33 registry.insert<TestDialect>(); 34 } 35 36 //===----------------------------------------------------------------------===// 37 // TestDialect Interfaces 38 //===----------------------------------------------------------------------===// 39 40 namespace { 41 42 /// Testing the correctness of some traits. 43 static_assert( 44 llvm::is_detected<OpTrait::has_implicit_terminator_t, 45 SingleBlockImplicitTerminatorOp>::value, 46 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp"); 47 static_assert(OpTrait::hasSingleBlockImplicitTerminator< 48 SingleBlockImplicitTerminatorOp>::value, 49 "hasSingleBlockImplicitTerminator does not match " 50 "SingleBlockImplicitTerminatorOp"); 51 52 // Test support for interacting with the AsmPrinter. 53 struct TestOpAsmInterface : public OpAsmDialectInterface { 54 using OpAsmDialectInterface::OpAsmDialectInterface; 55 56 AliasResult getAlias(Attribute attr, raw_ostream &os) const final { 57 StringAttr strAttr = attr.dyn_cast<StringAttr>(); 58 if (!strAttr) 59 return AliasResult::NoAlias; 60 61 // Check the contents of the string attribute to see what the test alias 62 // should be named. 63 Optional<StringRef> aliasName = 64 StringSwitch<Optional<StringRef>>(strAttr.getValue()) 65 .Case("alias_test:dot_in_name", StringRef("test.alias")) 66 .Case("alias_test:trailing_digit", StringRef("test_alias0")) 67 .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) 68 .Case("alias_test:sanitize_conflict_a", 69 StringRef("test_alias_conflict0")) 70 .Case("alias_test:sanitize_conflict_b", 71 StringRef("test_alias_conflict0_")) 72 .Case("alias_test:tensor_encoding", StringRef("test_encoding")) 73 .Default(llvm::None); 74 if (!aliasName) 75 return AliasResult::NoAlias; 76 77 os << *aliasName; 78 return AliasResult::FinalAlias; 79 } 80 81 AliasResult getAlias(Type type, raw_ostream &os) const final { 82 if (auto tupleType = type.dyn_cast<TupleType>()) { 83 if (tupleType.size() > 0 && 84 llvm::all_of(tupleType.getTypes(), [](Type elemType) { 85 return elemType.isa<SimpleAType>(); 86 })) { 87 os << "test_tuple"; 88 return AliasResult::FinalAlias; 89 } 90 } 91 return AliasResult::NoAlias; 92 } 93 94 void getAsmResultNames(Operation *op, 95 OpAsmSetValueNameFn setNameFn) const final { 96 if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) 97 setNameFn(asmOp, "result"); 98 } 99 100 void getAsmBlockArgumentNames(Block *block, 101 OpAsmSetValueNameFn setNameFn) const final { 102 auto op = block->getParentOp(); 103 auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names"); 104 if (!arrayAttr) 105 return; 106 auto args = block->getArguments(); 107 auto e = std::min(arrayAttr.size(), args.size()); 108 for (unsigned i = 0; i < e; ++i) { 109 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 110 setNameFn(args[i], strAttr.getValue()); 111 } 112 } 113 }; 114 115 struct TestDialectFoldInterface : public DialectFoldInterface { 116 using DialectFoldInterface::DialectFoldInterface; 117 118 /// Registered hook to check if the given region, which is attached to an 119 /// operation that is *not* isolated from above, should be used when 120 /// materializing constants. 121 bool shouldMaterializeInto(Region *region) const final { 122 // If this is a one region operation, then insert into it. 123 return isa<OneRegionOp>(region->getParentOp()); 124 } 125 }; 126 127 /// This class defines the interface for handling inlining with standard 128 /// operations. 129 struct TestInlinerInterface : public DialectInlinerInterface { 130 using DialectInlinerInterface::DialectInlinerInterface; 131 132 //===--------------------------------------------------------------------===// 133 // Analysis Hooks 134 //===--------------------------------------------------------------------===// 135 136 bool isLegalToInline(Operation *call, Operation *callable, 137 bool wouldBeCloned) const final { 138 // Don't allow inlining calls that are marked `noinline`. 139 return !call->hasAttr("noinline"); 140 } 141 bool isLegalToInline(Region *, Region *, bool, 142 BlockAndValueMapping &) const final { 143 // Inlining into test dialect regions is legal. 144 return true; 145 } 146 bool isLegalToInline(Operation *, Region *, bool, 147 BlockAndValueMapping &) const final { 148 return true; 149 } 150 151 bool shouldAnalyzeRecursively(Operation *op) const final { 152 // Analyze recursively if this is not a functional region operation, it 153 // froms a separate functional scope. 154 return !isa<FunctionalRegionOp>(op); 155 } 156 157 //===--------------------------------------------------------------------===// 158 // Transformation Hooks 159 //===--------------------------------------------------------------------===// 160 161 /// Handle the given inlined terminator by replacing it with a new operation 162 /// as necessary. 163 void handleTerminator(Operation *op, 164 ArrayRef<Value> valuesToRepl) const final { 165 // Only handle "test.return" here. 166 auto returnOp = dyn_cast<TestReturnOp>(op); 167 if (!returnOp) 168 return; 169 170 // Replace the values directly with the return operands. 171 assert(returnOp.getNumOperands() == valuesToRepl.size()); 172 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 173 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 174 } 175 176 /// Attempt to materialize a conversion for a type mismatch between a call 177 /// from this dialect, and a callable region. This method should generate an 178 /// operation that takes 'input' as the only operand, and produces a single 179 /// result of 'resultType'. If a conversion can not be generated, nullptr 180 /// should be returned. 181 Operation *materializeCallConversion(OpBuilder &builder, Value input, 182 Type resultType, 183 Location conversionLoc) const final { 184 // Only allow conversion for i16/i32 types. 185 if (!(resultType.isSignlessInteger(16) || 186 resultType.isSignlessInteger(32)) || 187 !(input.getType().isSignlessInteger(16) || 188 input.getType().isSignlessInteger(32))) 189 return nullptr; 190 return builder.create<TestCastOp>(conversionLoc, resultType, input); 191 } 192 193 void processInlinedCallBlocks( 194 Operation *call, 195 iterator_range<Region::iterator> inlinedBlocks) const final { 196 if (!isa<ConversionCallOp>(call)) 197 return; 198 199 // Set attributed on all ops in the inlined blocks. 200 for (Block &block : inlinedBlocks) { 201 block.walk([&](Operation *op) { 202 op->setAttr("inlined_conversion", UnitAttr::get(call->getContext())); 203 }); 204 } 205 } 206 }; 207 208 struct TestReductionPatternInterface : public DialectReductionPatternInterface { 209 public: 210 TestReductionPatternInterface(Dialect *dialect) 211 : DialectReductionPatternInterface(dialect) {} 212 213 virtual void 214 populateReductionPatterns(RewritePatternSet &patterns) const final { 215 populateTestReductionPatterns(patterns); 216 } 217 }; 218 219 } // end anonymous namespace 220 221 //===----------------------------------------------------------------------===// 222 // TestDialect 223 //===----------------------------------------------------------------------===// 224 225 static void testSideEffectOpGetEffect( 226 Operation *op, 227 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects); 228 229 // This is the implementation of a dialect fallback for `TestEffectOpInterface`. 230 struct TestOpEffectInterfaceFallback 231 : public TestEffectOpInterface::FallbackModel< 232 TestOpEffectInterfaceFallback> { 233 static bool classof(Operation *op) { 234 bool isSupportedOp = 235 op->getName().getStringRef() == "test.unregistered_side_effect_op"; 236 assert(isSupportedOp && "Unexpected dispatch"); 237 return isSupportedOp; 238 } 239 240 void 241 getEffects(Operation *op, 242 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 243 &effects) const { 244 testSideEffectOpGetEffect(op, effects); 245 } 246 }; 247 248 void TestDialect::initialize() { 249 registerAttributes(); 250 registerTypes(); 251 addOperations< 252 #define GET_OP_LIST 253 #include "TestOps.cpp.inc" 254 >(); 255 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 256 TestInlinerInterface, TestReductionPatternInterface>(); 257 allowUnknownOperations(); 258 259 // Instantiate our fallback op interface that we'll use on specific 260 // unregistered op. 261 fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback; 262 } 263 TestDialect::~TestDialect() { 264 delete static_cast<TestOpEffectInterfaceFallback *>( 265 fallbackEffectOpInterfaces); 266 } 267 268 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, 269 Type type, Location loc) { 270 return builder.create<TestOpConstant>(loc, type, value); 271 } 272 273 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, 274 OperationName opName) { 275 if (opName.getIdentifier() == "test.unregistered_side_effect_op" && 276 typeID == TypeID::get<TestEffectOpInterface>()) 277 return fallbackEffectOpInterfaces; 278 return nullptr; 279 } 280 281 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 282 NamedAttribute namedAttr) { 283 if (namedAttr.first == "test.invalid_attr") 284 return op->emitError() << "invalid to use 'test.invalid_attr'"; 285 return success(); 286 } 287 288 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 289 unsigned regionIndex, 290 unsigned argIndex, 291 NamedAttribute namedAttr) { 292 if (namedAttr.first == "test.invalid_attr") 293 return op->emitError() << "invalid to use 'test.invalid_attr'"; 294 return success(); 295 } 296 297 LogicalResult 298 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 299 unsigned resultIndex, 300 NamedAttribute namedAttr) { 301 if (namedAttr.first == "test.invalid_attr") 302 return op->emitError() << "invalid to use 'test.invalid_attr'"; 303 return success(); 304 } 305 306 Optional<Dialect::ParseOpHook> 307 TestDialect::getParseOperationHook(StringRef opName) const { 308 if (opName == "test.dialect_custom_printer") { 309 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 310 return parser.parseKeyword("custom_format"); 311 }}; 312 } 313 return None; 314 } 315 316 LogicalResult TestDialect::printOperation(Operation *op, 317 OpAsmPrinter &printer) const { 318 StringRef opName = op->getName().getStringRef(); 319 if (opName == "test.dialect_custom_printer") { 320 printer.getStream() << opName << " custom_format"; 321 return success(); 322 } 323 return failure(); 324 } 325 326 //===----------------------------------------------------------------------===// 327 // TestBranchOp 328 //===----------------------------------------------------------------------===// 329 330 Optional<MutableOperandRange> 331 TestBranchOp::getMutableSuccessorOperands(unsigned index) { 332 assert(index == 0 && "invalid successor index"); 333 return targetOperandsMutable(); 334 } 335 336 //===----------------------------------------------------------------------===// 337 // TestDialectCanonicalizerOp 338 //===----------------------------------------------------------------------===// 339 340 static LogicalResult 341 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op, 342 PatternRewriter &rewriter) { 343 rewriter.replaceOpWithNewOp<ConstantOp>(op, rewriter.getI32Type(), 344 rewriter.getI32IntegerAttr(42)); 345 return success(); 346 } 347 348 void TestDialect::getCanonicalizationPatterns( 349 RewritePatternSet &results) const { 350 results.add(&dialectCanonicalizationPattern); 351 } 352 353 //===----------------------------------------------------------------------===// 354 // TestFoldToCallOp 355 //===----------------------------------------------------------------------===// 356 357 namespace { 358 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 359 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 360 361 LogicalResult matchAndRewrite(FoldToCallOp op, 362 PatternRewriter &rewriter) const override { 363 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(), 364 ValueRange()); 365 return success(); 366 } 367 }; 368 } // end anonymous namespace 369 370 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results, 371 MLIRContext *context) { 372 results.add<FoldToCallOpPattern>(context); 373 } 374 375 //===----------------------------------------------------------------------===// 376 // Test Format* operations 377 //===----------------------------------------------------------------------===// 378 379 //===----------------------------------------------------------------------===// 380 // Parsing 381 382 static ParseResult parseCustomDirectiveOperands( 383 OpAsmParser &parser, OpAsmParser::OperandType &operand, 384 Optional<OpAsmParser::OperandType> &optOperand, 385 SmallVectorImpl<OpAsmParser::OperandType> &varOperands) { 386 if (parser.parseOperand(operand)) 387 return failure(); 388 if (succeeded(parser.parseOptionalComma())) { 389 optOperand.emplace(); 390 if (parser.parseOperand(*optOperand)) 391 return failure(); 392 } 393 if (parser.parseArrow() || parser.parseLParen() || 394 parser.parseOperandList(varOperands) || parser.parseRParen()) 395 return failure(); 396 return success(); 397 } 398 static ParseResult 399 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 400 Type &optOperandType, 401 SmallVectorImpl<Type> &varOperandTypes) { 402 if (parser.parseColon()) 403 return failure(); 404 405 if (parser.parseType(operandType)) 406 return failure(); 407 if (succeeded(parser.parseOptionalComma())) { 408 if (parser.parseType(optOperandType)) 409 return failure(); 410 } 411 if (parser.parseArrow() || parser.parseLParen() || 412 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 413 return failure(); 414 return success(); 415 } 416 static ParseResult 417 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 418 Type optOperandType, 419 const SmallVectorImpl<Type> &varOperandTypes) { 420 if (parser.parseKeyword("type_refs_capture")) 421 return failure(); 422 423 Type operandType2, optOperandType2; 424 SmallVector<Type, 1> varOperandTypes2; 425 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 426 varOperandTypes2)) 427 return failure(); 428 429 if (operandType != operandType2 || optOperandType != optOperandType2 || 430 varOperandTypes != varOperandTypes2) 431 return failure(); 432 433 return success(); 434 } 435 static ParseResult parseCustomDirectiveOperandsAndTypes( 436 OpAsmParser &parser, OpAsmParser::OperandType &operand, 437 Optional<OpAsmParser::OperandType> &optOperand, 438 SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType, 439 Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) { 440 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 441 parseCustomDirectiveResults(parser, operandType, optOperandType, 442 varOperandTypes)) 443 return failure(); 444 return success(); 445 } 446 static ParseResult parseCustomDirectiveRegions( 447 OpAsmParser &parser, Region ®ion, 448 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 449 if (parser.parseRegion(region)) 450 return failure(); 451 if (failed(parser.parseOptionalComma())) 452 return success(); 453 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 454 if (parser.parseRegion(*varRegion)) 455 return failure(); 456 varRegions.emplace_back(std::move(varRegion)); 457 return success(); 458 } 459 static ParseResult 460 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 461 SmallVectorImpl<Block *> &varSuccessors) { 462 if (parser.parseSuccessor(successor)) 463 return failure(); 464 if (failed(parser.parseOptionalComma())) 465 return success(); 466 Block *varSuccessor; 467 if (parser.parseSuccessor(varSuccessor)) 468 return failure(); 469 varSuccessors.append(2, varSuccessor); 470 return success(); 471 } 472 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 473 IntegerAttr &attr, 474 IntegerAttr &optAttr) { 475 if (parser.parseAttribute(attr)) 476 return failure(); 477 if (succeeded(parser.parseOptionalComma())) { 478 if (parser.parseAttribute(optAttr)) 479 return failure(); 480 } 481 return success(); 482 } 483 484 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, 485 NamedAttrList &attrs) { 486 return parser.parseOptionalAttrDict(attrs); 487 } 488 static ParseResult parseCustomDirectiveOptionalOperandRef( 489 OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) { 490 int64_t operandCount = 0; 491 if (parser.parseInteger(operandCount)) 492 return failure(); 493 bool expectedOptionalOperand = operandCount == 0; 494 return success(expectedOptionalOperand != optOperand.hasValue()); 495 } 496 497 //===----------------------------------------------------------------------===// 498 // Printing 499 500 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 501 Value operand, Value optOperand, 502 OperandRange varOperands) { 503 printer << operand; 504 if (optOperand) 505 printer << ", " << optOperand; 506 printer << " -> (" << varOperands << ")"; 507 } 508 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 509 Type operandType, Type optOperandType, 510 TypeRange varOperandTypes) { 511 printer << " : " << operandType; 512 if (optOperandType) 513 printer << ", " << optOperandType; 514 printer << " -> (" << varOperandTypes << ")"; 515 } 516 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 517 Operation *op, Type operandType, 518 Type optOperandType, 519 TypeRange varOperandTypes) { 520 printer << " type_refs_capture "; 521 printCustomDirectiveResults(printer, op, operandType, optOperandType, 522 varOperandTypes); 523 } 524 static void printCustomDirectiveOperandsAndTypes( 525 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 526 OperandRange varOperands, Type operandType, Type optOperandType, 527 TypeRange varOperandTypes) { 528 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 529 printCustomDirectiveResults(printer, op, operandType, optOperandType, 530 varOperandTypes); 531 } 532 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 533 Region ®ion, 534 MutableArrayRef<Region> varRegions) { 535 printer.printRegion(region); 536 if (!varRegions.empty()) { 537 printer << ", "; 538 for (Region ®ion : varRegions) 539 printer.printRegion(region); 540 } 541 } 542 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 543 Block *successor, 544 SuccessorRange varSuccessors) { 545 printer << successor; 546 if (!varSuccessors.empty()) 547 printer << ", " << varSuccessors.front(); 548 } 549 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 550 Attribute attribute, 551 Attribute optAttribute) { 552 printer << attribute; 553 if (optAttribute) 554 printer << ", " << optAttribute; 555 } 556 557 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 558 DictionaryAttr attrs) { 559 printer.printOptionalAttrDict(attrs.getValue()); 560 } 561 562 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, 563 Operation *op, 564 Value optOperand) { 565 printer << (optOperand ? "1" : "0"); 566 } 567 568 //===----------------------------------------------------------------------===// 569 // Test IsolatedRegionOp - parse passthrough region arguments. 570 //===----------------------------------------------------------------------===// 571 572 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, 573 OperationState &result) { 574 OpAsmParser::OperandType argInfo; 575 Type argType = parser.getBuilder().getIndexType(); 576 577 // Parse the input operand. 578 if (parser.parseOperand(argInfo) || 579 parser.resolveOperand(argInfo, argType, result.operands)) 580 return failure(); 581 582 // Parse the body region, and reuse the operand info as the argument info. 583 Region *body = result.addRegion(); 584 return parser.parseRegion(*body, argInfo, argType, 585 /*enableNameShadowing=*/true); 586 } 587 588 static void print(OpAsmPrinter &p, IsolatedRegionOp op) { 589 p << "test.isolated_region "; 590 p.printOperand(op.getOperand()); 591 p.shadowRegionArgs(op.region(), op.getOperand()); 592 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 593 } 594 595 //===----------------------------------------------------------------------===// 596 // Test SSACFGRegionOp 597 //===----------------------------------------------------------------------===// 598 599 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 600 return RegionKind::SSACFG; 601 } 602 603 //===----------------------------------------------------------------------===// 604 // Test GraphRegionOp 605 //===----------------------------------------------------------------------===// 606 607 static ParseResult parseGraphRegionOp(OpAsmParser &parser, 608 OperationState &result) { 609 // Parse the body region, and reuse the operand info as the argument info. 610 Region *body = result.addRegion(); 611 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 612 } 613 614 static void print(OpAsmPrinter &p, GraphRegionOp op) { 615 p << "test.graph_region "; 616 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 617 } 618 619 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 620 return RegionKind::Graph; 621 } 622 623 //===----------------------------------------------------------------------===// 624 // Test AffineScopeOp 625 //===----------------------------------------------------------------------===// 626 627 static ParseResult parseAffineScopeOp(OpAsmParser &parser, 628 OperationState &result) { 629 // Parse the body region, and reuse the operand info as the argument info. 630 Region *body = result.addRegion(); 631 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 632 } 633 634 static void print(OpAsmPrinter &p, AffineScopeOp op) { 635 p << "test.affine_scope "; 636 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 637 } 638 639 //===----------------------------------------------------------------------===// 640 // Test parser. 641 //===----------------------------------------------------------------------===// 642 643 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser, 644 OperationState &result) { 645 if (parser.parseOptionalColon()) 646 return success(); 647 uint64_t numResults; 648 if (parser.parseInteger(numResults)) 649 return failure(); 650 651 IndexType type = parser.getBuilder().getIndexType(); 652 for (unsigned i = 0; i < numResults; ++i) 653 result.addTypes(type); 654 return success(); 655 } 656 657 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) { 658 p << ParseIntegerLiteralOp::getOperationName(); 659 if (unsigned numResults = op->getNumResults()) 660 p << " : " << numResults; 661 } 662 663 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser, 664 OperationState &result) { 665 StringRef keyword; 666 if (parser.parseKeyword(&keyword)) 667 return failure(); 668 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 669 return success(); 670 } 671 672 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) { 673 p << ParseWrappedKeywordOp::getOperationName() << " " << op.keyword(); 674 } 675 676 //===----------------------------------------------------------------------===// 677 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 678 679 static ParseResult parseWrappingRegionOp(OpAsmParser &parser, 680 OperationState &result) { 681 if (parser.parseKeyword("wraps")) 682 return failure(); 683 684 // Parse the wrapped op in a region 685 Region &body = *result.addRegion(); 686 body.push_back(new Block); 687 Block &block = body.back(); 688 Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); 689 if (!wrapped_op) 690 return failure(); 691 692 // Create a return terminator in the inner region, pass as operand to the 693 // terminator the returned values from the wrapped operation. 694 SmallVector<Value, 8> return_operands(wrapped_op->getResults()); 695 OpBuilder builder(parser.getBuilder().getContext()); 696 builder.setInsertionPointToEnd(&block); 697 builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands); 698 699 // Get the results type for the wrapping op from the terminator operands. 700 Operation &return_op = body.back().back(); 701 result.types.append(return_op.operand_type_begin(), 702 return_op.operand_type_end()); 703 704 // Use the location of the wrapped op for the "test.wrapping_region" op. 705 result.location = wrapped_op->getLoc(); 706 707 return success(); 708 } 709 710 static void print(OpAsmPrinter &p, WrappingRegionOp op) { 711 p << op.getOperationName() << " wraps "; 712 p.printGenericOp(&op.region().front().front()); 713 } 714 715 //===----------------------------------------------------------------------===// 716 // Test PolyForOp - parse list of region arguments. 717 //===----------------------------------------------------------------------===// 718 719 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { 720 SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 721 // Parse list of region arguments without a delimiter. 722 if (parser.parseRegionArgumentList(ivsInfo)) 723 return failure(); 724 725 // Parse the body region. 726 Region *body = result.addRegion(); 727 auto &builder = parser.getBuilder(); 728 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 729 return parser.parseRegion(*body, ivsInfo, argTypes); 730 } 731 732 //===----------------------------------------------------------------------===// 733 // Test removing op with inner ops. 734 //===----------------------------------------------------------------------===// 735 736 namespace { 737 struct TestRemoveOpWithInnerOps 738 : public OpRewritePattern<TestOpWithRegionPattern> { 739 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 740 741 void initialize() { setDebugName("TestRemoveOpWithInnerOps"); } 742 743 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 744 PatternRewriter &rewriter) const override { 745 rewriter.eraseOp(op); 746 return success(); 747 } 748 }; 749 } // end anonymous namespace 750 751 void TestOpWithRegionPattern::getCanonicalizationPatterns( 752 RewritePatternSet &results, MLIRContext *context) { 753 results.add<TestRemoveOpWithInnerOps>(context); 754 } 755 756 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 757 return operand(); 758 } 759 760 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) { 761 return getValue(); 762 } 763 764 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 765 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 766 for (Value input : this->operands()) { 767 results.push_back(input); 768 } 769 return success(); 770 } 771 772 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 773 assert(operands.size() == 1); 774 if (operands.front()) { 775 (*this)->setAttr("attr", operands.front()); 776 return getResult(); 777 } 778 return {}; 779 } 780 781 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) { 782 return getOperand(); 783 } 784 785 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 786 MLIRContext *, Optional<Location> location, ValueRange operands, 787 DictionaryAttr attributes, RegionRange regions, 788 SmallVectorImpl<Type> &inferredReturnTypes) { 789 if (operands[0].getType() != operands[1].getType()) { 790 return emitOptionalError(location, "operand type mismatch ", 791 operands[0].getType(), " vs ", 792 operands[1].getType()); 793 } 794 inferredReturnTypes.assign({operands[0].getType()}); 795 return success(); 796 } 797 798 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 799 MLIRContext *context, Optional<Location> location, ValueShapeRange operands, 800 DictionaryAttr attributes, RegionRange regions, 801 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 802 // Create return type consisting of the last element of the first operand. 803 auto operandType = operands.front().getType(); 804 auto sval = operandType.dyn_cast<ShapedType>(); 805 if (!sval) { 806 return emitOptionalError(location, "only shaped type operands allowed"); 807 } 808 int64_t dim = 809 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 810 auto type = IntegerType::get(context, 17); 811 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 812 return success(); 813 } 814 815 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 816 OpBuilder &builder, ValueRange operands, 817 llvm::SmallVectorImpl<Value> &shapes) { 818 shapes = SmallVector<Value, 1>{ 819 builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)}; 820 return success(); 821 } 822 823 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( 824 OpBuilder &builder, ValueRange operands, 825 llvm::SmallVectorImpl<Value> &shapes) { 826 Location loc = getLoc(); 827 shapes.reserve(operands.size()); 828 for (Value operand : llvm::reverse(operands)) { 829 auto currShape = llvm::to_vector<4>(llvm::map_range( 830 llvm::seq<int64_t>( 831 0, operand.getType().cast<RankedTensorType>().getRank()), 832 [&](int64_t dim) -> Value { 833 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 834 })); 835 shapes.push_back(builder.create<tensor::FromElementsOp>( 836 getLoc(), builder.getIndexType(), currShape)); 837 } 838 return success(); 839 } 840 841 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( 842 OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { 843 Location loc = getLoc(); 844 shapes.reserve(getNumOperands()); 845 for (Value operand : llvm::reverse(getOperands())) { 846 auto currShape = llvm::to_vector<4>(llvm::map_range( 847 llvm::seq<int64_t>( 848 0, operand.getType().cast<RankedTensorType>().getRank()), 849 [&](int64_t dim) -> Value { 850 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 851 })); 852 shapes.emplace_back(std::move(currShape)); 853 } 854 return success(); 855 } 856 857 //===----------------------------------------------------------------------===// 858 // Test SideEffect interfaces 859 //===----------------------------------------------------------------------===// 860 861 namespace { 862 /// A test resource for side effects. 863 struct TestResource : public SideEffects::Resource::Base<TestResource> { 864 StringRef getName() final { return "<Test>"; } 865 }; 866 } // end anonymous namespace 867 868 static void testSideEffectOpGetEffect( 869 Operation *op, 870 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 871 &effects) { 872 auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter"); 873 if (!effectsAttr) 874 return; 875 876 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); 877 } 878 879 void SideEffectOp::getEffects( 880 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 881 // Check for an effects attribute on the op instance. 882 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects"); 883 if (!effectsAttr) 884 return; 885 886 // If there is one, it is an array of dictionary attributes that hold 887 // information on the effects of this operation. 888 for (Attribute element : effectsAttr) { 889 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 890 891 // Get the specific memory effect. 892 MemoryEffects::Effect *effect = 893 StringSwitch<MemoryEffects::Effect *>( 894 effectElement.get("effect").cast<StringAttr>().getValue()) 895 .Case("allocate", MemoryEffects::Allocate::get()) 896 .Case("free", MemoryEffects::Free::get()) 897 .Case("read", MemoryEffects::Read::get()) 898 .Case("write", MemoryEffects::Write::get()); 899 900 // Check for a non-default resource to use. 901 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 902 if (effectElement.get("test_resource")) 903 resource = TestResource::get(); 904 905 // Check for a result to affect. 906 if (effectElement.get("on_result")) 907 effects.emplace_back(effect, getResult(), resource); 908 else if (Attribute ref = effectElement.get("on_reference")) 909 effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource); 910 else 911 effects.emplace_back(effect, resource); 912 } 913 } 914 915 void SideEffectOp::getEffects( 916 SmallVectorImpl<TestEffects::EffectInstance> &effects) { 917 testSideEffectOpGetEffect(getOperation(), effects); 918 } 919 920 //===----------------------------------------------------------------------===// 921 // StringAttrPrettyNameOp 922 //===----------------------------------------------------------------------===// 923 924 // This op has fancy handling of its SSA result name. 925 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, 926 OperationState &result) { 927 // Add the result types. 928 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 929 result.addTypes(parser.getBuilder().getIntegerType(32)); 930 931 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 932 return failure(); 933 934 // If the attribute dictionary contains no 'names' attribute, infer it from 935 // the SSA name (if specified). 936 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 937 return attr.first == "names"; 938 }); 939 940 // If there was no name specified, check to see if there was a useful name 941 // specified in the asm file. 942 if (hadNames || parser.getNumResults() == 0) 943 return success(); 944 945 SmallVector<StringRef, 4> names; 946 auto *context = result.getContext(); 947 948 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 949 auto resultName = parser.getResultName(i); 950 StringRef nameStr; 951 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 952 nameStr = resultName.first; 953 954 names.push_back(nameStr); 955 } 956 957 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 958 result.attributes.push_back({Identifier::get("names", context), namesAttr}); 959 return success(); 960 } 961 962 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { 963 p << "test.string_attr_pretty_name"; 964 965 // Note that we only need to print the "name" attribute if the asmprinter 966 // result name disagrees with it. This can happen in strange cases, e.g. 967 // when there are conflicts. 968 bool namesDisagree = op.names().size() != op.getNumResults(); 969 970 SmallString<32> resultNameStr; 971 for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { 972 resultNameStr.clear(); 973 llvm::raw_svector_ostream tmpStream(resultNameStr); 974 p.printOperand(op.getResult(i), tmpStream); 975 976 auto expectedName = op.names()[i].dyn_cast<StringAttr>(); 977 if (!expectedName || 978 tmpStream.str().drop_front() != expectedName.getValue()) { 979 namesDisagree = true; 980 } 981 } 982 983 if (namesDisagree) 984 p.printOptionalAttrDictWithKeyword(op->getAttrs()); 985 else 986 p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"}); 987 } 988 989 // We set the SSA name in the asm syntax to the contents of the name 990 // attribute. 991 void StringAttrPrettyNameOp::getAsmResultNames( 992 function_ref<void(Value, StringRef)> setNameFn) { 993 994 auto value = names(); 995 for (size_t i = 0, e = value.size(); i != e; ++i) 996 if (auto str = value[i].dyn_cast<StringAttr>()) 997 if (!str.getValue().empty()) 998 setNameFn(getResult(i), str.getValue()); 999 } 1000 1001 //===----------------------------------------------------------------------===// 1002 // RegionIfOp 1003 //===----------------------------------------------------------------------===// 1004 1005 static void print(OpAsmPrinter &p, RegionIfOp op) { 1006 p << RegionIfOp::getOperationName() << " "; 1007 p.printOperands(op.getOperands()); 1008 p << ": " << op.getOperandTypes(); 1009 p.printArrowTypeList(op.getResultTypes()); 1010 p << " then"; 1011 p.printRegion(op.thenRegion(), 1012 /*printEntryBlockArgs=*/true, 1013 /*printBlockTerminators=*/true); 1014 p << " else"; 1015 p.printRegion(op.elseRegion(), 1016 /*printEntryBlockArgs=*/true, 1017 /*printBlockTerminators=*/true); 1018 p << " join"; 1019 p.printRegion(op.joinRegion(), 1020 /*printEntryBlockArgs=*/true, 1021 /*printBlockTerminators=*/true); 1022 } 1023 1024 static ParseResult parseRegionIfOp(OpAsmParser &parser, 1025 OperationState &result) { 1026 SmallVector<OpAsmParser::OperandType, 2> operandInfos; 1027 SmallVector<Type, 2> operandTypes; 1028 1029 result.regions.reserve(3); 1030 Region *thenRegion = result.addRegion(); 1031 Region *elseRegion = result.addRegion(); 1032 Region *joinRegion = result.addRegion(); 1033 1034 // Parse operand, type and arrow type lists. 1035 if (parser.parseOperandList(operandInfos) || 1036 parser.parseColonTypeList(operandTypes) || 1037 parser.parseArrowTypeList(result.types)) 1038 return failure(); 1039 1040 // Parse all attached regions. 1041 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 1042 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 1043 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 1044 return failure(); 1045 1046 return parser.resolveOperands(operandInfos, operandTypes, 1047 parser.getCurrentLocation(), result.operands); 1048 } 1049 1050 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 1051 assert(index < 2 && "invalid region index"); 1052 return getOperands(); 1053 } 1054 1055 void RegionIfOp::getSuccessorRegions( 1056 Optional<unsigned> index, ArrayRef<Attribute> operands, 1057 SmallVectorImpl<RegionSuccessor> ®ions) { 1058 // We always branch to the join region. 1059 if (index.hasValue()) { 1060 if (index.getValue() < 2) 1061 regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs())); 1062 else 1063 regions.push_back(RegionSuccessor(getResults())); 1064 return; 1065 } 1066 1067 // The then and else regions are the entry regions of this op. 1068 regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs())); 1069 regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs())); 1070 } 1071 1072 //===----------------------------------------------------------------------===// 1073 // SingleNoTerminatorCustomAsmOp 1074 //===----------------------------------------------------------------------===// 1075 1076 static ParseResult parseSingleNoTerminatorCustomAsmOp(OpAsmParser &parser, 1077 OperationState &state) { 1078 Region *body = state.addRegion(); 1079 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) 1080 return failure(); 1081 return success(); 1082 } 1083 1084 static void print(SingleNoTerminatorCustomAsmOp op, OpAsmPrinter &printer) { 1085 printer << op.getOperationName(); 1086 printer.printRegion( 1087 op.getRegion(), /*printEntryBlockArgs=*/false, 1088 // This op has a single block without terminators. But explicitly mark 1089 // as not printing block terminators for testing. 1090 /*printBlockTerminators=*/false); 1091 } 1092 1093 #include "TestOpEnums.cpp.inc" 1094 #include "TestOpInterfaces.cpp.inc" 1095 #include "TestOpStructs.cpp.inc" 1096 #include "TestTypeInterfaces.cpp.inc" 1097 1098 #define GET_OP_CLASSES 1099 #include "TestOps.cpp.inc" 1100