1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestAttributes.h" 11 #include "TestTypes.h" 12 #include "mlir/Dialect/DLTI/DLTI.h" 13 #include "mlir/Dialect/MemRef/IR/MemRef.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/BuiltinOps.h" 16 #include "mlir/IR/DialectImplementation.h" 17 #include "mlir/IR/PatternMatch.h" 18 #include "mlir/IR/TypeUtilities.h" 19 #include "mlir/Transforms/FoldUtils.h" 20 #include "mlir/Transforms/InliningUtils.h" 21 #include "llvm/ADT/StringSwitch.h" 22 23 using namespace mlir; 24 using namespace mlir::test; 25 26 void mlir::test::registerTestDialect(DialectRegistry ®istry) { 27 registry.insert<TestDialect>(); 28 } 29 30 //===----------------------------------------------------------------------===// 31 // TestDialect Interfaces 32 //===----------------------------------------------------------------------===// 33 34 namespace { 35 36 // Test support for interacting with the AsmPrinter. 37 struct TestOpAsmInterface : public OpAsmDialectInterface { 38 using OpAsmDialectInterface::OpAsmDialectInterface; 39 40 LogicalResult getAlias(Attribute attr, raw_ostream &os) const final { 41 StringAttr strAttr = attr.dyn_cast<StringAttr>(); 42 if (!strAttr) 43 return failure(); 44 45 // Check the contents of the string attribute to see what the test alias 46 // should be named. 47 Optional<StringRef> aliasName = 48 StringSwitch<Optional<StringRef>>(strAttr.getValue()) 49 .Case("alias_test:dot_in_name", StringRef("test.alias")) 50 .Case("alias_test:trailing_digit", StringRef("test_alias0")) 51 .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) 52 .Case("alias_test:sanitize_conflict_a", 53 StringRef("test_alias_conflict0")) 54 .Case("alias_test:sanitize_conflict_b", 55 StringRef("test_alias_conflict0_")) 56 .Default(llvm::None); 57 if (!aliasName) 58 return failure(); 59 60 os << *aliasName; 61 return success(); 62 } 63 64 void getAsmResultNames(Operation *op, 65 OpAsmSetValueNameFn setNameFn) const final { 66 if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) 67 setNameFn(asmOp, "result"); 68 } 69 70 void getAsmBlockArgumentNames(Block *block, 71 OpAsmSetValueNameFn setNameFn) const final { 72 auto op = block->getParentOp(); 73 auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names"); 74 if (!arrayAttr) 75 return; 76 auto args = block->getArguments(); 77 auto e = std::min(arrayAttr.size(), args.size()); 78 for (unsigned i = 0; i < e; ++i) { 79 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 80 setNameFn(args[i], strAttr.getValue()); 81 } 82 } 83 }; 84 85 struct TestDialectFoldInterface : public DialectFoldInterface { 86 using DialectFoldInterface::DialectFoldInterface; 87 88 /// Registered hook to check if the given region, which is attached to an 89 /// operation that is *not* isolated from above, should be used when 90 /// materializing constants. 91 bool shouldMaterializeInto(Region *region) const final { 92 // If this is a one region operation, then insert into it. 93 return isa<OneRegionOp>(region->getParentOp()); 94 } 95 }; 96 97 /// This class defines the interface for handling inlining with standard 98 /// operations. 99 struct TestInlinerInterface : public DialectInlinerInterface { 100 using DialectInlinerInterface::DialectInlinerInterface; 101 102 //===--------------------------------------------------------------------===// 103 // Analysis Hooks 104 //===--------------------------------------------------------------------===// 105 106 bool isLegalToInline(Operation *call, Operation *callable, 107 bool wouldBeCloned) const final { 108 // Don't allow inlining calls that are marked `noinline`. 109 return !call->hasAttr("noinline"); 110 } 111 bool isLegalToInline(Region *, Region *, bool, 112 BlockAndValueMapping &) const final { 113 // Inlining into test dialect regions is legal. 114 return true; 115 } 116 bool isLegalToInline(Operation *, Region *, bool, 117 BlockAndValueMapping &) const final { 118 return true; 119 } 120 121 bool shouldAnalyzeRecursively(Operation *op) const final { 122 // Analyze recursively if this is not a functional region operation, it 123 // froms a separate functional scope. 124 return !isa<FunctionalRegionOp>(op); 125 } 126 127 //===--------------------------------------------------------------------===// 128 // Transformation Hooks 129 //===--------------------------------------------------------------------===// 130 131 /// Handle the given inlined terminator by replacing it with a new operation 132 /// as necessary. 133 void handleTerminator(Operation *op, 134 ArrayRef<Value> valuesToRepl) const final { 135 // Only handle "test.return" here. 136 auto returnOp = dyn_cast<TestReturnOp>(op); 137 if (!returnOp) 138 return; 139 140 // Replace the values directly with the return operands. 141 assert(returnOp.getNumOperands() == valuesToRepl.size()); 142 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 143 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 144 } 145 146 /// Attempt to materialize a conversion for a type mismatch between a call 147 /// from this dialect, and a callable region. This method should generate an 148 /// operation that takes 'input' as the only operand, and produces a single 149 /// result of 'resultType'. If a conversion can not be generated, nullptr 150 /// should be returned. 151 Operation *materializeCallConversion(OpBuilder &builder, Value input, 152 Type resultType, 153 Location conversionLoc) const final { 154 // Only allow conversion for i16/i32 types. 155 if (!(resultType.isSignlessInteger(16) || 156 resultType.isSignlessInteger(32)) || 157 !(input.getType().isSignlessInteger(16) || 158 input.getType().isSignlessInteger(32))) 159 return nullptr; 160 return builder.create<TestCastOp>(conversionLoc, resultType, input); 161 } 162 }; 163 } // end anonymous namespace 164 165 //===----------------------------------------------------------------------===// 166 // TestDialect 167 //===----------------------------------------------------------------------===// 168 169 void TestDialect::initialize() { 170 registerAttributes(); 171 registerTypes(); 172 addOperations< 173 #define GET_OP_LIST 174 #include "TestOps.cpp.inc" 175 >(); 176 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 177 TestInlinerInterface>(); 178 allowUnknownOperations(); 179 } 180 181 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, 182 Type type, Location loc) { 183 return builder.create<TestOpConstant>(loc, type, value); 184 } 185 186 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 187 NamedAttribute namedAttr) { 188 if (namedAttr.first == "test.invalid_attr") 189 return op->emitError() << "invalid to use 'test.invalid_attr'"; 190 return success(); 191 } 192 193 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 194 unsigned regionIndex, 195 unsigned argIndex, 196 NamedAttribute namedAttr) { 197 if (namedAttr.first == "test.invalid_attr") 198 return op->emitError() << "invalid to use 'test.invalid_attr'"; 199 return success(); 200 } 201 202 LogicalResult 203 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 204 unsigned resultIndex, 205 NamedAttribute namedAttr) { 206 if (namedAttr.first == "test.invalid_attr") 207 return op->emitError() << "invalid to use 'test.invalid_attr'"; 208 return success(); 209 } 210 211 //===----------------------------------------------------------------------===// 212 // TestBranchOp 213 //===----------------------------------------------------------------------===// 214 215 Optional<MutableOperandRange> 216 TestBranchOp::getMutableSuccessorOperands(unsigned index) { 217 assert(index == 0 && "invalid successor index"); 218 return targetOperandsMutable(); 219 } 220 221 //===----------------------------------------------------------------------===// 222 // TestFoldToCallOp 223 //===----------------------------------------------------------------------===// 224 225 namespace { 226 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 227 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 228 229 LogicalResult matchAndRewrite(FoldToCallOp op, 230 PatternRewriter &rewriter) const override { 231 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(), 232 ValueRange()); 233 return success(); 234 } 235 }; 236 } // end anonymous namespace 237 238 void FoldToCallOp::getCanonicalizationPatterns( 239 OwningRewritePatternList &results, MLIRContext *context) { 240 results.insert<FoldToCallOpPattern>(context); 241 } 242 243 //===----------------------------------------------------------------------===// 244 // Test Format* operations 245 //===----------------------------------------------------------------------===// 246 247 //===----------------------------------------------------------------------===// 248 // Parsing 249 250 static ParseResult parseCustomDirectiveOperands( 251 OpAsmParser &parser, OpAsmParser::OperandType &operand, 252 Optional<OpAsmParser::OperandType> &optOperand, 253 SmallVectorImpl<OpAsmParser::OperandType> &varOperands) { 254 if (parser.parseOperand(operand)) 255 return failure(); 256 if (succeeded(parser.parseOptionalComma())) { 257 optOperand.emplace(); 258 if (parser.parseOperand(*optOperand)) 259 return failure(); 260 } 261 if (parser.parseArrow() || parser.parseLParen() || 262 parser.parseOperandList(varOperands) || parser.parseRParen()) 263 return failure(); 264 return success(); 265 } 266 static ParseResult 267 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 268 Type &optOperandType, 269 SmallVectorImpl<Type> &varOperandTypes) { 270 if (parser.parseColon()) 271 return failure(); 272 273 if (parser.parseType(operandType)) 274 return failure(); 275 if (succeeded(parser.parseOptionalComma())) { 276 if (parser.parseType(optOperandType)) 277 return failure(); 278 } 279 if (parser.parseArrow() || parser.parseLParen() || 280 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 281 return failure(); 282 return success(); 283 } 284 static ParseResult 285 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 286 Type optOperandType, 287 const SmallVectorImpl<Type> &varOperandTypes) { 288 if (parser.parseKeyword("type_refs_capture")) 289 return failure(); 290 291 Type operandType2, optOperandType2; 292 SmallVector<Type, 1> varOperandTypes2; 293 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 294 varOperandTypes2)) 295 return failure(); 296 297 if (operandType != operandType2 || optOperandType != optOperandType2 || 298 varOperandTypes != varOperandTypes2) 299 return failure(); 300 301 return success(); 302 } 303 static ParseResult parseCustomDirectiveOperandsAndTypes( 304 OpAsmParser &parser, OpAsmParser::OperandType &operand, 305 Optional<OpAsmParser::OperandType> &optOperand, 306 SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType, 307 Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) { 308 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 309 parseCustomDirectiveResults(parser, operandType, optOperandType, 310 varOperandTypes)) 311 return failure(); 312 return success(); 313 } 314 static ParseResult parseCustomDirectiveRegions( 315 OpAsmParser &parser, Region ®ion, 316 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 317 if (parser.parseRegion(region)) 318 return failure(); 319 if (failed(parser.parseOptionalComma())) 320 return success(); 321 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 322 if (parser.parseRegion(*varRegion)) 323 return failure(); 324 varRegions.emplace_back(std::move(varRegion)); 325 return success(); 326 } 327 static ParseResult 328 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 329 SmallVectorImpl<Block *> &varSuccessors) { 330 if (parser.parseSuccessor(successor)) 331 return failure(); 332 if (failed(parser.parseOptionalComma())) 333 return success(); 334 Block *varSuccessor; 335 if (parser.parseSuccessor(varSuccessor)) 336 return failure(); 337 varSuccessors.append(2, varSuccessor); 338 return success(); 339 } 340 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 341 IntegerAttr &attr, 342 IntegerAttr &optAttr) { 343 if (parser.parseAttribute(attr)) 344 return failure(); 345 if (succeeded(parser.parseOptionalComma())) { 346 if (parser.parseAttribute(optAttr)) 347 return failure(); 348 } 349 return success(); 350 } 351 352 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, 353 NamedAttrList &attrs) { 354 return parser.parseOptionalAttrDict(attrs); 355 } 356 static ParseResult parseCustomDirectiveOptionalOperandRef( 357 OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) { 358 int64_t operandCount = 0; 359 if (parser.parseInteger(operandCount)) 360 return failure(); 361 bool expectedOptionalOperand = operandCount == 0; 362 return success(expectedOptionalOperand != optOperand.hasValue()); 363 } 364 365 //===----------------------------------------------------------------------===// 366 // Printing 367 368 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 369 Value operand, Value optOperand, 370 OperandRange varOperands) { 371 printer << operand; 372 if (optOperand) 373 printer << ", " << optOperand; 374 printer << " -> (" << varOperands << ")"; 375 } 376 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 377 Type operandType, Type optOperandType, 378 TypeRange varOperandTypes) { 379 printer << " : " << operandType; 380 if (optOperandType) 381 printer << ", " << optOperandType; 382 printer << " -> (" << varOperandTypes << ")"; 383 } 384 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 385 Operation *op, Type operandType, 386 Type optOperandType, 387 TypeRange varOperandTypes) { 388 printer << " type_refs_capture "; 389 printCustomDirectiveResults(printer, op, operandType, optOperandType, 390 varOperandTypes); 391 } 392 static void printCustomDirectiveOperandsAndTypes( 393 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 394 OperandRange varOperands, Type operandType, Type optOperandType, 395 TypeRange varOperandTypes) { 396 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 397 printCustomDirectiveResults(printer, op, operandType, optOperandType, 398 varOperandTypes); 399 } 400 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 401 Region ®ion, 402 MutableArrayRef<Region> varRegions) { 403 printer.printRegion(region); 404 if (!varRegions.empty()) { 405 printer << ", "; 406 for (Region ®ion : varRegions) 407 printer.printRegion(region); 408 } 409 } 410 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 411 Block *successor, 412 SuccessorRange varSuccessors) { 413 printer << successor; 414 if (!varSuccessors.empty()) 415 printer << ", " << varSuccessors.front(); 416 } 417 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 418 Attribute attribute, 419 Attribute optAttribute) { 420 printer << attribute; 421 if (optAttribute) 422 printer << ", " << optAttribute; 423 } 424 425 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 426 DictionaryAttr attrs) { 427 printer.printOptionalAttrDict(attrs.getValue()); 428 } 429 430 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, 431 Operation *op, 432 Value optOperand) { 433 printer << (optOperand ? "1" : "0"); 434 } 435 436 //===----------------------------------------------------------------------===// 437 // Test IsolatedRegionOp - parse passthrough region arguments. 438 //===----------------------------------------------------------------------===// 439 440 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, 441 OperationState &result) { 442 OpAsmParser::OperandType argInfo; 443 Type argType = parser.getBuilder().getIndexType(); 444 445 // Parse the input operand. 446 if (parser.parseOperand(argInfo) || 447 parser.resolveOperand(argInfo, argType, result.operands)) 448 return failure(); 449 450 // Parse the body region, and reuse the operand info as the argument info. 451 Region *body = result.addRegion(); 452 return parser.parseRegion(*body, argInfo, argType, 453 /*enableNameShadowing=*/true); 454 } 455 456 static void print(OpAsmPrinter &p, IsolatedRegionOp op) { 457 p << "test.isolated_region "; 458 p.printOperand(op.getOperand()); 459 p.shadowRegionArgs(op.region(), op.getOperand()); 460 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 461 } 462 463 //===----------------------------------------------------------------------===// 464 // Test SSACFGRegionOp 465 //===----------------------------------------------------------------------===// 466 467 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 468 return RegionKind::SSACFG; 469 } 470 471 //===----------------------------------------------------------------------===// 472 // Test GraphRegionOp 473 //===----------------------------------------------------------------------===// 474 475 static ParseResult parseGraphRegionOp(OpAsmParser &parser, 476 OperationState &result) { 477 // Parse the body region, and reuse the operand info as the argument info. 478 Region *body = result.addRegion(); 479 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 480 } 481 482 static void print(OpAsmPrinter &p, GraphRegionOp op) { 483 p << "test.graph_region "; 484 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 485 } 486 487 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 488 return RegionKind::Graph; 489 } 490 491 //===----------------------------------------------------------------------===// 492 // Test AffineScopeOp 493 //===----------------------------------------------------------------------===// 494 495 static ParseResult parseAffineScopeOp(OpAsmParser &parser, 496 OperationState &result) { 497 // Parse the body region, and reuse the operand info as the argument info. 498 Region *body = result.addRegion(); 499 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 500 } 501 502 static void print(OpAsmPrinter &p, AffineScopeOp op) { 503 p << "test.affine_scope "; 504 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 505 } 506 507 //===----------------------------------------------------------------------===// 508 // Test parser. 509 //===----------------------------------------------------------------------===// 510 511 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser, 512 OperationState &result) { 513 if (parser.parseOptionalColon()) 514 return success(); 515 uint64_t numResults; 516 if (parser.parseInteger(numResults)) 517 return failure(); 518 519 IndexType type = parser.getBuilder().getIndexType(); 520 for (unsigned i = 0; i < numResults; ++i) 521 result.addTypes(type); 522 return success(); 523 } 524 525 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) { 526 p << ParseIntegerLiteralOp::getOperationName(); 527 if (unsigned numResults = op->getNumResults()) 528 p << " : " << numResults; 529 } 530 531 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser, 532 OperationState &result) { 533 StringRef keyword; 534 if (parser.parseKeyword(&keyword)) 535 return failure(); 536 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 537 return success(); 538 } 539 540 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) { 541 p << ParseWrappedKeywordOp::getOperationName() << " " << op.keyword(); 542 } 543 544 //===----------------------------------------------------------------------===// 545 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 546 547 static ParseResult parseWrappingRegionOp(OpAsmParser &parser, 548 OperationState &result) { 549 if (parser.parseKeyword("wraps")) 550 return failure(); 551 552 // Parse the wrapped op in a region 553 Region &body = *result.addRegion(); 554 body.push_back(new Block); 555 Block &block = body.back(); 556 Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); 557 if (!wrapped_op) 558 return failure(); 559 560 // Create a return terminator in the inner region, pass as operand to the 561 // terminator the returned values from the wrapped operation. 562 SmallVector<Value, 8> return_operands(wrapped_op->getResults()); 563 OpBuilder builder(parser.getBuilder().getContext()); 564 builder.setInsertionPointToEnd(&block); 565 builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands); 566 567 // Get the results type for the wrapping op from the terminator operands. 568 Operation &return_op = body.back().back(); 569 result.types.append(return_op.operand_type_begin(), 570 return_op.operand_type_end()); 571 572 // Use the location of the wrapped op for the "test.wrapping_region" op. 573 result.location = wrapped_op->getLoc(); 574 575 return success(); 576 } 577 578 static void print(OpAsmPrinter &p, WrappingRegionOp op) { 579 p << op.getOperationName() << " wraps "; 580 p.printGenericOp(&op.region().front().front()); 581 } 582 583 //===----------------------------------------------------------------------===// 584 // Test PolyForOp - parse list of region arguments. 585 //===----------------------------------------------------------------------===// 586 587 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { 588 SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 589 // Parse list of region arguments without a delimiter. 590 if (parser.parseRegionArgumentList(ivsInfo)) 591 return failure(); 592 593 // Parse the body region. 594 Region *body = result.addRegion(); 595 auto &builder = parser.getBuilder(); 596 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 597 return parser.parseRegion(*body, ivsInfo, argTypes); 598 } 599 600 //===----------------------------------------------------------------------===// 601 // Test removing op with inner ops. 602 //===----------------------------------------------------------------------===// 603 604 namespace { 605 struct TestRemoveOpWithInnerOps 606 : public OpRewritePattern<TestOpWithRegionPattern> { 607 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 608 609 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 610 PatternRewriter &rewriter) const override { 611 rewriter.eraseOp(op); 612 return success(); 613 } 614 }; 615 } // end anonymous namespace 616 617 void TestOpWithRegionPattern::getCanonicalizationPatterns( 618 OwningRewritePatternList &results, MLIRContext *context) { 619 results.insert<TestRemoveOpWithInnerOps>(context); 620 } 621 622 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 623 return operand(); 624 } 625 626 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) { 627 return getValue(); 628 } 629 630 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 631 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 632 for (Value input : this->operands()) { 633 results.push_back(input); 634 } 635 return success(); 636 } 637 638 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 639 assert(operands.size() == 1); 640 if (operands.front()) { 641 (*this)->setAttr("attr", operands.front()); 642 return getResult(); 643 } 644 return {}; 645 } 646 647 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) { 648 return getOperand(); 649 } 650 651 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 652 MLIRContext *, Optional<Location> location, ValueRange operands, 653 DictionaryAttr attributes, RegionRange regions, 654 SmallVectorImpl<Type> &inferredReturnTypes) { 655 if (operands[0].getType() != operands[1].getType()) { 656 return emitOptionalError(location, "operand type mismatch ", 657 operands[0].getType(), " vs ", 658 operands[1].getType()); 659 } 660 inferredReturnTypes.assign({operands[0].getType()}); 661 return success(); 662 } 663 664 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 665 MLIRContext *context, Optional<Location> location, ValueRange operands, 666 DictionaryAttr attributes, RegionRange regions, 667 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 668 // Create return type consisting of the last element of the first operand. 669 auto operandType = *operands.getTypes().begin(); 670 auto sval = operandType.dyn_cast<ShapedType>(); 671 if (!sval) { 672 return emitOptionalError(location, "only shaped type operands allowed"); 673 } 674 int64_t dim = 675 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 676 auto type = IntegerType::get(context, 17); 677 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 678 return success(); 679 } 680 681 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 682 OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) { 683 shapes = SmallVector<Value, 1>{ 684 builder.createOrFold<memref::DimOp>(getLoc(), getOperand(0), 0)}; 685 return success(); 686 } 687 688 //===----------------------------------------------------------------------===// 689 // Test SideEffect interfaces 690 //===----------------------------------------------------------------------===// 691 692 namespace { 693 /// A test resource for side effects. 694 struct TestResource : public SideEffects::Resource::Base<TestResource> { 695 StringRef getName() final { return "<Test>"; } 696 }; 697 } // end anonymous namespace 698 699 void SideEffectOp::getEffects( 700 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 701 // Check for an effects attribute on the op instance. 702 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects"); 703 if (!effectsAttr) 704 return; 705 706 // If there is one, it is an array of dictionary attributes that hold 707 // information on the effects of this operation. 708 for (Attribute element : effectsAttr) { 709 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 710 711 // Get the specific memory effect. 712 MemoryEffects::Effect *effect = 713 StringSwitch<MemoryEffects::Effect *>( 714 effectElement.get("effect").cast<StringAttr>().getValue()) 715 .Case("allocate", MemoryEffects::Allocate::get()) 716 .Case("free", MemoryEffects::Free::get()) 717 .Case("read", MemoryEffects::Read::get()) 718 .Case("write", MemoryEffects::Write::get()); 719 720 // Check for a non-default resource to use. 721 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 722 if (effectElement.get("test_resource")) 723 resource = TestResource::get(); 724 725 // Check for a result to affect. 726 if (effectElement.get("on_result")) 727 effects.emplace_back(effect, getResult(), resource); 728 else if (Attribute ref = effectElement.get("on_reference")) 729 effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource); 730 else 731 effects.emplace_back(effect, resource); 732 } 733 } 734 735 void SideEffectOp::getEffects( 736 SmallVectorImpl<TestEffects::EffectInstance> &effects) { 737 auto effectsAttr = (*this)->getAttrOfType<AffineMapAttr>("effect_parameter"); 738 if (!effectsAttr) 739 return; 740 741 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); 742 } 743 744 //===----------------------------------------------------------------------===// 745 // StringAttrPrettyNameOp 746 //===----------------------------------------------------------------------===// 747 748 // This op has fancy handling of its SSA result name. 749 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, 750 OperationState &result) { 751 // Add the result types. 752 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 753 result.addTypes(parser.getBuilder().getIntegerType(32)); 754 755 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 756 return failure(); 757 758 // If the attribute dictionary contains no 'names' attribute, infer it from 759 // the SSA name (if specified). 760 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 761 return attr.first == "names"; 762 }); 763 764 // If there was no name specified, check to see if there was a useful name 765 // specified in the asm file. 766 if (hadNames || parser.getNumResults() == 0) 767 return success(); 768 769 SmallVector<StringRef, 4> names; 770 auto *context = result.getContext(); 771 772 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 773 auto resultName = parser.getResultName(i); 774 StringRef nameStr; 775 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 776 nameStr = resultName.first; 777 778 names.push_back(nameStr); 779 } 780 781 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 782 result.attributes.push_back({Identifier::get("names", context), namesAttr}); 783 return success(); 784 } 785 786 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { 787 p << "test.string_attr_pretty_name"; 788 789 // Note that we only need to print the "name" attribute if the asmprinter 790 // result name disagrees with it. This can happen in strange cases, e.g. 791 // when there are conflicts. 792 bool namesDisagree = op.names().size() != op.getNumResults(); 793 794 SmallString<32> resultNameStr; 795 for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { 796 resultNameStr.clear(); 797 llvm::raw_svector_ostream tmpStream(resultNameStr); 798 p.printOperand(op.getResult(i), tmpStream); 799 800 auto expectedName = op.names()[i].dyn_cast<StringAttr>(); 801 if (!expectedName || 802 tmpStream.str().drop_front() != expectedName.getValue()) { 803 namesDisagree = true; 804 } 805 } 806 807 if (namesDisagree) 808 p.printOptionalAttrDictWithKeyword(op->getAttrs()); 809 else 810 p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"}); 811 } 812 813 // We set the SSA name in the asm syntax to the contents of the name 814 // attribute. 815 void StringAttrPrettyNameOp::getAsmResultNames( 816 function_ref<void(Value, StringRef)> setNameFn) { 817 818 auto value = names(); 819 for (size_t i = 0, e = value.size(); i != e; ++i) 820 if (auto str = value[i].dyn_cast<StringAttr>()) 821 if (!str.getValue().empty()) 822 setNameFn(getResult(i), str.getValue()); 823 } 824 825 //===----------------------------------------------------------------------===// 826 // RegionIfOp 827 //===----------------------------------------------------------------------===// 828 829 static void print(OpAsmPrinter &p, RegionIfOp op) { 830 p << RegionIfOp::getOperationName() << " "; 831 p.printOperands(op.getOperands()); 832 p << ": " << op.getOperandTypes(); 833 p.printArrowTypeList(op.getResultTypes()); 834 p << " then"; 835 p.printRegion(op.thenRegion(), 836 /*printEntryBlockArgs=*/true, 837 /*printBlockTerminators=*/true); 838 p << " else"; 839 p.printRegion(op.elseRegion(), 840 /*printEntryBlockArgs=*/true, 841 /*printBlockTerminators=*/true); 842 p << " join"; 843 p.printRegion(op.joinRegion(), 844 /*printEntryBlockArgs=*/true, 845 /*printBlockTerminators=*/true); 846 } 847 848 static ParseResult parseRegionIfOp(OpAsmParser &parser, 849 OperationState &result) { 850 SmallVector<OpAsmParser::OperandType, 2> operandInfos; 851 SmallVector<Type, 2> operandTypes; 852 853 result.regions.reserve(3); 854 Region *thenRegion = result.addRegion(); 855 Region *elseRegion = result.addRegion(); 856 Region *joinRegion = result.addRegion(); 857 858 // Parse operand, type and arrow type lists. 859 if (parser.parseOperandList(operandInfos) || 860 parser.parseColonTypeList(operandTypes) || 861 parser.parseArrowTypeList(result.types)) 862 return failure(); 863 864 // Parse all attached regions. 865 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 866 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 867 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 868 return failure(); 869 870 return parser.resolveOperands(operandInfos, operandTypes, 871 parser.getCurrentLocation(), result.operands); 872 } 873 874 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 875 assert(index < 2 && "invalid region index"); 876 return getOperands(); 877 } 878 879 void RegionIfOp::getSuccessorRegions( 880 Optional<unsigned> index, ArrayRef<Attribute> operands, 881 SmallVectorImpl<RegionSuccessor> ®ions) { 882 // We always branch to the join region. 883 if (index.hasValue()) { 884 if (index.getValue() < 2) 885 regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs())); 886 else 887 regions.push_back(RegionSuccessor(getResults())); 888 return; 889 } 890 891 // The then and else regions are the entry regions of this op. 892 regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs())); 893 regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs())); 894 } 895 896 #include "TestOpEnums.cpp.inc" 897 #include "TestOpInterfaces.cpp.inc" 898 #include "TestOpStructs.cpp.inc" 899 #include "TestTypeInterfaces.cpp.inc" 900 901 #define GET_OP_CLASSES 902 #include "TestOps.cpp.inc" 903