1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestAttributes.h" 11 #include "TestTypes.h" 12 #include "mlir/Dialect/DLTI/DLTI.h" 13 #include "mlir/Dialect/StandardOps/IR/Ops.h" 14 #include "mlir/IR/BuiltinOps.h" 15 #include "mlir/IR/DialectImplementation.h" 16 #include "mlir/IR/PatternMatch.h" 17 #include "mlir/IR/TypeUtilities.h" 18 #include "mlir/Transforms/FoldUtils.h" 19 #include "mlir/Transforms/InliningUtils.h" 20 #include "llvm/ADT/StringSwitch.h" 21 22 using namespace mlir; 23 using namespace mlir::test; 24 25 void mlir::test::registerTestDialect(DialectRegistry ®istry) { 26 registry.insert<TestDialect>(); 27 } 28 29 //===----------------------------------------------------------------------===// 30 // TestDialect Interfaces 31 //===----------------------------------------------------------------------===// 32 33 namespace { 34 35 // Test support for interacting with the AsmPrinter. 36 struct TestOpAsmInterface : public OpAsmDialectInterface { 37 using OpAsmDialectInterface::OpAsmDialectInterface; 38 39 LogicalResult getAlias(Attribute attr, raw_ostream &os) const final { 40 StringAttr strAttr = attr.dyn_cast<StringAttr>(); 41 if (!strAttr) 42 return failure(); 43 44 // Check the contents of the string attribute to see what the test alias 45 // should be named. 46 Optional<StringRef> aliasName = 47 StringSwitch<Optional<StringRef>>(strAttr.getValue()) 48 .Case("alias_test:dot_in_name", StringRef("test.alias")) 49 .Case("alias_test:trailing_digit", StringRef("test_alias0")) 50 .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) 51 .Case("alias_test:sanitize_conflict_a", 52 StringRef("test_alias_conflict0")) 53 .Case("alias_test:sanitize_conflict_b", 54 StringRef("test_alias_conflict0_")) 55 .Default(llvm::None); 56 if (!aliasName) 57 return failure(); 58 59 os << *aliasName; 60 return success(); 61 } 62 63 void getAsmResultNames(Operation *op, 64 OpAsmSetValueNameFn setNameFn) const final { 65 if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) 66 setNameFn(asmOp, "result"); 67 } 68 69 void getAsmBlockArgumentNames(Block *block, 70 OpAsmSetValueNameFn setNameFn) const final { 71 auto op = block->getParentOp(); 72 auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names"); 73 if (!arrayAttr) 74 return; 75 auto args = block->getArguments(); 76 auto e = std::min(arrayAttr.size(), args.size()); 77 for (unsigned i = 0; i < e; ++i) { 78 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 79 setNameFn(args[i], strAttr.getValue()); 80 } 81 } 82 }; 83 84 struct TestDialectFoldInterface : public DialectFoldInterface { 85 using DialectFoldInterface::DialectFoldInterface; 86 87 /// Registered hook to check if the given region, which is attached to an 88 /// operation that is *not* isolated from above, should be used when 89 /// materializing constants. 90 bool shouldMaterializeInto(Region *region) const final { 91 // If this is a one region operation, then insert into it. 92 return isa<OneRegionOp>(region->getParentOp()); 93 } 94 }; 95 96 /// This class defines the interface for handling inlining with standard 97 /// operations. 98 struct TestInlinerInterface : public DialectInlinerInterface { 99 using DialectInlinerInterface::DialectInlinerInterface; 100 101 //===--------------------------------------------------------------------===// 102 // Analysis Hooks 103 //===--------------------------------------------------------------------===// 104 105 bool isLegalToInline(Operation *call, Operation *callable, 106 bool wouldBeCloned) const final { 107 // Don't allow inlining calls that are marked `noinline`. 108 return !call->hasAttr("noinline"); 109 } 110 bool isLegalToInline(Region *, Region *, bool, 111 BlockAndValueMapping &) const final { 112 // Inlining into test dialect regions is legal. 113 return true; 114 } 115 bool isLegalToInline(Operation *, Region *, bool, 116 BlockAndValueMapping &) const final { 117 return true; 118 } 119 120 bool shouldAnalyzeRecursively(Operation *op) const final { 121 // Analyze recursively if this is not a functional region operation, it 122 // froms a separate functional scope. 123 return !isa<FunctionalRegionOp>(op); 124 } 125 126 //===--------------------------------------------------------------------===// 127 // Transformation Hooks 128 //===--------------------------------------------------------------------===// 129 130 /// Handle the given inlined terminator by replacing it with a new operation 131 /// as necessary. 132 void handleTerminator(Operation *op, 133 ArrayRef<Value> valuesToRepl) const final { 134 // Only handle "test.return" here. 135 auto returnOp = dyn_cast<TestReturnOp>(op); 136 if (!returnOp) 137 return; 138 139 // Replace the values directly with the return operands. 140 assert(returnOp.getNumOperands() == valuesToRepl.size()); 141 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 142 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 143 } 144 145 /// Attempt to materialize a conversion for a type mismatch between a call 146 /// from this dialect, and a callable region. This method should generate an 147 /// operation that takes 'input' as the only operand, and produces a single 148 /// result of 'resultType'. If a conversion can not be generated, nullptr 149 /// should be returned. 150 Operation *materializeCallConversion(OpBuilder &builder, Value input, 151 Type resultType, 152 Location conversionLoc) const final { 153 // Only allow conversion for i16/i32 types. 154 if (!(resultType.isSignlessInteger(16) || 155 resultType.isSignlessInteger(32)) || 156 !(input.getType().isSignlessInteger(16) || 157 input.getType().isSignlessInteger(32))) 158 return nullptr; 159 return builder.create<TestCastOp>(conversionLoc, resultType, input); 160 } 161 }; 162 } // end anonymous namespace 163 164 //===----------------------------------------------------------------------===// 165 // TestDialect 166 //===----------------------------------------------------------------------===// 167 168 void TestDialect::initialize() { 169 addOperations< 170 #define GET_OP_LIST 171 #include "TestOps.cpp.inc" 172 >(); 173 addAttributes< 174 #define GET_ATTRDEF_LIST 175 #include "TestAttrDefs.cpp.inc" 176 >(); 177 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 178 TestInlinerInterface>(); 179 addTypes<TestType, TestTypeWithLayout, TestRecursiveType, 180 #define GET_TYPEDEF_LIST 181 #include "TestTypeDefs.cpp.inc" 182 >(); 183 allowUnknownOperations(); 184 } 185 186 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, 187 Type type, Location loc) { 188 return builder.create<TestOpConstant>(loc, type, value); 189 } 190 191 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 192 NamedAttribute namedAttr) { 193 if (namedAttr.first == "test.invalid_attr") 194 return op->emitError() << "invalid to use 'test.invalid_attr'"; 195 return success(); 196 } 197 198 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 199 unsigned regionIndex, 200 unsigned argIndex, 201 NamedAttribute namedAttr) { 202 if (namedAttr.first == "test.invalid_attr") 203 return op->emitError() << "invalid to use 'test.invalid_attr'"; 204 return success(); 205 } 206 207 LogicalResult 208 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 209 unsigned resultIndex, 210 NamedAttribute namedAttr) { 211 if (namedAttr.first == "test.invalid_attr") 212 return op->emitError() << "invalid to use 'test.invalid_attr'"; 213 return success(); 214 } 215 216 //===----------------------------------------------------------------------===// 217 // TestBranchOp 218 //===----------------------------------------------------------------------===// 219 220 Optional<MutableOperandRange> 221 TestBranchOp::getMutableSuccessorOperands(unsigned index) { 222 assert(index == 0 && "invalid successor index"); 223 return targetOperandsMutable(); 224 } 225 226 //===----------------------------------------------------------------------===// 227 // TestFoldToCallOp 228 //===----------------------------------------------------------------------===// 229 230 namespace { 231 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 232 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 233 234 LogicalResult matchAndRewrite(FoldToCallOp op, 235 PatternRewriter &rewriter) const override { 236 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(), 237 ValueRange()); 238 return success(); 239 } 240 }; 241 } // end anonymous namespace 242 243 void FoldToCallOp::getCanonicalizationPatterns( 244 OwningRewritePatternList &results, MLIRContext *context) { 245 results.insert<FoldToCallOpPattern>(context); 246 } 247 248 //===----------------------------------------------------------------------===// 249 // Test Format* operations 250 //===----------------------------------------------------------------------===// 251 252 //===----------------------------------------------------------------------===// 253 // Parsing 254 255 static ParseResult parseCustomDirectiveOperands( 256 OpAsmParser &parser, OpAsmParser::OperandType &operand, 257 Optional<OpAsmParser::OperandType> &optOperand, 258 SmallVectorImpl<OpAsmParser::OperandType> &varOperands) { 259 if (parser.parseOperand(operand)) 260 return failure(); 261 if (succeeded(parser.parseOptionalComma())) { 262 optOperand.emplace(); 263 if (parser.parseOperand(*optOperand)) 264 return failure(); 265 } 266 if (parser.parseArrow() || parser.parseLParen() || 267 parser.parseOperandList(varOperands) || parser.parseRParen()) 268 return failure(); 269 return success(); 270 } 271 static ParseResult 272 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 273 Type &optOperandType, 274 SmallVectorImpl<Type> &varOperandTypes) { 275 if (parser.parseColon()) 276 return failure(); 277 278 if (parser.parseType(operandType)) 279 return failure(); 280 if (succeeded(parser.parseOptionalComma())) { 281 if (parser.parseType(optOperandType)) 282 return failure(); 283 } 284 if (parser.parseArrow() || parser.parseLParen() || 285 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 286 return failure(); 287 return success(); 288 } 289 static ParseResult 290 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 291 Type optOperandType, 292 const SmallVectorImpl<Type> &varOperandTypes) { 293 if (parser.parseKeyword("type_refs_capture")) 294 return failure(); 295 296 Type operandType2, optOperandType2; 297 SmallVector<Type, 1> varOperandTypes2; 298 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 299 varOperandTypes2)) 300 return failure(); 301 302 if (operandType != operandType2 || optOperandType != optOperandType2 || 303 varOperandTypes != varOperandTypes2) 304 return failure(); 305 306 return success(); 307 } 308 static ParseResult parseCustomDirectiveOperandsAndTypes( 309 OpAsmParser &parser, OpAsmParser::OperandType &operand, 310 Optional<OpAsmParser::OperandType> &optOperand, 311 SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType, 312 Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) { 313 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 314 parseCustomDirectiveResults(parser, operandType, optOperandType, 315 varOperandTypes)) 316 return failure(); 317 return success(); 318 } 319 static ParseResult parseCustomDirectiveRegions( 320 OpAsmParser &parser, Region ®ion, 321 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 322 if (parser.parseRegion(region)) 323 return failure(); 324 if (failed(parser.parseOptionalComma())) 325 return success(); 326 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 327 if (parser.parseRegion(*varRegion)) 328 return failure(); 329 varRegions.emplace_back(std::move(varRegion)); 330 return success(); 331 } 332 static ParseResult 333 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 334 SmallVectorImpl<Block *> &varSuccessors) { 335 if (parser.parseSuccessor(successor)) 336 return failure(); 337 if (failed(parser.parseOptionalComma())) 338 return success(); 339 Block *varSuccessor; 340 if (parser.parseSuccessor(varSuccessor)) 341 return failure(); 342 varSuccessors.append(2, varSuccessor); 343 return success(); 344 } 345 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 346 IntegerAttr &attr, 347 IntegerAttr &optAttr) { 348 if (parser.parseAttribute(attr)) 349 return failure(); 350 if (succeeded(parser.parseOptionalComma())) { 351 if (parser.parseAttribute(optAttr)) 352 return failure(); 353 } 354 return success(); 355 } 356 357 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, 358 NamedAttrList &attrs) { 359 return parser.parseOptionalAttrDict(attrs); 360 } 361 static ParseResult parseCustomDirectiveOptionalOperandRef( 362 OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) { 363 int64_t operandCount = 0; 364 if (parser.parseInteger(operandCount)) 365 return failure(); 366 bool expectedOptionalOperand = operandCount == 0; 367 return success(expectedOptionalOperand != optOperand.hasValue()); 368 } 369 370 //===----------------------------------------------------------------------===// 371 // Printing 372 373 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 374 Value operand, Value optOperand, 375 OperandRange varOperands) { 376 printer << operand; 377 if (optOperand) 378 printer << ", " << optOperand; 379 printer << " -> (" << varOperands << ")"; 380 } 381 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 382 Type operandType, Type optOperandType, 383 TypeRange varOperandTypes) { 384 printer << " : " << operandType; 385 if (optOperandType) 386 printer << ", " << optOperandType; 387 printer << " -> (" << varOperandTypes << ")"; 388 } 389 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 390 Operation *op, Type operandType, 391 Type optOperandType, 392 TypeRange varOperandTypes) { 393 printer << " type_refs_capture "; 394 printCustomDirectiveResults(printer, op, operandType, optOperandType, 395 varOperandTypes); 396 } 397 static void printCustomDirectiveOperandsAndTypes( 398 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 399 OperandRange varOperands, Type operandType, Type optOperandType, 400 TypeRange varOperandTypes) { 401 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 402 printCustomDirectiveResults(printer, op, operandType, optOperandType, 403 varOperandTypes); 404 } 405 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 406 Region ®ion, 407 MutableArrayRef<Region> varRegions) { 408 printer.printRegion(region); 409 if (!varRegions.empty()) { 410 printer << ", "; 411 for (Region ®ion : varRegions) 412 printer.printRegion(region); 413 } 414 } 415 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 416 Block *successor, 417 SuccessorRange varSuccessors) { 418 printer << successor; 419 if (!varSuccessors.empty()) 420 printer << ", " << varSuccessors.front(); 421 } 422 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 423 Attribute attribute, 424 Attribute optAttribute) { 425 printer << attribute; 426 if (optAttribute) 427 printer << ", " << optAttribute; 428 } 429 430 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 431 DictionaryAttr attrs) { 432 printer.printOptionalAttrDict(attrs.getValue()); 433 } 434 435 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, 436 Operation *op, 437 Value optOperand) { 438 printer << (optOperand ? "1" : "0"); 439 } 440 441 //===----------------------------------------------------------------------===// 442 // Test IsolatedRegionOp - parse passthrough region arguments. 443 //===----------------------------------------------------------------------===// 444 445 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, 446 OperationState &result) { 447 OpAsmParser::OperandType argInfo; 448 Type argType = parser.getBuilder().getIndexType(); 449 450 // Parse the input operand. 451 if (parser.parseOperand(argInfo) || 452 parser.resolveOperand(argInfo, argType, result.operands)) 453 return failure(); 454 455 // Parse the body region, and reuse the operand info as the argument info. 456 Region *body = result.addRegion(); 457 return parser.parseRegion(*body, argInfo, argType, 458 /*enableNameShadowing=*/true); 459 } 460 461 static void print(OpAsmPrinter &p, IsolatedRegionOp op) { 462 p << "test.isolated_region "; 463 p.printOperand(op.getOperand()); 464 p.shadowRegionArgs(op.region(), op.getOperand()); 465 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 466 } 467 468 //===----------------------------------------------------------------------===// 469 // Test SSACFGRegionOp 470 //===----------------------------------------------------------------------===// 471 472 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 473 return RegionKind::SSACFG; 474 } 475 476 //===----------------------------------------------------------------------===// 477 // Test GraphRegionOp 478 //===----------------------------------------------------------------------===// 479 480 static ParseResult parseGraphRegionOp(OpAsmParser &parser, 481 OperationState &result) { 482 // Parse the body region, and reuse the operand info as the argument info. 483 Region *body = result.addRegion(); 484 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 485 } 486 487 static void print(OpAsmPrinter &p, GraphRegionOp op) { 488 p << "test.graph_region "; 489 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 490 } 491 492 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 493 return RegionKind::Graph; 494 } 495 496 //===----------------------------------------------------------------------===// 497 // Test AffineScopeOp 498 //===----------------------------------------------------------------------===// 499 500 static ParseResult parseAffineScopeOp(OpAsmParser &parser, 501 OperationState &result) { 502 // Parse the body region, and reuse the operand info as the argument info. 503 Region *body = result.addRegion(); 504 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 505 } 506 507 static void print(OpAsmPrinter &p, AffineScopeOp op) { 508 p << "test.affine_scope "; 509 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 510 } 511 512 //===----------------------------------------------------------------------===// 513 // Test parser. 514 //===----------------------------------------------------------------------===// 515 516 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser, 517 OperationState &result) { 518 if (parser.parseOptionalColon()) 519 return success(); 520 uint64_t numResults; 521 if (parser.parseInteger(numResults)) 522 return failure(); 523 524 IndexType type = parser.getBuilder().getIndexType(); 525 for (unsigned i = 0; i < numResults; ++i) 526 result.addTypes(type); 527 return success(); 528 } 529 530 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) { 531 p << ParseIntegerLiteralOp::getOperationName(); 532 if (unsigned numResults = op->getNumResults()) 533 p << " : " << numResults; 534 } 535 536 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser, 537 OperationState &result) { 538 StringRef keyword; 539 if (parser.parseKeyword(&keyword)) 540 return failure(); 541 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 542 return success(); 543 } 544 545 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) { 546 p << ParseWrappedKeywordOp::getOperationName() << " " << op.keyword(); 547 } 548 549 //===----------------------------------------------------------------------===// 550 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 551 552 static ParseResult parseWrappingRegionOp(OpAsmParser &parser, 553 OperationState &result) { 554 if (parser.parseKeyword("wraps")) 555 return failure(); 556 557 // Parse the wrapped op in a region 558 Region &body = *result.addRegion(); 559 body.push_back(new Block); 560 Block &block = body.back(); 561 Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); 562 if (!wrapped_op) 563 return failure(); 564 565 // Create a return terminator in the inner region, pass as operand to the 566 // terminator the returned values from the wrapped operation. 567 SmallVector<Value, 8> return_operands(wrapped_op->getResults()); 568 OpBuilder builder(parser.getBuilder().getContext()); 569 builder.setInsertionPointToEnd(&block); 570 builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands); 571 572 // Get the results type for the wrapping op from the terminator operands. 573 Operation &return_op = body.back().back(); 574 result.types.append(return_op.operand_type_begin(), 575 return_op.operand_type_end()); 576 577 // Use the location of the wrapped op for the "test.wrapping_region" op. 578 result.location = wrapped_op->getLoc(); 579 580 return success(); 581 } 582 583 static void print(OpAsmPrinter &p, WrappingRegionOp op) { 584 p << op.getOperationName() << " wraps "; 585 p.printGenericOp(&op.region().front().front()); 586 } 587 588 //===----------------------------------------------------------------------===// 589 // Test PolyForOp - parse list of region arguments. 590 //===----------------------------------------------------------------------===// 591 592 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { 593 SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 594 // Parse list of region arguments without a delimiter. 595 if (parser.parseRegionArgumentList(ivsInfo)) 596 return failure(); 597 598 // Parse the body region. 599 Region *body = result.addRegion(); 600 auto &builder = parser.getBuilder(); 601 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 602 return parser.parseRegion(*body, ivsInfo, argTypes); 603 } 604 605 //===----------------------------------------------------------------------===// 606 // Test removing op with inner ops. 607 //===----------------------------------------------------------------------===// 608 609 namespace { 610 struct TestRemoveOpWithInnerOps 611 : public OpRewritePattern<TestOpWithRegionPattern> { 612 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 613 614 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 615 PatternRewriter &rewriter) const override { 616 rewriter.eraseOp(op); 617 return success(); 618 } 619 }; 620 } // end anonymous namespace 621 622 void TestOpWithRegionPattern::getCanonicalizationPatterns( 623 OwningRewritePatternList &results, MLIRContext *context) { 624 results.insert<TestRemoveOpWithInnerOps>(context); 625 } 626 627 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 628 return operand(); 629 } 630 631 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) { 632 return getValue(); 633 } 634 635 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 636 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 637 for (Value input : this->operands()) { 638 results.push_back(input); 639 } 640 return success(); 641 } 642 643 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 644 assert(operands.size() == 1); 645 if (operands.front()) { 646 (*this)->setAttr("attr", operands.front()); 647 return getResult(); 648 } 649 return {}; 650 } 651 652 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) { 653 return getOperand(); 654 } 655 656 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 657 MLIRContext *, Optional<Location> location, ValueRange operands, 658 DictionaryAttr attributes, RegionRange regions, 659 SmallVectorImpl<Type> &inferredReturnTypes) { 660 if (operands[0].getType() != operands[1].getType()) { 661 return emitOptionalError(location, "operand type mismatch ", 662 operands[0].getType(), " vs ", 663 operands[1].getType()); 664 } 665 inferredReturnTypes.assign({operands[0].getType()}); 666 return success(); 667 } 668 669 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 670 MLIRContext *context, Optional<Location> location, ValueRange operands, 671 DictionaryAttr attributes, RegionRange regions, 672 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 673 // Create return type consisting of the last element of the first operand. 674 auto operandType = *operands.getTypes().begin(); 675 auto sval = operandType.dyn_cast<ShapedType>(); 676 if (!sval) { 677 return emitOptionalError(location, "only shaped type operands allowed"); 678 } 679 int64_t dim = 680 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 681 auto type = IntegerType::get(context, 17); 682 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 683 return success(); 684 } 685 686 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 687 OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) { 688 shapes = SmallVector<Value, 1>{ 689 builder.createOrFold<DimOp>(getLoc(), getOperand(0), 0)}; 690 return success(); 691 } 692 693 //===----------------------------------------------------------------------===// 694 // Test SideEffect interfaces 695 //===----------------------------------------------------------------------===// 696 697 namespace { 698 /// A test resource for side effects. 699 struct TestResource : public SideEffects::Resource::Base<TestResource> { 700 StringRef getName() final { return "<Test>"; } 701 }; 702 } // end anonymous namespace 703 704 void SideEffectOp::getEffects( 705 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 706 // Check for an effects attribute on the op instance. 707 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects"); 708 if (!effectsAttr) 709 return; 710 711 // If there is one, it is an array of dictionary attributes that hold 712 // information on the effects of this operation. 713 for (Attribute element : effectsAttr) { 714 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 715 716 // Get the specific memory effect. 717 MemoryEffects::Effect *effect = 718 StringSwitch<MemoryEffects::Effect *>( 719 effectElement.get("effect").cast<StringAttr>().getValue()) 720 .Case("allocate", MemoryEffects::Allocate::get()) 721 .Case("free", MemoryEffects::Free::get()) 722 .Case("read", MemoryEffects::Read::get()) 723 .Case("write", MemoryEffects::Write::get()); 724 725 // Check for a non-default resource to use. 726 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 727 if (effectElement.get("test_resource")) 728 resource = TestResource::get(); 729 730 // Check for a result to affect. 731 if (effectElement.get("on_result")) 732 effects.emplace_back(effect, getResult(), resource); 733 else if (Attribute ref = effectElement.get("on_reference")) 734 effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource); 735 else 736 effects.emplace_back(effect, resource); 737 } 738 } 739 740 void SideEffectOp::getEffects( 741 SmallVectorImpl<TestEffects::EffectInstance> &effects) { 742 auto effectsAttr = (*this)->getAttrOfType<AffineMapAttr>("effect_parameter"); 743 if (!effectsAttr) 744 return; 745 746 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); 747 } 748 749 //===----------------------------------------------------------------------===// 750 // StringAttrPrettyNameOp 751 //===----------------------------------------------------------------------===// 752 753 // This op has fancy handling of its SSA result name. 754 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, 755 OperationState &result) { 756 // Add the result types. 757 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 758 result.addTypes(parser.getBuilder().getIntegerType(32)); 759 760 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 761 return failure(); 762 763 // If the attribute dictionary contains no 'names' attribute, infer it from 764 // the SSA name (if specified). 765 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 766 return attr.first == "names"; 767 }); 768 769 // If there was no name specified, check to see if there was a useful name 770 // specified in the asm file. 771 if (hadNames || parser.getNumResults() == 0) 772 return success(); 773 774 SmallVector<StringRef, 4> names; 775 auto *context = result.getContext(); 776 777 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 778 auto resultName = parser.getResultName(i); 779 StringRef nameStr; 780 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 781 nameStr = resultName.first; 782 783 names.push_back(nameStr); 784 } 785 786 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 787 result.attributes.push_back({Identifier::get("names", context), namesAttr}); 788 return success(); 789 } 790 791 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { 792 p << "test.string_attr_pretty_name"; 793 794 // Note that we only need to print the "name" attribute if the asmprinter 795 // result name disagrees with it. This can happen in strange cases, e.g. 796 // when there are conflicts. 797 bool namesDisagree = op.names().size() != op.getNumResults(); 798 799 SmallString<32> resultNameStr; 800 for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { 801 resultNameStr.clear(); 802 llvm::raw_svector_ostream tmpStream(resultNameStr); 803 p.printOperand(op.getResult(i), tmpStream); 804 805 auto expectedName = op.names()[i].dyn_cast<StringAttr>(); 806 if (!expectedName || 807 tmpStream.str().drop_front() != expectedName.getValue()) { 808 namesDisagree = true; 809 } 810 } 811 812 if (namesDisagree) 813 p.printOptionalAttrDictWithKeyword(op->getAttrs()); 814 else 815 p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"}); 816 } 817 818 // We set the SSA name in the asm syntax to the contents of the name 819 // attribute. 820 void StringAttrPrettyNameOp::getAsmResultNames( 821 function_ref<void(Value, StringRef)> setNameFn) { 822 823 auto value = names(); 824 for (size_t i = 0, e = value.size(); i != e; ++i) 825 if (auto str = value[i].dyn_cast<StringAttr>()) 826 if (!str.getValue().empty()) 827 setNameFn(getResult(i), str.getValue()); 828 } 829 830 //===----------------------------------------------------------------------===// 831 // RegionIfOp 832 //===----------------------------------------------------------------------===// 833 834 static void print(OpAsmPrinter &p, RegionIfOp op) { 835 p << RegionIfOp::getOperationName() << " "; 836 p.printOperands(op.getOperands()); 837 p << ": " << op.getOperandTypes(); 838 p.printArrowTypeList(op.getResultTypes()); 839 p << " then"; 840 p.printRegion(op.thenRegion(), 841 /*printEntryBlockArgs=*/true, 842 /*printBlockTerminators=*/true); 843 p << " else"; 844 p.printRegion(op.elseRegion(), 845 /*printEntryBlockArgs=*/true, 846 /*printBlockTerminators=*/true); 847 p << " join"; 848 p.printRegion(op.joinRegion(), 849 /*printEntryBlockArgs=*/true, 850 /*printBlockTerminators=*/true); 851 } 852 853 static ParseResult parseRegionIfOp(OpAsmParser &parser, 854 OperationState &result) { 855 SmallVector<OpAsmParser::OperandType, 2> operandInfos; 856 SmallVector<Type, 2> operandTypes; 857 858 result.regions.reserve(3); 859 Region *thenRegion = result.addRegion(); 860 Region *elseRegion = result.addRegion(); 861 Region *joinRegion = result.addRegion(); 862 863 // Parse operand, type and arrow type lists. 864 if (parser.parseOperandList(operandInfos) || 865 parser.parseColonTypeList(operandTypes) || 866 parser.parseArrowTypeList(result.types)) 867 return failure(); 868 869 // Parse all attached regions. 870 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 871 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 872 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 873 return failure(); 874 875 return parser.resolveOperands(operandInfos, operandTypes, 876 parser.getCurrentLocation(), result.operands); 877 } 878 879 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 880 assert(index < 2 && "invalid region index"); 881 return getOperands(); 882 } 883 884 void RegionIfOp::getSuccessorRegions( 885 Optional<unsigned> index, ArrayRef<Attribute> operands, 886 SmallVectorImpl<RegionSuccessor> ®ions) { 887 // We always branch to the join region. 888 if (index.hasValue()) { 889 if (index.getValue() < 2) 890 regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs())); 891 else 892 regions.push_back(RegionSuccessor(getResults())); 893 return; 894 } 895 896 // The then and else regions are the entry regions of this op. 897 regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs())); 898 regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs())); 899 } 900 901 #include "TestOpEnums.cpp.inc" 902 #include "TestOpInterfaces.cpp.inc" 903 #include "TestOpStructs.cpp.inc" 904 #include "TestTypeInterfaces.cpp.inc" 905 906 #define GET_OP_CLASSES 907 #include "TestOps.cpp.inc" 908