1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestAttributes.h" 11 #include "TestInterfaces.h" 12 #include "TestTypes.h" 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/DLTI/DLTI.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/DialectImplementation.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/IR/TypeUtilities.h" 21 #include "mlir/Reducer/ReductionPatternInterface.h" 22 #include "mlir/Transforms/FoldUtils.h" 23 #include "mlir/Transforms/InliningUtils.h" 24 #include "llvm/ADT/StringSwitch.h" 25 26 // Include this before the using namespace lines below to 27 // test that we don't have namespace dependencies. 28 #include "TestOpsDialect.cpp.inc" 29 30 using namespace mlir; 31 using namespace test; 32 33 void test::registerTestDialect(DialectRegistry ®istry) { 34 registry.insert<TestDialect>(); 35 } 36 37 //===----------------------------------------------------------------------===// 38 // TestDialect Interfaces 39 //===----------------------------------------------------------------------===// 40 41 namespace { 42 43 /// Testing the correctness of some traits. 44 static_assert( 45 llvm::is_detected<OpTrait::has_implicit_terminator_t, 46 SingleBlockImplicitTerminatorOp>::value, 47 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp"); 48 static_assert(OpTrait::hasSingleBlockImplicitTerminator< 49 SingleBlockImplicitTerminatorOp>::value, 50 "hasSingleBlockImplicitTerminator does not match " 51 "SingleBlockImplicitTerminatorOp"); 52 53 // Test support for interacting with the AsmPrinter. 54 struct TestOpAsmInterface : public OpAsmDialectInterface { 55 using OpAsmDialectInterface::OpAsmDialectInterface; 56 57 AliasResult getAlias(Attribute attr, raw_ostream &os) const final { 58 StringAttr strAttr = attr.dyn_cast<StringAttr>(); 59 if (!strAttr) 60 return AliasResult::NoAlias; 61 62 // Check the contents of the string attribute to see what the test alias 63 // should be named. 64 Optional<StringRef> aliasName = 65 StringSwitch<Optional<StringRef>>(strAttr.getValue()) 66 .Case("alias_test:dot_in_name", StringRef("test.alias")) 67 .Case("alias_test:trailing_digit", StringRef("test_alias0")) 68 .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) 69 .Case("alias_test:sanitize_conflict_a", 70 StringRef("test_alias_conflict0")) 71 .Case("alias_test:sanitize_conflict_b", 72 StringRef("test_alias_conflict0_")) 73 .Case("alias_test:tensor_encoding", StringRef("test_encoding")) 74 .Default(llvm::None); 75 if (!aliasName) 76 return AliasResult::NoAlias; 77 78 os << *aliasName; 79 return AliasResult::FinalAlias; 80 } 81 82 AliasResult getAlias(Type type, raw_ostream &os) const final { 83 if (auto tupleType = type.dyn_cast<TupleType>()) { 84 if (tupleType.size() > 0 && 85 llvm::all_of(tupleType.getTypes(), [](Type elemType) { 86 return elemType.isa<SimpleAType>(); 87 })) { 88 os << "test_tuple"; 89 return AliasResult::FinalAlias; 90 } 91 } 92 if (auto intType = type.dyn_cast<TestIntegerType>()) { 93 if (intType.getSignedness() == 94 TestIntegerType::SignednessSemantics::Unsigned && 95 intType.getWidth() == 8) { 96 os << "test_ui8"; 97 return AliasResult::FinalAlias; 98 } 99 } 100 return AliasResult::NoAlias; 101 } 102 103 void getAsmResultNames(Operation *op, 104 OpAsmSetValueNameFn setNameFn) const final { 105 if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) 106 setNameFn(asmOp, "result"); 107 } 108 109 void getAsmBlockArgumentNames(Block *block, 110 OpAsmSetValueNameFn setNameFn) const final { 111 auto op = block->getParentOp(); 112 auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names"); 113 if (!arrayAttr) 114 return; 115 auto args = block->getArguments(); 116 auto e = std::min(arrayAttr.size(), args.size()); 117 for (unsigned i = 0; i < e; ++i) { 118 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 119 setNameFn(args[i], strAttr.getValue()); 120 } 121 } 122 }; 123 124 struct TestDialectFoldInterface : public DialectFoldInterface { 125 using DialectFoldInterface::DialectFoldInterface; 126 127 /// Registered hook to check if the given region, which is attached to an 128 /// operation that is *not* isolated from above, should be used when 129 /// materializing constants. 130 bool shouldMaterializeInto(Region *region) const final { 131 // If this is a one region operation, then insert into it. 132 return isa<OneRegionOp>(region->getParentOp()); 133 } 134 }; 135 136 /// This class defines the interface for handling inlining with standard 137 /// operations. 138 struct TestInlinerInterface : public DialectInlinerInterface { 139 using DialectInlinerInterface::DialectInlinerInterface; 140 141 //===--------------------------------------------------------------------===// 142 // Analysis Hooks 143 //===--------------------------------------------------------------------===// 144 145 bool isLegalToInline(Operation *call, Operation *callable, 146 bool wouldBeCloned) const final { 147 // Don't allow inlining calls that are marked `noinline`. 148 return !call->hasAttr("noinline"); 149 } 150 bool isLegalToInline(Region *, Region *, bool, 151 BlockAndValueMapping &) const final { 152 // Inlining into test dialect regions is legal. 153 return true; 154 } 155 bool isLegalToInline(Operation *, Region *, bool, 156 BlockAndValueMapping &) const final { 157 return true; 158 } 159 160 bool shouldAnalyzeRecursively(Operation *op) const final { 161 // Analyze recursively if this is not a functional region operation, it 162 // froms a separate functional scope. 163 return !isa<FunctionalRegionOp>(op); 164 } 165 166 //===--------------------------------------------------------------------===// 167 // Transformation Hooks 168 //===--------------------------------------------------------------------===// 169 170 /// Handle the given inlined terminator by replacing it with a new operation 171 /// as necessary. 172 void handleTerminator(Operation *op, 173 ArrayRef<Value> valuesToRepl) const final { 174 // Only handle "test.return" here. 175 auto returnOp = dyn_cast<TestReturnOp>(op); 176 if (!returnOp) 177 return; 178 179 // Replace the values directly with the return operands. 180 assert(returnOp.getNumOperands() == valuesToRepl.size()); 181 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 182 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 183 } 184 185 /// Attempt to materialize a conversion for a type mismatch between a call 186 /// from this dialect, and a callable region. This method should generate an 187 /// operation that takes 'input' as the only operand, and produces a single 188 /// result of 'resultType'. If a conversion can not be generated, nullptr 189 /// should be returned. 190 Operation *materializeCallConversion(OpBuilder &builder, Value input, 191 Type resultType, 192 Location conversionLoc) const final { 193 // Only allow conversion for i16/i32 types. 194 if (!(resultType.isSignlessInteger(16) || 195 resultType.isSignlessInteger(32)) || 196 !(input.getType().isSignlessInteger(16) || 197 input.getType().isSignlessInteger(32))) 198 return nullptr; 199 return builder.create<TestCastOp>(conversionLoc, resultType, input); 200 } 201 202 void processInlinedCallBlocks( 203 Operation *call, 204 iterator_range<Region::iterator> inlinedBlocks) const final { 205 if (!isa<ConversionCallOp>(call)) 206 return; 207 208 // Set attributed on all ops in the inlined blocks. 209 for (Block &block : inlinedBlocks) { 210 block.walk([&](Operation *op) { 211 op->setAttr("inlined_conversion", UnitAttr::get(call->getContext())); 212 }); 213 } 214 } 215 }; 216 217 struct TestReductionPatternInterface : public DialectReductionPatternInterface { 218 public: 219 TestReductionPatternInterface(Dialect *dialect) 220 : DialectReductionPatternInterface(dialect) {} 221 222 void populateReductionPatterns(RewritePatternSet &patterns) const final { 223 populateTestReductionPatterns(patterns); 224 } 225 }; 226 227 } // end anonymous namespace 228 229 //===----------------------------------------------------------------------===// 230 // TestDialect 231 //===----------------------------------------------------------------------===// 232 233 static void testSideEffectOpGetEffect( 234 Operation *op, 235 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects); 236 237 // This is the implementation of a dialect fallback for `TestEffectOpInterface`. 238 struct TestOpEffectInterfaceFallback 239 : public TestEffectOpInterface::FallbackModel< 240 TestOpEffectInterfaceFallback> { 241 static bool classof(Operation *op) { 242 bool isSupportedOp = 243 op->getName().getStringRef() == "test.unregistered_side_effect_op"; 244 assert(isSupportedOp && "Unexpected dispatch"); 245 return isSupportedOp; 246 } 247 248 void 249 getEffects(Operation *op, 250 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 251 &effects) const { 252 testSideEffectOpGetEffect(op, effects); 253 } 254 }; 255 256 void TestDialect::initialize() { 257 registerAttributes(); 258 registerTypes(); 259 addOperations< 260 #define GET_OP_LIST 261 #include "TestOps.cpp.inc" 262 >(); 263 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 264 TestInlinerInterface, TestReductionPatternInterface>(); 265 allowUnknownOperations(); 266 267 // Instantiate our fallback op interface that we'll use on specific 268 // unregistered op. 269 fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback; 270 } 271 TestDialect::~TestDialect() { 272 delete static_cast<TestOpEffectInterfaceFallback *>( 273 fallbackEffectOpInterfaces); 274 } 275 276 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, 277 Type type, Location loc) { 278 return builder.create<TestOpConstant>(loc, type, value); 279 } 280 281 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, 282 OperationName opName) { 283 if (opName.getIdentifier() == "test.unregistered_side_effect_op" && 284 typeID == TypeID::get<TestEffectOpInterface>()) 285 return fallbackEffectOpInterfaces; 286 return nullptr; 287 } 288 289 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 290 NamedAttribute namedAttr) { 291 if (namedAttr.first == "test.invalid_attr") 292 return op->emitError() << "invalid to use 'test.invalid_attr'"; 293 return success(); 294 } 295 296 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 297 unsigned regionIndex, 298 unsigned argIndex, 299 NamedAttribute namedAttr) { 300 if (namedAttr.first == "test.invalid_attr") 301 return op->emitError() << "invalid to use 'test.invalid_attr'"; 302 return success(); 303 } 304 305 LogicalResult 306 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 307 unsigned resultIndex, 308 NamedAttribute namedAttr) { 309 if (namedAttr.first == "test.invalid_attr") 310 return op->emitError() << "invalid to use 'test.invalid_attr'"; 311 return success(); 312 } 313 314 Optional<Dialect::ParseOpHook> 315 TestDialect::getParseOperationHook(StringRef opName) const { 316 if (opName == "test.dialect_custom_printer") { 317 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 318 return parser.parseKeyword("custom_format"); 319 }}; 320 } 321 return None; 322 } 323 324 llvm::unique_function<void(Operation *, OpAsmPrinter &)> 325 TestDialect::getOperationPrinter(Operation *op) const { 326 StringRef opName = op->getName().getStringRef(); 327 if (opName == "test.dialect_custom_printer") { 328 return [](Operation *op, OpAsmPrinter &printer) { 329 printer.getStream() << " custom_format"; 330 }; 331 } 332 return {}; 333 } 334 335 //===----------------------------------------------------------------------===// 336 // TestBranchOp 337 //===----------------------------------------------------------------------===// 338 339 Optional<MutableOperandRange> 340 TestBranchOp::getMutableSuccessorOperands(unsigned index) { 341 assert(index == 0 && "invalid successor index"); 342 return getTargetOperandsMutable(); 343 } 344 345 //===----------------------------------------------------------------------===// 346 // TestDialectCanonicalizerOp 347 //===----------------------------------------------------------------------===// 348 349 static LogicalResult 350 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op, 351 PatternRewriter &rewriter) { 352 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 353 op, rewriter.getI32IntegerAttr(42)); 354 return success(); 355 } 356 357 void TestDialect::getCanonicalizationPatterns( 358 RewritePatternSet &results) const { 359 results.add(&dialectCanonicalizationPattern); 360 } 361 362 //===----------------------------------------------------------------------===// 363 // TestFoldToCallOp 364 //===----------------------------------------------------------------------===// 365 366 namespace { 367 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 368 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 369 370 LogicalResult matchAndRewrite(FoldToCallOp op, 371 PatternRewriter &rewriter) const override { 372 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.getCalleeAttr(), 373 ValueRange()); 374 return success(); 375 } 376 }; 377 } // end anonymous namespace 378 379 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results, 380 MLIRContext *context) { 381 results.add<FoldToCallOpPattern>(context); 382 } 383 384 //===----------------------------------------------------------------------===// 385 // Test Format* operations 386 //===----------------------------------------------------------------------===// 387 388 //===----------------------------------------------------------------------===// 389 // Parsing 390 391 static ParseResult parseCustomDirectiveOperands( 392 OpAsmParser &parser, OpAsmParser::OperandType &operand, 393 Optional<OpAsmParser::OperandType> &optOperand, 394 SmallVectorImpl<OpAsmParser::OperandType> &varOperands) { 395 if (parser.parseOperand(operand)) 396 return failure(); 397 if (succeeded(parser.parseOptionalComma())) { 398 optOperand.emplace(); 399 if (parser.parseOperand(*optOperand)) 400 return failure(); 401 } 402 if (parser.parseArrow() || parser.parseLParen() || 403 parser.parseOperandList(varOperands) || parser.parseRParen()) 404 return failure(); 405 return success(); 406 } 407 static ParseResult 408 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 409 Type &optOperandType, 410 SmallVectorImpl<Type> &varOperandTypes) { 411 if (parser.parseColon()) 412 return failure(); 413 414 if (parser.parseType(operandType)) 415 return failure(); 416 if (succeeded(parser.parseOptionalComma())) { 417 if (parser.parseType(optOperandType)) 418 return failure(); 419 } 420 if (parser.parseArrow() || parser.parseLParen() || 421 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 422 return failure(); 423 return success(); 424 } 425 static ParseResult 426 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 427 Type optOperandType, 428 const SmallVectorImpl<Type> &varOperandTypes) { 429 if (parser.parseKeyword("type_refs_capture")) 430 return failure(); 431 432 Type operandType2, optOperandType2; 433 SmallVector<Type, 1> varOperandTypes2; 434 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 435 varOperandTypes2)) 436 return failure(); 437 438 if (operandType != operandType2 || optOperandType != optOperandType2 || 439 varOperandTypes != varOperandTypes2) 440 return failure(); 441 442 return success(); 443 } 444 static ParseResult parseCustomDirectiveOperandsAndTypes( 445 OpAsmParser &parser, OpAsmParser::OperandType &operand, 446 Optional<OpAsmParser::OperandType> &optOperand, 447 SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType, 448 Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) { 449 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 450 parseCustomDirectiveResults(parser, operandType, optOperandType, 451 varOperandTypes)) 452 return failure(); 453 return success(); 454 } 455 static ParseResult parseCustomDirectiveRegions( 456 OpAsmParser &parser, Region ®ion, 457 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 458 if (parser.parseRegion(region)) 459 return failure(); 460 if (failed(parser.parseOptionalComma())) 461 return success(); 462 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 463 if (parser.parseRegion(*varRegion)) 464 return failure(); 465 varRegions.emplace_back(std::move(varRegion)); 466 return success(); 467 } 468 static ParseResult 469 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 470 SmallVectorImpl<Block *> &varSuccessors) { 471 if (parser.parseSuccessor(successor)) 472 return failure(); 473 if (failed(parser.parseOptionalComma())) 474 return success(); 475 Block *varSuccessor; 476 if (parser.parseSuccessor(varSuccessor)) 477 return failure(); 478 varSuccessors.append(2, varSuccessor); 479 return success(); 480 } 481 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 482 IntegerAttr &attr, 483 IntegerAttr &optAttr) { 484 if (parser.parseAttribute(attr)) 485 return failure(); 486 if (succeeded(parser.parseOptionalComma())) { 487 if (parser.parseAttribute(optAttr)) 488 return failure(); 489 } 490 return success(); 491 } 492 493 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, 494 NamedAttrList &attrs) { 495 return parser.parseOptionalAttrDict(attrs); 496 } 497 static ParseResult parseCustomDirectiveOptionalOperandRef( 498 OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) { 499 int64_t operandCount = 0; 500 if (parser.parseInteger(operandCount)) 501 return failure(); 502 bool expectedOptionalOperand = operandCount == 0; 503 return success(expectedOptionalOperand != optOperand.hasValue()); 504 } 505 506 //===----------------------------------------------------------------------===// 507 // Printing 508 509 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 510 Value operand, Value optOperand, 511 OperandRange varOperands) { 512 printer << operand; 513 if (optOperand) 514 printer << ", " << optOperand; 515 printer << " -> (" << varOperands << ")"; 516 } 517 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 518 Type operandType, Type optOperandType, 519 TypeRange varOperandTypes) { 520 printer << " : " << operandType; 521 if (optOperandType) 522 printer << ", " << optOperandType; 523 printer << " -> (" << varOperandTypes << ")"; 524 } 525 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 526 Operation *op, Type operandType, 527 Type optOperandType, 528 TypeRange varOperandTypes) { 529 printer << " type_refs_capture "; 530 printCustomDirectiveResults(printer, op, operandType, optOperandType, 531 varOperandTypes); 532 } 533 static void printCustomDirectiveOperandsAndTypes( 534 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 535 OperandRange varOperands, Type operandType, Type optOperandType, 536 TypeRange varOperandTypes) { 537 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 538 printCustomDirectiveResults(printer, op, operandType, optOperandType, 539 varOperandTypes); 540 } 541 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 542 Region ®ion, 543 MutableArrayRef<Region> varRegions) { 544 printer.printRegion(region); 545 if (!varRegions.empty()) { 546 printer << ", "; 547 for (Region ®ion : varRegions) 548 printer.printRegion(region); 549 } 550 } 551 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 552 Block *successor, 553 SuccessorRange varSuccessors) { 554 printer << successor; 555 if (!varSuccessors.empty()) 556 printer << ", " << varSuccessors.front(); 557 } 558 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 559 Attribute attribute, 560 Attribute optAttribute) { 561 printer << attribute; 562 if (optAttribute) 563 printer << ", " << optAttribute; 564 } 565 566 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 567 DictionaryAttr attrs) { 568 printer.printOptionalAttrDict(attrs.getValue()); 569 } 570 571 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, 572 Operation *op, 573 Value optOperand) { 574 printer << (optOperand ? "1" : "0"); 575 } 576 577 //===----------------------------------------------------------------------===// 578 // Test IsolatedRegionOp - parse passthrough region arguments. 579 //===----------------------------------------------------------------------===// 580 581 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, 582 OperationState &result) { 583 OpAsmParser::OperandType argInfo; 584 Type argType = parser.getBuilder().getIndexType(); 585 586 // Parse the input operand. 587 if (parser.parseOperand(argInfo) || 588 parser.resolveOperand(argInfo, argType, result.operands)) 589 return failure(); 590 591 // Parse the body region, and reuse the operand info as the argument info. 592 Region *body = result.addRegion(); 593 return parser.parseRegion(*body, argInfo, argType, 594 /*enableNameShadowing=*/true); 595 } 596 597 static void print(OpAsmPrinter &p, IsolatedRegionOp op) { 598 p << "test.isolated_region "; 599 p.printOperand(op.getOperand()); 600 p.shadowRegionArgs(op.getRegion(), op.getOperand()); 601 p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false); 602 } 603 604 //===----------------------------------------------------------------------===// 605 // Test SSACFGRegionOp 606 //===----------------------------------------------------------------------===// 607 608 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 609 return RegionKind::SSACFG; 610 } 611 612 //===----------------------------------------------------------------------===// 613 // Test GraphRegionOp 614 //===----------------------------------------------------------------------===// 615 616 static ParseResult parseGraphRegionOp(OpAsmParser &parser, 617 OperationState &result) { 618 // Parse the body region, and reuse the operand info as the argument info. 619 Region *body = result.addRegion(); 620 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 621 } 622 623 static void print(OpAsmPrinter &p, GraphRegionOp op) { 624 p << "test.graph_region "; 625 p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false); 626 } 627 628 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 629 return RegionKind::Graph; 630 } 631 632 //===----------------------------------------------------------------------===// 633 // Test AffineScopeOp 634 //===----------------------------------------------------------------------===// 635 636 static ParseResult parseAffineScopeOp(OpAsmParser &parser, 637 OperationState &result) { 638 // Parse the body region, and reuse the operand info as the argument info. 639 Region *body = result.addRegion(); 640 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 641 } 642 643 static void print(OpAsmPrinter &p, AffineScopeOp op) { 644 p << "test.affine_scope "; 645 p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false); 646 } 647 648 //===----------------------------------------------------------------------===// 649 // Test parser. 650 //===----------------------------------------------------------------------===// 651 652 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser, 653 OperationState &result) { 654 if (parser.parseOptionalColon()) 655 return success(); 656 uint64_t numResults; 657 if (parser.parseInteger(numResults)) 658 return failure(); 659 660 IndexType type = parser.getBuilder().getIndexType(); 661 for (unsigned i = 0; i < numResults; ++i) 662 result.addTypes(type); 663 return success(); 664 } 665 666 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) { 667 if (unsigned numResults = op->getNumResults()) 668 p << " : " << numResults; 669 } 670 671 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser, 672 OperationState &result) { 673 StringRef keyword; 674 if (parser.parseKeyword(&keyword)) 675 return failure(); 676 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 677 return success(); 678 } 679 680 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) { 681 p << " " << op.getKeyword(); 682 } 683 684 //===----------------------------------------------------------------------===// 685 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 686 687 static ParseResult parseWrappingRegionOp(OpAsmParser &parser, 688 OperationState &result) { 689 if (parser.parseKeyword("wraps")) 690 return failure(); 691 692 // Parse the wrapped op in a region 693 Region &body = *result.addRegion(); 694 body.push_back(new Block); 695 Block &block = body.back(); 696 Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); 697 if (!wrapped_op) 698 return failure(); 699 700 // Create a return terminator in the inner region, pass as operand to the 701 // terminator the returned values from the wrapped operation. 702 SmallVector<Value, 8> return_operands(wrapped_op->getResults()); 703 OpBuilder builder(parser.getContext()); 704 builder.setInsertionPointToEnd(&block); 705 builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands); 706 707 // Get the results type for the wrapping op from the terminator operands. 708 Operation &return_op = body.back().back(); 709 result.types.append(return_op.operand_type_begin(), 710 return_op.operand_type_end()); 711 712 // Use the location of the wrapped op for the "test.wrapping_region" op. 713 result.location = wrapped_op->getLoc(); 714 715 return success(); 716 } 717 718 static void print(OpAsmPrinter &p, WrappingRegionOp op) { 719 p << " wraps "; 720 p.printGenericOp(&op.getRegion().front().front()); 721 } 722 723 //===----------------------------------------------------------------------===// 724 // Test PolyForOp - parse list of region arguments. 725 //===----------------------------------------------------------------------===// 726 727 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { 728 SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 729 // Parse list of region arguments without a delimiter. 730 if (parser.parseRegionArgumentList(ivsInfo)) 731 return failure(); 732 733 // Parse the body region. 734 Region *body = result.addRegion(); 735 auto &builder = parser.getBuilder(); 736 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 737 return parser.parseRegion(*body, ivsInfo, argTypes); 738 } 739 740 //===----------------------------------------------------------------------===// 741 // Test removing op with inner ops. 742 //===----------------------------------------------------------------------===// 743 744 namespace { 745 struct TestRemoveOpWithInnerOps 746 : public OpRewritePattern<TestOpWithRegionPattern> { 747 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 748 749 void initialize() { setDebugName("TestRemoveOpWithInnerOps"); } 750 751 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 752 PatternRewriter &rewriter) const override { 753 rewriter.eraseOp(op); 754 return success(); 755 } 756 }; 757 } // end anonymous namespace 758 759 void TestOpWithRegionPattern::getCanonicalizationPatterns( 760 RewritePatternSet &results, MLIRContext *context) { 761 results.add<TestRemoveOpWithInnerOps>(context); 762 } 763 764 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 765 return getOperand(); 766 } 767 768 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) { 769 return getValue(); 770 } 771 772 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 773 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 774 for (Value input : this->getOperands()) { 775 results.push_back(input); 776 } 777 return success(); 778 } 779 780 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 781 assert(operands.size() == 1); 782 if (operands.front()) { 783 (*this)->setAttr("attr", operands.front()); 784 return getResult(); 785 } 786 return {}; 787 } 788 789 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) { 790 return getOperand(); 791 } 792 793 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 794 MLIRContext *, Optional<Location> location, ValueRange operands, 795 DictionaryAttr attributes, RegionRange regions, 796 SmallVectorImpl<Type> &inferredReturnTypes) { 797 if (operands[0].getType() != operands[1].getType()) { 798 return emitOptionalError(location, "operand type mismatch ", 799 operands[0].getType(), " vs ", 800 operands[1].getType()); 801 } 802 inferredReturnTypes.assign({operands[0].getType()}); 803 return success(); 804 } 805 806 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 807 MLIRContext *context, Optional<Location> location, ValueShapeRange operands, 808 DictionaryAttr attributes, RegionRange regions, 809 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 810 // Create return type consisting of the last element of the first operand. 811 auto operandType = operands.front().getType(); 812 auto sval = operandType.dyn_cast<ShapedType>(); 813 if (!sval) { 814 return emitOptionalError(location, "only shaped type operands allowed"); 815 } 816 int64_t dim = 817 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 818 auto type = IntegerType::get(context, 17); 819 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 820 return success(); 821 } 822 823 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 824 OpBuilder &builder, ValueRange operands, 825 llvm::SmallVectorImpl<Value> &shapes) { 826 shapes = SmallVector<Value, 1>{ 827 builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)}; 828 return success(); 829 } 830 831 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( 832 OpBuilder &builder, ValueRange operands, 833 llvm::SmallVectorImpl<Value> &shapes) { 834 Location loc = getLoc(); 835 shapes.reserve(operands.size()); 836 for (Value operand : llvm::reverse(operands)) { 837 auto currShape = llvm::to_vector<4>(llvm::map_range( 838 llvm::seq<int64_t>( 839 0, operand.getType().cast<RankedTensorType>().getRank()), 840 [&](int64_t dim) -> Value { 841 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 842 })); 843 shapes.push_back(builder.create<tensor::FromElementsOp>( 844 getLoc(), builder.getIndexType(), currShape)); 845 } 846 return success(); 847 } 848 849 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( 850 OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { 851 Location loc = getLoc(); 852 shapes.reserve(getNumOperands()); 853 for (Value operand : llvm::reverse(getOperands())) { 854 auto currShape = llvm::to_vector<4>(llvm::map_range( 855 llvm::seq<int64_t>( 856 0, operand.getType().cast<RankedTensorType>().getRank()), 857 [&](int64_t dim) -> Value { 858 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 859 })); 860 shapes.emplace_back(std::move(currShape)); 861 } 862 return success(); 863 } 864 865 //===----------------------------------------------------------------------===// 866 // Test SideEffect interfaces 867 //===----------------------------------------------------------------------===// 868 869 namespace { 870 /// A test resource for side effects. 871 struct TestResource : public SideEffects::Resource::Base<TestResource> { 872 StringRef getName() final { return "<Test>"; } 873 }; 874 } // end anonymous namespace 875 876 static void testSideEffectOpGetEffect( 877 Operation *op, 878 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 879 &effects) { 880 auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter"); 881 if (!effectsAttr) 882 return; 883 884 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); 885 } 886 887 void SideEffectOp::getEffects( 888 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 889 // Check for an effects attribute on the op instance. 890 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects"); 891 if (!effectsAttr) 892 return; 893 894 // If there is one, it is an array of dictionary attributes that hold 895 // information on the effects of this operation. 896 for (Attribute element : effectsAttr) { 897 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 898 899 // Get the specific memory effect. 900 MemoryEffects::Effect *effect = 901 StringSwitch<MemoryEffects::Effect *>( 902 effectElement.get("effect").cast<StringAttr>().getValue()) 903 .Case("allocate", MemoryEffects::Allocate::get()) 904 .Case("free", MemoryEffects::Free::get()) 905 .Case("read", MemoryEffects::Read::get()) 906 .Case("write", MemoryEffects::Write::get()); 907 908 // Check for a non-default resource to use. 909 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 910 if (effectElement.get("test_resource")) 911 resource = TestResource::get(); 912 913 // Check for a result to affect. 914 if (effectElement.get("on_result")) 915 effects.emplace_back(effect, getResult(), resource); 916 else if (Attribute ref = effectElement.get("on_reference")) 917 effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource); 918 else 919 effects.emplace_back(effect, resource); 920 } 921 } 922 923 void SideEffectOp::getEffects( 924 SmallVectorImpl<TestEffects::EffectInstance> &effects) { 925 testSideEffectOpGetEffect(getOperation(), effects); 926 } 927 928 //===----------------------------------------------------------------------===// 929 // StringAttrPrettyNameOp 930 //===----------------------------------------------------------------------===// 931 932 // This op has fancy handling of its SSA result name. 933 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, 934 OperationState &result) { 935 // Add the result types. 936 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 937 result.addTypes(parser.getBuilder().getIntegerType(32)); 938 939 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 940 return failure(); 941 942 // If the attribute dictionary contains no 'names' attribute, infer it from 943 // the SSA name (if specified). 944 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 945 return attr.first == "names"; 946 }); 947 948 // If there was no name specified, check to see if there was a useful name 949 // specified in the asm file. 950 if (hadNames || parser.getNumResults() == 0) 951 return success(); 952 953 SmallVector<StringRef, 4> names; 954 auto *context = result.getContext(); 955 956 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 957 auto resultName = parser.getResultName(i); 958 StringRef nameStr; 959 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 960 nameStr = resultName.first; 961 962 names.push_back(nameStr); 963 } 964 965 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 966 result.attributes.push_back({Identifier::get("names", context), namesAttr}); 967 return success(); 968 } 969 970 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { 971 // Note that we only need to print the "name" attribute if the asmprinter 972 // result name disagrees with it. This can happen in strange cases, e.g. 973 // when there are conflicts. 974 bool namesDisagree = op.getNames().size() != op.getNumResults(); 975 976 SmallString<32> resultNameStr; 977 for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { 978 resultNameStr.clear(); 979 llvm::raw_svector_ostream tmpStream(resultNameStr); 980 p.printOperand(op.getResult(i), tmpStream); 981 982 auto expectedName = op.getNames()[i].dyn_cast<StringAttr>(); 983 if (!expectedName || 984 tmpStream.str().drop_front() != expectedName.getValue()) { 985 namesDisagree = true; 986 } 987 } 988 989 if (namesDisagree) 990 p.printOptionalAttrDictWithKeyword(op->getAttrs()); 991 else 992 p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"}); 993 } 994 995 // We set the SSA name in the asm syntax to the contents of the name 996 // attribute. 997 void StringAttrPrettyNameOp::getAsmResultNames( 998 function_ref<void(Value, StringRef)> setNameFn) { 999 1000 auto value = getNames(); 1001 for (size_t i = 0, e = value.size(); i != e; ++i) 1002 if (auto str = value[i].dyn_cast<StringAttr>()) 1003 if (!str.getValue().empty()) 1004 setNameFn(getResult(i), str.getValue()); 1005 } 1006 1007 //===----------------------------------------------------------------------===// 1008 // RegionIfOp 1009 //===----------------------------------------------------------------------===// 1010 1011 static void print(OpAsmPrinter &p, RegionIfOp op) { 1012 p << " "; 1013 p.printOperands(op.getOperands()); 1014 p << ": " << op.getOperandTypes(); 1015 p.printArrowTypeList(op.getResultTypes()); 1016 p << " then"; 1017 p.printRegion(op.getThenRegion(), 1018 /*printEntryBlockArgs=*/true, 1019 /*printBlockTerminators=*/true); 1020 p << " else"; 1021 p.printRegion(op.getElseRegion(), 1022 /*printEntryBlockArgs=*/true, 1023 /*printBlockTerminators=*/true); 1024 p << " join"; 1025 p.printRegion(op.getJoinRegion(), 1026 /*printEntryBlockArgs=*/true, 1027 /*printBlockTerminators=*/true); 1028 } 1029 1030 static ParseResult parseRegionIfOp(OpAsmParser &parser, 1031 OperationState &result) { 1032 SmallVector<OpAsmParser::OperandType, 2> operandInfos; 1033 SmallVector<Type, 2> operandTypes; 1034 1035 result.regions.reserve(3); 1036 Region *thenRegion = result.addRegion(); 1037 Region *elseRegion = result.addRegion(); 1038 Region *joinRegion = result.addRegion(); 1039 1040 // Parse operand, type and arrow type lists. 1041 if (parser.parseOperandList(operandInfos) || 1042 parser.parseColonTypeList(operandTypes) || 1043 parser.parseArrowTypeList(result.types)) 1044 return failure(); 1045 1046 // Parse all attached regions. 1047 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 1048 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 1049 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 1050 return failure(); 1051 1052 return parser.resolveOperands(operandInfos, operandTypes, 1053 parser.getCurrentLocation(), result.operands); 1054 } 1055 1056 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 1057 assert(index < 2 && "invalid region index"); 1058 return getOperands(); 1059 } 1060 1061 void RegionIfOp::getSuccessorRegions( 1062 Optional<unsigned> index, ArrayRef<Attribute> operands, 1063 SmallVectorImpl<RegionSuccessor> ®ions) { 1064 // We always branch to the join region. 1065 if (index.hasValue()) { 1066 if (index.getValue() < 2) 1067 regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs())); 1068 else 1069 regions.push_back(RegionSuccessor(getResults())); 1070 return; 1071 } 1072 1073 // The then and else regions are the entry regions of this op. 1074 regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs())); 1075 regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs())); 1076 } 1077 1078 //===----------------------------------------------------------------------===// 1079 // SingleNoTerminatorCustomAsmOp 1080 //===----------------------------------------------------------------------===// 1081 1082 static ParseResult parseSingleNoTerminatorCustomAsmOp(OpAsmParser &parser, 1083 OperationState &state) { 1084 Region *body = state.addRegion(); 1085 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) 1086 return failure(); 1087 return success(); 1088 } 1089 1090 static void print(SingleNoTerminatorCustomAsmOp op, OpAsmPrinter &printer) { 1091 printer.printRegion( 1092 op.getRegion(), /*printEntryBlockArgs=*/false, 1093 // This op has a single block without terminators. But explicitly mark 1094 // as not printing block terminators for testing. 1095 /*printBlockTerminators=*/false); 1096 } 1097 1098 #include "TestOpEnums.cpp.inc" 1099 #include "TestOpInterfaces.cpp.inc" 1100 #include "TestOpStructs.cpp.inc" 1101 #include "TestTypeInterfaces.cpp.inc" 1102 1103 #define GET_OP_CLASSES 1104 #include "TestOps.cpp.inc" 1105