1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestAttributes.h" 11 #include "TestInterfaces.h" 12 #include "TestTypes.h" 13 #include "mlir/Dialect/DLTI/DLTI.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/DialectImplementation.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Reducer/ReductionPatternInterface.h" 21 #include "mlir/Transforms/FoldUtils.h" 22 #include "mlir/Transforms/InliningUtils.h" 23 #include "llvm/ADT/StringSwitch.h" 24 25 using namespace mlir; 26 using namespace mlir::test; 27 28 #include "TestOpsDialect.cpp.inc" 29 30 void mlir::test::registerTestDialect(DialectRegistry ®istry) { 31 registry.insert<TestDialect>(); 32 } 33 34 //===----------------------------------------------------------------------===// 35 // TestDialect Interfaces 36 //===----------------------------------------------------------------------===// 37 38 namespace { 39 40 /// Testing the correctness of some traits. 41 static_assert( 42 llvm::is_detected<OpTrait::has_implicit_terminator_t, 43 SingleBlockImplicitTerminatorOp>::value, 44 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp"); 45 static_assert(OpTrait::hasSingleBlockImplicitTerminator< 46 SingleBlockImplicitTerminatorOp>::value, 47 "hasSingleBlockImplicitTerminator does not match " 48 "SingleBlockImplicitTerminatorOp"); 49 50 // Test support for interacting with the AsmPrinter. 51 struct TestOpAsmInterface : public OpAsmDialectInterface { 52 using OpAsmDialectInterface::OpAsmDialectInterface; 53 54 AliasResult getAlias(Attribute attr, raw_ostream &os) const final { 55 StringAttr strAttr = attr.dyn_cast<StringAttr>(); 56 if (!strAttr) 57 return AliasResult::NoAlias; 58 59 // Check the contents of the string attribute to see what the test alias 60 // should be named. 61 Optional<StringRef> aliasName = 62 StringSwitch<Optional<StringRef>>(strAttr.getValue()) 63 .Case("alias_test:dot_in_name", StringRef("test.alias")) 64 .Case("alias_test:trailing_digit", StringRef("test_alias0")) 65 .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) 66 .Case("alias_test:sanitize_conflict_a", 67 StringRef("test_alias_conflict0")) 68 .Case("alias_test:sanitize_conflict_b", 69 StringRef("test_alias_conflict0_")) 70 .Case("alias_test:tensor_encoding", StringRef("test_encoding")) 71 .Default(llvm::None); 72 if (!aliasName) 73 return AliasResult::NoAlias; 74 75 os << *aliasName; 76 return AliasResult::FinalAlias; 77 } 78 79 AliasResult getAlias(Type type, raw_ostream &os) const final { 80 if (auto tupleType = type.dyn_cast<TupleType>()) { 81 if (tupleType.size() > 0 && 82 llvm::all_of(tupleType.getTypes(), [](Type elemType) { 83 return elemType.isa<SimpleAType>(); 84 })) { 85 os << "test_tuple"; 86 return AliasResult::FinalAlias; 87 } 88 } 89 return AliasResult::NoAlias; 90 } 91 92 void getAsmResultNames(Operation *op, 93 OpAsmSetValueNameFn setNameFn) const final { 94 if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) 95 setNameFn(asmOp, "result"); 96 } 97 98 void getAsmBlockArgumentNames(Block *block, 99 OpAsmSetValueNameFn setNameFn) const final { 100 auto op = block->getParentOp(); 101 auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names"); 102 if (!arrayAttr) 103 return; 104 auto args = block->getArguments(); 105 auto e = std::min(arrayAttr.size(), args.size()); 106 for (unsigned i = 0; i < e; ++i) { 107 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 108 setNameFn(args[i], strAttr.getValue()); 109 } 110 } 111 }; 112 113 struct TestDialectFoldInterface : public DialectFoldInterface { 114 using DialectFoldInterface::DialectFoldInterface; 115 116 /// Registered hook to check if the given region, which is attached to an 117 /// operation that is *not* isolated from above, should be used when 118 /// materializing constants. 119 bool shouldMaterializeInto(Region *region) const final { 120 // If this is a one region operation, then insert into it. 121 return isa<OneRegionOp>(region->getParentOp()); 122 } 123 }; 124 125 /// This class defines the interface for handling inlining with standard 126 /// operations. 127 struct TestInlinerInterface : public DialectInlinerInterface { 128 using DialectInlinerInterface::DialectInlinerInterface; 129 130 //===--------------------------------------------------------------------===// 131 // Analysis Hooks 132 //===--------------------------------------------------------------------===// 133 134 bool isLegalToInline(Operation *call, Operation *callable, 135 bool wouldBeCloned) const final { 136 // Don't allow inlining calls that are marked `noinline`. 137 return !call->hasAttr("noinline"); 138 } 139 bool isLegalToInline(Region *, Region *, bool, 140 BlockAndValueMapping &) const final { 141 // Inlining into test dialect regions is legal. 142 return true; 143 } 144 bool isLegalToInline(Operation *, Region *, bool, 145 BlockAndValueMapping &) const final { 146 return true; 147 } 148 149 bool shouldAnalyzeRecursively(Operation *op) const final { 150 // Analyze recursively if this is not a functional region operation, it 151 // froms a separate functional scope. 152 return !isa<FunctionalRegionOp>(op); 153 } 154 155 //===--------------------------------------------------------------------===// 156 // Transformation Hooks 157 //===--------------------------------------------------------------------===// 158 159 /// Handle the given inlined terminator by replacing it with a new operation 160 /// as necessary. 161 void handleTerminator(Operation *op, 162 ArrayRef<Value> valuesToRepl) const final { 163 // Only handle "test.return" here. 164 auto returnOp = dyn_cast<TestReturnOp>(op); 165 if (!returnOp) 166 return; 167 168 // Replace the values directly with the return operands. 169 assert(returnOp.getNumOperands() == valuesToRepl.size()); 170 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 171 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 172 } 173 174 /// Attempt to materialize a conversion for a type mismatch between a call 175 /// from this dialect, and a callable region. This method should generate an 176 /// operation that takes 'input' as the only operand, and produces a single 177 /// result of 'resultType'. If a conversion can not be generated, nullptr 178 /// should be returned. 179 Operation *materializeCallConversion(OpBuilder &builder, Value input, 180 Type resultType, 181 Location conversionLoc) const final { 182 // Only allow conversion for i16/i32 types. 183 if (!(resultType.isSignlessInteger(16) || 184 resultType.isSignlessInteger(32)) || 185 !(input.getType().isSignlessInteger(16) || 186 input.getType().isSignlessInteger(32))) 187 return nullptr; 188 return builder.create<TestCastOp>(conversionLoc, resultType, input); 189 } 190 191 void processInlinedCallBlocks( 192 Operation *call, 193 iterator_range<Region::iterator> inlinedBlocks) const final { 194 if (!isa<ConversionCallOp>(call)) 195 return; 196 197 // Set attributed on all ops in the inlined blocks. 198 for (Block &block : inlinedBlocks) { 199 block.walk([&](Operation *op) { 200 op->setAttr("inlined_conversion", UnitAttr::get(call->getContext())); 201 }); 202 } 203 } 204 }; 205 206 struct TestReductionPatternInterface : public DialectReductionPatternInterface { 207 public: 208 TestReductionPatternInterface(Dialect *dialect) 209 : DialectReductionPatternInterface(dialect) {} 210 211 virtual void 212 populateReductionPatterns(RewritePatternSet &patterns) const final { 213 populateTestReductionPatterns(patterns); 214 } 215 }; 216 217 } // end anonymous namespace 218 219 //===----------------------------------------------------------------------===// 220 // TestDialect 221 //===----------------------------------------------------------------------===// 222 223 static void testSideEffectOpGetEffect( 224 Operation *op, 225 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects); 226 227 // This is the implementation of a dialect fallback for `TestEffectOpInterface`. 228 struct TestOpEffectInterfaceFallback 229 : public TestEffectOpInterface::FallbackModel< 230 TestOpEffectInterfaceFallback> { 231 static bool classof(Operation *op) { 232 bool isSupportedOp = 233 op->getName().getStringRef() == "test.unregistered_side_effect_op"; 234 assert(isSupportedOp && "Unexpected dispatch"); 235 return isSupportedOp; 236 } 237 238 void 239 getEffects(Operation *op, 240 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 241 &effects) const { 242 testSideEffectOpGetEffect(op, effects); 243 } 244 }; 245 246 void TestDialect::initialize() { 247 registerAttributes(); 248 registerTypes(); 249 addOperations< 250 #define GET_OP_LIST 251 #include "TestOps.cpp.inc" 252 >(); 253 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 254 TestInlinerInterface, TestReductionPatternInterface>(); 255 allowUnknownOperations(); 256 257 // Instantiate our fallback op interface that we'll use on specific 258 // unregistered op. 259 fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback; 260 } 261 TestDialect::~TestDialect() { 262 delete static_cast<TestOpEffectInterfaceFallback *>( 263 fallbackEffectOpInterfaces); 264 } 265 266 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, 267 Type type, Location loc) { 268 return builder.create<TestOpConstant>(loc, type, value); 269 } 270 271 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, 272 OperationName opName) { 273 if (opName.getIdentifier() == "test.unregistered_side_effect_op" && 274 typeID == TypeID::get<TestEffectOpInterface>()) 275 return fallbackEffectOpInterfaces; 276 return nullptr; 277 } 278 279 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 280 NamedAttribute namedAttr) { 281 if (namedAttr.first == "test.invalid_attr") 282 return op->emitError() << "invalid to use 'test.invalid_attr'"; 283 return success(); 284 } 285 286 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 287 unsigned regionIndex, 288 unsigned argIndex, 289 NamedAttribute namedAttr) { 290 if (namedAttr.first == "test.invalid_attr") 291 return op->emitError() << "invalid to use 'test.invalid_attr'"; 292 return success(); 293 } 294 295 LogicalResult 296 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 297 unsigned resultIndex, 298 NamedAttribute namedAttr) { 299 if (namedAttr.first == "test.invalid_attr") 300 return op->emitError() << "invalid to use 'test.invalid_attr'"; 301 return success(); 302 } 303 304 Optional<Dialect::ParseOpHook> 305 TestDialect::getParseOperationHook(StringRef opName) const { 306 if (opName == "test.dialect_custom_printer") { 307 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 308 return parser.parseKeyword("custom_format"); 309 }}; 310 } 311 return None; 312 } 313 314 LogicalResult TestDialect::printOperation(Operation *op, 315 OpAsmPrinter &printer) const { 316 StringRef opName = op->getName().getStringRef(); 317 if (opName == "test.dialect_custom_printer") { 318 printer.getStream() << opName << " custom_format"; 319 return success(); 320 } 321 return failure(); 322 } 323 324 //===----------------------------------------------------------------------===// 325 // TestBranchOp 326 //===----------------------------------------------------------------------===// 327 328 Optional<MutableOperandRange> 329 TestBranchOp::getMutableSuccessorOperands(unsigned index) { 330 assert(index == 0 && "invalid successor index"); 331 return targetOperandsMutable(); 332 } 333 334 //===----------------------------------------------------------------------===// 335 // TestDialectCanonicalizerOp 336 //===----------------------------------------------------------------------===// 337 338 static LogicalResult 339 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op, 340 PatternRewriter &rewriter) { 341 rewriter.replaceOpWithNewOp<ConstantOp>(op, rewriter.getI32Type(), 342 rewriter.getI32IntegerAttr(42)); 343 return success(); 344 } 345 346 void TestDialect::getCanonicalizationPatterns( 347 RewritePatternSet &results) const { 348 results.add(&dialectCanonicalizationPattern); 349 } 350 351 //===----------------------------------------------------------------------===// 352 // TestFoldToCallOp 353 //===----------------------------------------------------------------------===// 354 355 namespace { 356 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 357 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 358 359 LogicalResult matchAndRewrite(FoldToCallOp op, 360 PatternRewriter &rewriter) const override { 361 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(), 362 ValueRange()); 363 return success(); 364 } 365 }; 366 } // end anonymous namespace 367 368 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results, 369 MLIRContext *context) { 370 results.add<FoldToCallOpPattern>(context); 371 } 372 373 //===----------------------------------------------------------------------===// 374 // Test Format* operations 375 //===----------------------------------------------------------------------===// 376 377 //===----------------------------------------------------------------------===// 378 // Parsing 379 380 static ParseResult parseCustomDirectiveOperands( 381 OpAsmParser &parser, OpAsmParser::OperandType &operand, 382 Optional<OpAsmParser::OperandType> &optOperand, 383 SmallVectorImpl<OpAsmParser::OperandType> &varOperands) { 384 if (parser.parseOperand(operand)) 385 return failure(); 386 if (succeeded(parser.parseOptionalComma())) { 387 optOperand.emplace(); 388 if (parser.parseOperand(*optOperand)) 389 return failure(); 390 } 391 if (parser.parseArrow() || parser.parseLParen() || 392 parser.parseOperandList(varOperands) || parser.parseRParen()) 393 return failure(); 394 return success(); 395 } 396 static ParseResult 397 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 398 Type &optOperandType, 399 SmallVectorImpl<Type> &varOperandTypes) { 400 if (parser.parseColon()) 401 return failure(); 402 403 if (parser.parseType(operandType)) 404 return failure(); 405 if (succeeded(parser.parseOptionalComma())) { 406 if (parser.parseType(optOperandType)) 407 return failure(); 408 } 409 if (parser.parseArrow() || parser.parseLParen() || 410 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 411 return failure(); 412 return success(); 413 } 414 static ParseResult 415 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 416 Type optOperandType, 417 const SmallVectorImpl<Type> &varOperandTypes) { 418 if (parser.parseKeyword("type_refs_capture")) 419 return failure(); 420 421 Type operandType2, optOperandType2; 422 SmallVector<Type, 1> varOperandTypes2; 423 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 424 varOperandTypes2)) 425 return failure(); 426 427 if (operandType != operandType2 || optOperandType != optOperandType2 || 428 varOperandTypes != varOperandTypes2) 429 return failure(); 430 431 return success(); 432 } 433 static ParseResult parseCustomDirectiveOperandsAndTypes( 434 OpAsmParser &parser, OpAsmParser::OperandType &operand, 435 Optional<OpAsmParser::OperandType> &optOperand, 436 SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType, 437 Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) { 438 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 439 parseCustomDirectiveResults(parser, operandType, optOperandType, 440 varOperandTypes)) 441 return failure(); 442 return success(); 443 } 444 static ParseResult parseCustomDirectiveRegions( 445 OpAsmParser &parser, Region ®ion, 446 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 447 if (parser.parseRegion(region)) 448 return failure(); 449 if (failed(parser.parseOptionalComma())) 450 return success(); 451 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 452 if (parser.parseRegion(*varRegion)) 453 return failure(); 454 varRegions.emplace_back(std::move(varRegion)); 455 return success(); 456 } 457 static ParseResult 458 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 459 SmallVectorImpl<Block *> &varSuccessors) { 460 if (parser.parseSuccessor(successor)) 461 return failure(); 462 if (failed(parser.parseOptionalComma())) 463 return success(); 464 Block *varSuccessor; 465 if (parser.parseSuccessor(varSuccessor)) 466 return failure(); 467 varSuccessors.append(2, varSuccessor); 468 return success(); 469 } 470 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 471 IntegerAttr &attr, 472 IntegerAttr &optAttr) { 473 if (parser.parseAttribute(attr)) 474 return failure(); 475 if (succeeded(parser.parseOptionalComma())) { 476 if (parser.parseAttribute(optAttr)) 477 return failure(); 478 } 479 return success(); 480 } 481 482 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, 483 NamedAttrList &attrs) { 484 return parser.parseOptionalAttrDict(attrs); 485 } 486 static ParseResult parseCustomDirectiveOptionalOperandRef( 487 OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) { 488 int64_t operandCount = 0; 489 if (parser.parseInteger(operandCount)) 490 return failure(); 491 bool expectedOptionalOperand = operandCount == 0; 492 return success(expectedOptionalOperand != optOperand.hasValue()); 493 } 494 495 //===----------------------------------------------------------------------===// 496 // Printing 497 498 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 499 Value operand, Value optOperand, 500 OperandRange varOperands) { 501 printer << operand; 502 if (optOperand) 503 printer << ", " << optOperand; 504 printer << " -> (" << varOperands << ")"; 505 } 506 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 507 Type operandType, Type optOperandType, 508 TypeRange varOperandTypes) { 509 printer << " : " << operandType; 510 if (optOperandType) 511 printer << ", " << optOperandType; 512 printer << " -> (" << varOperandTypes << ")"; 513 } 514 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 515 Operation *op, Type operandType, 516 Type optOperandType, 517 TypeRange varOperandTypes) { 518 printer << " type_refs_capture "; 519 printCustomDirectiveResults(printer, op, operandType, optOperandType, 520 varOperandTypes); 521 } 522 static void printCustomDirectiveOperandsAndTypes( 523 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 524 OperandRange varOperands, Type operandType, Type optOperandType, 525 TypeRange varOperandTypes) { 526 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 527 printCustomDirectiveResults(printer, op, operandType, optOperandType, 528 varOperandTypes); 529 } 530 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 531 Region ®ion, 532 MutableArrayRef<Region> varRegions) { 533 printer.printRegion(region); 534 if (!varRegions.empty()) { 535 printer << ", "; 536 for (Region ®ion : varRegions) 537 printer.printRegion(region); 538 } 539 } 540 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 541 Block *successor, 542 SuccessorRange varSuccessors) { 543 printer << successor; 544 if (!varSuccessors.empty()) 545 printer << ", " << varSuccessors.front(); 546 } 547 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 548 Attribute attribute, 549 Attribute optAttribute) { 550 printer << attribute; 551 if (optAttribute) 552 printer << ", " << optAttribute; 553 } 554 555 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 556 DictionaryAttr attrs) { 557 printer.printOptionalAttrDict(attrs.getValue()); 558 } 559 560 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, 561 Operation *op, 562 Value optOperand) { 563 printer << (optOperand ? "1" : "0"); 564 } 565 566 //===----------------------------------------------------------------------===// 567 // Test IsolatedRegionOp - parse passthrough region arguments. 568 //===----------------------------------------------------------------------===// 569 570 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, 571 OperationState &result) { 572 OpAsmParser::OperandType argInfo; 573 Type argType = parser.getBuilder().getIndexType(); 574 575 // Parse the input operand. 576 if (parser.parseOperand(argInfo) || 577 parser.resolveOperand(argInfo, argType, result.operands)) 578 return failure(); 579 580 // Parse the body region, and reuse the operand info as the argument info. 581 Region *body = result.addRegion(); 582 return parser.parseRegion(*body, argInfo, argType, 583 /*enableNameShadowing=*/true); 584 } 585 586 static void print(OpAsmPrinter &p, IsolatedRegionOp op) { 587 p << "test.isolated_region "; 588 p.printOperand(op.getOperand()); 589 p.shadowRegionArgs(op.region(), op.getOperand()); 590 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 591 } 592 593 //===----------------------------------------------------------------------===// 594 // Test SSACFGRegionOp 595 //===----------------------------------------------------------------------===// 596 597 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 598 return RegionKind::SSACFG; 599 } 600 601 //===----------------------------------------------------------------------===// 602 // Test GraphRegionOp 603 //===----------------------------------------------------------------------===// 604 605 static ParseResult parseGraphRegionOp(OpAsmParser &parser, 606 OperationState &result) { 607 // Parse the body region, and reuse the operand info as the argument info. 608 Region *body = result.addRegion(); 609 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 610 } 611 612 static void print(OpAsmPrinter &p, GraphRegionOp op) { 613 p << "test.graph_region "; 614 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 615 } 616 617 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 618 return RegionKind::Graph; 619 } 620 621 //===----------------------------------------------------------------------===// 622 // Test AffineScopeOp 623 //===----------------------------------------------------------------------===// 624 625 static ParseResult parseAffineScopeOp(OpAsmParser &parser, 626 OperationState &result) { 627 // Parse the body region, and reuse the operand info as the argument info. 628 Region *body = result.addRegion(); 629 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 630 } 631 632 static void print(OpAsmPrinter &p, AffineScopeOp op) { 633 p << "test.affine_scope "; 634 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 635 } 636 637 //===----------------------------------------------------------------------===// 638 // Test parser. 639 //===----------------------------------------------------------------------===// 640 641 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser, 642 OperationState &result) { 643 if (parser.parseOptionalColon()) 644 return success(); 645 uint64_t numResults; 646 if (parser.parseInteger(numResults)) 647 return failure(); 648 649 IndexType type = parser.getBuilder().getIndexType(); 650 for (unsigned i = 0; i < numResults; ++i) 651 result.addTypes(type); 652 return success(); 653 } 654 655 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) { 656 p << ParseIntegerLiteralOp::getOperationName(); 657 if (unsigned numResults = op->getNumResults()) 658 p << " : " << numResults; 659 } 660 661 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser, 662 OperationState &result) { 663 StringRef keyword; 664 if (parser.parseKeyword(&keyword)) 665 return failure(); 666 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 667 return success(); 668 } 669 670 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) { 671 p << ParseWrappedKeywordOp::getOperationName() << " " << op.keyword(); 672 } 673 674 //===----------------------------------------------------------------------===// 675 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 676 677 static ParseResult parseWrappingRegionOp(OpAsmParser &parser, 678 OperationState &result) { 679 if (parser.parseKeyword("wraps")) 680 return failure(); 681 682 // Parse the wrapped op in a region 683 Region &body = *result.addRegion(); 684 body.push_back(new Block); 685 Block &block = body.back(); 686 Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); 687 if (!wrapped_op) 688 return failure(); 689 690 // Create a return terminator in the inner region, pass as operand to the 691 // terminator the returned values from the wrapped operation. 692 SmallVector<Value, 8> return_operands(wrapped_op->getResults()); 693 OpBuilder builder(parser.getBuilder().getContext()); 694 builder.setInsertionPointToEnd(&block); 695 builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands); 696 697 // Get the results type for the wrapping op from the terminator operands. 698 Operation &return_op = body.back().back(); 699 result.types.append(return_op.operand_type_begin(), 700 return_op.operand_type_end()); 701 702 // Use the location of the wrapped op for the "test.wrapping_region" op. 703 result.location = wrapped_op->getLoc(); 704 705 return success(); 706 } 707 708 static void print(OpAsmPrinter &p, WrappingRegionOp op) { 709 p << op.getOperationName() << " wraps "; 710 p.printGenericOp(&op.region().front().front()); 711 } 712 713 //===----------------------------------------------------------------------===// 714 // Test PolyForOp - parse list of region arguments. 715 //===----------------------------------------------------------------------===// 716 717 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { 718 SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 719 // Parse list of region arguments without a delimiter. 720 if (parser.parseRegionArgumentList(ivsInfo)) 721 return failure(); 722 723 // Parse the body region. 724 Region *body = result.addRegion(); 725 auto &builder = parser.getBuilder(); 726 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 727 return parser.parseRegion(*body, ivsInfo, argTypes); 728 } 729 730 //===----------------------------------------------------------------------===// 731 // Test removing op with inner ops. 732 //===----------------------------------------------------------------------===// 733 734 namespace { 735 struct TestRemoveOpWithInnerOps 736 : public OpRewritePattern<TestOpWithRegionPattern> { 737 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 738 739 void initialize() { setDebugName("TestRemoveOpWithInnerOps"); } 740 741 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 742 PatternRewriter &rewriter) const override { 743 rewriter.eraseOp(op); 744 return success(); 745 } 746 }; 747 } // end anonymous namespace 748 749 void TestOpWithRegionPattern::getCanonicalizationPatterns( 750 RewritePatternSet &results, MLIRContext *context) { 751 results.add<TestRemoveOpWithInnerOps>(context); 752 } 753 754 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 755 return operand(); 756 } 757 758 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) { 759 return getValue(); 760 } 761 762 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 763 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 764 for (Value input : this->operands()) { 765 results.push_back(input); 766 } 767 return success(); 768 } 769 770 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 771 assert(operands.size() == 1); 772 if (operands.front()) { 773 (*this)->setAttr("attr", operands.front()); 774 return getResult(); 775 } 776 return {}; 777 } 778 779 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) { 780 return getOperand(); 781 } 782 783 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 784 MLIRContext *, Optional<Location> location, ValueRange operands, 785 DictionaryAttr attributes, RegionRange regions, 786 SmallVectorImpl<Type> &inferredReturnTypes) { 787 if (operands[0].getType() != operands[1].getType()) { 788 return emitOptionalError(location, "operand type mismatch ", 789 operands[0].getType(), " vs ", 790 operands[1].getType()); 791 } 792 inferredReturnTypes.assign({operands[0].getType()}); 793 return success(); 794 } 795 796 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 797 MLIRContext *context, Optional<Location> location, ValueShapeRange operands, 798 DictionaryAttr attributes, RegionRange regions, 799 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 800 // Create return type consisting of the last element of the first operand. 801 auto operandType = operands.front().getType(); 802 auto sval = operandType.dyn_cast<ShapedType>(); 803 if (!sval) { 804 return emitOptionalError(location, "only shaped type operands allowed"); 805 } 806 int64_t dim = 807 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 808 auto type = IntegerType::get(context, 17); 809 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 810 return success(); 811 } 812 813 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 814 OpBuilder &builder, ValueRange operands, 815 llvm::SmallVectorImpl<Value> &shapes) { 816 shapes = SmallVector<Value, 1>{ 817 builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)}; 818 return success(); 819 } 820 821 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( 822 OpBuilder &builder, ValueRange operands, 823 llvm::SmallVectorImpl<Value> &shapes) { 824 Location loc = getLoc(); 825 shapes.reserve(operands.size()); 826 for (Value operand : llvm::reverse(operands)) { 827 auto currShape = llvm::to_vector<4>(llvm::map_range( 828 llvm::seq<int64_t>( 829 0, operand.getType().cast<RankedTensorType>().getRank()), 830 [&](int64_t dim) -> Value { 831 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 832 })); 833 shapes.push_back(builder.create<tensor::FromElementsOp>( 834 getLoc(), builder.getIndexType(), currShape)); 835 } 836 return success(); 837 } 838 839 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( 840 OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { 841 Location loc = getLoc(); 842 shapes.reserve(getNumOperands()); 843 for (Value operand : llvm::reverse(getOperands())) { 844 auto currShape = llvm::to_vector<4>(llvm::map_range( 845 llvm::seq<int64_t>( 846 0, operand.getType().cast<RankedTensorType>().getRank()), 847 [&](int64_t dim) -> Value { 848 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 849 })); 850 shapes.emplace_back(std::move(currShape)); 851 } 852 return success(); 853 } 854 855 //===----------------------------------------------------------------------===// 856 // Test SideEffect interfaces 857 //===----------------------------------------------------------------------===// 858 859 namespace { 860 /// A test resource for side effects. 861 struct TestResource : public SideEffects::Resource::Base<TestResource> { 862 StringRef getName() final { return "<Test>"; } 863 }; 864 } // end anonymous namespace 865 866 static void testSideEffectOpGetEffect( 867 Operation *op, 868 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 869 &effects) { 870 auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter"); 871 if (!effectsAttr) 872 return; 873 874 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); 875 } 876 877 void SideEffectOp::getEffects( 878 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 879 // Check for an effects attribute on the op instance. 880 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects"); 881 if (!effectsAttr) 882 return; 883 884 // If there is one, it is an array of dictionary attributes that hold 885 // information on the effects of this operation. 886 for (Attribute element : effectsAttr) { 887 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 888 889 // Get the specific memory effect. 890 MemoryEffects::Effect *effect = 891 StringSwitch<MemoryEffects::Effect *>( 892 effectElement.get("effect").cast<StringAttr>().getValue()) 893 .Case("allocate", MemoryEffects::Allocate::get()) 894 .Case("free", MemoryEffects::Free::get()) 895 .Case("read", MemoryEffects::Read::get()) 896 .Case("write", MemoryEffects::Write::get()); 897 898 // Check for a non-default resource to use. 899 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 900 if (effectElement.get("test_resource")) 901 resource = TestResource::get(); 902 903 // Check for a result to affect. 904 if (effectElement.get("on_result")) 905 effects.emplace_back(effect, getResult(), resource); 906 else if (Attribute ref = effectElement.get("on_reference")) 907 effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource); 908 else 909 effects.emplace_back(effect, resource); 910 } 911 } 912 913 void SideEffectOp::getEffects( 914 SmallVectorImpl<TestEffects::EffectInstance> &effects) { 915 testSideEffectOpGetEffect(getOperation(), effects); 916 } 917 918 //===----------------------------------------------------------------------===// 919 // StringAttrPrettyNameOp 920 //===----------------------------------------------------------------------===// 921 922 // This op has fancy handling of its SSA result name. 923 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, 924 OperationState &result) { 925 // Add the result types. 926 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 927 result.addTypes(parser.getBuilder().getIntegerType(32)); 928 929 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 930 return failure(); 931 932 // If the attribute dictionary contains no 'names' attribute, infer it from 933 // the SSA name (if specified). 934 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 935 return attr.first == "names"; 936 }); 937 938 // If there was no name specified, check to see if there was a useful name 939 // specified in the asm file. 940 if (hadNames || parser.getNumResults() == 0) 941 return success(); 942 943 SmallVector<StringRef, 4> names; 944 auto *context = result.getContext(); 945 946 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 947 auto resultName = parser.getResultName(i); 948 StringRef nameStr; 949 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 950 nameStr = resultName.first; 951 952 names.push_back(nameStr); 953 } 954 955 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 956 result.attributes.push_back({Identifier::get("names", context), namesAttr}); 957 return success(); 958 } 959 960 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { 961 p << "test.string_attr_pretty_name"; 962 963 // Note that we only need to print the "name" attribute if the asmprinter 964 // result name disagrees with it. This can happen in strange cases, e.g. 965 // when there are conflicts. 966 bool namesDisagree = op.names().size() != op.getNumResults(); 967 968 SmallString<32> resultNameStr; 969 for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { 970 resultNameStr.clear(); 971 llvm::raw_svector_ostream tmpStream(resultNameStr); 972 p.printOperand(op.getResult(i), tmpStream); 973 974 auto expectedName = op.names()[i].dyn_cast<StringAttr>(); 975 if (!expectedName || 976 tmpStream.str().drop_front() != expectedName.getValue()) { 977 namesDisagree = true; 978 } 979 } 980 981 if (namesDisagree) 982 p.printOptionalAttrDictWithKeyword(op->getAttrs()); 983 else 984 p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"}); 985 } 986 987 // We set the SSA name in the asm syntax to the contents of the name 988 // attribute. 989 void StringAttrPrettyNameOp::getAsmResultNames( 990 function_ref<void(Value, StringRef)> setNameFn) { 991 992 auto value = names(); 993 for (size_t i = 0, e = value.size(); i != e; ++i) 994 if (auto str = value[i].dyn_cast<StringAttr>()) 995 if (!str.getValue().empty()) 996 setNameFn(getResult(i), str.getValue()); 997 } 998 999 //===----------------------------------------------------------------------===// 1000 // RegionIfOp 1001 //===----------------------------------------------------------------------===// 1002 1003 static void print(OpAsmPrinter &p, RegionIfOp op) { 1004 p << RegionIfOp::getOperationName() << " "; 1005 p.printOperands(op.getOperands()); 1006 p << ": " << op.getOperandTypes(); 1007 p.printArrowTypeList(op.getResultTypes()); 1008 p << " then"; 1009 p.printRegion(op.thenRegion(), 1010 /*printEntryBlockArgs=*/true, 1011 /*printBlockTerminators=*/true); 1012 p << " else"; 1013 p.printRegion(op.elseRegion(), 1014 /*printEntryBlockArgs=*/true, 1015 /*printBlockTerminators=*/true); 1016 p << " join"; 1017 p.printRegion(op.joinRegion(), 1018 /*printEntryBlockArgs=*/true, 1019 /*printBlockTerminators=*/true); 1020 } 1021 1022 static ParseResult parseRegionIfOp(OpAsmParser &parser, 1023 OperationState &result) { 1024 SmallVector<OpAsmParser::OperandType, 2> operandInfos; 1025 SmallVector<Type, 2> operandTypes; 1026 1027 result.regions.reserve(3); 1028 Region *thenRegion = result.addRegion(); 1029 Region *elseRegion = result.addRegion(); 1030 Region *joinRegion = result.addRegion(); 1031 1032 // Parse operand, type and arrow type lists. 1033 if (parser.parseOperandList(operandInfos) || 1034 parser.parseColonTypeList(operandTypes) || 1035 parser.parseArrowTypeList(result.types)) 1036 return failure(); 1037 1038 // Parse all attached regions. 1039 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 1040 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 1041 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 1042 return failure(); 1043 1044 return parser.resolveOperands(operandInfos, operandTypes, 1045 parser.getCurrentLocation(), result.operands); 1046 } 1047 1048 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 1049 assert(index < 2 && "invalid region index"); 1050 return getOperands(); 1051 } 1052 1053 void RegionIfOp::getSuccessorRegions( 1054 Optional<unsigned> index, ArrayRef<Attribute> operands, 1055 SmallVectorImpl<RegionSuccessor> ®ions) { 1056 // We always branch to the join region. 1057 if (index.hasValue()) { 1058 if (index.getValue() < 2) 1059 regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs())); 1060 else 1061 regions.push_back(RegionSuccessor(getResults())); 1062 return; 1063 } 1064 1065 // The then and else regions are the entry regions of this op. 1066 regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs())); 1067 regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs())); 1068 } 1069 1070 //===----------------------------------------------------------------------===// 1071 // SingleNoTerminatorCustomAsmOp 1072 //===----------------------------------------------------------------------===// 1073 1074 static ParseResult parseSingleNoTerminatorCustomAsmOp(OpAsmParser &parser, 1075 OperationState &state) { 1076 Region *body = state.addRegion(); 1077 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) 1078 return failure(); 1079 return success(); 1080 } 1081 1082 static void print(SingleNoTerminatorCustomAsmOp op, OpAsmPrinter &printer) { 1083 printer << op.getOperationName(); 1084 printer.printRegion( 1085 op.getRegion(), /*printEntryBlockArgs=*/false, 1086 // This op has a single block without terminators. But explicitly mark 1087 // as not printing block terminators for testing. 1088 /*printBlockTerminators=*/false); 1089 } 1090 1091 #include "TestOpEnums.cpp.inc" 1092 #include "TestOpInterfaces.cpp.inc" 1093 #include "TestOpStructs.cpp.inc" 1094 #include "TestTypeInterfaces.cpp.inc" 1095 1096 #define GET_OP_CLASSES 1097 #include "TestOps.cpp.inc" 1098