1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestAttributes.h" 11 #include "TestInterfaces.h" 12 #include "TestTypes.h" 13 #include "mlir/Dialect/DLTI/DLTI.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/DialectImplementation.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Reducer/ReductionPatternInterface.h" 21 #include "mlir/Transforms/FoldUtils.h" 22 #include "mlir/Transforms/InliningUtils.h" 23 #include "llvm/ADT/StringSwitch.h" 24 25 // Include this before the using namespace lines below to 26 // test that we don't have namespace dependencies. 27 #include "TestOpsDialect.cpp.inc" 28 29 using namespace mlir; 30 using namespace test; 31 32 void test::registerTestDialect(DialectRegistry ®istry) { 33 registry.insert<TestDialect>(); 34 } 35 36 //===----------------------------------------------------------------------===// 37 // TestDialect Interfaces 38 //===----------------------------------------------------------------------===// 39 40 namespace { 41 42 /// Testing the correctness of some traits. 43 static_assert( 44 llvm::is_detected<OpTrait::has_implicit_terminator_t, 45 SingleBlockImplicitTerminatorOp>::value, 46 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp"); 47 static_assert(OpTrait::hasSingleBlockImplicitTerminator< 48 SingleBlockImplicitTerminatorOp>::value, 49 "hasSingleBlockImplicitTerminator does not match " 50 "SingleBlockImplicitTerminatorOp"); 51 52 // Test support for interacting with the AsmPrinter. 53 struct TestOpAsmInterface : public OpAsmDialectInterface { 54 using OpAsmDialectInterface::OpAsmDialectInterface; 55 56 AliasResult getAlias(Attribute attr, raw_ostream &os) const final { 57 StringAttr strAttr = attr.dyn_cast<StringAttr>(); 58 if (!strAttr) 59 return AliasResult::NoAlias; 60 61 // Check the contents of the string attribute to see what the test alias 62 // should be named. 63 Optional<StringRef> aliasName = 64 StringSwitch<Optional<StringRef>>(strAttr.getValue()) 65 .Case("alias_test:dot_in_name", StringRef("test.alias")) 66 .Case("alias_test:trailing_digit", StringRef("test_alias0")) 67 .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) 68 .Case("alias_test:sanitize_conflict_a", 69 StringRef("test_alias_conflict0")) 70 .Case("alias_test:sanitize_conflict_b", 71 StringRef("test_alias_conflict0_")) 72 .Case("alias_test:tensor_encoding", StringRef("test_encoding")) 73 .Default(llvm::None); 74 if (!aliasName) 75 return AliasResult::NoAlias; 76 77 os << *aliasName; 78 return AliasResult::FinalAlias; 79 } 80 81 AliasResult getAlias(Type type, raw_ostream &os) const final { 82 if (auto tupleType = type.dyn_cast<TupleType>()) { 83 if (tupleType.size() > 0 && 84 llvm::all_of(tupleType.getTypes(), [](Type elemType) { 85 return elemType.isa<SimpleAType>(); 86 })) { 87 os << "test_tuple"; 88 return AliasResult::FinalAlias; 89 } 90 } 91 return AliasResult::NoAlias; 92 } 93 94 void getAsmResultNames(Operation *op, 95 OpAsmSetValueNameFn setNameFn) const final { 96 if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) 97 setNameFn(asmOp, "result"); 98 } 99 100 void getAsmBlockArgumentNames(Block *block, 101 OpAsmSetValueNameFn setNameFn) const final { 102 auto op = block->getParentOp(); 103 auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names"); 104 if (!arrayAttr) 105 return; 106 auto args = block->getArguments(); 107 auto e = std::min(arrayAttr.size(), args.size()); 108 for (unsigned i = 0; i < e; ++i) { 109 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 110 setNameFn(args[i], strAttr.getValue()); 111 } 112 } 113 }; 114 115 struct TestDialectFoldInterface : public DialectFoldInterface { 116 using DialectFoldInterface::DialectFoldInterface; 117 118 /// Registered hook to check if the given region, which is attached to an 119 /// operation that is *not* isolated from above, should be used when 120 /// materializing constants. 121 bool shouldMaterializeInto(Region *region) const final { 122 // If this is a one region operation, then insert into it. 123 return isa<OneRegionOp>(region->getParentOp()); 124 } 125 }; 126 127 /// This class defines the interface for handling inlining with standard 128 /// operations. 129 struct TestInlinerInterface : public DialectInlinerInterface { 130 using DialectInlinerInterface::DialectInlinerInterface; 131 132 //===--------------------------------------------------------------------===// 133 // Analysis Hooks 134 //===--------------------------------------------------------------------===// 135 136 bool isLegalToInline(Operation *call, Operation *callable, 137 bool wouldBeCloned) const final { 138 // Don't allow inlining calls that are marked `noinline`. 139 return !call->hasAttr("noinline"); 140 } 141 bool isLegalToInline(Region *, Region *, bool, 142 BlockAndValueMapping &) const final { 143 // Inlining into test dialect regions is legal. 144 return true; 145 } 146 bool isLegalToInline(Operation *, Region *, bool, 147 BlockAndValueMapping &) const final { 148 return true; 149 } 150 151 bool shouldAnalyzeRecursively(Operation *op) const final { 152 // Analyze recursively if this is not a functional region operation, it 153 // froms a separate functional scope. 154 return !isa<FunctionalRegionOp>(op); 155 } 156 157 //===--------------------------------------------------------------------===// 158 // Transformation Hooks 159 //===--------------------------------------------------------------------===// 160 161 /// Handle the given inlined terminator by replacing it with a new operation 162 /// as necessary. 163 void handleTerminator(Operation *op, 164 ArrayRef<Value> valuesToRepl) const final { 165 // Only handle "test.return" here. 166 auto returnOp = dyn_cast<TestReturnOp>(op); 167 if (!returnOp) 168 return; 169 170 // Replace the values directly with the return operands. 171 assert(returnOp.getNumOperands() == valuesToRepl.size()); 172 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 173 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 174 } 175 176 /// Attempt to materialize a conversion for a type mismatch between a call 177 /// from this dialect, and a callable region. This method should generate an 178 /// operation that takes 'input' as the only operand, and produces a single 179 /// result of 'resultType'. If a conversion can not be generated, nullptr 180 /// should be returned. 181 Operation *materializeCallConversion(OpBuilder &builder, Value input, 182 Type resultType, 183 Location conversionLoc) const final { 184 // Only allow conversion for i16/i32 types. 185 if (!(resultType.isSignlessInteger(16) || 186 resultType.isSignlessInteger(32)) || 187 !(input.getType().isSignlessInteger(16) || 188 input.getType().isSignlessInteger(32))) 189 return nullptr; 190 return builder.create<TestCastOp>(conversionLoc, resultType, input); 191 } 192 193 void processInlinedCallBlocks( 194 Operation *call, 195 iterator_range<Region::iterator> inlinedBlocks) const final { 196 if (!isa<ConversionCallOp>(call)) 197 return; 198 199 // Set attributed on all ops in the inlined blocks. 200 for (Block &block : inlinedBlocks) { 201 block.walk([&](Operation *op) { 202 op->setAttr("inlined_conversion", UnitAttr::get(call->getContext())); 203 }); 204 } 205 } 206 }; 207 208 struct TestReductionPatternInterface : public DialectReductionPatternInterface { 209 public: 210 TestReductionPatternInterface(Dialect *dialect) 211 : DialectReductionPatternInterface(dialect) {} 212 213 void populateReductionPatterns(RewritePatternSet &patterns) const final { 214 populateTestReductionPatterns(patterns); 215 } 216 }; 217 218 } // end anonymous namespace 219 220 //===----------------------------------------------------------------------===// 221 // TestDialect 222 //===----------------------------------------------------------------------===// 223 224 static void testSideEffectOpGetEffect( 225 Operation *op, 226 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects); 227 228 // This is the implementation of a dialect fallback for `TestEffectOpInterface`. 229 struct TestOpEffectInterfaceFallback 230 : public TestEffectOpInterface::FallbackModel< 231 TestOpEffectInterfaceFallback> { 232 static bool classof(Operation *op) { 233 bool isSupportedOp = 234 op->getName().getStringRef() == "test.unregistered_side_effect_op"; 235 assert(isSupportedOp && "Unexpected dispatch"); 236 return isSupportedOp; 237 } 238 239 void 240 getEffects(Operation *op, 241 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 242 &effects) const { 243 testSideEffectOpGetEffect(op, effects); 244 } 245 }; 246 247 void TestDialect::initialize() { 248 registerAttributes(); 249 registerTypes(); 250 addOperations< 251 #define GET_OP_LIST 252 #include "TestOps.cpp.inc" 253 >(); 254 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 255 TestInlinerInterface, TestReductionPatternInterface>(); 256 allowUnknownOperations(); 257 258 // Instantiate our fallback op interface that we'll use on specific 259 // unregistered op. 260 fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback; 261 } 262 TestDialect::~TestDialect() { 263 delete static_cast<TestOpEffectInterfaceFallback *>( 264 fallbackEffectOpInterfaces); 265 } 266 267 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, 268 Type type, Location loc) { 269 return builder.create<TestOpConstant>(loc, type, value); 270 } 271 272 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, 273 OperationName opName) { 274 if (opName.getIdentifier() == "test.unregistered_side_effect_op" && 275 typeID == TypeID::get<TestEffectOpInterface>()) 276 return fallbackEffectOpInterfaces; 277 return nullptr; 278 } 279 280 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 281 NamedAttribute namedAttr) { 282 if (namedAttr.first == "test.invalid_attr") 283 return op->emitError() << "invalid to use 'test.invalid_attr'"; 284 return success(); 285 } 286 287 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 288 unsigned regionIndex, 289 unsigned argIndex, 290 NamedAttribute namedAttr) { 291 if (namedAttr.first == "test.invalid_attr") 292 return op->emitError() << "invalid to use 'test.invalid_attr'"; 293 return success(); 294 } 295 296 LogicalResult 297 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 298 unsigned resultIndex, 299 NamedAttribute namedAttr) { 300 if (namedAttr.first == "test.invalid_attr") 301 return op->emitError() << "invalid to use 'test.invalid_attr'"; 302 return success(); 303 } 304 305 Optional<Dialect::ParseOpHook> 306 TestDialect::getParseOperationHook(StringRef opName) const { 307 if (opName == "test.dialect_custom_printer") { 308 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 309 return parser.parseKeyword("custom_format"); 310 }}; 311 } 312 return None; 313 } 314 315 llvm::unique_function<void(Operation *, OpAsmPrinter &)> 316 TestDialect::getOperationPrinter(Operation *op) const { 317 StringRef opName = op->getName().getStringRef(); 318 if (opName == "test.dialect_custom_printer") { 319 return [](Operation *op, OpAsmPrinter &printer) { 320 printer.getStream() << " custom_format"; 321 }; 322 } 323 return {}; 324 } 325 326 //===----------------------------------------------------------------------===// 327 // TestBranchOp 328 //===----------------------------------------------------------------------===// 329 330 Optional<MutableOperandRange> 331 TestBranchOp::getMutableSuccessorOperands(unsigned index) { 332 assert(index == 0 && "invalid successor index"); 333 return targetOperandsMutable(); 334 } 335 336 //===----------------------------------------------------------------------===// 337 // TestDialectCanonicalizerOp 338 //===----------------------------------------------------------------------===// 339 340 static LogicalResult 341 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op, 342 PatternRewriter &rewriter) { 343 rewriter.replaceOpWithNewOp<ConstantOp>(op, rewriter.getI32Type(), 344 rewriter.getI32IntegerAttr(42)); 345 return success(); 346 } 347 348 void TestDialect::getCanonicalizationPatterns( 349 RewritePatternSet &results) const { 350 results.add(&dialectCanonicalizationPattern); 351 } 352 353 //===----------------------------------------------------------------------===// 354 // TestFoldToCallOp 355 //===----------------------------------------------------------------------===// 356 357 namespace { 358 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 359 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 360 361 LogicalResult matchAndRewrite(FoldToCallOp op, 362 PatternRewriter &rewriter) const override { 363 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(), 364 ValueRange()); 365 return success(); 366 } 367 }; 368 } // end anonymous namespace 369 370 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results, 371 MLIRContext *context) { 372 results.add<FoldToCallOpPattern>(context); 373 } 374 375 //===----------------------------------------------------------------------===// 376 // Test Format* operations 377 //===----------------------------------------------------------------------===// 378 379 //===----------------------------------------------------------------------===// 380 // Parsing 381 382 static ParseResult parseCustomDirectiveOperands( 383 OpAsmParser &parser, OpAsmParser::OperandType &operand, 384 Optional<OpAsmParser::OperandType> &optOperand, 385 SmallVectorImpl<OpAsmParser::OperandType> &varOperands) { 386 if (parser.parseOperand(operand)) 387 return failure(); 388 if (succeeded(parser.parseOptionalComma())) { 389 optOperand.emplace(); 390 if (parser.parseOperand(*optOperand)) 391 return failure(); 392 } 393 if (parser.parseArrow() || parser.parseLParen() || 394 parser.parseOperandList(varOperands) || parser.parseRParen()) 395 return failure(); 396 return success(); 397 } 398 static ParseResult 399 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 400 Type &optOperandType, 401 SmallVectorImpl<Type> &varOperandTypes) { 402 if (parser.parseColon()) 403 return failure(); 404 405 if (parser.parseType(operandType)) 406 return failure(); 407 if (succeeded(parser.parseOptionalComma())) { 408 if (parser.parseType(optOperandType)) 409 return failure(); 410 } 411 if (parser.parseArrow() || parser.parseLParen() || 412 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 413 return failure(); 414 return success(); 415 } 416 static ParseResult 417 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 418 Type optOperandType, 419 const SmallVectorImpl<Type> &varOperandTypes) { 420 if (parser.parseKeyword("type_refs_capture")) 421 return failure(); 422 423 Type operandType2, optOperandType2; 424 SmallVector<Type, 1> varOperandTypes2; 425 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 426 varOperandTypes2)) 427 return failure(); 428 429 if (operandType != operandType2 || optOperandType != optOperandType2 || 430 varOperandTypes != varOperandTypes2) 431 return failure(); 432 433 return success(); 434 } 435 static ParseResult parseCustomDirectiveOperandsAndTypes( 436 OpAsmParser &parser, OpAsmParser::OperandType &operand, 437 Optional<OpAsmParser::OperandType> &optOperand, 438 SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType, 439 Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) { 440 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 441 parseCustomDirectiveResults(parser, operandType, optOperandType, 442 varOperandTypes)) 443 return failure(); 444 return success(); 445 } 446 static ParseResult parseCustomDirectiveRegions( 447 OpAsmParser &parser, Region ®ion, 448 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 449 if (parser.parseRegion(region)) 450 return failure(); 451 if (failed(parser.parseOptionalComma())) 452 return success(); 453 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 454 if (parser.parseRegion(*varRegion)) 455 return failure(); 456 varRegions.emplace_back(std::move(varRegion)); 457 return success(); 458 } 459 static ParseResult 460 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 461 SmallVectorImpl<Block *> &varSuccessors) { 462 if (parser.parseSuccessor(successor)) 463 return failure(); 464 if (failed(parser.parseOptionalComma())) 465 return success(); 466 Block *varSuccessor; 467 if (parser.parseSuccessor(varSuccessor)) 468 return failure(); 469 varSuccessors.append(2, varSuccessor); 470 return success(); 471 } 472 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 473 IntegerAttr &attr, 474 IntegerAttr &optAttr) { 475 if (parser.parseAttribute(attr)) 476 return failure(); 477 if (succeeded(parser.parseOptionalComma())) { 478 if (parser.parseAttribute(optAttr)) 479 return failure(); 480 } 481 return success(); 482 } 483 484 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, 485 NamedAttrList &attrs) { 486 return parser.parseOptionalAttrDict(attrs); 487 } 488 static ParseResult parseCustomDirectiveOptionalOperandRef( 489 OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) { 490 int64_t operandCount = 0; 491 if (parser.parseInteger(operandCount)) 492 return failure(); 493 bool expectedOptionalOperand = operandCount == 0; 494 return success(expectedOptionalOperand != optOperand.hasValue()); 495 } 496 497 //===----------------------------------------------------------------------===// 498 // Printing 499 500 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 501 Value operand, Value optOperand, 502 OperandRange varOperands) { 503 printer << operand; 504 if (optOperand) 505 printer << ", " << optOperand; 506 printer << " -> (" << varOperands << ")"; 507 } 508 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 509 Type operandType, Type optOperandType, 510 TypeRange varOperandTypes) { 511 printer << " : " << operandType; 512 if (optOperandType) 513 printer << ", " << optOperandType; 514 printer << " -> (" << varOperandTypes << ")"; 515 } 516 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 517 Operation *op, Type operandType, 518 Type optOperandType, 519 TypeRange varOperandTypes) { 520 printer << " type_refs_capture "; 521 printCustomDirectiveResults(printer, op, operandType, optOperandType, 522 varOperandTypes); 523 } 524 static void printCustomDirectiveOperandsAndTypes( 525 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 526 OperandRange varOperands, Type operandType, Type optOperandType, 527 TypeRange varOperandTypes) { 528 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 529 printCustomDirectiveResults(printer, op, operandType, optOperandType, 530 varOperandTypes); 531 } 532 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 533 Region ®ion, 534 MutableArrayRef<Region> varRegions) { 535 printer.printRegion(region); 536 if (!varRegions.empty()) { 537 printer << ", "; 538 for (Region ®ion : varRegions) 539 printer.printRegion(region); 540 } 541 } 542 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 543 Block *successor, 544 SuccessorRange varSuccessors) { 545 printer << successor; 546 if (!varSuccessors.empty()) 547 printer << ", " << varSuccessors.front(); 548 } 549 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 550 Attribute attribute, 551 Attribute optAttribute) { 552 printer << attribute; 553 if (optAttribute) 554 printer << ", " << optAttribute; 555 } 556 557 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 558 DictionaryAttr attrs) { 559 printer.printOptionalAttrDict(attrs.getValue()); 560 } 561 562 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, 563 Operation *op, 564 Value optOperand) { 565 printer << (optOperand ? "1" : "0"); 566 } 567 568 //===----------------------------------------------------------------------===// 569 // Test IsolatedRegionOp - parse passthrough region arguments. 570 //===----------------------------------------------------------------------===// 571 572 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, 573 OperationState &result) { 574 OpAsmParser::OperandType argInfo; 575 Type argType = parser.getBuilder().getIndexType(); 576 577 // Parse the input operand. 578 if (parser.parseOperand(argInfo) || 579 parser.resolveOperand(argInfo, argType, result.operands)) 580 return failure(); 581 582 // Parse the body region, and reuse the operand info as the argument info. 583 Region *body = result.addRegion(); 584 return parser.parseRegion(*body, argInfo, argType, 585 /*enableNameShadowing=*/true); 586 } 587 588 static void print(OpAsmPrinter &p, IsolatedRegionOp op) { 589 p << "test.isolated_region "; 590 p.printOperand(op.getOperand()); 591 p.shadowRegionArgs(op.region(), op.getOperand()); 592 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 593 } 594 595 //===----------------------------------------------------------------------===// 596 // Test SSACFGRegionOp 597 //===----------------------------------------------------------------------===// 598 599 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 600 return RegionKind::SSACFG; 601 } 602 603 //===----------------------------------------------------------------------===// 604 // Test GraphRegionOp 605 //===----------------------------------------------------------------------===// 606 607 static ParseResult parseGraphRegionOp(OpAsmParser &parser, 608 OperationState &result) { 609 // Parse the body region, and reuse the operand info as the argument info. 610 Region *body = result.addRegion(); 611 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 612 } 613 614 static void print(OpAsmPrinter &p, GraphRegionOp op) { 615 p << "test.graph_region "; 616 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 617 } 618 619 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 620 return RegionKind::Graph; 621 } 622 623 //===----------------------------------------------------------------------===// 624 // Test AffineScopeOp 625 //===----------------------------------------------------------------------===// 626 627 static ParseResult parseAffineScopeOp(OpAsmParser &parser, 628 OperationState &result) { 629 // Parse the body region, and reuse the operand info as the argument info. 630 Region *body = result.addRegion(); 631 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 632 } 633 634 static void print(OpAsmPrinter &p, AffineScopeOp op) { 635 p << "test.affine_scope "; 636 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 637 } 638 639 //===----------------------------------------------------------------------===// 640 // Test parser. 641 //===----------------------------------------------------------------------===// 642 643 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser, 644 OperationState &result) { 645 if (parser.parseOptionalColon()) 646 return success(); 647 uint64_t numResults; 648 if (parser.parseInteger(numResults)) 649 return failure(); 650 651 IndexType type = parser.getBuilder().getIndexType(); 652 for (unsigned i = 0; i < numResults; ++i) 653 result.addTypes(type); 654 return success(); 655 } 656 657 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) { 658 if (unsigned numResults = op->getNumResults()) 659 p << " : " << numResults; 660 } 661 662 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser, 663 OperationState &result) { 664 StringRef keyword; 665 if (parser.parseKeyword(&keyword)) 666 return failure(); 667 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 668 return success(); 669 } 670 671 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) { 672 p << " " << op.keyword(); 673 } 674 675 //===----------------------------------------------------------------------===// 676 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 677 678 static ParseResult parseWrappingRegionOp(OpAsmParser &parser, 679 OperationState &result) { 680 if (parser.parseKeyword("wraps")) 681 return failure(); 682 683 // Parse the wrapped op in a region 684 Region &body = *result.addRegion(); 685 body.push_back(new Block); 686 Block &block = body.back(); 687 Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); 688 if (!wrapped_op) 689 return failure(); 690 691 // Create a return terminator in the inner region, pass as operand to the 692 // terminator the returned values from the wrapped operation. 693 SmallVector<Value, 8> return_operands(wrapped_op->getResults()); 694 OpBuilder builder(parser.getContext()); 695 builder.setInsertionPointToEnd(&block); 696 builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands); 697 698 // Get the results type for the wrapping op from the terminator operands. 699 Operation &return_op = body.back().back(); 700 result.types.append(return_op.operand_type_begin(), 701 return_op.operand_type_end()); 702 703 // Use the location of the wrapped op for the "test.wrapping_region" op. 704 result.location = wrapped_op->getLoc(); 705 706 return success(); 707 } 708 709 static void print(OpAsmPrinter &p, WrappingRegionOp op) { 710 p << " wraps "; 711 p.printGenericOp(&op.region().front().front()); 712 } 713 714 //===----------------------------------------------------------------------===// 715 // Test PolyForOp - parse list of region arguments. 716 //===----------------------------------------------------------------------===// 717 718 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { 719 SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 720 // Parse list of region arguments without a delimiter. 721 if (parser.parseRegionArgumentList(ivsInfo)) 722 return failure(); 723 724 // Parse the body region. 725 Region *body = result.addRegion(); 726 auto &builder = parser.getBuilder(); 727 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 728 return parser.parseRegion(*body, ivsInfo, argTypes); 729 } 730 731 //===----------------------------------------------------------------------===// 732 // Test removing op with inner ops. 733 //===----------------------------------------------------------------------===// 734 735 namespace { 736 struct TestRemoveOpWithInnerOps 737 : public OpRewritePattern<TestOpWithRegionPattern> { 738 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 739 740 void initialize() { setDebugName("TestRemoveOpWithInnerOps"); } 741 742 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 743 PatternRewriter &rewriter) const override { 744 rewriter.eraseOp(op); 745 return success(); 746 } 747 }; 748 } // end anonymous namespace 749 750 void TestOpWithRegionPattern::getCanonicalizationPatterns( 751 RewritePatternSet &results, MLIRContext *context) { 752 results.add<TestRemoveOpWithInnerOps>(context); 753 } 754 755 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 756 return operand(); 757 } 758 759 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) { 760 return getValue(); 761 } 762 763 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 764 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 765 for (Value input : this->operands()) { 766 results.push_back(input); 767 } 768 return success(); 769 } 770 771 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 772 assert(operands.size() == 1); 773 if (operands.front()) { 774 (*this)->setAttr("attr", operands.front()); 775 return getResult(); 776 } 777 return {}; 778 } 779 780 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) { 781 return getOperand(); 782 } 783 784 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 785 MLIRContext *, Optional<Location> location, ValueRange operands, 786 DictionaryAttr attributes, RegionRange regions, 787 SmallVectorImpl<Type> &inferredReturnTypes) { 788 if (operands[0].getType() != operands[1].getType()) { 789 return emitOptionalError(location, "operand type mismatch ", 790 operands[0].getType(), " vs ", 791 operands[1].getType()); 792 } 793 inferredReturnTypes.assign({operands[0].getType()}); 794 return success(); 795 } 796 797 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 798 MLIRContext *context, Optional<Location> location, ValueShapeRange operands, 799 DictionaryAttr attributes, RegionRange regions, 800 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 801 // Create return type consisting of the last element of the first operand. 802 auto operandType = operands.front().getType(); 803 auto sval = operandType.dyn_cast<ShapedType>(); 804 if (!sval) { 805 return emitOptionalError(location, "only shaped type operands allowed"); 806 } 807 int64_t dim = 808 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 809 auto type = IntegerType::get(context, 17); 810 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 811 return success(); 812 } 813 814 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 815 OpBuilder &builder, ValueRange operands, 816 llvm::SmallVectorImpl<Value> &shapes) { 817 shapes = SmallVector<Value, 1>{ 818 builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)}; 819 return success(); 820 } 821 822 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( 823 OpBuilder &builder, ValueRange operands, 824 llvm::SmallVectorImpl<Value> &shapes) { 825 Location loc = getLoc(); 826 shapes.reserve(operands.size()); 827 for (Value operand : llvm::reverse(operands)) { 828 auto currShape = llvm::to_vector<4>(llvm::map_range( 829 llvm::seq<int64_t>( 830 0, operand.getType().cast<RankedTensorType>().getRank()), 831 [&](int64_t dim) -> Value { 832 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 833 })); 834 shapes.push_back(builder.create<tensor::FromElementsOp>( 835 getLoc(), builder.getIndexType(), currShape)); 836 } 837 return success(); 838 } 839 840 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( 841 OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { 842 Location loc = getLoc(); 843 shapes.reserve(getNumOperands()); 844 for (Value operand : llvm::reverse(getOperands())) { 845 auto currShape = llvm::to_vector<4>(llvm::map_range( 846 llvm::seq<int64_t>( 847 0, operand.getType().cast<RankedTensorType>().getRank()), 848 [&](int64_t dim) -> Value { 849 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 850 })); 851 shapes.emplace_back(std::move(currShape)); 852 } 853 return success(); 854 } 855 856 //===----------------------------------------------------------------------===// 857 // Test SideEffect interfaces 858 //===----------------------------------------------------------------------===// 859 860 namespace { 861 /// A test resource for side effects. 862 struct TestResource : public SideEffects::Resource::Base<TestResource> { 863 StringRef getName() final { return "<Test>"; } 864 }; 865 } // end anonymous namespace 866 867 static void testSideEffectOpGetEffect( 868 Operation *op, 869 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 870 &effects) { 871 auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter"); 872 if (!effectsAttr) 873 return; 874 875 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); 876 } 877 878 void SideEffectOp::getEffects( 879 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 880 // Check for an effects attribute on the op instance. 881 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects"); 882 if (!effectsAttr) 883 return; 884 885 // If there is one, it is an array of dictionary attributes that hold 886 // information on the effects of this operation. 887 for (Attribute element : effectsAttr) { 888 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 889 890 // Get the specific memory effect. 891 MemoryEffects::Effect *effect = 892 StringSwitch<MemoryEffects::Effect *>( 893 effectElement.get("effect").cast<StringAttr>().getValue()) 894 .Case("allocate", MemoryEffects::Allocate::get()) 895 .Case("free", MemoryEffects::Free::get()) 896 .Case("read", MemoryEffects::Read::get()) 897 .Case("write", MemoryEffects::Write::get()); 898 899 // Check for a non-default resource to use. 900 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 901 if (effectElement.get("test_resource")) 902 resource = TestResource::get(); 903 904 // Check for a result to affect. 905 if (effectElement.get("on_result")) 906 effects.emplace_back(effect, getResult(), resource); 907 else if (Attribute ref = effectElement.get("on_reference")) 908 effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource); 909 else 910 effects.emplace_back(effect, resource); 911 } 912 } 913 914 void SideEffectOp::getEffects( 915 SmallVectorImpl<TestEffects::EffectInstance> &effects) { 916 testSideEffectOpGetEffect(getOperation(), effects); 917 } 918 919 //===----------------------------------------------------------------------===// 920 // StringAttrPrettyNameOp 921 //===----------------------------------------------------------------------===// 922 923 // This op has fancy handling of its SSA result name. 924 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, 925 OperationState &result) { 926 // Add the result types. 927 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 928 result.addTypes(parser.getBuilder().getIntegerType(32)); 929 930 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 931 return failure(); 932 933 // If the attribute dictionary contains no 'names' attribute, infer it from 934 // the SSA name (if specified). 935 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 936 return attr.first == "names"; 937 }); 938 939 // If there was no name specified, check to see if there was a useful name 940 // specified in the asm file. 941 if (hadNames || parser.getNumResults() == 0) 942 return success(); 943 944 SmallVector<StringRef, 4> names; 945 auto *context = result.getContext(); 946 947 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 948 auto resultName = parser.getResultName(i); 949 StringRef nameStr; 950 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 951 nameStr = resultName.first; 952 953 names.push_back(nameStr); 954 } 955 956 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 957 result.attributes.push_back({Identifier::get("names", context), namesAttr}); 958 return success(); 959 } 960 961 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { 962 // Note that we only need to print the "name" attribute if the asmprinter 963 // result name disagrees with it. This can happen in strange cases, e.g. 964 // when there are conflicts. 965 bool namesDisagree = op.names().size() != op.getNumResults(); 966 967 SmallString<32> resultNameStr; 968 for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { 969 resultNameStr.clear(); 970 llvm::raw_svector_ostream tmpStream(resultNameStr); 971 p.printOperand(op.getResult(i), tmpStream); 972 973 auto expectedName = op.names()[i].dyn_cast<StringAttr>(); 974 if (!expectedName || 975 tmpStream.str().drop_front() != expectedName.getValue()) { 976 namesDisagree = true; 977 } 978 } 979 980 if (namesDisagree) 981 p.printOptionalAttrDictWithKeyword(op->getAttrs()); 982 else 983 p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"}); 984 } 985 986 // We set the SSA name in the asm syntax to the contents of the name 987 // attribute. 988 void StringAttrPrettyNameOp::getAsmResultNames( 989 function_ref<void(Value, StringRef)> setNameFn) { 990 991 auto value = names(); 992 for (size_t i = 0, e = value.size(); i != e; ++i) 993 if (auto str = value[i].dyn_cast<StringAttr>()) 994 if (!str.getValue().empty()) 995 setNameFn(getResult(i), str.getValue()); 996 } 997 998 //===----------------------------------------------------------------------===// 999 // RegionIfOp 1000 //===----------------------------------------------------------------------===// 1001 1002 static void print(OpAsmPrinter &p, RegionIfOp op) { 1003 p << " "; 1004 p.printOperands(op.getOperands()); 1005 p << ": " << op.getOperandTypes(); 1006 p.printArrowTypeList(op.getResultTypes()); 1007 p << " then"; 1008 p.printRegion(op.thenRegion(), 1009 /*printEntryBlockArgs=*/true, 1010 /*printBlockTerminators=*/true); 1011 p << " else"; 1012 p.printRegion(op.elseRegion(), 1013 /*printEntryBlockArgs=*/true, 1014 /*printBlockTerminators=*/true); 1015 p << " join"; 1016 p.printRegion(op.joinRegion(), 1017 /*printEntryBlockArgs=*/true, 1018 /*printBlockTerminators=*/true); 1019 } 1020 1021 static ParseResult parseRegionIfOp(OpAsmParser &parser, 1022 OperationState &result) { 1023 SmallVector<OpAsmParser::OperandType, 2> operandInfos; 1024 SmallVector<Type, 2> operandTypes; 1025 1026 result.regions.reserve(3); 1027 Region *thenRegion = result.addRegion(); 1028 Region *elseRegion = result.addRegion(); 1029 Region *joinRegion = result.addRegion(); 1030 1031 // Parse operand, type and arrow type lists. 1032 if (parser.parseOperandList(operandInfos) || 1033 parser.parseColonTypeList(operandTypes) || 1034 parser.parseArrowTypeList(result.types)) 1035 return failure(); 1036 1037 // Parse all attached regions. 1038 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 1039 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 1040 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 1041 return failure(); 1042 1043 return parser.resolveOperands(operandInfos, operandTypes, 1044 parser.getCurrentLocation(), result.operands); 1045 } 1046 1047 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 1048 assert(index < 2 && "invalid region index"); 1049 return getOperands(); 1050 } 1051 1052 void RegionIfOp::getSuccessorRegions( 1053 Optional<unsigned> index, ArrayRef<Attribute> operands, 1054 SmallVectorImpl<RegionSuccessor> ®ions) { 1055 // We always branch to the join region. 1056 if (index.hasValue()) { 1057 if (index.getValue() < 2) 1058 regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs())); 1059 else 1060 regions.push_back(RegionSuccessor(getResults())); 1061 return; 1062 } 1063 1064 // The then and else regions are the entry regions of this op. 1065 regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs())); 1066 regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs())); 1067 } 1068 1069 //===----------------------------------------------------------------------===// 1070 // SingleNoTerminatorCustomAsmOp 1071 //===----------------------------------------------------------------------===// 1072 1073 static ParseResult parseSingleNoTerminatorCustomAsmOp(OpAsmParser &parser, 1074 OperationState &state) { 1075 Region *body = state.addRegion(); 1076 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) 1077 return failure(); 1078 return success(); 1079 } 1080 1081 static void print(SingleNoTerminatorCustomAsmOp op, OpAsmPrinter &printer) { 1082 printer.printRegion( 1083 op.getRegion(), /*printEntryBlockArgs=*/false, 1084 // This op has a single block without terminators. But explicitly mark 1085 // as not printing block terminators for testing. 1086 /*printBlockTerminators=*/false); 1087 } 1088 1089 #include "TestOpEnums.cpp.inc" 1090 #include "TestOpInterfaces.cpp.inc" 1091 #include "TestOpStructs.cpp.inc" 1092 #include "TestTypeInterfaces.cpp.inc" 1093 1094 #define GET_OP_CLASSES 1095 #include "TestOps.cpp.inc" 1096