1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestAttributes.h" 11 #include "TestTypes.h" 12 #include "mlir/Dialect/StandardOps/IR/Ops.h" 13 #include "mlir/IR/BuiltinOps.h" 14 #include "mlir/IR/DialectImplementation.h" 15 #include "mlir/IR/PatternMatch.h" 16 #include "mlir/IR/TypeUtilities.h" 17 #include "mlir/Transforms/FoldUtils.h" 18 #include "mlir/Transforms/InliningUtils.h" 19 #include "llvm/ADT/StringSwitch.h" 20 21 using namespace mlir; 22 using namespace mlir::test; 23 24 void mlir::test::registerTestDialect(DialectRegistry ®istry) { 25 registry.insert<TestDialect>(); 26 } 27 28 //===----------------------------------------------------------------------===// 29 // TestDialect Interfaces 30 //===----------------------------------------------------------------------===// 31 32 namespace { 33 34 // Test support for interacting with the AsmPrinter. 35 struct TestOpAsmInterface : public OpAsmDialectInterface { 36 using OpAsmDialectInterface::OpAsmDialectInterface; 37 38 LogicalResult getAlias(Attribute attr, raw_ostream &os) const final { 39 StringAttr strAttr = attr.dyn_cast<StringAttr>(); 40 if (!strAttr) 41 return failure(); 42 43 // Check the contents of the string attribute to see what the test alias 44 // should be named. 45 Optional<StringRef> aliasName = 46 StringSwitch<Optional<StringRef>>(strAttr.getValue()) 47 .Case("alias_test:dot_in_name", StringRef("test.alias")) 48 .Case("alias_test:trailing_digit", StringRef("test_alias0")) 49 .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) 50 .Case("alias_test:sanitize_conflict_a", 51 StringRef("test_alias_conflict0")) 52 .Case("alias_test:sanitize_conflict_b", 53 StringRef("test_alias_conflict0_")) 54 .Default(llvm::None); 55 if (!aliasName) 56 return failure(); 57 58 os << *aliasName; 59 return success(); 60 } 61 62 void getAsmResultNames(Operation *op, 63 OpAsmSetValueNameFn setNameFn) const final { 64 if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) 65 setNameFn(asmOp, "result"); 66 } 67 68 void getAsmBlockArgumentNames(Block *block, 69 OpAsmSetValueNameFn setNameFn) const final { 70 auto op = block->getParentOp(); 71 auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names"); 72 if (!arrayAttr) 73 return; 74 auto args = block->getArguments(); 75 auto e = std::min(arrayAttr.size(), args.size()); 76 for (unsigned i = 0; i < e; ++i) { 77 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 78 setNameFn(args[i], strAttr.getValue()); 79 } 80 } 81 }; 82 83 struct TestDialectFoldInterface : public DialectFoldInterface { 84 using DialectFoldInterface::DialectFoldInterface; 85 86 /// Registered hook to check if the given region, which is attached to an 87 /// operation that is *not* isolated from above, should be used when 88 /// materializing constants. 89 bool shouldMaterializeInto(Region *region) const final { 90 // If this is a one region operation, then insert into it. 91 return isa<OneRegionOp>(region->getParentOp()); 92 } 93 }; 94 95 /// This class defines the interface for handling inlining with standard 96 /// operations. 97 struct TestInlinerInterface : public DialectInlinerInterface { 98 using DialectInlinerInterface::DialectInlinerInterface; 99 100 //===--------------------------------------------------------------------===// 101 // Analysis Hooks 102 //===--------------------------------------------------------------------===// 103 104 bool isLegalToInline(Operation *call, Operation *callable, 105 bool wouldBeCloned) const final { 106 // Don't allow inlining calls that are marked `noinline`. 107 return !call->hasAttr("noinline"); 108 } 109 bool isLegalToInline(Region *, Region *, bool, 110 BlockAndValueMapping &) const final { 111 // Inlining into test dialect regions is legal. 112 return true; 113 } 114 bool isLegalToInline(Operation *, Region *, bool, 115 BlockAndValueMapping &) const final { 116 return true; 117 } 118 119 bool shouldAnalyzeRecursively(Operation *op) const final { 120 // Analyze recursively if this is not a functional region operation, it 121 // froms a separate functional scope. 122 return !isa<FunctionalRegionOp>(op); 123 } 124 125 //===--------------------------------------------------------------------===// 126 // Transformation Hooks 127 //===--------------------------------------------------------------------===// 128 129 /// Handle the given inlined terminator by replacing it with a new operation 130 /// as necessary. 131 void handleTerminator(Operation *op, 132 ArrayRef<Value> valuesToRepl) const final { 133 // Only handle "test.return" here. 134 auto returnOp = dyn_cast<TestReturnOp>(op); 135 if (!returnOp) 136 return; 137 138 // Replace the values directly with the return operands. 139 assert(returnOp.getNumOperands() == valuesToRepl.size()); 140 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 141 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 142 } 143 144 /// Attempt to materialize a conversion for a type mismatch between a call 145 /// from this dialect, and a callable region. This method should generate an 146 /// operation that takes 'input' as the only operand, and produces a single 147 /// result of 'resultType'. If a conversion can not be generated, nullptr 148 /// should be returned. 149 Operation *materializeCallConversion(OpBuilder &builder, Value input, 150 Type resultType, 151 Location conversionLoc) const final { 152 // Only allow conversion for i16/i32 types. 153 if (!(resultType.isSignlessInteger(16) || 154 resultType.isSignlessInteger(32)) || 155 !(input.getType().isSignlessInteger(16) || 156 input.getType().isSignlessInteger(32))) 157 return nullptr; 158 return builder.create<TestCastOp>(conversionLoc, resultType, input); 159 } 160 }; 161 } // end anonymous namespace 162 163 //===----------------------------------------------------------------------===// 164 // TestDialect 165 //===----------------------------------------------------------------------===// 166 167 void TestDialect::initialize() { 168 addOperations< 169 #define GET_OP_LIST 170 #include "TestOps.cpp.inc" 171 >(); 172 addAttributes< 173 #define GET_ATTRDEF_LIST 174 #include "TestAttrDefs.cpp.inc" 175 >(); 176 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 177 TestInlinerInterface>(); 178 addTypes<TestType, TestRecursiveType, 179 #define GET_TYPEDEF_LIST 180 #include "TestTypeDefs.cpp.inc" 181 >(); 182 allowUnknownOperations(); 183 } 184 185 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, 186 Type type, Location loc) { 187 return builder.create<TestOpConstant>(loc, type, value); 188 } 189 190 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 191 NamedAttribute namedAttr) { 192 if (namedAttr.first == "test.invalid_attr") 193 return op->emitError() << "invalid to use 'test.invalid_attr'"; 194 return success(); 195 } 196 197 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 198 unsigned regionIndex, 199 unsigned argIndex, 200 NamedAttribute namedAttr) { 201 if (namedAttr.first == "test.invalid_attr") 202 return op->emitError() << "invalid to use 'test.invalid_attr'"; 203 return success(); 204 } 205 206 LogicalResult 207 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 208 unsigned resultIndex, 209 NamedAttribute namedAttr) { 210 if (namedAttr.first == "test.invalid_attr") 211 return op->emitError() << "invalid to use 'test.invalid_attr'"; 212 return success(); 213 } 214 215 //===----------------------------------------------------------------------===// 216 // TestBranchOp 217 //===----------------------------------------------------------------------===// 218 219 Optional<MutableOperandRange> 220 TestBranchOp::getMutableSuccessorOperands(unsigned index) { 221 assert(index == 0 && "invalid successor index"); 222 return targetOperandsMutable(); 223 } 224 225 //===----------------------------------------------------------------------===// 226 // TestFoldToCallOp 227 //===----------------------------------------------------------------------===// 228 229 namespace { 230 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 231 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 232 233 LogicalResult matchAndRewrite(FoldToCallOp op, 234 PatternRewriter &rewriter) const override { 235 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(), 236 ValueRange()); 237 return success(); 238 } 239 }; 240 } // end anonymous namespace 241 242 void FoldToCallOp::getCanonicalizationPatterns( 243 OwningRewritePatternList &results, MLIRContext *context) { 244 results.insert<FoldToCallOpPattern>(context); 245 } 246 247 //===----------------------------------------------------------------------===// 248 // Test Format* operations 249 //===----------------------------------------------------------------------===// 250 251 //===----------------------------------------------------------------------===// 252 // Parsing 253 254 static ParseResult parseCustomDirectiveOperands( 255 OpAsmParser &parser, OpAsmParser::OperandType &operand, 256 Optional<OpAsmParser::OperandType> &optOperand, 257 SmallVectorImpl<OpAsmParser::OperandType> &varOperands) { 258 if (parser.parseOperand(operand)) 259 return failure(); 260 if (succeeded(parser.parseOptionalComma())) { 261 optOperand.emplace(); 262 if (parser.parseOperand(*optOperand)) 263 return failure(); 264 } 265 if (parser.parseArrow() || parser.parseLParen() || 266 parser.parseOperandList(varOperands) || parser.parseRParen()) 267 return failure(); 268 return success(); 269 } 270 static ParseResult 271 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 272 Type &optOperandType, 273 SmallVectorImpl<Type> &varOperandTypes) { 274 if (parser.parseColon()) 275 return failure(); 276 277 if (parser.parseType(operandType)) 278 return failure(); 279 if (succeeded(parser.parseOptionalComma())) { 280 if (parser.parseType(optOperandType)) 281 return failure(); 282 } 283 if (parser.parseArrow() || parser.parseLParen() || 284 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 285 return failure(); 286 return success(); 287 } 288 static ParseResult 289 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 290 Type optOperandType, 291 const SmallVectorImpl<Type> &varOperandTypes) { 292 if (parser.parseKeyword("type_refs_capture")) 293 return failure(); 294 295 Type operandType2, optOperandType2; 296 SmallVector<Type, 1> varOperandTypes2; 297 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 298 varOperandTypes2)) 299 return failure(); 300 301 if (operandType != operandType2 || optOperandType != optOperandType2 || 302 varOperandTypes != varOperandTypes2) 303 return failure(); 304 305 return success(); 306 } 307 static ParseResult parseCustomDirectiveOperandsAndTypes( 308 OpAsmParser &parser, OpAsmParser::OperandType &operand, 309 Optional<OpAsmParser::OperandType> &optOperand, 310 SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType, 311 Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) { 312 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 313 parseCustomDirectiveResults(parser, operandType, optOperandType, 314 varOperandTypes)) 315 return failure(); 316 return success(); 317 } 318 static ParseResult parseCustomDirectiveRegions( 319 OpAsmParser &parser, Region ®ion, 320 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 321 if (parser.parseRegion(region)) 322 return failure(); 323 if (failed(parser.parseOptionalComma())) 324 return success(); 325 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 326 if (parser.parseRegion(*varRegion)) 327 return failure(); 328 varRegions.emplace_back(std::move(varRegion)); 329 return success(); 330 } 331 static ParseResult 332 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 333 SmallVectorImpl<Block *> &varSuccessors) { 334 if (parser.parseSuccessor(successor)) 335 return failure(); 336 if (failed(parser.parseOptionalComma())) 337 return success(); 338 Block *varSuccessor; 339 if (parser.parseSuccessor(varSuccessor)) 340 return failure(); 341 varSuccessors.append(2, varSuccessor); 342 return success(); 343 } 344 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 345 IntegerAttr &attr, 346 IntegerAttr &optAttr) { 347 if (parser.parseAttribute(attr)) 348 return failure(); 349 if (succeeded(parser.parseOptionalComma())) { 350 if (parser.parseAttribute(optAttr)) 351 return failure(); 352 } 353 return success(); 354 } 355 356 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, 357 NamedAttrList &attrs) { 358 return parser.parseOptionalAttrDict(attrs); 359 } 360 static ParseResult parseCustomDirectiveOptionalOperandRef( 361 OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) { 362 int64_t operandCount = 0; 363 if (parser.parseInteger(operandCount)) 364 return failure(); 365 bool expectedOptionalOperand = operandCount == 0; 366 return success(expectedOptionalOperand != optOperand.hasValue()); 367 } 368 369 //===----------------------------------------------------------------------===// 370 // Printing 371 372 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 373 Value operand, Value optOperand, 374 OperandRange varOperands) { 375 printer << operand; 376 if (optOperand) 377 printer << ", " << optOperand; 378 printer << " -> (" << varOperands << ")"; 379 } 380 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 381 Type operandType, Type optOperandType, 382 TypeRange varOperandTypes) { 383 printer << " : " << operandType; 384 if (optOperandType) 385 printer << ", " << optOperandType; 386 printer << " -> (" << varOperandTypes << ")"; 387 } 388 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 389 Operation *op, Type operandType, 390 Type optOperandType, 391 TypeRange varOperandTypes) { 392 printer << " type_refs_capture "; 393 printCustomDirectiveResults(printer, op, operandType, optOperandType, 394 varOperandTypes); 395 } 396 static void printCustomDirectiveOperandsAndTypes( 397 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 398 OperandRange varOperands, Type operandType, Type optOperandType, 399 TypeRange varOperandTypes) { 400 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 401 printCustomDirectiveResults(printer, op, operandType, optOperandType, 402 varOperandTypes); 403 } 404 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 405 Region ®ion, 406 MutableArrayRef<Region> varRegions) { 407 printer.printRegion(region); 408 if (!varRegions.empty()) { 409 printer << ", "; 410 for (Region ®ion : varRegions) 411 printer.printRegion(region); 412 } 413 } 414 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 415 Block *successor, 416 SuccessorRange varSuccessors) { 417 printer << successor; 418 if (!varSuccessors.empty()) 419 printer << ", " << varSuccessors.front(); 420 } 421 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 422 Attribute attribute, 423 Attribute optAttribute) { 424 printer << attribute; 425 if (optAttribute) 426 printer << ", " << optAttribute; 427 } 428 429 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 430 DictionaryAttr attrs) { 431 printer.printOptionalAttrDict(attrs.getValue()); 432 } 433 434 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, 435 Operation *op, 436 Value optOperand) { 437 printer << (optOperand ? "1" : "0"); 438 } 439 440 //===----------------------------------------------------------------------===// 441 // Test IsolatedRegionOp - parse passthrough region arguments. 442 //===----------------------------------------------------------------------===// 443 444 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, 445 OperationState &result) { 446 OpAsmParser::OperandType argInfo; 447 Type argType = parser.getBuilder().getIndexType(); 448 449 // Parse the input operand. 450 if (parser.parseOperand(argInfo) || 451 parser.resolveOperand(argInfo, argType, result.operands)) 452 return failure(); 453 454 // Parse the body region, and reuse the operand info as the argument info. 455 Region *body = result.addRegion(); 456 return parser.parseRegion(*body, argInfo, argType, 457 /*enableNameShadowing=*/true); 458 } 459 460 static void print(OpAsmPrinter &p, IsolatedRegionOp op) { 461 p << "test.isolated_region "; 462 p.printOperand(op.getOperand()); 463 p.shadowRegionArgs(op.region(), op.getOperand()); 464 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 465 } 466 467 //===----------------------------------------------------------------------===// 468 // Test SSACFGRegionOp 469 //===----------------------------------------------------------------------===// 470 471 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 472 return RegionKind::SSACFG; 473 } 474 475 //===----------------------------------------------------------------------===// 476 // Test GraphRegionOp 477 //===----------------------------------------------------------------------===// 478 479 static ParseResult parseGraphRegionOp(OpAsmParser &parser, 480 OperationState &result) { 481 // Parse the body region, and reuse the operand info as the argument info. 482 Region *body = result.addRegion(); 483 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 484 } 485 486 static void print(OpAsmPrinter &p, GraphRegionOp op) { 487 p << "test.graph_region "; 488 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 489 } 490 491 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 492 return RegionKind::Graph; 493 } 494 495 //===----------------------------------------------------------------------===// 496 // Test AffineScopeOp 497 //===----------------------------------------------------------------------===// 498 499 static ParseResult parseAffineScopeOp(OpAsmParser &parser, 500 OperationState &result) { 501 // Parse the body region, and reuse the operand info as the argument info. 502 Region *body = result.addRegion(); 503 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 504 } 505 506 static void print(OpAsmPrinter &p, AffineScopeOp op) { 507 p << "test.affine_scope "; 508 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 509 } 510 511 //===----------------------------------------------------------------------===// 512 // Test parser. 513 //===----------------------------------------------------------------------===// 514 515 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser, 516 OperationState &result) { 517 if (parser.parseOptionalColon()) 518 return success(); 519 uint64_t numResults; 520 if (parser.parseInteger(numResults)) 521 return failure(); 522 523 IndexType type = parser.getBuilder().getIndexType(); 524 for (unsigned i = 0; i < numResults; ++i) 525 result.addTypes(type); 526 return success(); 527 } 528 529 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) { 530 p << ParseIntegerLiteralOp::getOperationName(); 531 if (unsigned numResults = op->getNumResults()) 532 p << " : " << numResults; 533 } 534 535 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser, 536 OperationState &result) { 537 StringRef keyword; 538 if (parser.parseKeyword(&keyword)) 539 return failure(); 540 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 541 return success(); 542 } 543 544 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) { 545 p << ParseWrappedKeywordOp::getOperationName() << " " << op.keyword(); 546 } 547 548 //===----------------------------------------------------------------------===// 549 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 550 551 static ParseResult parseWrappingRegionOp(OpAsmParser &parser, 552 OperationState &result) { 553 if (parser.parseKeyword("wraps")) 554 return failure(); 555 556 // Parse the wrapped op in a region 557 Region &body = *result.addRegion(); 558 body.push_back(new Block); 559 Block &block = body.back(); 560 Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); 561 if (!wrapped_op) 562 return failure(); 563 564 // Create a return terminator in the inner region, pass as operand to the 565 // terminator the returned values from the wrapped operation. 566 SmallVector<Value, 8> return_operands(wrapped_op->getResults()); 567 OpBuilder builder(parser.getBuilder().getContext()); 568 builder.setInsertionPointToEnd(&block); 569 builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands); 570 571 // Get the results type for the wrapping op from the terminator operands. 572 Operation &return_op = body.back().back(); 573 result.types.append(return_op.operand_type_begin(), 574 return_op.operand_type_end()); 575 576 // Use the location of the wrapped op for the "test.wrapping_region" op. 577 result.location = wrapped_op->getLoc(); 578 579 return success(); 580 } 581 582 static void print(OpAsmPrinter &p, WrappingRegionOp op) { 583 p << op.getOperationName() << " wraps "; 584 p.printGenericOp(&op.region().front().front()); 585 } 586 587 //===----------------------------------------------------------------------===// 588 // Test PolyForOp - parse list of region arguments. 589 //===----------------------------------------------------------------------===// 590 591 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { 592 SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 593 // Parse list of region arguments without a delimiter. 594 if (parser.parseRegionArgumentList(ivsInfo)) 595 return failure(); 596 597 // Parse the body region. 598 Region *body = result.addRegion(); 599 auto &builder = parser.getBuilder(); 600 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 601 return parser.parseRegion(*body, ivsInfo, argTypes); 602 } 603 604 //===----------------------------------------------------------------------===// 605 // Test removing op with inner ops. 606 //===----------------------------------------------------------------------===// 607 608 namespace { 609 struct TestRemoveOpWithInnerOps 610 : public OpRewritePattern<TestOpWithRegionPattern> { 611 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 612 613 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 614 PatternRewriter &rewriter) const override { 615 rewriter.eraseOp(op); 616 return success(); 617 } 618 }; 619 } // end anonymous namespace 620 621 void TestOpWithRegionPattern::getCanonicalizationPatterns( 622 OwningRewritePatternList &results, MLIRContext *context) { 623 results.insert<TestRemoveOpWithInnerOps>(context); 624 } 625 626 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 627 return operand(); 628 } 629 630 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) { 631 return getValue(); 632 } 633 634 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 635 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 636 for (Value input : this->operands()) { 637 results.push_back(input); 638 } 639 return success(); 640 } 641 642 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 643 assert(operands.size() == 1); 644 if (operands.front()) { 645 (*this)->setAttr("attr", operands.front()); 646 return getResult(); 647 } 648 return {}; 649 } 650 651 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) { 652 return getOperand(); 653 } 654 655 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 656 MLIRContext *, Optional<Location> location, ValueRange operands, 657 DictionaryAttr attributes, RegionRange regions, 658 SmallVectorImpl<Type> &inferredReturnTypes) { 659 if (operands[0].getType() != operands[1].getType()) { 660 return emitOptionalError(location, "operand type mismatch ", 661 operands[0].getType(), " vs ", 662 operands[1].getType()); 663 } 664 inferredReturnTypes.assign({operands[0].getType()}); 665 return success(); 666 } 667 668 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 669 MLIRContext *context, Optional<Location> location, ValueRange operands, 670 DictionaryAttr attributes, RegionRange regions, 671 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 672 // Create return type consisting of the last element of the first operand. 673 auto operandType = *operands.getTypes().begin(); 674 auto sval = operandType.dyn_cast<ShapedType>(); 675 if (!sval) { 676 return emitOptionalError(location, "only shaped type operands allowed"); 677 } 678 int64_t dim = 679 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 680 auto type = IntegerType::get(context, 17); 681 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 682 return success(); 683 } 684 685 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 686 OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) { 687 shapes = SmallVector<Value, 1>{ 688 builder.createOrFold<DimOp>(getLoc(), getOperand(0), 0)}; 689 return success(); 690 } 691 692 //===----------------------------------------------------------------------===// 693 // Test SideEffect interfaces 694 //===----------------------------------------------------------------------===// 695 696 namespace { 697 /// A test resource for side effects. 698 struct TestResource : public SideEffects::Resource::Base<TestResource> { 699 StringRef getName() final { return "<Test>"; } 700 }; 701 } // end anonymous namespace 702 703 void SideEffectOp::getEffects( 704 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 705 // Check for an effects attribute on the op instance. 706 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects"); 707 if (!effectsAttr) 708 return; 709 710 // If there is one, it is an array of dictionary attributes that hold 711 // information on the effects of this operation. 712 for (Attribute element : effectsAttr) { 713 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 714 715 // Get the specific memory effect. 716 MemoryEffects::Effect *effect = 717 StringSwitch<MemoryEffects::Effect *>( 718 effectElement.get("effect").cast<StringAttr>().getValue()) 719 .Case("allocate", MemoryEffects::Allocate::get()) 720 .Case("free", MemoryEffects::Free::get()) 721 .Case("read", MemoryEffects::Read::get()) 722 .Case("write", MemoryEffects::Write::get()); 723 724 // Check for a non-default resource to use. 725 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 726 if (effectElement.get("test_resource")) 727 resource = TestResource::get(); 728 729 // Check for a result to affect. 730 if (effectElement.get("on_result")) 731 effects.emplace_back(effect, getResult(), resource); 732 else if (Attribute ref = effectElement.get("on_reference")) 733 effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource); 734 else 735 effects.emplace_back(effect, resource); 736 } 737 } 738 739 void SideEffectOp::getEffects( 740 SmallVectorImpl<TestEffects::EffectInstance> &effects) { 741 auto effectsAttr = (*this)->getAttrOfType<AffineMapAttr>("effect_parameter"); 742 if (!effectsAttr) 743 return; 744 745 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); 746 } 747 748 //===----------------------------------------------------------------------===// 749 // StringAttrPrettyNameOp 750 //===----------------------------------------------------------------------===// 751 752 // This op has fancy handling of its SSA result name. 753 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, 754 OperationState &result) { 755 // Add the result types. 756 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 757 result.addTypes(parser.getBuilder().getIntegerType(32)); 758 759 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 760 return failure(); 761 762 // If the attribute dictionary contains no 'names' attribute, infer it from 763 // the SSA name (if specified). 764 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 765 return attr.first == "names"; 766 }); 767 768 // If there was no name specified, check to see if there was a useful name 769 // specified in the asm file. 770 if (hadNames || parser.getNumResults() == 0) 771 return success(); 772 773 SmallVector<StringRef, 4> names; 774 auto *context = result.getContext(); 775 776 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 777 auto resultName = parser.getResultName(i); 778 StringRef nameStr; 779 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 780 nameStr = resultName.first; 781 782 names.push_back(nameStr); 783 } 784 785 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 786 result.attributes.push_back({Identifier::get("names", context), namesAttr}); 787 return success(); 788 } 789 790 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { 791 p << "test.string_attr_pretty_name"; 792 793 // Note that we only need to print the "name" attribute if the asmprinter 794 // result name disagrees with it. This can happen in strange cases, e.g. 795 // when there are conflicts. 796 bool namesDisagree = op.names().size() != op.getNumResults(); 797 798 SmallString<32> resultNameStr; 799 for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { 800 resultNameStr.clear(); 801 llvm::raw_svector_ostream tmpStream(resultNameStr); 802 p.printOperand(op.getResult(i), tmpStream); 803 804 auto expectedName = op.names()[i].dyn_cast<StringAttr>(); 805 if (!expectedName || 806 tmpStream.str().drop_front() != expectedName.getValue()) { 807 namesDisagree = true; 808 } 809 } 810 811 if (namesDisagree) 812 p.printOptionalAttrDictWithKeyword(op->getAttrs()); 813 else 814 p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"}); 815 } 816 817 // We set the SSA name in the asm syntax to the contents of the name 818 // attribute. 819 void StringAttrPrettyNameOp::getAsmResultNames( 820 function_ref<void(Value, StringRef)> setNameFn) { 821 822 auto value = names(); 823 for (size_t i = 0, e = value.size(); i != e; ++i) 824 if (auto str = value[i].dyn_cast<StringAttr>()) 825 if (!str.getValue().empty()) 826 setNameFn(getResult(i), str.getValue()); 827 } 828 829 //===----------------------------------------------------------------------===// 830 // RegionIfOp 831 //===----------------------------------------------------------------------===// 832 833 static void print(OpAsmPrinter &p, RegionIfOp op) { 834 p << RegionIfOp::getOperationName() << " "; 835 p.printOperands(op.getOperands()); 836 p << ": " << op.getOperandTypes(); 837 p.printArrowTypeList(op.getResultTypes()); 838 p << " then"; 839 p.printRegion(op.thenRegion(), 840 /*printEntryBlockArgs=*/true, 841 /*printBlockTerminators=*/true); 842 p << " else"; 843 p.printRegion(op.elseRegion(), 844 /*printEntryBlockArgs=*/true, 845 /*printBlockTerminators=*/true); 846 p << " join"; 847 p.printRegion(op.joinRegion(), 848 /*printEntryBlockArgs=*/true, 849 /*printBlockTerminators=*/true); 850 } 851 852 static ParseResult parseRegionIfOp(OpAsmParser &parser, 853 OperationState &result) { 854 SmallVector<OpAsmParser::OperandType, 2> operandInfos; 855 SmallVector<Type, 2> operandTypes; 856 857 result.regions.reserve(3); 858 Region *thenRegion = result.addRegion(); 859 Region *elseRegion = result.addRegion(); 860 Region *joinRegion = result.addRegion(); 861 862 // Parse operand, type and arrow type lists. 863 if (parser.parseOperandList(operandInfos) || 864 parser.parseColonTypeList(operandTypes) || 865 parser.parseArrowTypeList(result.types)) 866 return failure(); 867 868 // Parse all attached regions. 869 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 870 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 871 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 872 return failure(); 873 874 return parser.resolveOperands(operandInfos, operandTypes, 875 parser.getCurrentLocation(), result.operands); 876 } 877 878 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 879 assert(index < 2 && "invalid region index"); 880 return getOperands(); 881 } 882 883 void RegionIfOp::getSuccessorRegions( 884 Optional<unsigned> index, ArrayRef<Attribute> operands, 885 SmallVectorImpl<RegionSuccessor> ®ions) { 886 // We always branch to the join region. 887 if (index.hasValue()) { 888 if (index.getValue() < 2) 889 regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs())); 890 else 891 regions.push_back(RegionSuccessor(getResults())); 892 return; 893 } 894 895 // The then and else regions are the entry regions of this op. 896 regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs())); 897 regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs())); 898 } 899 900 #include "TestOpEnums.cpp.inc" 901 #include "TestOpInterfaces.cpp.inc" 902 #include "TestOpStructs.cpp.inc" 903 #include "TestTypeInterfaces.cpp.inc" 904 905 #define GET_OP_CLASSES 906 #include "TestOps.cpp.inc" 907