1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestTypes.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/IR/BuiltinOps.h" 13 #include "mlir/IR/DialectImplementation.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/IR/TypeUtilities.h" 16 #include "mlir/Transforms/FoldUtils.h" 17 #include "mlir/Transforms/InliningUtils.h" 18 #include "llvm/ADT/SetVector.h" 19 #include "llvm/ADT/StringSwitch.h" 20 21 using namespace mlir; 22 using namespace mlir::test; 23 24 void mlir::test::registerTestDialect(DialectRegistry ®istry) { 25 registry.insert<TestDialect>(); 26 } 27 28 //===----------------------------------------------------------------------===// 29 // TestDialect Interfaces 30 //===----------------------------------------------------------------------===// 31 32 namespace { 33 34 // Test support for interacting with the AsmPrinter. 35 struct TestOpAsmInterface : public OpAsmDialectInterface { 36 using OpAsmDialectInterface::OpAsmDialectInterface; 37 38 LogicalResult getAlias(Attribute attr, raw_ostream &os) const final { 39 StringAttr strAttr = attr.dyn_cast<StringAttr>(); 40 if (!strAttr) 41 return failure(); 42 43 // Check the contents of the string attribute to see what the test alias 44 // should be named. 45 Optional<StringRef> aliasName = 46 StringSwitch<Optional<StringRef>>(strAttr.getValue()) 47 .Case("alias_test:dot_in_name", StringRef("test.alias")) 48 .Case("alias_test:trailing_digit", StringRef("test_alias0")) 49 .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) 50 .Case("alias_test:sanitize_conflict_a", 51 StringRef("test_alias_conflict0")) 52 .Case("alias_test:sanitize_conflict_b", 53 StringRef("test_alias_conflict0_")) 54 .Default(llvm::None); 55 if (!aliasName) 56 return failure(); 57 58 os << *aliasName; 59 return success(); 60 } 61 62 void getAsmResultNames(Operation *op, 63 OpAsmSetValueNameFn setNameFn) const final { 64 if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) 65 setNameFn(asmOp, "result"); 66 } 67 68 void getAsmBlockArgumentNames(Block *block, 69 OpAsmSetValueNameFn setNameFn) const final { 70 auto op = block->getParentOp(); 71 auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names"); 72 if (!arrayAttr) 73 return; 74 auto args = block->getArguments(); 75 auto e = std::min(arrayAttr.size(), args.size()); 76 for (unsigned i = 0; i < e; ++i) { 77 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 78 setNameFn(args[i], strAttr.getValue()); 79 } 80 } 81 }; 82 83 struct TestDialectFoldInterface : public DialectFoldInterface { 84 using DialectFoldInterface::DialectFoldInterface; 85 86 /// Registered hook to check if the given region, which is attached to an 87 /// operation that is *not* isolated from above, should be used when 88 /// materializing constants. 89 bool shouldMaterializeInto(Region *region) const final { 90 // If this is a one region operation, then insert into it. 91 return isa<OneRegionOp>(region->getParentOp()); 92 } 93 }; 94 95 /// This class defines the interface for handling inlining with standard 96 /// operations. 97 struct TestInlinerInterface : public DialectInlinerInterface { 98 using DialectInlinerInterface::DialectInlinerInterface; 99 100 //===--------------------------------------------------------------------===// 101 // Analysis Hooks 102 //===--------------------------------------------------------------------===// 103 104 bool isLegalToInline(Operation *call, Operation *callable, 105 bool wouldBeCloned) const final { 106 // Don't allow inlining calls that are marked `noinline`. 107 return !call->hasAttr("noinline"); 108 } 109 bool isLegalToInline(Region *, Region *, bool, 110 BlockAndValueMapping &) const final { 111 // Inlining into test dialect regions is legal. 112 return true; 113 } 114 bool isLegalToInline(Operation *, Region *, bool, 115 BlockAndValueMapping &) const final { 116 return true; 117 } 118 119 bool shouldAnalyzeRecursively(Operation *op) const final { 120 // Analyze recursively if this is not a functional region operation, it 121 // froms a separate functional scope. 122 return !isa<FunctionalRegionOp>(op); 123 } 124 125 //===--------------------------------------------------------------------===// 126 // Transformation Hooks 127 //===--------------------------------------------------------------------===// 128 129 /// Handle the given inlined terminator by replacing it with a new operation 130 /// as necessary. 131 void handleTerminator(Operation *op, 132 ArrayRef<Value> valuesToRepl) const final { 133 // Only handle "test.return" here. 134 auto returnOp = dyn_cast<TestReturnOp>(op); 135 if (!returnOp) 136 return; 137 138 // Replace the values directly with the return operands. 139 assert(returnOp.getNumOperands() == valuesToRepl.size()); 140 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 141 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 142 } 143 144 /// Attempt to materialize a conversion for a type mismatch between a call 145 /// from this dialect, and a callable region. This method should generate an 146 /// operation that takes 'input' as the only operand, and produces a single 147 /// result of 'resultType'. If a conversion can not be generated, nullptr 148 /// should be returned. 149 Operation *materializeCallConversion(OpBuilder &builder, Value input, 150 Type resultType, 151 Location conversionLoc) const final { 152 // Only allow conversion for i16/i32 types. 153 if (!(resultType.isSignlessInteger(16) || 154 resultType.isSignlessInteger(32)) || 155 !(input.getType().isSignlessInteger(16) || 156 input.getType().isSignlessInteger(32))) 157 return nullptr; 158 return builder.create<TestCastOp>(conversionLoc, resultType, input); 159 } 160 }; 161 } // end anonymous namespace 162 163 //===----------------------------------------------------------------------===// 164 // TestDialect 165 //===----------------------------------------------------------------------===// 166 167 void TestDialect::initialize() { 168 addOperations< 169 #define GET_OP_LIST 170 #include "TestOps.cpp.inc" 171 >(); 172 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 173 TestInlinerInterface>(); 174 addTypes<TestType, TestRecursiveType, 175 #define GET_TYPEDEF_LIST 176 #include "TestTypeDefs.cpp.inc" 177 >(); 178 allowUnknownOperations(); 179 } 180 181 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, 182 Type type, Location loc) { 183 return builder.create<TestOpConstant>(loc, type, value); 184 } 185 186 static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser, 187 llvm::SetVector<Type> &stack) { 188 StringRef typeTag; 189 if (failed(parser.parseKeyword(&typeTag))) 190 return Type(); 191 192 auto genType = generatedTypeParser(ctxt, parser, typeTag); 193 if (genType != Type()) 194 return genType; 195 196 if (typeTag == "test_type") 197 return TestType::get(parser.getBuilder().getContext()); 198 199 if (typeTag != "test_rec") 200 return Type(); 201 202 StringRef name; 203 if (parser.parseLess() || parser.parseKeyword(&name)) 204 return Type(); 205 auto rec = TestRecursiveType::get(parser.getBuilder().getContext(), name); 206 207 // If this type already has been parsed above in the stack, expect just the 208 // name. 209 if (stack.contains(rec)) { 210 if (failed(parser.parseGreater())) 211 return Type(); 212 return rec; 213 } 214 215 // Otherwise, parse the body and update the type. 216 if (failed(parser.parseComma())) 217 return Type(); 218 stack.insert(rec); 219 Type subtype = parseTestType(ctxt, parser, stack); 220 stack.pop_back(); 221 if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype))) 222 return Type(); 223 224 return rec; 225 } 226 227 Type TestDialect::parseType(DialectAsmParser &parser) const { 228 llvm::SetVector<Type> stack; 229 return parseTestType(getContext(), parser, stack); 230 } 231 232 static void printTestType(Type type, DialectAsmPrinter &printer, 233 llvm::SetVector<Type> &stack) { 234 if (succeeded(generatedTypePrinter(type, printer))) 235 return; 236 if (type.isa<TestType>()) { 237 printer << "test_type"; 238 return; 239 } 240 241 auto rec = type.cast<TestRecursiveType>(); 242 printer << "test_rec<" << rec.getName(); 243 if (!stack.contains(rec)) { 244 printer << ", "; 245 stack.insert(rec); 246 printTestType(rec.getBody(), printer, stack); 247 stack.pop_back(); 248 } 249 printer << ">"; 250 } 251 252 void TestDialect::printType(Type type, DialectAsmPrinter &printer) const { 253 llvm::SetVector<Type> stack; 254 printTestType(type, printer, stack); 255 } 256 257 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 258 NamedAttribute namedAttr) { 259 if (namedAttr.first == "test.invalid_attr") 260 return op->emitError() << "invalid to use 'test.invalid_attr'"; 261 return success(); 262 } 263 264 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 265 unsigned regionIndex, 266 unsigned argIndex, 267 NamedAttribute namedAttr) { 268 if (namedAttr.first == "test.invalid_attr") 269 return op->emitError() << "invalid to use 'test.invalid_attr'"; 270 return success(); 271 } 272 273 LogicalResult 274 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 275 unsigned resultIndex, 276 NamedAttribute namedAttr) { 277 if (namedAttr.first == "test.invalid_attr") 278 return op->emitError() << "invalid to use 'test.invalid_attr'"; 279 return success(); 280 } 281 282 //===----------------------------------------------------------------------===// 283 // TestBranchOp 284 //===----------------------------------------------------------------------===// 285 286 Optional<MutableOperandRange> 287 TestBranchOp::getMutableSuccessorOperands(unsigned index) { 288 assert(index == 0 && "invalid successor index"); 289 return targetOperandsMutable(); 290 } 291 292 //===----------------------------------------------------------------------===// 293 // TestFoldToCallOp 294 //===----------------------------------------------------------------------===// 295 296 namespace { 297 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 298 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 299 300 LogicalResult matchAndRewrite(FoldToCallOp op, 301 PatternRewriter &rewriter) const override { 302 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(), 303 ValueRange()); 304 return success(); 305 } 306 }; 307 } // end anonymous namespace 308 309 void FoldToCallOp::getCanonicalizationPatterns( 310 OwningRewritePatternList &results, MLIRContext *context) { 311 results.insert<FoldToCallOpPattern>(context); 312 } 313 314 //===----------------------------------------------------------------------===// 315 // Test Format* operations 316 //===----------------------------------------------------------------------===// 317 318 //===----------------------------------------------------------------------===// 319 // Parsing 320 321 static ParseResult parseCustomDirectiveOperands( 322 OpAsmParser &parser, OpAsmParser::OperandType &operand, 323 Optional<OpAsmParser::OperandType> &optOperand, 324 SmallVectorImpl<OpAsmParser::OperandType> &varOperands) { 325 if (parser.parseOperand(operand)) 326 return failure(); 327 if (succeeded(parser.parseOptionalComma())) { 328 optOperand.emplace(); 329 if (parser.parseOperand(*optOperand)) 330 return failure(); 331 } 332 if (parser.parseArrow() || parser.parseLParen() || 333 parser.parseOperandList(varOperands) || parser.parseRParen()) 334 return failure(); 335 return success(); 336 } 337 static ParseResult 338 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 339 Type &optOperandType, 340 SmallVectorImpl<Type> &varOperandTypes) { 341 if (parser.parseColon()) 342 return failure(); 343 344 if (parser.parseType(operandType)) 345 return failure(); 346 if (succeeded(parser.parseOptionalComma())) { 347 if (parser.parseType(optOperandType)) 348 return failure(); 349 } 350 if (parser.parseArrow() || parser.parseLParen() || 351 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 352 return failure(); 353 return success(); 354 } 355 static ParseResult 356 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 357 Type optOperandType, 358 const SmallVectorImpl<Type> &varOperandTypes) { 359 if (parser.parseKeyword("type_refs_capture")) 360 return failure(); 361 362 Type operandType2, optOperandType2; 363 SmallVector<Type, 1> varOperandTypes2; 364 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 365 varOperandTypes2)) 366 return failure(); 367 368 if (operandType != operandType2 || optOperandType != optOperandType2 || 369 varOperandTypes != varOperandTypes2) 370 return failure(); 371 372 return success(); 373 } 374 static ParseResult parseCustomDirectiveOperandsAndTypes( 375 OpAsmParser &parser, OpAsmParser::OperandType &operand, 376 Optional<OpAsmParser::OperandType> &optOperand, 377 SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType, 378 Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) { 379 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 380 parseCustomDirectiveResults(parser, operandType, optOperandType, 381 varOperandTypes)) 382 return failure(); 383 return success(); 384 } 385 static ParseResult parseCustomDirectiveRegions( 386 OpAsmParser &parser, Region ®ion, 387 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 388 if (parser.parseRegion(region)) 389 return failure(); 390 if (failed(parser.parseOptionalComma())) 391 return success(); 392 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 393 if (parser.parseRegion(*varRegion)) 394 return failure(); 395 varRegions.emplace_back(std::move(varRegion)); 396 return success(); 397 } 398 static ParseResult 399 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 400 SmallVectorImpl<Block *> &varSuccessors) { 401 if (parser.parseSuccessor(successor)) 402 return failure(); 403 if (failed(parser.parseOptionalComma())) 404 return success(); 405 Block *varSuccessor; 406 if (parser.parseSuccessor(varSuccessor)) 407 return failure(); 408 varSuccessors.append(2, varSuccessor); 409 return success(); 410 } 411 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 412 IntegerAttr &attr, 413 IntegerAttr &optAttr) { 414 if (parser.parseAttribute(attr)) 415 return failure(); 416 if (succeeded(parser.parseOptionalComma())) { 417 if (parser.parseAttribute(optAttr)) 418 return failure(); 419 } 420 return success(); 421 } 422 423 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, 424 NamedAttrList &attrs) { 425 return parser.parseOptionalAttrDict(attrs); 426 } 427 428 //===----------------------------------------------------------------------===// 429 // Printing 430 431 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 432 Value operand, Value optOperand, 433 OperandRange varOperands) { 434 printer << operand; 435 if (optOperand) 436 printer << ", " << optOperand; 437 printer << " -> (" << varOperands << ")"; 438 } 439 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 440 Type operandType, Type optOperandType, 441 TypeRange varOperandTypes) { 442 printer << " : " << operandType; 443 if (optOperandType) 444 printer << ", " << optOperandType; 445 printer << " -> (" << varOperandTypes << ")"; 446 } 447 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 448 Operation *op, Type operandType, 449 Type optOperandType, 450 TypeRange varOperandTypes) { 451 printer << " type_refs_capture "; 452 printCustomDirectiveResults(printer, op, operandType, optOperandType, 453 varOperandTypes); 454 } 455 static void printCustomDirectiveOperandsAndTypes( 456 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 457 OperandRange varOperands, Type operandType, Type optOperandType, 458 TypeRange varOperandTypes) { 459 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 460 printCustomDirectiveResults(printer, op, operandType, optOperandType, 461 varOperandTypes); 462 } 463 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 464 Region ®ion, 465 MutableArrayRef<Region> varRegions) { 466 printer.printRegion(region); 467 if (!varRegions.empty()) { 468 printer << ", "; 469 for (Region ®ion : varRegions) 470 printer.printRegion(region); 471 } 472 } 473 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 474 Block *successor, 475 SuccessorRange varSuccessors) { 476 printer << successor; 477 if (!varSuccessors.empty()) 478 printer << ", " << varSuccessors.front(); 479 } 480 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 481 Attribute attribute, 482 Attribute optAttribute) { 483 printer << attribute; 484 if (optAttribute) 485 printer << ", " << optAttribute; 486 } 487 488 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 489 MutableDictionaryAttr attrs) { 490 printer.printOptionalAttrDict(attrs.getAttrs()); 491 } 492 //===----------------------------------------------------------------------===// 493 // Test IsolatedRegionOp - parse passthrough region arguments. 494 //===----------------------------------------------------------------------===// 495 496 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, 497 OperationState &result) { 498 OpAsmParser::OperandType argInfo; 499 Type argType = parser.getBuilder().getIndexType(); 500 501 // Parse the input operand. 502 if (parser.parseOperand(argInfo) || 503 parser.resolveOperand(argInfo, argType, result.operands)) 504 return failure(); 505 506 // Parse the body region, and reuse the operand info as the argument info. 507 Region *body = result.addRegion(); 508 return parser.parseRegion(*body, argInfo, argType, 509 /*enableNameShadowing=*/true); 510 } 511 512 static void print(OpAsmPrinter &p, IsolatedRegionOp op) { 513 p << "test.isolated_region "; 514 p.printOperand(op.getOperand()); 515 p.shadowRegionArgs(op.region(), op.getOperand()); 516 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 517 } 518 519 //===----------------------------------------------------------------------===// 520 // Test SSACFGRegionOp 521 //===----------------------------------------------------------------------===// 522 523 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 524 return RegionKind::SSACFG; 525 } 526 527 //===----------------------------------------------------------------------===// 528 // Test GraphRegionOp 529 //===----------------------------------------------------------------------===// 530 531 static ParseResult parseGraphRegionOp(OpAsmParser &parser, 532 OperationState &result) { 533 // Parse the body region, and reuse the operand info as the argument info. 534 Region *body = result.addRegion(); 535 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 536 } 537 538 static void print(OpAsmPrinter &p, GraphRegionOp op) { 539 p << "test.graph_region "; 540 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 541 } 542 543 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 544 return RegionKind::Graph; 545 } 546 547 //===----------------------------------------------------------------------===// 548 // Test AffineScopeOp 549 //===----------------------------------------------------------------------===// 550 551 static ParseResult parseAffineScopeOp(OpAsmParser &parser, 552 OperationState &result) { 553 // Parse the body region, and reuse the operand info as the argument info. 554 Region *body = result.addRegion(); 555 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 556 } 557 558 static void print(OpAsmPrinter &p, AffineScopeOp op) { 559 p << "test.affine_scope "; 560 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 561 } 562 563 //===----------------------------------------------------------------------===// 564 // Test parser. 565 //===----------------------------------------------------------------------===// 566 567 static ParseResult parseWrappedKeywordOp(OpAsmParser &parser, 568 OperationState &result) { 569 StringRef keyword; 570 if (parser.parseKeyword(&keyword)) 571 return failure(); 572 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 573 return success(); 574 } 575 576 static void print(OpAsmPrinter &p, WrappedKeywordOp op) { 577 p << WrappedKeywordOp::getOperationName() << " " << op.keyword(); 578 } 579 580 //===----------------------------------------------------------------------===// 581 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 582 583 static ParseResult parseWrappingRegionOp(OpAsmParser &parser, 584 OperationState &result) { 585 if (parser.parseKeyword("wraps")) 586 return failure(); 587 588 // Parse the wrapped op in a region 589 Region &body = *result.addRegion(); 590 body.push_back(new Block); 591 Block &block = body.back(); 592 Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); 593 if (!wrapped_op) 594 return failure(); 595 596 // Create a return terminator in the inner region, pass as operand to the 597 // terminator the returned values from the wrapped operation. 598 SmallVector<Value, 8> return_operands(wrapped_op->getResults()); 599 OpBuilder builder(parser.getBuilder().getContext()); 600 builder.setInsertionPointToEnd(&block); 601 builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands); 602 603 // Get the results type for the wrapping op from the terminator operands. 604 Operation &return_op = body.back().back(); 605 result.types.append(return_op.operand_type_begin(), 606 return_op.operand_type_end()); 607 608 // Use the location of the wrapped op for the "test.wrapping_region" op. 609 result.location = wrapped_op->getLoc(); 610 611 return success(); 612 } 613 614 static void print(OpAsmPrinter &p, WrappingRegionOp op) { 615 p << op.getOperationName() << " wraps "; 616 p.printGenericOp(&op.region().front().front()); 617 } 618 619 //===----------------------------------------------------------------------===// 620 // Test PolyForOp - parse list of region arguments. 621 //===----------------------------------------------------------------------===// 622 623 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { 624 SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 625 // Parse list of region arguments without a delimiter. 626 if (parser.parseRegionArgumentList(ivsInfo)) 627 return failure(); 628 629 // Parse the body region. 630 Region *body = result.addRegion(); 631 auto &builder = parser.getBuilder(); 632 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 633 return parser.parseRegion(*body, ivsInfo, argTypes); 634 } 635 636 //===----------------------------------------------------------------------===// 637 // Test removing op with inner ops. 638 //===----------------------------------------------------------------------===// 639 640 namespace { 641 struct TestRemoveOpWithInnerOps 642 : public OpRewritePattern<TestOpWithRegionPattern> { 643 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 644 645 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 646 PatternRewriter &rewriter) const override { 647 rewriter.eraseOp(op); 648 return success(); 649 } 650 }; 651 } // end anonymous namespace 652 653 void TestOpWithRegionPattern::getCanonicalizationPatterns( 654 OwningRewritePatternList &results, MLIRContext *context) { 655 results.insert<TestRemoveOpWithInnerOps>(context); 656 } 657 658 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 659 return operand(); 660 } 661 662 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) { 663 return getValue(); 664 } 665 666 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 667 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 668 for (Value input : this->operands()) { 669 results.push_back(input); 670 } 671 return success(); 672 } 673 674 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 675 assert(operands.size() == 1); 676 if (operands.front()) { 677 setAttr("attr", operands.front()); 678 return getResult(); 679 } 680 return {}; 681 } 682 683 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 684 MLIRContext *, Optional<Location> location, ValueRange operands, 685 DictionaryAttr attributes, RegionRange regions, 686 SmallVectorImpl<Type> &inferredReturnTypes) { 687 if (operands[0].getType() != operands[1].getType()) { 688 return emitOptionalError(location, "operand type mismatch ", 689 operands[0].getType(), " vs ", 690 operands[1].getType()); 691 } 692 inferredReturnTypes.assign({operands[0].getType()}); 693 return success(); 694 } 695 696 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 697 MLIRContext *context, Optional<Location> location, ValueRange operands, 698 DictionaryAttr attributes, RegionRange regions, 699 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 700 // Create return type consisting of the last element of the first operand. 701 auto operandType = *operands.getTypes().begin(); 702 auto sval = operandType.dyn_cast<ShapedType>(); 703 if (!sval) { 704 return emitOptionalError(location, "only shaped type operands allowed"); 705 } 706 int64_t dim = 707 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 708 auto type = IntegerType::get(17, context); 709 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 710 return success(); 711 } 712 713 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 714 OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) { 715 shapes = SmallVector<Value, 1>{ 716 builder.createOrFold<DimOp>(getLoc(), getOperand(0), 0)}; 717 return success(); 718 } 719 720 //===----------------------------------------------------------------------===// 721 // Test SideEffect interfaces 722 //===----------------------------------------------------------------------===// 723 724 namespace { 725 /// A test resource for side effects. 726 struct TestResource : public SideEffects::Resource::Base<TestResource> { 727 StringRef getName() final { return "<Test>"; } 728 }; 729 } // end anonymous namespace 730 731 void SideEffectOp::getEffects( 732 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 733 // Check for an effects attribute on the op instance. 734 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects"); 735 if (!effectsAttr) 736 return; 737 738 // If there is one, it is an array of dictionary attributes that hold 739 // information on the effects of this operation. 740 for (Attribute element : effectsAttr) { 741 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 742 743 // Get the specific memory effect. 744 MemoryEffects::Effect *effect = 745 StringSwitch<MemoryEffects::Effect *>( 746 effectElement.get("effect").cast<StringAttr>().getValue()) 747 .Case("allocate", MemoryEffects::Allocate::get()) 748 .Case("free", MemoryEffects::Free::get()) 749 .Case("read", MemoryEffects::Read::get()) 750 .Case("write", MemoryEffects::Write::get()); 751 752 // Check for a non-default resource to use. 753 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 754 if (effectElement.get("test_resource")) 755 resource = TestResource::get(); 756 757 // Check for a result to affect. 758 if (effectElement.get("on_result")) 759 effects.emplace_back(effect, getResult(), resource); 760 else if (Attribute ref = effectElement.get("on_reference")) 761 effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource); 762 else 763 effects.emplace_back(effect, resource); 764 } 765 } 766 767 void SideEffectOp::getEffects( 768 SmallVectorImpl<TestEffects::EffectInstance> &effects) { 769 auto effectsAttr = (*this)->getAttrOfType<AffineMapAttr>("effect_parameter"); 770 if (!effectsAttr) 771 return; 772 773 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); 774 } 775 776 //===----------------------------------------------------------------------===// 777 // StringAttrPrettyNameOp 778 //===----------------------------------------------------------------------===// 779 780 // This op has fancy handling of its SSA result name. 781 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, 782 OperationState &result) { 783 // Add the result types. 784 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 785 result.addTypes(parser.getBuilder().getIntegerType(32)); 786 787 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 788 return failure(); 789 790 // If the attribute dictionary contains no 'names' attribute, infer it from 791 // the SSA name (if specified). 792 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 793 return attr.first == "names"; 794 }); 795 796 // If there was no name specified, check to see if there was a useful name 797 // specified in the asm file. 798 if (hadNames || parser.getNumResults() == 0) 799 return success(); 800 801 SmallVector<StringRef, 4> names; 802 auto *context = result.getContext(); 803 804 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 805 auto resultName = parser.getResultName(i); 806 StringRef nameStr; 807 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 808 nameStr = resultName.first; 809 810 names.push_back(nameStr); 811 } 812 813 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 814 result.attributes.push_back({Identifier::get("names", context), namesAttr}); 815 return success(); 816 } 817 818 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { 819 p << "test.string_attr_pretty_name"; 820 821 // Note that we only need to print the "name" attribute if the asmprinter 822 // result name disagrees with it. This can happen in strange cases, e.g. 823 // when there are conflicts. 824 bool namesDisagree = op.names().size() != op.getNumResults(); 825 826 SmallString<32> resultNameStr; 827 for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { 828 resultNameStr.clear(); 829 llvm::raw_svector_ostream tmpStream(resultNameStr); 830 p.printOperand(op.getResult(i), tmpStream); 831 832 auto expectedName = op.names()[i].dyn_cast<StringAttr>(); 833 if (!expectedName || 834 tmpStream.str().drop_front() != expectedName.getValue()) { 835 namesDisagree = true; 836 } 837 } 838 839 if (namesDisagree) 840 p.printOptionalAttrDictWithKeyword(op.getAttrs()); 841 else 842 p.printOptionalAttrDictWithKeyword(op.getAttrs(), {"names"}); 843 } 844 845 // We set the SSA name in the asm syntax to the contents of the name 846 // attribute. 847 void StringAttrPrettyNameOp::getAsmResultNames( 848 function_ref<void(Value, StringRef)> setNameFn) { 849 850 auto value = names(); 851 for (size_t i = 0, e = value.size(); i != e; ++i) 852 if (auto str = value[i].dyn_cast<StringAttr>()) 853 if (!str.getValue().empty()) 854 setNameFn(getResult(i), str.getValue()); 855 } 856 857 //===----------------------------------------------------------------------===// 858 // RegionIfOp 859 //===----------------------------------------------------------------------===// 860 861 static void print(OpAsmPrinter &p, RegionIfOp op) { 862 p << RegionIfOp::getOperationName() << " "; 863 p.printOperands(op.getOperands()); 864 p << ": " << op.getOperandTypes(); 865 p.printArrowTypeList(op.getResultTypes()); 866 p << " then"; 867 p.printRegion(op.thenRegion(), 868 /*printEntryBlockArgs=*/true, 869 /*printBlockTerminators=*/true); 870 p << " else"; 871 p.printRegion(op.elseRegion(), 872 /*printEntryBlockArgs=*/true, 873 /*printBlockTerminators=*/true); 874 p << " join"; 875 p.printRegion(op.joinRegion(), 876 /*printEntryBlockArgs=*/true, 877 /*printBlockTerminators=*/true); 878 } 879 880 static ParseResult parseRegionIfOp(OpAsmParser &parser, 881 OperationState &result) { 882 SmallVector<OpAsmParser::OperandType, 2> operandInfos; 883 SmallVector<Type, 2> operandTypes; 884 885 result.regions.reserve(3); 886 Region *thenRegion = result.addRegion(); 887 Region *elseRegion = result.addRegion(); 888 Region *joinRegion = result.addRegion(); 889 890 // Parse operand, type and arrow type lists. 891 if (parser.parseOperandList(operandInfos) || 892 parser.parseColonTypeList(operandTypes) || 893 parser.parseArrowTypeList(result.types)) 894 return failure(); 895 896 // Parse all attached regions. 897 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 898 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 899 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 900 return failure(); 901 902 return parser.resolveOperands(operandInfos, operandTypes, 903 parser.getCurrentLocation(), result.operands); 904 } 905 906 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 907 assert(index < 2 && "invalid region index"); 908 return getOperands(); 909 } 910 911 void RegionIfOp::getSuccessorRegions( 912 Optional<unsigned> index, ArrayRef<Attribute> operands, 913 SmallVectorImpl<RegionSuccessor> ®ions) { 914 // We always branch to the join region. 915 if (index.hasValue()) { 916 if (index.getValue() < 2) 917 regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs())); 918 else 919 regions.push_back(RegionSuccessor(getResults())); 920 return; 921 } 922 923 // The then and else regions are the entry regions of this op. 924 regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs())); 925 regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs())); 926 } 927 928 #include "TestOpEnums.cpp.inc" 929 #include "TestOpInterfaces.cpp.inc" 930 #include "TestOpStructs.cpp.inc" 931 #include "TestTypeInterfaces.cpp.inc" 932 933 #define GET_OP_CLASSES 934 #include "TestOps.cpp.inc" 935