1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestTypes.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/IR/BuiltinOps.h" 13 #include "mlir/IR/DialectImplementation.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/IR/TypeUtilities.h" 16 #include "mlir/Transforms/FoldUtils.h" 17 #include "mlir/Transforms/InliningUtils.h" 18 #include "llvm/ADT/StringSwitch.h" 19 20 using namespace mlir; 21 using namespace mlir::test; 22 23 void mlir::test::registerTestDialect(DialectRegistry ®istry) { 24 registry.insert<TestDialect>(); 25 } 26 27 //===----------------------------------------------------------------------===// 28 // TestDialect Interfaces 29 //===----------------------------------------------------------------------===// 30 31 namespace { 32 33 // Test support for interacting with the AsmPrinter. 34 struct TestOpAsmInterface : public OpAsmDialectInterface { 35 using OpAsmDialectInterface::OpAsmDialectInterface; 36 37 LogicalResult getAlias(Attribute attr, raw_ostream &os) const final { 38 StringAttr strAttr = attr.dyn_cast<StringAttr>(); 39 if (!strAttr) 40 return failure(); 41 42 // Check the contents of the string attribute to see what the test alias 43 // should be named. 44 Optional<StringRef> aliasName = 45 StringSwitch<Optional<StringRef>>(strAttr.getValue()) 46 .Case("alias_test:dot_in_name", StringRef("test.alias")) 47 .Case("alias_test:trailing_digit", StringRef("test_alias0")) 48 .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) 49 .Case("alias_test:sanitize_conflict_a", 50 StringRef("test_alias_conflict0")) 51 .Case("alias_test:sanitize_conflict_b", 52 StringRef("test_alias_conflict0_")) 53 .Default(llvm::None); 54 if (!aliasName) 55 return failure(); 56 57 os << *aliasName; 58 return success(); 59 } 60 61 void getAsmResultNames(Operation *op, 62 OpAsmSetValueNameFn setNameFn) const final { 63 if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) 64 setNameFn(asmOp, "result"); 65 } 66 67 void getAsmBlockArgumentNames(Block *block, 68 OpAsmSetValueNameFn setNameFn) const final { 69 auto op = block->getParentOp(); 70 auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names"); 71 if (!arrayAttr) 72 return; 73 auto args = block->getArguments(); 74 auto e = std::min(arrayAttr.size(), args.size()); 75 for (unsigned i = 0; i < e; ++i) { 76 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 77 setNameFn(args[i], strAttr.getValue()); 78 } 79 } 80 }; 81 82 struct TestDialectFoldInterface : public DialectFoldInterface { 83 using DialectFoldInterface::DialectFoldInterface; 84 85 /// Registered hook to check if the given region, which is attached to an 86 /// operation that is *not* isolated from above, should be used when 87 /// materializing constants. 88 bool shouldMaterializeInto(Region *region) const final { 89 // If this is a one region operation, then insert into it. 90 return isa<OneRegionOp>(region->getParentOp()); 91 } 92 }; 93 94 /// This class defines the interface for handling inlining with standard 95 /// operations. 96 struct TestInlinerInterface : public DialectInlinerInterface { 97 using DialectInlinerInterface::DialectInlinerInterface; 98 99 //===--------------------------------------------------------------------===// 100 // Analysis Hooks 101 //===--------------------------------------------------------------------===// 102 103 bool isLegalToInline(Operation *call, Operation *callable, 104 bool wouldBeCloned) const final { 105 // Don't allow inlining calls that are marked `noinline`. 106 return !call->hasAttr("noinline"); 107 } 108 bool isLegalToInline(Region *, Region *, bool, 109 BlockAndValueMapping &) const final { 110 // Inlining into test dialect regions is legal. 111 return true; 112 } 113 bool isLegalToInline(Operation *, Region *, bool, 114 BlockAndValueMapping &) const final { 115 return true; 116 } 117 118 bool shouldAnalyzeRecursively(Operation *op) const final { 119 // Analyze recursively if this is not a functional region operation, it 120 // froms a separate functional scope. 121 return !isa<FunctionalRegionOp>(op); 122 } 123 124 //===--------------------------------------------------------------------===// 125 // Transformation Hooks 126 //===--------------------------------------------------------------------===// 127 128 /// Handle the given inlined terminator by replacing it with a new operation 129 /// as necessary. 130 void handleTerminator(Operation *op, 131 ArrayRef<Value> valuesToRepl) const final { 132 // Only handle "test.return" here. 133 auto returnOp = dyn_cast<TestReturnOp>(op); 134 if (!returnOp) 135 return; 136 137 // Replace the values directly with the return operands. 138 assert(returnOp.getNumOperands() == valuesToRepl.size()); 139 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 140 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 141 } 142 143 /// Attempt to materialize a conversion for a type mismatch between a call 144 /// from this dialect, and a callable region. This method should generate an 145 /// operation that takes 'input' as the only operand, and produces a single 146 /// result of 'resultType'. If a conversion can not be generated, nullptr 147 /// should be returned. 148 Operation *materializeCallConversion(OpBuilder &builder, Value input, 149 Type resultType, 150 Location conversionLoc) const final { 151 // Only allow conversion for i16/i32 types. 152 if (!(resultType.isSignlessInteger(16) || 153 resultType.isSignlessInteger(32)) || 154 !(input.getType().isSignlessInteger(16) || 155 input.getType().isSignlessInteger(32))) 156 return nullptr; 157 return builder.create<TestCastOp>(conversionLoc, resultType, input); 158 } 159 }; 160 } // end anonymous namespace 161 162 //===----------------------------------------------------------------------===// 163 // TestDialect 164 //===----------------------------------------------------------------------===// 165 166 void TestDialect::initialize() { 167 addOperations< 168 #define GET_OP_LIST 169 #include "TestOps.cpp.inc" 170 >(); 171 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 172 TestInlinerInterface>(); 173 addTypes<TestType, TestRecursiveType, 174 #define GET_TYPEDEF_LIST 175 #include "TestTypeDefs.cpp.inc" 176 >(); 177 allowUnknownOperations(); 178 } 179 180 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, 181 Type type, Location loc) { 182 return builder.create<TestOpConstant>(loc, type, value); 183 } 184 185 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 186 NamedAttribute namedAttr) { 187 if (namedAttr.first == "test.invalid_attr") 188 return op->emitError() << "invalid to use 'test.invalid_attr'"; 189 return success(); 190 } 191 192 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 193 unsigned regionIndex, 194 unsigned argIndex, 195 NamedAttribute namedAttr) { 196 if (namedAttr.first == "test.invalid_attr") 197 return op->emitError() << "invalid to use 'test.invalid_attr'"; 198 return success(); 199 } 200 201 LogicalResult 202 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 203 unsigned resultIndex, 204 NamedAttribute namedAttr) { 205 if (namedAttr.first == "test.invalid_attr") 206 return op->emitError() << "invalid to use 'test.invalid_attr'"; 207 return success(); 208 } 209 210 //===----------------------------------------------------------------------===// 211 // TestBranchOp 212 //===----------------------------------------------------------------------===// 213 214 Optional<MutableOperandRange> 215 TestBranchOp::getMutableSuccessorOperands(unsigned index) { 216 assert(index == 0 && "invalid successor index"); 217 return targetOperandsMutable(); 218 } 219 220 //===----------------------------------------------------------------------===// 221 // TestFoldToCallOp 222 //===----------------------------------------------------------------------===// 223 224 namespace { 225 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 226 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 227 228 LogicalResult matchAndRewrite(FoldToCallOp op, 229 PatternRewriter &rewriter) const override { 230 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(), 231 ValueRange()); 232 return success(); 233 } 234 }; 235 } // end anonymous namespace 236 237 void FoldToCallOp::getCanonicalizationPatterns( 238 OwningRewritePatternList &results, MLIRContext *context) { 239 results.insert<FoldToCallOpPattern>(context); 240 } 241 242 //===----------------------------------------------------------------------===// 243 // Test Format* operations 244 //===----------------------------------------------------------------------===// 245 246 //===----------------------------------------------------------------------===// 247 // Parsing 248 249 static ParseResult parseCustomDirectiveOperands( 250 OpAsmParser &parser, OpAsmParser::OperandType &operand, 251 Optional<OpAsmParser::OperandType> &optOperand, 252 SmallVectorImpl<OpAsmParser::OperandType> &varOperands) { 253 if (parser.parseOperand(operand)) 254 return failure(); 255 if (succeeded(parser.parseOptionalComma())) { 256 optOperand.emplace(); 257 if (parser.parseOperand(*optOperand)) 258 return failure(); 259 } 260 if (parser.parseArrow() || parser.parseLParen() || 261 parser.parseOperandList(varOperands) || parser.parseRParen()) 262 return failure(); 263 return success(); 264 } 265 static ParseResult 266 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 267 Type &optOperandType, 268 SmallVectorImpl<Type> &varOperandTypes) { 269 if (parser.parseColon()) 270 return failure(); 271 272 if (parser.parseType(operandType)) 273 return failure(); 274 if (succeeded(parser.parseOptionalComma())) { 275 if (parser.parseType(optOperandType)) 276 return failure(); 277 } 278 if (parser.parseArrow() || parser.parseLParen() || 279 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 280 return failure(); 281 return success(); 282 } 283 static ParseResult 284 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 285 Type optOperandType, 286 const SmallVectorImpl<Type> &varOperandTypes) { 287 if (parser.parseKeyword("type_refs_capture")) 288 return failure(); 289 290 Type operandType2, optOperandType2; 291 SmallVector<Type, 1> varOperandTypes2; 292 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 293 varOperandTypes2)) 294 return failure(); 295 296 if (operandType != operandType2 || optOperandType != optOperandType2 || 297 varOperandTypes != varOperandTypes2) 298 return failure(); 299 300 return success(); 301 } 302 static ParseResult parseCustomDirectiveOperandsAndTypes( 303 OpAsmParser &parser, OpAsmParser::OperandType &operand, 304 Optional<OpAsmParser::OperandType> &optOperand, 305 SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType, 306 Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) { 307 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 308 parseCustomDirectiveResults(parser, operandType, optOperandType, 309 varOperandTypes)) 310 return failure(); 311 return success(); 312 } 313 static ParseResult parseCustomDirectiveRegions( 314 OpAsmParser &parser, Region ®ion, 315 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 316 if (parser.parseRegion(region)) 317 return failure(); 318 if (failed(parser.parseOptionalComma())) 319 return success(); 320 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 321 if (parser.parseRegion(*varRegion)) 322 return failure(); 323 varRegions.emplace_back(std::move(varRegion)); 324 return success(); 325 } 326 static ParseResult 327 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 328 SmallVectorImpl<Block *> &varSuccessors) { 329 if (parser.parseSuccessor(successor)) 330 return failure(); 331 if (failed(parser.parseOptionalComma())) 332 return success(); 333 Block *varSuccessor; 334 if (parser.parseSuccessor(varSuccessor)) 335 return failure(); 336 varSuccessors.append(2, varSuccessor); 337 return success(); 338 } 339 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 340 IntegerAttr &attr, 341 IntegerAttr &optAttr) { 342 if (parser.parseAttribute(attr)) 343 return failure(); 344 if (succeeded(parser.parseOptionalComma())) { 345 if (parser.parseAttribute(optAttr)) 346 return failure(); 347 } 348 return success(); 349 } 350 351 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, 352 NamedAttrList &attrs) { 353 return parser.parseOptionalAttrDict(attrs); 354 } 355 static ParseResult parseCustomDirectiveOptionalOperandRef( 356 OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) { 357 int64_t operandCount = 0; 358 if (parser.parseInteger(operandCount)) 359 return failure(); 360 bool expectedOptionalOperand = operandCount == 0; 361 return success(expectedOptionalOperand != optOperand.hasValue()); 362 } 363 364 //===----------------------------------------------------------------------===// 365 // Printing 366 367 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 368 Value operand, Value optOperand, 369 OperandRange varOperands) { 370 printer << operand; 371 if (optOperand) 372 printer << ", " << optOperand; 373 printer << " -> (" << varOperands << ")"; 374 } 375 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 376 Type operandType, Type optOperandType, 377 TypeRange varOperandTypes) { 378 printer << " : " << operandType; 379 if (optOperandType) 380 printer << ", " << optOperandType; 381 printer << " -> (" << varOperandTypes << ")"; 382 } 383 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 384 Operation *op, Type operandType, 385 Type optOperandType, 386 TypeRange varOperandTypes) { 387 printer << " type_refs_capture "; 388 printCustomDirectiveResults(printer, op, operandType, optOperandType, 389 varOperandTypes); 390 } 391 static void printCustomDirectiveOperandsAndTypes( 392 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 393 OperandRange varOperands, Type operandType, Type optOperandType, 394 TypeRange varOperandTypes) { 395 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 396 printCustomDirectiveResults(printer, op, operandType, optOperandType, 397 varOperandTypes); 398 } 399 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 400 Region ®ion, 401 MutableArrayRef<Region> varRegions) { 402 printer.printRegion(region); 403 if (!varRegions.empty()) { 404 printer << ", "; 405 for (Region ®ion : varRegions) 406 printer.printRegion(region); 407 } 408 } 409 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 410 Block *successor, 411 SuccessorRange varSuccessors) { 412 printer << successor; 413 if (!varSuccessors.empty()) 414 printer << ", " << varSuccessors.front(); 415 } 416 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 417 Attribute attribute, 418 Attribute optAttribute) { 419 printer << attribute; 420 if (optAttribute) 421 printer << ", " << optAttribute; 422 } 423 424 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 425 DictionaryAttr attrs) { 426 printer.printOptionalAttrDict(attrs.getValue()); 427 } 428 429 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, 430 Operation *op, 431 Value optOperand) { 432 printer << (optOperand ? "1" : "0"); 433 } 434 435 //===----------------------------------------------------------------------===// 436 // Test IsolatedRegionOp - parse passthrough region arguments. 437 //===----------------------------------------------------------------------===// 438 439 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, 440 OperationState &result) { 441 OpAsmParser::OperandType argInfo; 442 Type argType = parser.getBuilder().getIndexType(); 443 444 // Parse the input operand. 445 if (parser.parseOperand(argInfo) || 446 parser.resolveOperand(argInfo, argType, result.operands)) 447 return failure(); 448 449 // Parse the body region, and reuse the operand info as the argument info. 450 Region *body = result.addRegion(); 451 return parser.parseRegion(*body, argInfo, argType, 452 /*enableNameShadowing=*/true); 453 } 454 455 static void print(OpAsmPrinter &p, IsolatedRegionOp op) { 456 p << "test.isolated_region "; 457 p.printOperand(op.getOperand()); 458 p.shadowRegionArgs(op.region(), op.getOperand()); 459 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 460 } 461 462 //===----------------------------------------------------------------------===// 463 // Test SSACFGRegionOp 464 //===----------------------------------------------------------------------===// 465 466 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 467 return RegionKind::SSACFG; 468 } 469 470 //===----------------------------------------------------------------------===// 471 // Test GraphRegionOp 472 //===----------------------------------------------------------------------===// 473 474 static ParseResult parseGraphRegionOp(OpAsmParser &parser, 475 OperationState &result) { 476 // Parse the body region, and reuse the operand info as the argument info. 477 Region *body = result.addRegion(); 478 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 479 } 480 481 static void print(OpAsmPrinter &p, GraphRegionOp op) { 482 p << "test.graph_region "; 483 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 484 } 485 486 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 487 return RegionKind::Graph; 488 } 489 490 //===----------------------------------------------------------------------===// 491 // Test AffineScopeOp 492 //===----------------------------------------------------------------------===// 493 494 static ParseResult parseAffineScopeOp(OpAsmParser &parser, 495 OperationState &result) { 496 // Parse the body region, and reuse the operand info as the argument info. 497 Region *body = result.addRegion(); 498 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 499 } 500 501 static void print(OpAsmPrinter &p, AffineScopeOp op) { 502 p << "test.affine_scope "; 503 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 504 } 505 506 //===----------------------------------------------------------------------===// 507 // Test parser. 508 //===----------------------------------------------------------------------===// 509 510 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser, 511 OperationState &result) { 512 if (parser.parseOptionalColon()) 513 return success(); 514 uint64_t numResults; 515 if (parser.parseInteger(numResults)) 516 return failure(); 517 518 IndexType type = parser.getBuilder().getIndexType(); 519 for (unsigned i = 0; i < numResults; ++i) 520 result.addTypes(type); 521 return success(); 522 } 523 524 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) { 525 p << ParseIntegerLiteralOp::getOperationName(); 526 if (unsigned numResults = op->getNumResults()) 527 p << " : " << numResults; 528 } 529 530 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser, 531 OperationState &result) { 532 StringRef keyword; 533 if (parser.parseKeyword(&keyword)) 534 return failure(); 535 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 536 return success(); 537 } 538 539 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) { 540 p << ParseWrappedKeywordOp::getOperationName() << " " << op.keyword(); 541 } 542 543 //===----------------------------------------------------------------------===// 544 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 545 546 static ParseResult parseWrappingRegionOp(OpAsmParser &parser, 547 OperationState &result) { 548 if (parser.parseKeyword("wraps")) 549 return failure(); 550 551 // Parse the wrapped op in a region 552 Region &body = *result.addRegion(); 553 body.push_back(new Block); 554 Block &block = body.back(); 555 Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); 556 if (!wrapped_op) 557 return failure(); 558 559 // Create a return terminator in the inner region, pass as operand to the 560 // terminator the returned values from the wrapped operation. 561 SmallVector<Value, 8> return_operands(wrapped_op->getResults()); 562 OpBuilder builder(parser.getBuilder().getContext()); 563 builder.setInsertionPointToEnd(&block); 564 builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands); 565 566 // Get the results type for the wrapping op from the terminator operands. 567 Operation &return_op = body.back().back(); 568 result.types.append(return_op.operand_type_begin(), 569 return_op.operand_type_end()); 570 571 // Use the location of the wrapped op for the "test.wrapping_region" op. 572 result.location = wrapped_op->getLoc(); 573 574 return success(); 575 } 576 577 static void print(OpAsmPrinter &p, WrappingRegionOp op) { 578 p << op.getOperationName() << " wraps "; 579 p.printGenericOp(&op.region().front().front()); 580 } 581 582 //===----------------------------------------------------------------------===// 583 // Test PolyForOp - parse list of region arguments. 584 //===----------------------------------------------------------------------===// 585 586 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { 587 SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 588 // Parse list of region arguments without a delimiter. 589 if (parser.parseRegionArgumentList(ivsInfo)) 590 return failure(); 591 592 // Parse the body region. 593 Region *body = result.addRegion(); 594 auto &builder = parser.getBuilder(); 595 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 596 return parser.parseRegion(*body, ivsInfo, argTypes); 597 } 598 599 //===----------------------------------------------------------------------===// 600 // Test removing op with inner ops. 601 //===----------------------------------------------------------------------===// 602 603 namespace { 604 struct TestRemoveOpWithInnerOps 605 : public OpRewritePattern<TestOpWithRegionPattern> { 606 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 607 608 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 609 PatternRewriter &rewriter) const override { 610 rewriter.eraseOp(op); 611 return success(); 612 } 613 }; 614 } // end anonymous namespace 615 616 void TestOpWithRegionPattern::getCanonicalizationPatterns( 617 OwningRewritePatternList &results, MLIRContext *context) { 618 results.insert<TestRemoveOpWithInnerOps>(context); 619 } 620 621 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 622 return operand(); 623 } 624 625 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) { 626 return getValue(); 627 } 628 629 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 630 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 631 for (Value input : this->operands()) { 632 results.push_back(input); 633 } 634 return success(); 635 } 636 637 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 638 assert(operands.size() == 1); 639 if (operands.front()) { 640 (*this)->setAttr("attr", operands.front()); 641 return getResult(); 642 } 643 return {}; 644 } 645 646 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) { 647 return getOperand(); 648 } 649 650 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 651 MLIRContext *, Optional<Location> location, ValueRange operands, 652 DictionaryAttr attributes, RegionRange regions, 653 SmallVectorImpl<Type> &inferredReturnTypes) { 654 if (operands[0].getType() != operands[1].getType()) { 655 return emitOptionalError(location, "operand type mismatch ", 656 operands[0].getType(), " vs ", 657 operands[1].getType()); 658 } 659 inferredReturnTypes.assign({operands[0].getType()}); 660 return success(); 661 } 662 663 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 664 MLIRContext *context, Optional<Location> location, ValueRange operands, 665 DictionaryAttr attributes, RegionRange regions, 666 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 667 // Create return type consisting of the last element of the first operand. 668 auto operandType = *operands.getTypes().begin(); 669 auto sval = operandType.dyn_cast<ShapedType>(); 670 if (!sval) { 671 return emitOptionalError(location, "only shaped type operands allowed"); 672 } 673 int64_t dim = 674 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 675 auto type = IntegerType::get(context, 17); 676 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 677 return success(); 678 } 679 680 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 681 OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) { 682 shapes = SmallVector<Value, 1>{ 683 builder.createOrFold<DimOp>(getLoc(), getOperand(0), 0)}; 684 return success(); 685 } 686 687 //===----------------------------------------------------------------------===// 688 // Test SideEffect interfaces 689 //===----------------------------------------------------------------------===// 690 691 namespace { 692 /// A test resource for side effects. 693 struct TestResource : public SideEffects::Resource::Base<TestResource> { 694 StringRef getName() final { return "<Test>"; } 695 }; 696 } // end anonymous namespace 697 698 void SideEffectOp::getEffects( 699 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 700 // Check for an effects attribute on the op instance. 701 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects"); 702 if (!effectsAttr) 703 return; 704 705 // If there is one, it is an array of dictionary attributes that hold 706 // information on the effects of this operation. 707 for (Attribute element : effectsAttr) { 708 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 709 710 // Get the specific memory effect. 711 MemoryEffects::Effect *effect = 712 StringSwitch<MemoryEffects::Effect *>( 713 effectElement.get("effect").cast<StringAttr>().getValue()) 714 .Case("allocate", MemoryEffects::Allocate::get()) 715 .Case("free", MemoryEffects::Free::get()) 716 .Case("read", MemoryEffects::Read::get()) 717 .Case("write", MemoryEffects::Write::get()); 718 719 // Check for a non-default resource to use. 720 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 721 if (effectElement.get("test_resource")) 722 resource = TestResource::get(); 723 724 // Check for a result to affect. 725 if (effectElement.get("on_result")) 726 effects.emplace_back(effect, getResult(), resource); 727 else if (Attribute ref = effectElement.get("on_reference")) 728 effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource); 729 else 730 effects.emplace_back(effect, resource); 731 } 732 } 733 734 void SideEffectOp::getEffects( 735 SmallVectorImpl<TestEffects::EffectInstance> &effects) { 736 auto effectsAttr = (*this)->getAttrOfType<AffineMapAttr>("effect_parameter"); 737 if (!effectsAttr) 738 return; 739 740 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); 741 } 742 743 //===----------------------------------------------------------------------===// 744 // StringAttrPrettyNameOp 745 //===----------------------------------------------------------------------===// 746 747 // This op has fancy handling of its SSA result name. 748 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, 749 OperationState &result) { 750 // Add the result types. 751 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 752 result.addTypes(parser.getBuilder().getIntegerType(32)); 753 754 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 755 return failure(); 756 757 // If the attribute dictionary contains no 'names' attribute, infer it from 758 // the SSA name (if specified). 759 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 760 return attr.first == "names"; 761 }); 762 763 // If there was no name specified, check to see if there was a useful name 764 // specified in the asm file. 765 if (hadNames || parser.getNumResults() == 0) 766 return success(); 767 768 SmallVector<StringRef, 4> names; 769 auto *context = result.getContext(); 770 771 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 772 auto resultName = parser.getResultName(i); 773 StringRef nameStr; 774 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 775 nameStr = resultName.first; 776 777 names.push_back(nameStr); 778 } 779 780 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 781 result.attributes.push_back({Identifier::get("names", context), namesAttr}); 782 return success(); 783 } 784 785 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { 786 p << "test.string_attr_pretty_name"; 787 788 // Note that we only need to print the "name" attribute if the asmprinter 789 // result name disagrees with it. This can happen in strange cases, e.g. 790 // when there are conflicts. 791 bool namesDisagree = op.names().size() != op.getNumResults(); 792 793 SmallString<32> resultNameStr; 794 for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { 795 resultNameStr.clear(); 796 llvm::raw_svector_ostream tmpStream(resultNameStr); 797 p.printOperand(op.getResult(i), tmpStream); 798 799 auto expectedName = op.names()[i].dyn_cast<StringAttr>(); 800 if (!expectedName || 801 tmpStream.str().drop_front() != expectedName.getValue()) { 802 namesDisagree = true; 803 } 804 } 805 806 if (namesDisagree) 807 p.printOptionalAttrDictWithKeyword(op->getAttrs()); 808 else 809 p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"}); 810 } 811 812 // We set the SSA name in the asm syntax to the contents of the name 813 // attribute. 814 void StringAttrPrettyNameOp::getAsmResultNames( 815 function_ref<void(Value, StringRef)> setNameFn) { 816 817 auto value = names(); 818 for (size_t i = 0, e = value.size(); i != e; ++i) 819 if (auto str = value[i].dyn_cast<StringAttr>()) 820 if (!str.getValue().empty()) 821 setNameFn(getResult(i), str.getValue()); 822 } 823 824 //===----------------------------------------------------------------------===// 825 // RegionIfOp 826 //===----------------------------------------------------------------------===// 827 828 static void print(OpAsmPrinter &p, RegionIfOp op) { 829 p << RegionIfOp::getOperationName() << " "; 830 p.printOperands(op.getOperands()); 831 p << ": " << op.getOperandTypes(); 832 p.printArrowTypeList(op.getResultTypes()); 833 p << " then"; 834 p.printRegion(op.thenRegion(), 835 /*printEntryBlockArgs=*/true, 836 /*printBlockTerminators=*/true); 837 p << " else"; 838 p.printRegion(op.elseRegion(), 839 /*printEntryBlockArgs=*/true, 840 /*printBlockTerminators=*/true); 841 p << " join"; 842 p.printRegion(op.joinRegion(), 843 /*printEntryBlockArgs=*/true, 844 /*printBlockTerminators=*/true); 845 } 846 847 static ParseResult parseRegionIfOp(OpAsmParser &parser, 848 OperationState &result) { 849 SmallVector<OpAsmParser::OperandType, 2> operandInfos; 850 SmallVector<Type, 2> operandTypes; 851 852 result.regions.reserve(3); 853 Region *thenRegion = result.addRegion(); 854 Region *elseRegion = result.addRegion(); 855 Region *joinRegion = result.addRegion(); 856 857 // Parse operand, type and arrow type lists. 858 if (parser.parseOperandList(operandInfos) || 859 parser.parseColonTypeList(operandTypes) || 860 parser.parseArrowTypeList(result.types)) 861 return failure(); 862 863 // Parse all attached regions. 864 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 865 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 866 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 867 return failure(); 868 869 return parser.resolveOperands(operandInfos, operandTypes, 870 parser.getCurrentLocation(), result.operands); 871 } 872 873 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 874 assert(index < 2 && "invalid region index"); 875 return getOperands(); 876 } 877 878 void RegionIfOp::getSuccessorRegions( 879 Optional<unsigned> index, ArrayRef<Attribute> operands, 880 SmallVectorImpl<RegionSuccessor> ®ions) { 881 // We always branch to the join region. 882 if (index.hasValue()) { 883 if (index.getValue() < 2) 884 regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs())); 885 else 886 regions.push_back(RegionSuccessor(getResults())); 887 return; 888 } 889 890 // The then and else regions are the entry regions of this op. 891 regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs())); 892 regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs())); 893 } 894 895 #include "TestOpEnums.cpp.inc" 896 #include "TestOpInterfaces.cpp.inc" 897 #include "TestOpStructs.cpp.inc" 898 #include "TestTypeInterfaces.cpp.inc" 899 900 #define GET_OP_CLASSES 901 #include "TestOps.cpp.inc" 902