1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestAttributes.h" 11 #include "TestInterfaces.h" 12 #include "TestTypes.h" 13 #include "mlir/Dialect/DLTI/DLTI.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/DialectImplementation.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Reducer/ReductionPatternInterface.h" 21 #include "mlir/Transforms/FoldUtils.h" 22 #include "mlir/Transforms/InliningUtils.h" 23 #include "llvm/ADT/StringSwitch.h" 24 25 // Include this before the using namespace lines below to 26 // test that we don't have namespace dependencies. 27 #include "TestOpsDialect.cpp.inc" 28 29 using namespace mlir; 30 using namespace test; 31 32 void test::registerTestDialect(DialectRegistry ®istry) { 33 registry.insert<TestDialect>(); 34 } 35 36 //===----------------------------------------------------------------------===// 37 // TestDialect Interfaces 38 //===----------------------------------------------------------------------===// 39 40 namespace { 41 42 /// Testing the correctness of some traits. 43 static_assert( 44 llvm::is_detected<OpTrait::has_implicit_terminator_t, 45 SingleBlockImplicitTerminatorOp>::value, 46 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp"); 47 static_assert(OpTrait::hasSingleBlockImplicitTerminator< 48 SingleBlockImplicitTerminatorOp>::value, 49 "hasSingleBlockImplicitTerminator does not match " 50 "SingleBlockImplicitTerminatorOp"); 51 52 // Test support for interacting with the AsmPrinter. 53 struct TestOpAsmInterface : public OpAsmDialectInterface { 54 using OpAsmDialectInterface::OpAsmDialectInterface; 55 56 AliasResult getAlias(Attribute attr, raw_ostream &os) const final { 57 StringAttr strAttr = attr.dyn_cast<StringAttr>(); 58 if (!strAttr) 59 return AliasResult::NoAlias; 60 61 // Check the contents of the string attribute to see what the test alias 62 // should be named. 63 Optional<StringRef> aliasName = 64 StringSwitch<Optional<StringRef>>(strAttr.getValue()) 65 .Case("alias_test:dot_in_name", StringRef("test.alias")) 66 .Case("alias_test:trailing_digit", StringRef("test_alias0")) 67 .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) 68 .Case("alias_test:sanitize_conflict_a", 69 StringRef("test_alias_conflict0")) 70 .Case("alias_test:sanitize_conflict_b", 71 StringRef("test_alias_conflict0_")) 72 .Case("alias_test:tensor_encoding", StringRef("test_encoding")) 73 .Default(llvm::None); 74 if (!aliasName) 75 return AliasResult::NoAlias; 76 77 os << *aliasName; 78 return AliasResult::FinalAlias; 79 } 80 81 AliasResult getAlias(Type type, raw_ostream &os) const final { 82 if (auto tupleType = type.dyn_cast<TupleType>()) { 83 if (tupleType.size() > 0 && 84 llvm::all_of(tupleType.getTypes(), [](Type elemType) { 85 return elemType.isa<SimpleAType>(); 86 })) { 87 os << "test_tuple"; 88 return AliasResult::FinalAlias; 89 } 90 } 91 return AliasResult::NoAlias; 92 } 93 94 void getAsmResultNames(Operation *op, 95 OpAsmSetValueNameFn setNameFn) const final { 96 if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) 97 setNameFn(asmOp, "result"); 98 } 99 100 void getAsmBlockArgumentNames(Block *block, 101 OpAsmSetValueNameFn setNameFn) const final { 102 auto op = block->getParentOp(); 103 auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names"); 104 if (!arrayAttr) 105 return; 106 auto args = block->getArguments(); 107 auto e = std::min(arrayAttr.size(), args.size()); 108 for (unsigned i = 0; i < e; ++i) { 109 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 110 setNameFn(args[i], strAttr.getValue()); 111 } 112 } 113 }; 114 115 struct TestDialectFoldInterface : public DialectFoldInterface { 116 using DialectFoldInterface::DialectFoldInterface; 117 118 /// Registered hook to check if the given region, which is attached to an 119 /// operation that is *not* isolated from above, should be used when 120 /// materializing constants. 121 bool shouldMaterializeInto(Region *region) const final { 122 // If this is a one region operation, then insert into it. 123 return isa<OneRegionOp>(region->getParentOp()); 124 } 125 }; 126 127 /// This class defines the interface for handling inlining with standard 128 /// operations. 129 struct TestInlinerInterface : public DialectInlinerInterface { 130 using DialectInlinerInterface::DialectInlinerInterface; 131 132 //===--------------------------------------------------------------------===// 133 // Analysis Hooks 134 //===--------------------------------------------------------------------===// 135 136 bool isLegalToInline(Operation *call, Operation *callable, 137 bool wouldBeCloned) const final { 138 // Don't allow inlining calls that are marked `noinline`. 139 return !call->hasAttr("noinline"); 140 } 141 bool isLegalToInline(Region *, Region *, bool, 142 BlockAndValueMapping &) const final { 143 // Inlining into test dialect regions is legal. 144 return true; 145 } 146 bool isLegalToInline(Operation *, Region *, bool, 147 BlockAndValueMapping &) const final { 148 return true; 149 } 150 151 bool shouldAnalyzeRecursively(Operation *op) const final { 152 // Analyze recursively if this is not a functional region operation, it 153 // froms a separate functional scope. 154 return !isa<FunctionalRegionOp>(op); 155 } 156 157 //===--------------------------------------------------------------------===// 158 // Transformation Hooks 159 //===--------------------------------------------------------------------===// 160 161 /// Handle the given inlined terminator by replacing it with a new operation 162 /// as necessary. 163 void handleTerminator(Operation *op, 164 ArrayRef<Value> valuesToRepl) const final { 165 // Only handle "test.return" here. 166 auto returnOp = dyn_cast<TestReturnOp>(op); 167 if (!returnOp) 168 return; 169 170 // Replace the values directly with the return operands. 171 assert(returnOp.getNumOperands() == valuesToRepl.size()); 172 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 173 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 174 } 175 176 /// Attempt to materialize a conversion for a type mismatch between a call 177 /// from this dialect, and a callable region. This method should generate an 178 /// operation that takes 'input' as the only operand, and produces a single 179 /// result of 'resultType'. If a conversion can not be generated, nullptr 180 /// should be returned. 181 Operation *materializeCallConversion(OpBuilder &builder, Value input, 182 Type resultType, 183 Location conversionLoc) const final { 184 // Only allow conversion for i16/i32 types. 185 if (!(resultType.isSignlessInteger(16) || 186 resultType.isSignlessInteger(32)) || 187 !(input.getType().isSignlessInteger(16) || 188 input.getType().isSignlessInteger(32))) 189 return nullptr; 190 return builder.create<TestCastOp>(conversionLoc, resultType, input); 191 } 192 193 void processInlinedCallBlocks( 194 Operation *call, 195 iterator_range<Region::iterator> inlinedBlocks) const final { 196 if (!isa<ConversionCallOp>(call)) 197 return; 198 199 // Set attributed on all ops in the inlined blocks. 200 for (Block &block : inlinedBlocks) { 201 block.walk([&](Operation *op) { 202 op->setAttr("inlined_conversion", UnitAttr::get(call->getContext())); 203 }); 204 } 205 } 206 }; 207 208 struct TestReductionPatternInterface : public DialectReductionPatternInterface { 209 public: 210 TestReductionPatternInterface(Dialect *dialect) 211 : DialectReductionPatternInterface(dialect) {} 212 213 virtual void 214 populateReductionPatterns(RewritePatternSet &patterns) const final { 215 populateTestReductionPatterns(patterns); 216 } 217 }; 218 219 } // end anonymous namespace 220 221 //===----------------------------------------------------------------------===// 222 // TestDialect 223 //===----------------------------------------------------------------------===// 224 225 static void testSideEffectOpGetEffect( 226 Operation *op, 227 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects); 228 229 // This is the implementation of a dialect fallback for `TestEffectOpInterface`. 230 struct TestOpEffectInterfaceFallback 231 : public TestEffectOpInterface::FallbackModel< 232 TestOpEffectInterfaceFallback> { 233 static bool classof(Operation *op) { 234 bool isSupportedOp = 235 op->getName().getStringRef() == "test.unregistered_side_effect_op"; 236 assert(isSupportedOp && "Unexpected dispatch"); 237 return isSupportedOp; 238 } 239 240 void 241 getEffects(Operation *op, 242 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 243 &effects) const { 244 testSideEffectOpGetEffect(op, effects); 245 } 246 }; 247 248 void TestDialect::initialize() { 249 registerAttributes(); 250 registerTypes(); 251 addOperations< 252 #define GET_OP_LIST 253 #include "TestOps.cpp.inc" 254 >(); 255 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 256 TestInlinerInterface, TestReductionPatternInterface>(); 257 allowUnknownOperations(); 258 259 // Instantiate our fallback op interface that we'll use on specific 260 // unregistered op. 261 fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback; 262 } 263 TestDialect::~TestDialect() { 264 delete static_cast<TestOpEffectInterfaceFallback *>( 265 fallbackEffectOpInterfaces); 266 } 267 268 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, 269 Type type, Location loc) { 270 return builder.create<TestOpConstant>(loc, type, value); 271 } 272 273 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, 274 OperationName opName) { 275 if (opName.getIdentifier() == "test.unregistered_side_effect_op" && 276 typeID == TypeID::get<TestEffectOpInterface>()) 277 return fallbackEffectOpInterfaces; 278 return nullptr; 279 } 280 281 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 282 NamedAttribute namedAttr) { 283 if (namedAttr.first == "test.invalid_attr") 284 return op->emitError() << "invalid to use 'test.invalid_attr'"; 285 return success(); 286 } 287 288 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 289 unsigned regionIndex, 290 unsigned argIndex, 291 NamedAttribute namedAttr) { 292 if (namedAttr.first == "test.invalid_attr") 293 return op->emitError() << "invalid to use 'test.invalid_attr'"; 294 return success(); 295 } 296 297 LogicalResult 298 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 299 unsigned resultIndex, 300 NamedAttribute namedAttr) { 301 if (namedAttr.first == "test.invalid_attr") 302 return op->emitError() << "invalid to use 'test.invalid_attr'"; 303 return success(); 304 } 305 306 Optional<Dialect::ParseOpHook> 307 TestDialect::getParseOperationHook(StringRef opName) const { 308 if (opName == "test.dialect_custom_printer") { 309 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 310 return parser.parseKeyword("custom_format"); 311 }}; 312 } 313 return None; 314 } 315 316 llvm::unique_function<void(Operation *, OpAsmPrinter &)> 317 TestDialect::getOperationPrinter(Operation *op) const { 318 StringRef opName = op->getName().getStringRef(); 319 if (opName == "test.dialect_custom_printer") { 320 return [](Operation *op, OpAsmPrinter &printer) { 321 printer.getStream() << " custom_format"; 322 }; 323 } 324 return {}; 325 } 326 327 //===----------------------------------------------------------------------===// 328 // TestBranchOp 329 //===----------------------------------------------------------------------===// 330 331 Optional<MutableOperandRange> 332 TestBranchOp::getMutableSuccessorOperands(unsigned index) { 333 assert(index == 0 && "invalid successor index"); 334 return targetOperandsMutable(); 335 } 336 337 //===----------------------------------------------------------------------===// 338 // TestDialectCanonicalizerOp 339 //===----------------------------------------------------------------------===// 340 341 static LogicalResult 342 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op, 343 PatternRewriter &rewriter) { 344 rewriter.replaceOpWithNewOp<ConstantOp>(op, rewriter.getI32Type(), 345 rewriter.getI32IntegerAttr(42)); 346 return success(); 347 } 348 349 void TestDialect::getCanonicalizationPatterns( 350 RewritePatternSet &results) const { 351 results.add(&dialectCanonicalizationPattern); 352 } 353 354 //===----------------------------------------------------------------------===// 355 // TestFoldToCallOp 356 //===----------------------------------------------------------------------===// 357 358 namespace { 359 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 360 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 361 362 LogicalResult matchAndRewrite(FoldToCallOp op, 363 PatternRewriter &rewriter) const override { 364 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(), 365 ValueRange()); 366 return success(); 367 } 368 }; 369 } // end anonymous namespace 370 371 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results, 372 MLIRContext *context) { 373 results.add<FoldToCallOpPattern>(context); 374 } 375 376 //===----------------------------------------------------------------------===// 377 // Test Format* operations 378 //===----------------------------------------------------------------------===// 379 380 //===----------------------------------------------------------------------===// 381 // Parsing 382 383 static ParseResult parseCustomDirectiveOperands( 384 OpAsmParser &parser, OpAsmParser::OperandType &operand, 385 Optional<OpAsmParser::OperandType> &optOperand, 386 SmallVectorImpl<OpAsmParser::OperandType> &varOperands) { 387 if (parser.parseOperand(operand)) 388 return failure(); 389 if (succeeded(parser.parseOptionalComma())) { 390 optOperand.emplace(); 391 if (parser.parseOperand(*optOperand)) 392 return failure(); 393 } 394 if (parser.parseArrow() || parser.parseLParen() || 395 parser.parseOperandList(varOperands) || parser.parseRParen()) 396 return failure(); 397 return success(); 398 } 399 static ParseResult 400 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 401 Type &optOperandType, 402 SmallVectorImpl<Type> &varOperandTypes) { 403 if (parser.parseColon()) 404 return failure(); 405 406 if (parser.parseType(operandType)) 407 return failure(); 408 if (succeeded(parser.parseOptionalComma())) { 409 if (parser.parseType(optOperandType)) 410 return failure(); 411 } 412 if (parser.parseArrow() || parser.parseLParen() || 413 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 414 return failure(); 415 return success(); 416 } 417 static ParseResult 418 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 419 Type optOperandType, 420 const SmallVectorImpl<Type> &varOperandTypes) { 421 if (parser.parseKeyword("type_refs_capture")) 422 return failure(); 423 424 Type operandType2, optOperandType2; 425 SmallVector<Type, 1> varOperandTypes2; 426 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 427 varOperandTypes2)) 428 return failure(); 429 430 if (operandType != operandType2 || optOperandType != optOperandType2 || 431 varOperandTypes != varOperandTypes2) 432 return failure(); 433 434 return success(); 435 } 436 static ParseResult parseCustomDirectiveOperandsAndTypes( 437 OpAsmParser &parser, OpAsmParser::OperandType &operand, 438 Optional<OpAsmParser::OperandType> &optOperand, 439 SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType, 440 Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) { 441 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 442 parseCustomDirectiveResults(parser, operandType, optOperandType, 443 varOperandTypes)) 444 return failure(); 445 return success(); 446 } 447 static ParseResult parseCustomDirectiveRegions( 448 OpAsmParser &parser, Region ®ion, 449 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 450 if (parser.parseRegion(region)) 451 return failure(); 452 if (failed(parser.parseOptionalComma())) 453 return success(); 454 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 455 if (parser.parseRegion(*varRegion)) 456 return failure(); 457 varRegions.emplace_back(std::move(varRegion)); 458 return success(); 459 } 460 static ParseResult 461 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 462 SmallVectorImpl<Block *> &varSuccessors) { 463 if (parser.parseSuccessor(successor)) 464 return failure(); 465 if (failed(parser.parseOptionalComma())) 466 return success(); 467 Block *varSuccessor; 468 if (parser.parseSuccessor(varSuccessor)) 469 return failure(); 470 varSuccessors.append(2, varSuccessor); 471 return success(); 472 } 473 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 474 IntegerAttr &attr, 475 IntegerAttr &optAttr) { 476 if (parser.parseAttribute(attr)) 477 return failure(); 478 if (succeeded(parser.parseOptionalComma())) { 479 if (parser.parseAttribute(optAttr)) 480 return failure(); 481 } 482 return success(); 483 } 484 485 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, 486 NamedAttrList &attrs) { 487 return parser.parseOptionalAttrDict(attrs); 488 } 489 static ParseResult parseCustomDirectiveOptionalOperandRef( 490 OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) { 491 int64_t operandCount = 0; 492 if (parser.parseInteger(operandCount)) 493 return failure(); 494 bool expectedOptionalOperand = operandCount == 0; 495 return success(expectedOptionalOperand != optOperand.hasValue()); 496 } 497 498 //===----------------------------------------------------------------------===// 499 // Printing 500 501 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 502 Value operand, Value optOperand, 503 OperandRange varOperands) { 504 printer << operand; 505 if (optOperand) 506 printer << ", " << optOperand; 507 printer << " -> (" << varOperands << ")"; 508 } 509 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 510 Type operandType, Type optOperandType, 511 TypeRange varOperandTypes) { 512 printer << " : " << operandType; 513 if (optOperandType) 514 printer << ", " << optOperandType; 515 printer << " -> (" << varOperandTypes << ")"; 516 } 517 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 518 Operation *op, Type operandType, 519 Type optOperandType, 520 TypeRange varOperandTypes) { 521 printer << " type_refs_capture "; 522 printCustomDirectiveResults(printer, op, operandType, optOperandType, 523 varOperandTypes); 524 } 525 static void printCustomDirectiveOperandsAndTypes( 526 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 527 OperandRange varOperands, Type operandType, Type optOperandType, 528 TypeRange varOperandTypes) { 529 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 530 printCustomDirectiveResults(printer, op, operandType, optOperandType, 531 varOperandTypes); 532 } 533 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 534 Region ®ion, 535 MutableArrayRef<Region> varRegions) { 536 printer.printRegion(region); 537 if (!varRegions.empty()) { 538 printer << ", "; 539 for (Region ®ion : varRegions) 540 printer.printRegion(region); 541 } 542 } 543 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 544 Block *successor, 545 SuccessorRange varSuccessors) { 546 printer << successor; 547 if (!varSuccessors.empty()) 548 printer << ", " << varSuccessors.front(); 549 } 550 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 551 Attribute attribute, 552 Attribute optAttribute) { 553 printer << attribute; 554 if (optAttribute) 555 printer << ", " << optAttribute; 556 } 557 558 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 559 DictionaryAttr attrs) { 560 printer.printOptionalAttrDict(attrs.getValue()); 561 } 562 563 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, 564 Operation *op, 565 Value optOperand) { 566 printer << (optOperand ? "1" : "0"); 567 } 568 569 //===----------------------------------------------------------------------===// 570 // Test IsolatedRegionOp - parse passthrough region arguments. 571 //===----------------------------------------------------------------------===// 572 573 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, 574 OperationState &result) { 575 OpAsmParser::OperandType argInfo; 576 Type argType = parser.getBuilder().getIndexType(); 577 578 // Parse the input operand. 579 if (parser.parseOperand(argInfo) || 580 parser.resolveOperand(argInfo, argType, result.operands)) 581 return failure(); 582 583 // Parse the body region, and reuse the operand info as the argument info. 584 Region *body = result.addRegion(); 585 return parser.parseRegion(*body, argInfo, argType, 586 /*enableNameShadowing=*/true); 587 } 588 589 static void print(OpAsmPrinter &p, IsolatedRegionOp op) { 590 p << "test.isolated_region "; 591 p.printOperand(op.getOperand()); 592 p.shadowRegionArgs(op.region(), op.getOperand()); 593 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 594 } 595 596 //===----------------------------------------------------------------------===// 597 // Test SSACFGRegionOp 598 //===----------------------------------------------------------------------===// 599 600 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 601 return RegionKind::SSACFG; 602 } 603 604 //===----------------------------------------------------------------------===// 605 // Test GraphRegionOp 606 //===----------------------------------------------------------------------===// 607 608 static ParseResult parseGraphRegionOp(OpAsmParser &parser, 609 OperationState &result) { 610 // Parse the body region, and reuse the operand info as the argument info. 611 Region *body = result.addRegion(); 612 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 613 } 614 615 static void print(OpAsmPrinter &p, GraphRegionOp op) { 616 p << "test.graph_region "; 617 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 618 } 619 620 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 621 return RegionKind::Graph; 622 } 623 624 //===----------------------------------------------------------------------===// 625 // Test AffineScopeOp 626 //===----------------------------------------------------------------------===// 627 628 static ParseResult parseAffineScopeOp(OpAsmParser &parser, 629 OperationState &result) { 630 // Parse the body region, and reuse the operand info as the argument info. 631 Region *body = result.addRegion(); 632 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 633 } 634 635 static void print(OpAsmPrinter &p, AffineScopeOp op) { 636 p << "test.affine_scope "; 637 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 638 } 639 640 //===----------------------------------------------------------------------===// 641 // Test parser. 642 //===----------------------------------------------------------------------===// 643 644 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser, 645 OperationState &result) { 646 if (parser.parseOptionalColon()) 647 return success(); 648 uint64_t numResults; 649 if (parser.parseInteger(numResults)) 650 return failure(); 651 652 IndexType type = parser.getBuilder().getIndexType(); 653 for (unsigned i = 0; i < numResults; ++i) 654 result.addTypes(type); 655 return success(); 656 } 657 658 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) { 659 if (unsigned numResults = op->getNumResults()) 660 p << " : " << numResults; 661 } 662 663 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser, 664 OperationState &result) { 665 StringRef keyword; 666 if (parser.parseKeyword(&keyword)) 667 return failure(); 668 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 669 return success(); 670 } 671 672 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) { 673 p << " " << op.keyword(); 674 } 675 676 //===----------------------------------------------------------------------===// 677 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 678 679 static ParseResult parseWrappingRegionOp(OpAsmParser &parser, 680 OperationState &result) { 681 if (parser.parseKeyword("wraps")) 682 return failure(); 683 684 // Parse the wrapped op in a region 685 Region &body = *result.addRegion(); 686 body.push_back(new Block); 687 Block &block = body.back(); 688 Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); 689 if (!wrapped_op) 690 return failure(); 691 692 // Create a return terminator in the inner region, pass as operand to the 693 // terminator the returned values from the wrapped operation. 694 SmallVector<Value, 8> return_operands(wrapped_op->getResults()); 695 OpBuilder builder(parser.getBuilder().getContext()); 696 builder.setInsertionPointToEnd(&block); 697 builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands); 698 699 // Get the results type for the wrapping op from the terminator operands. 700 Operation &return_op = body.back().back(); 701 result.types.append(return_op.operand_type_begin(), 702 return_op.operand_type_end()); 703 704 // Use the location of the wrapped op for the "test.wrapping_region" op. 705 result.location = wrapped_op->getLoc(); 706 707 return success(); 708 } 709 710 static void print(OpAsmPrinter &p, WrappingRegionOp op) { 711 p << " wraps "; 712 p.printGenericOp(&op.region().front().front()); 713 } 714 715 //===----------------------------------------------------------------------===// 716 // Test PolyForOp - parse list of region arguments. 717 //===----------------------------------------------------------------------===// 718 719 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { 720 SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 721 // Parse list of region arguments without a delimiter. 722 if (parser.parseRegionArgumentList(ivsInfo)) 723 return failure(); 724 725 // Parse the body region. 726 Region *body = result.addRegion(); 727 auto &builder = parser.getBuilder(); 728 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 729 return parser.parseRegion(*body, ivsInfo, argTypes); 730 } 731 732 //===----------------------------------------------------------------------===// 733 // Test removing op with inner ops. 734 //===----------------------------------------------------------------------===// 735 736 namespace { 737 struct TestRemoveOpWithInnerOps 738 : public OpRewritePattern<TestOpWithRegionPattern> { 739 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 740 741 void initialize() { setDebugName("TestRemoveOpWithInnerOps"); } 742 743 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 744 PatternRewriter &rewriter) const override { 745 rewriter.eraseOp(op); 746 return success(); 747 } 748 }; 749 } // end anonymous namespace 750 751 void TestOpWithRegionPattern::getCanonicalizationPatterns( 752 RewritePatternSet &results, MLIRContext *context) { 753 results.add<TestRemoveOpWithInnerOps>(context); 754 } 755 756 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 757 return operand(); 758 } 759 760 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) { 761 return getValue(); 762 } 763 764 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 765 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 766 for (Value input : this->operands()) { 767 results.push_back(input); 768 } 769 return success(); 770 } 771 772 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 773 assert(operands.size() == 1); 774 if (operands.front()) { 775 (*this)->setAttr("attr", operands.front()); 776 return getResult(); 777 } 778 return {}; 779 } 780 781 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) { 782 return getOperand(); 783 } 784 785 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 786 MLIRContext *, Optional<Location> location, ValueRange operands, 787 DictionaryAttr attributes, RegionRange regions, 788 SmallVectorImpl<Type> &inferredReturnTypes) { 789 if (operands[0].getType() != operands[1].getType()) { 790 return emitOptionalError(location, "operand type mismatch ", 791 operands[0].getType(), " vs ", 792 operands[1].getType()); 793 } 794 inferredReturnTypes.assign({operands[0].getType()}); 795 return success(); 796 } 797 798 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 799 MLIRContext *context, Optional<Location> location, ValueShapeRange operands, 800 DictionaryAttr attributes, RegionRange regions, 801 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 802 // Create return type consisting of the last element of the first operand. 803 auto operandType = operands.front().getType(); 804 auto sval = operandType.dyn_cast<ShapedType>(); 805 if (!sval) { 806 return emitOptionalError(location, "only shaped type operands allowed"); 807 } 808 int64_t dim = 809 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 810 auto type = IntegerType::get(context, 17); 811 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 812 return success(); 813 } 814 815 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 816 OpBuilder &builder, ValueRange operands, 817 llvm::SmallVectorImpl<Value> &shapes) { 818 shapes = SmallVector<Value, 1>{ 819 builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)}; 820 return success(); 821 } 822 823 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( 824 OpBuilder &builder, ValueRange operands, 825 llvm::SmallVectorImpl<Value> &shapes) { 826 Location loc = getLoc(); 827 shapes.reserve(operands.size()); 828 for (Value operand : llvm::reverse(operands)) { 829 auto currShape = llvm::to_vector<4>(llvm::map_range( 830 llvm::seq<int64_t>( 831 0, operand.getType().cast<RankedTensorType>().getRank()), 832 [&](int64_t dim) -> Value { 833 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 834 })); 835 shapes.push_back(builder.create<tensor::FromElementsOp>( 836 getLoc(), builder.getIndexType(), currShape)); 837 } 838 return success(); 839 } 840 841 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( 842 OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { 843 Location loc = getLoc(); 844 shapes.reserve(getNumOperands()); 845 for (Value operand : llvm::reverse(getOperands())) { 846 auto currShape = llvm::to_vector<4>(llvm::map_range( 847 llvm::seq<int64_t>( 848 0, operand.getType().cast<RankedTensorType>().getRank()), 849 [&](int64_t dim) -> Value { 850 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 851 })); 852 shapes.emplace_back(std::move(currShape)); 853 } 854 return success(); 855 } 856 857 //===----------------------------------------------------------------------===// 858 // Test SideEffect interfaces 859 //===----------------------------------------------------------------------===// 860 861 namespace { 862 /// A test resource for side effects. 863 struct TestResource : public SideEffects::Resource::Base<TestResource> { 864 StringRef getName() final { return "<Test>"; } 865 }; 866 } // end anonymous namespace 867 868 static void testSideEffectOpGetEffect( 869 Operation *op, 870 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 871 &effects) { 872 auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter"); 873 if (!effectsAttr) 874 return; 875 876 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); 877 } 878 879 void SideEffectOp::getEffects( 880 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 881 // Check for an effects attribute on the op instance. 882 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects"); 883 if (!effectsAttr) 884 return; 885 886 // If there is one, it is an array of dictionary attributes that hold 887 // information on the effects of this operation. 888 for (Attribute element : effectsAttr) { 889 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 890 891 // Get the specific memory effect. 892 MemoryEffects::Effect *effect = 893 StringSwitch<MemoryEffects::Effect *>( 894 effectElement.get("effect").cast<StringAttr>().getValue()) 895 .Case("allocate", MemoryEffects::Allocate::get()) 896 .Case("free", MemoryEffects::Free::get()) 897 .Case("read", MemoryEffects::Read::get()) 898 .Case("write", MemoryEffects::Write::get()); 899 900 // Check for a non-default resource to use. 901 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 902 if (effectElement.get("test_resource")) 903 resource = TestResource::get(); 904 905 // Check for a result to affect. 906 if (effectElement.get("on_result")) 907 effects.emplace_back(effect, getResult(), resource); 908 else if (Attribute ref = effectElement.get("on_reference")) 909 effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource); 910 else 911 effects.emplace_back(effect, resource); 912 } 913 } 914 915 void SideEffectOp::getEffects( 916 SmallVectorImpl<TestEffects::EffectInstance> &effects) { 917 testSideEffectOpGetEffect(getOperation(), effects); 918 } 919 920 //===----------------------------------------------------------------------===// 921 // StringAttrPrettyNameOp 922 //===----------------------------------------------------------------------===// 923 924 // This op has fancy handling of its SSA result name. 925 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, 926 OperationState &result) { 927 // Add the result types. 928 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 929 result.addTypes(parser.getBuilder().getIntegerType(32)); 930 931 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 932 return failure(); 933 934 // If the attribute dictionary contains no 'names' attribute, infer it from 935 // the SSA name (if specified). 936 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 937 return attr.first == "names"; 938 }); 939 940 // If there was no name specified, check to see if there was a useful name 941 // specified in the asm file. 942 if (hadNames || parser.getNumResults() == 0) 943 return success(); 944 945 SmallVector<StringRef, 4> names; 946 auto *context = result.getContext(); 947 948 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 949 auto resultName = parser.getResultName(i); 950 StringRef nameStr; 951 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 952 nameStr = resultName.first; 953 954 names.push_back(nameStr); 955 } 956 957 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 958 result.attributes.push_back({Identifier::get("names", context), namesAttr}); 959 return success(); 960 } 961 962 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { 963 // Note that we only need to print the "name" attribute if the asmprinter 964 // result name disagrees with it. This can happen in strange cases, e.g. 965 // when there are conflicts. 966 bool namesDisagree = op.names().size() != op.getNumResults(); 967 968 SmallString<32> resultNameStr; 969 for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { 970 resultNameStr.clear(); 971 llvm::raw_svector_ostream tmpStream(resultNameStr); 972 p.printOperand(op.getResult(i), tmpStream); 973 974 auto expectedName = op.names()[i].dyn_cast<StringAttr>(); 975 if (!expectedName || 976 tmpStream.str().drop_front() != expectedName.getValue()) { 977 namesDisagree = true; 978 } 979 } 980 981 if (namesDisagree) 982 p.printOptionalAttrDictWithKeyword(op->getAttrs()); 983 else 984 p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"}); 985 } 986 987 // We set the SSA name in the asm syntax to the contents of the name 988 // attribute. 989 void StringAttrPrettyNameOp::getAsmResultNames( 990 function_ref<void(Value, StringRef)> setNameFn) { 991 992 auto value = names(); 993 for (size_t i = 0, e = value.size(); i != e; ++i) 994 if (auto str = value[i].dyn_cast<StringAttr>()) 995 if (!str.getValue().empty()) 996 setNameFn(getResult(i), str.getValue()); 997 } 998 999 //===----------------------------------------------------------------------===// 1000 // RegionIfOp 1001 //===----------------------------------------------------------------------===// 1002 1003 static void print(OpAsmPrinter &p, RegionIfOp op) { 1004 p << " "; 1005 p.printOperands(op.getOperands()); 1006 p << ": " << op.getOperandTypes(); 1007 p.printArrowTypeList(op.getResultTypes()); 1008 p << " then"; 1009 p.printRegion(op.thenRegion(), 1010 /*printEntryBlockArgs=*/true, 1011 /*printBlockTerminators=*/true); 1012 p << " else"; 1013 p.printRegion(op.elseRegion(), 1014 /*printEntryBlockArgs=*/true, 1015 /*printBlockTerminators=*/true); 1016 p << " join"; 1017 p.printRegion(op.joinRegion(), 1018 /*printEntryBlockArgs=*/true, 1019 /*printBlockTerminators=*/true); 1020 } 1021 1022 static ParseResult parseRegionIfOp(OpAsmParser &parser, 1023 OperationState &result) { 1024 SmallVector<OpAsmParser::OperandType, 2> operandInfos; 1025 SmallVector<Type, 2> operandTypes; 1026 1027 result.regions.reserve(3); 1028 Region *thenRegion = result.addRegion(); 1029 Region *elseRegion = result.addRegion(); 1030 Region *joinRegion = result.addRegion(); 1031 1032 // Parse operand, type and arrow type lists. 1033 if (parser.parseOperandList(operandInfos) || 1034 parser.parseColonTypeList(operandTypes) || 1035 parser.parseArrowTypeList(result.types)) 1036 return failure(); 1037 1038 // Parse all attached regions. 1039 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 1040 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 1041 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 1042 return failure(); 1043 1044 return parser.resolveOperands(operandInfos, operandTypes, 1045 parser.getCurrentLocation(), result.operands); 1046 } 1047 1048 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 1049 assert(index < 2 && "invalid region index"); 1050 return getOperands(); 1051 } 1052 1053 void RegionIfOp::getSuccessorRegions( 1054 Optional<unsigned> index, ArrayRef<Attribute> operands, 1055 SmallVectorImpl<RegionSuccessor> ®ions) { 1056 // We always branch to the join region. 1057 if (index.hasValue()) { 1058 if (index.getValue() < 2) 1059 regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs())); 1060 else 1061 regions.push_back(RegionSuccessor(getResults())); 1062 return; 1063 } 1064 1065 // The then and else regions are the entry regions of this op. 1066 regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs())); 1067 regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs())); 1068 } 1069 1070 //===----------------------------------------------------------------------===// 1071 // SingleNoTerminatorCustomAsmOp 1072 //===----------------------------------------------------------------------===// 1073 1074 static ParseResult parseSingleNoTerminatorCustomAsmOp(OpAsmParser &parser, 1075 OperationState &state) { 1076 Region *body = state.addRegion(); 1077 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) 1078 return failure(); 1079 return success(); 1080 } 1081 1082 static void print(SingleNoTerminatorCustomAsmOp op, OpAsmPrinter &printer) { 1083 printer.printRegion( 1084 op.getRegion(), /*printEntryBlockArgs=*/false, 1085 // This op has a single block without terminators. But explicitly mark 1086 // as not printing block terminators for testing. 1087 /*printBlockTerminators=*/false); 1088 } 1089 1090 #include "TestOpEnums.cpp.inc" 1091 #include "TestOpInterfaces.cpp.inc" 1092 #include "TestOpStructs.cpp.inc" 1093 #include "TestTypeInterfaces.cpp.inc" 1094 1095 #define GET_OP_CLASSES 1096 #include "TestOps.cpp.inc" 1097