1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestAttributes.h" 11 #include "TestTypes.h" 12 #include "mlir/Dialect/DLTI/DLTI.h" 13 #include "mlir/Dialect/MemRef/IR/MemRef.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/BuiltinOps.h" 16 #include "mlir/IR/DialectImplementation.h" 17 #include "mlir/IR/PatternMatch.h" 18 #include "mlir/IR/TypeUtilities.h" 19 #include "mlir/Transforms/FoldUtils.h" 20 #include "mlir/Transforms/InliningUtils.h" 21 #include "llvm/ADT/StringSwitch.h" 22 23 using namespace mlir; 24 using namespace mlir::test; 25 26 void mlir::test::registerTestDialect(DialectRegistry ®istry) { 27 registry.insert<TestDialect>(); 28 } 29 30 //===----------------------------------------------------------------------===// 31 // TestDialect Interfaces 32 //===----------------------------------------------------------------------===// 33 34 namespace { 35 36 // Test support for interacting with the AsmPrinter. 37 struct TestOpAsmInterface : public OpAsmDialectInterface { 38 using OpAsmDialectInterface::OpAsmDialectInterface; 39 40 LogicalResult getAlias(Attribute attr, raw_ostream &os) const final { 41 StringAttr strAttr = attr.dyn_cast<StringAttr>(); 42 if (!strAttr) 43 return failure(); 44 45 // Check the contents of the string attribute to see what the test alias 46 // should be named. 47 Optional<StringRef> aliasName = 48 StringSwitch<Optional<StringRef>>(strAttr.getValue()) 49 .Case("alias_test:dot_in_name", StringRef("test.alias")) 50 .Case("alias_test:trailing_digit", StringRef("test_alias0")) 51 .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) 52 .Case("alias_test:sanitize_conflict_a", 53 StringRef("test_alias_conflict0")) 54 .Case("alias_test:sanitize_conflict_b", 55 StringRef("test_alias_conflict0_")) 56 .Default(llvm::None); 57 if (!aliasName) 58 return failure(); 59 60 os << *aliasName; 61 return success(); 62 } 63 64 void getAsmResultNames(Operation *op, 65 OpAsmSetValueNameFn setNameFn) const final { 66 if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) 67 setNameFn(asmOp, "result"); 68 } 69 70 void getAsmBlockArgumentNames(Block *block, 71 OpAsmSetValueNameFn setNameFn) const final { 72 auto op = block->getParentOp(); 73 auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names"); 74 if (!arrayAttr) 75 return; 76 auto args = block->getArguments(); 77 auto e = std::min(arrayAttr.size(), args.size()); 78 for (unsigned i = 0; i < e; ++i) { 79 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 80 setNameFn(args[i], strAttr.getValue()); 81 } 82 } 83 }; 84 85 struct TestDialectFoldInterface : public DialectFoldInterface { 86 using DialectFoldInterface::DialectFoldInterface; 87 88 /// Registered hook to check if the given region, which is attached to an 89 /// operation that is *not* isolated from above, should be used when 90 /// materializing constants. 91 bool shouldMaterializeInto(Region *region) const final { 92 // If this is a one region operation, then insert into it. 93 return isa<OneRegionOp>(region->getParentOp()); 94 } 95 }; 96 97 /// This class defines the interface for handling inlining with standard 98 /// operations. 99 struct TestInlinerInterface : public DialectInlinerInterface { 100 using DialectInlinerInterface::DialectInlinerInterface; 101 102 //===--------------------------------------------------------------------===// 103 // Analysis Hooks 104 //===--------------------------------------------------------------------===// 105 106 bool isLegalToInline(Operation *call, Operation *callable, 107 bool wouldBeCloned) const final { 108 // Don't allow inlining calls that are marked `noinline`. 109 return !call->hasAttr("noinline"); 110 } 111 bool isLegalToInline(Region *, Region *, bool, 112 BlockAndValueMapping &) const final { 113 // Inlining into test dialect regions is legal. 114 return true; 115 } 116 bool isLegalToInline(Operation *, Region *, bool, 117 BlockAndValueMapping &) const final { 118 return true; 119 } 120 121 bool shouldAnalyzeRecursively(Operation *op) const final { 122 // Analyze recursively if this is not a functional region operation, it 123 // froms a separate functional scope. 124 return !isa<FunctionalRegionOp>(op); 125 } 126 127 //===--------------------------------------------------------------------===// 128 // Transformation Hooks 129 //===--------------------------------------------------------------------===// 130 131 /// Handle the given inlined terminator by replacing it with a new operation 132 /// as necessary. 133 void handleTerminator(Operation *op, 134 ArrayRef<Value> valuesToRepl) const final { 135 // Only handle "test.return" here. 136 auto returnOp = dyn_cast<TestReturnOp>(op); 137 if (!returnOp) 138 return; 139 140 // Replace the values directly with the return operands. 141 assert(returnOp.getNumOperands() == valuesToRepl.size()); 142 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 143 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 144 } 145 146 /// Attempt to materialize a conversion for a type mismatch between a call 147 /// from this dialect, and a callable region. This method should generate an 148 /// operation that takes 'input' as the only operand, and produces a single 149 /// result of 'resultType'. If a conversion can not be generated, nullptr 150 /// should be returned. 151 Operation *materializeCallConversion(OpBuilder &builder, Value input, 152 Type resultType, 153 Location conversionLoc) const final { 154 // Only allow conversion for i16/i32 types. 155 if (!(resultType.isSignlessInteger(16) || 156 resultType.isSignlessInteger(32)) || 157 !(input.getType().isSignlessInteger(16) || 158 input.getType().isSignlessInteger(32))) 159 return nullptr; 160 return builder.create<TestCastOp>(conversionLoc, resultType, input); 161 } 162 }; 163 } // end anonymous namespace 164 165 //===----------------------------------------------------------------------===// 166 // TestDialect 167 //===----------------------------------------------------------------------===// 168 169 void TestDialect::initialize() { 170 registerAttributes(); 171 registerTypes(); 172 addOperations< 173 #define GET_OP_LIST 174 #include "TestOps.cpp.inc" 175 >(); 176 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 177 TestInlinerInterface>(); 178 allowUnknownOperations(); 179 } 180 181 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, 182 Type type, Location loc) { 183 return builder.create<TestOpConstant>(loc, type, value); 184 } 185 186 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 187 NamedAttribute namedAttr) { 188 if (namedAttr.first == "test.invalid_attr") 189 return op->emitError() << "invalid to use 'test.invalid_attr'"; 190 return success(); 191 } 192 193 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 194 unsigned regionIndex, 195 unsigned argIndex, 196 NamedAttribute namedAttr) { 197 if (namedAttr.first == "test.invalid_attr") 198 return op->emitError() << "invalid to use 'test.invalid_attr'"; 199 return success(); 200 } 201 202 LogicalResult 203 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 204 unsigned resultIndex, 205 NamedAttribute namedAttr) { 206 if (namedAttr.first == "test.invalid_attr") 207 return op->emitError() << "invalid to use 'test.invalid_attr'"; 208 return success(); 209 } 210 211 Optional<Dialect::ParseOpHook> 212 TestDialect::getParseOperationHook(StringRef opName) const { 213 if (opName == "test.dialect_custom_printer") { 214 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 215 return parser.parseKeyword("custom_format"); 216 }}; 217 } 218 return None; 219 } 220 221 LogicalResult TestDialect::printOperation(Operation *op, 222 OpAsmPrinter &printer) const { 223 StringRef opName = op->getName().getStringRef(); 224 if (opName == "test.dialect_custom_printer") { 225 printer.getStream() << opName << " custom_format"; 226 return success(); 227 } 228 return failure(); 229 } 230 231 //===----------------------------------------------------------------------===// 232 // TestBranchOp 233 //===----------------------------------------------------------------------===// 234 235 Optional<MutableOperandRange> 236 TestBranchOp::getMutableSuccessorOperands(unsigned index) { 237 assert(index == 0 && "invalid successor index"); 238 return targetOperandsMutable(); 239 } 240 241 //===----------------------------------------------------------------------===// 242 // TestFoldToCallOp 243 //===----------------------------------------------------------------------===// 244 245 namespace { 246 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 247 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 248 249 LogicalResult matchAndRewrite(FoldToCallOp op, 250 PatternRewriter &rewriter) const override { 251 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(), 252 ValueRange()); 253 return success(); 254 } 255 }; 256 } // end anonymous namespace 257 258 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results, 259 MLIRContext *context) { 260 results.add<FoldToCallOpPattern>(context); 261 } 262 263 //===----------------------------------------------------------------------===// 264 // Test Format* operations 265 //===----------------------------------------------------------------------===// 266 267 //===----------------------------------------------------------------------===// 268 // Parsing 269 270 static ParseResult parseCustomDirectiveOperands( 271 OpAsmParser &parser, OpAsmParser::OperandType &operand, 272 Optional<OpAsmParser::OperandType> &optOperand, 273 SmallVectorImpl<OpAsmParser::OperandType> &varOperands) { 274 if (parser.parseOperand(operand)) 275 return failure(); 276 if (succeeded(parser.parseOptionalComma())) { 277 optOperand.emplace(); 278 if (parser.parseOperand(*optOperand)) 279 return failure(); 280 } 281 if (parser.parseArrow() || parser.parseLParen() || 282 parser.parseOperandList(varOperands) || parser.parseRParen()) 283 return failure(); 284 return success(); 285 } 286 static ParseResult 287 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 288 Type &optOperandType, 289 SmallVectorImpl<Type> &varOperandTypes) { 290 if (parser.parseColon()) 291 return failure(); 292 293 if (parser.parseType(operandType)) 294 return failure(); 295 if (succeeded(parser.parseOptionalComma())) { 296 if (parser.parseType(optOperandType)) 297 return failure(); 298 } 299 if (parser.parseArrow() || parser.parseLParen() || 300 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 301 return failure(); 302 return success(); 303 } 304 static ParseResult 305 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 306 Type optOperandType, 307 const SmallVectorImpl<Type> &varOperandTypes) { 308 if (parser.parseKeyword("type_refs_capture")) 309 return failure(); 310 311 Type operandType2, optOperandType2; 312 SmallVector<Type, 1> varOperandTypes2; 313 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 314 varOperandTypes2)) 315 return failure(); 316 317 if (operandType != operandType2 || optOperandType != optOperandType2 || 318 varOperandTypes != varOperandTypes2) 319 return failure(); 320 321 return success(); 322 } 323 static ParseResult parseCustomDirectiveOperandsAndTypes( 324 OpAsmParser &parser, OpAsmParser::OperandType &operand, 325 Optional<OpAsmParser::OperandType> &optOperand, 326 SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType, 327 Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) { 328 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 329 parseCustomDirectiveResults(parser, operandType, optOperandType, 330 varOperandTypes)) 331 return failure(); 332 return success(); 333 } 334 static ParseResult parseCustomDirectiveRegions( 335 OpAsmParser &parser, Region ®ion, 336 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 337 if (parser.parseRegion(region)) 338 return failure(); 339 if (failed(parser.parseOptionalComma())) 340 return success(); 341 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 342 if (parser.parseRegion(*varRegion)) 343 return failure(); 344 varRegions.emplace_back(std::move(varRegion)); 345 return success(); 346 } 347 static ParseResult 348 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 349 SmallVectorImpl<Block *> &varSuccessors) { 350 if (parser.parseSuccessor(successor)) 351 return failure(); 352 if (failed(parser.parseOptionalComma())) 353 return success(); 354 Block *varSuccessor; 355 if (parser.parseSuccessor(varSuccessor)) 356 return failure(); 357 varSuccessors.append(2, varSuccessor); 358 return success(); 359 } 360 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 361 IntegerAttr &attr, 362 IntegerAttr &optAttr) { 363 if (parser.parseAttribute(attr)) 364 return failure(); 365 if (succeeded(parser.parseOptionalComma())) { 366 if (parser.parseAttribute(optAttr)) 367 return failure(); 368 } 369 return success(); 370 } 371 372 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, 373 NamedAttrList &attrs) { 374 return parser.parseOptionalAttrDict(attrs); 375 } 376 static ParseResult parseCustomDirectiveOptionalOperandRef( 377 OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) { 378 int64_t operandCount = 0; 379 if (parser.parseInteger(operandCount)) 380 return failure(); 381 bool expectedOptionalOperand = operandCount == 0; 382 return success(expectedOptionalOperand != optOperand.hasValue()); 383 } 384 385 //===----------------------------------------------------------------------===// 386 // Printing 387 388 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 389 Value operand, Value optOperand, 390 OperandRange varOperands) { 391 printer << operand; 392 if (optOperand) 393 printer << ", " << optOperand; 394 printer << " -> (" << varOperands << ")"; 395 } 396 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 397 Type operandType, Type optOperandType, 398 TypeRange varOperandTypes) { 399 printer << " : " << operandType; 400 if (optOperandType) 401 printer << ", " << optOperandType; 402 printer << " -> (" << varOperandTypes << ")"; 403 } 404 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 405 Operation *op, Type operandType, 406 Type optOperandType, 407 TypeRange varOperandTypes) { 408 printer << " type_refs_capture "; 409 printCustomDirectiveResults(printer, op, operandType, optOperandType, 410 varOperandTypes); 411 } 412 static void printCustomDirectiveOperandsAndTypes( 413 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 414 OperandRange varOperands, Type operandType, Type optOperandType, 415 TypeRange varOperandTypes) { 416 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 417 printCustomDirectiveResults(printer, op, operandType, optOperandType, 418 varOperandTypes); 419 } 420 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 421 Region ®ion, 422 MutableArrayRef<Region> varRegions) { 423 printer.printRegion(region); 424 if (!varRegions.empty()) { 425 printer << ", "; 426 for (Region ®ion : varRegions) 427 printer.printRegion(region); 428 } 429 } 430 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 431 Block *successor, 432 SuccessorRange varSuccessors) { 433 printer << successor; 434 if (!varSuccessors.empty()) 435 printer << ", " << varSuccessors.front(); 436 } 437 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 438 Attribute attribute, 439 Attribute optAttribute) { 440 printer << attribute; 441 if (optAttribute) 442 printer << ", " << optAttribute; 443 } 444 445 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 446 DictionaryAttr attrs) { 447 printer.printOptionalAttrDict(attrs.getValue()); 448 } 449 450 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, 451 Operation *op, 452 Value optOperand) { 453 printer << (optOperand ? "1" : "0"); 454 } 455 456 //===----------------------------------------------------------------------===// 457 // Test IsolatedRegionOp - parse passthrough region arguments. 458 //===----------------------------------------------------------------------===// 459 460 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, 461 OperationState &result) { 462 OpAsmParser::OperandType argInfo; 463 Type argType = parser.getBuilder().getIndexType(); 464 465 // Parse the input operand. 466 if (parser.parseOperand(argInfo) || 467 parser.resolveOperand(argInfo, argType, result.operands)) 468 return failure(); 469 470 // Parse the body region, and reuse the operand info as the argument info. 471 Region *body = result.addRegion(); 472 return parser.parseRegion(*body, argInfo, argType, 473 /*enableNameShadowing=*/true); 474 } 475 476 static void print(OpAsmPrinter &p, IsolatedRegionOp op) { 477 p << "test.isolated_region "; 478 p.printOperand(op.getOperand()); 479 p.shadowRegionArgs(op.region(), op.getOperand()); 480 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 481 } 482 483 //===----------------------------------------------------------------------===// 484 // Test SSACFGRegionOp 485 //===----------------------------------------------------------------------===// 486 487 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 488 return RegionKind::SSACFG; 489 } 490 491 //===----------------------------------------------------------------------===// 492 // Test GraphRegionOp 493 //===----------------------------------------------------------------------===// 494 495 static ParseResult parseGraphRegionOp(OpAsmParser &parser, 496 OperationState &result) { 497 // Parse the body region, and reuse the operand info as the argument info. 498 Region *body = result.addRegion(); 499 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 500 } 501 502 static void print(OpAsmPrinter &p, GraphRegionOp op) { 503 p << "test.graph_region "; 504 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 505 } 506 507 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 508 return RegionKind::Graph; 509 } 510 511 //===----------------------------------------------------------------------===// 512 // Test AffineScopeOp 513 //===----------------------------------------------------------------------===// 514 515 static ParseResult parseAffineScopeOp(OpAsmParser &parser, 516 OperationState &result) { 517 // Parse the body region, and reuse the operand info as the argument info. 518 Region *body = result.addRegion(); 519 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 520 } 521 522 static void print(OpAsmPrinter &p, AffineScopeOp op) { 523 p << "test.affine_scope "; 524 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 525 } 526 527 //===----------------------------------------------------------------------===// 528 // Test parser. 529 //===----------------------------------------------------------------------===// 530 531 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser, 532 OperationState &result) { 533 if (parser.parseOptionalColon()) 534 return success(); 535 uint64_t numResults; 536 if (parser.parseInteger(numResults)) 537 return failure(); 538 539 IndexType type = parser.getBuilder().getIndexType(); 540 for (unsigned i = 0; i < numResults; ++i) 541 result.addTypes(type); 542 return success(); 543 } 544 545 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) { 546 p << ParseIntegerLiteralOp::getOperationName(); 547 if (unsigned numResults = op->getNumResults()) 548 p << " : " << numResults; 549 } 550 551 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser, 552 OperationState &result) { 553 StringRef keyword; 554 if (parser.parseKeyword(&keyword)) 555 return failure(); 556 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 557 return success(); 558 } 559 560 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) { 561 p << ParseWrappedKeywordOp::getOperationName() << " " << op.keyword(); 562 } 563 564 //===----------------------------------------------------------------------===// 565 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 566 567 static ParseResult parseWrappingRegionOp(OpAsmParser &parser, 568 OperationState &result) { 569 if (parser.parseKeyword("wraps")) 570 return failure(); 571 572 // Parse the wrapped op in a region 573 Region &body = *result.addRegion(); 574 body.push_back(new Block); 575 Block &block = body.back(); 576 Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); 577 if (!wrapped_op) 578 return failure(); 579 580 // Create a return terminator in the inner region, pass as operand to the 581 // terminator the returned values from the wrapped operation. 582 SmallVector<Value, 8> return_operands(wrapped_op->getResults()); 583 OpBuilder builder(parser.getBuilder().getContext()); 584 builder.setInsertionPointToEnd(&block); 585 builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands); 586 587 // Get the results type for the wrapping op from the terminator operands. 588 Operation &return_op = body.back().back(); 589 result.types.append(return_op.operand_type_begin(), 590 return_op.operand_type_end()); 591 592 // Use the location of the wrapped op for the "test.wrapping_region" op. 593 result.location = wrapped_op->getLoc(); 594 595 return success(); 596 } 597 598 static void print(OpAsmPrinter &p, WrappingRegionOp op) { 599 p << op.getOperationName() << " wraps "; 600 p.printGenericOp(&op.region().front().front()); 601 } 602 603 //===----------------------------------------------------------------------===// 604 // Test PolyForOp - parse list of region arguments. 605 //===----------------------------------------------------------------------===// 606 607 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { 608 SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 609 // Parse list of region arguments without a delimiter. 610 if (parser.parseRegionArgumentList(ivsInfo)) 611 return failure(); 612 613 // Parse the body region. 614 Region *body = result.addRegion(); 615 auto &builder = parser.getBuilder(); 616 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 617 return parser.parseRegion(*body, ivsInfo, argTypes); 618 } 619 620 //===----------------------------------------------------------------------===// 621 // Test removing op with inner ops. 622 //===----------------------------------------------------------------------===// 623 624 namespace { 625 struct TestRemoveOpWithInnerOps 626 : public OpRewritePattern<TestOpWithRegionPattern> { 627 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 628 629 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 630 PatternRewriter &rewriter) const override { 631 rewriter.eraseOp(op); 632 return success(); 633 } 634 }; 635 } // end anonymous namespace 636 637 void TestOpWithRegionPattern::getCanonicalizationPatterns( 638 RewritePatternSet &results, MLIRContext *context) { 639 results.add<TestRemoveOpWithInnerOps>(context); 640 } 641 642 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 643 return operand(); 644 } 645 646 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) { 647 return getValue(); 648 } 649 650 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 651 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 652 for (Value input : this->operands()) { 653 results.push_back(input); 654 } 655 return success(); 656 } 657 658 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 659 assert(operands.size() == 1); 660 if (operands.front()) { 661 (*this)->setAttr("attr", operands.front()); 662 return getResult(); 663 } 664 return {}; 665 } 666 667 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) { 668 return getOperand(); 669 } 670 671 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 672 MLIRContext *, Optional<Location> location, ValueRange operands, 673 DictionaryAttr attributes, RegionRange regions, 674 SmallVectorImpl<Type> &inferredReturnTypes) { 675 if (operands[0].getType() != operands[1].getType()) { 676 return emitOptionalError(location, "operand type mismatch ", 677 operands[0].getType(), " vs ", 678 operands[1].getType()); 679 } 680 inferredReturnTypes.assign({operands[0].getType()}); 681 return success(); 682 } 683 684 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 685 MLIRContext *context, Optional<Location> location, ValueRange operands, 686 DictionaryAttr attributes, RegionRange regions, 687 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 688 // Create return type consisting of the last element of the first operand. 689 auto operandType = *operands.getTypes().begin(); 690 auto sval = operandType.dyn_cast<ShapedType>(); 691 if (!sval) { 692 return emitOptionalError(location, "only shaped type operands allowed"); 693 } 694 int64_t dim = 695 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 696 auto type = IntegerType::get(context, 17); 697 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 698 return success(); 699 } 700 701 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 702 OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) { 703 shapes = SmallVector<Value, 1>{ 704 builder.createOrFold<memref::DimOp>(getLoc(), getOperand(0), 0)}; 705 return success(); 706 } 707 708 //===----------------------------------------------------------------------===// 709 // Test SideEffect interfaces 710 //===----------------------------------------------------------------------===// 711 712 namespace { 713 /// A test resource for side effects. 714 struct TestResource : public SideEffects::Resource::Base<TestResource> { 715 StringRef getName() final { return "<Test>"; } 716 }; 717 } // end anonymous namespace 718 719 void SideEffectOp::getEffects( 720 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 721 // Check for an effects attribute on the op instance. 722 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects"); 723 if (!effectsAttr) 724 return; 725 726 // If there is one, it is an array of dictionary attributes that hold 727 // information on the effects of this operation. 728 for (Attribute element : effectsAttr) { 729 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 730 731 // Get the specific memory effect. 732 MemoryEffects::Effect *effect = 733 StringSwitch<MemoryEffects::Effect *>( 734 effectElement.get("effect").cast<StringAttr>().getValue()) 735 .Case("allocate", MemoryEffects::Allocate::get()) 736 .Case("free", MemoryEffects::Free::get()) 737 .Case("read", MemoryEffects::Read::get()) 738 .Case("write", MemoryEffects::Write::get()); 739 740 // Check for a non-default resource to use. 741 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 742 if (effectElement.get("test_resource")) 743 resource = TestResource::get(); 744 745 // Check for a result to affect. 746 if (effectElement.get("on_result")) 747 effects.emplace_back(effect, getResult(), resource); 748 else if (Attribute ref = effectElement.get("on_reference")) 749 effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource); 750 else 751 effects.emplace_back(effect, resource); 752 } 753 } 754 755 void SideEffectOp::getEffects( 756 SmallVectorImpl<TestEffects::EffectInstance> &effects) { 757 auto effectsAttr = (*this)->getAttrOfType<AffineMapAttr>("effect_parameter"); 758 if (!effectsAttr) 759 return; 760 761 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); 762 } 763 764 //===----------------------------------------------------------------------===// 765 // StringAttrPrettyNameOp 766 //===----------------------------------------------------------------------===// 767 768 // This op has fancy handling of its SSA result name. 769 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, 770 OperationState &result) { 771 // Add the result types. 772 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 773 result.addTypes(parser.getBuilder().getIntegerType(32)); 774 775 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 776 return failure(); 777 778 // If the attribute dictionary contains no 'names' attribute, infer it from 779 // the SSA name (if specified). 780 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 781 return attr.first == "names"; 782 }); 783 784 // If there was no name specified, check to see if there was a useful name 785 // specified in the asm file. 786 if (hadNames || parser.getNumResults() == 0) 787 return success(); 788 789 SmallVector<StringRef, 4> names; 790 auto *context = result.getContext(); 791 792 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 793 auto resultName = parser.getResultName(i); 794 StringRef nameStr; 795 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 796 nameStr = resultName.first; 797 798 names.push_back(nameStr); 799 } 800 801 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 802 result.attributes.push_back({Identifier::get("names", context), namesAttr}); 803 return success(); 804 } 805 806 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { 807 p << "test.string_attr_pretty_name"; 808 809 // Note that we only need to print the "name" attribute if the asmprinter 810 // result name disagrees with it. This can happen in strange cases, e.g. 811 // when there are conflicts. 812 bool namesDisagree = op.names().size() != op.getNumResults(); 813 814 SmallString<32> resultNameStr; 815 for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { 816 resultNameStr.clear(); 817 llvm::raw_svector_ostream tmpStream(resultNameStr); 818 p.printOperand(op.getResult(i), tmpStream); 819 820 auto expectedName = op.names()[i].dyn_cast<StringAttr>(); 821 if (!expectedName || 822 tmpStream.str().drop_front() != expectedName.getValue()) { 823 namesDisagree = true; 824 } 825 } 826 827 if (namesDisagree) 828 p.printOptionalAttrDictWithKeyword(op->getAttrs()); 829 else 830 p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"}); 831 } 832 833 // We set the SSA name in the asm syntax to the contents of the name 834 // attribute. 835 void StringAttrPrettyNameOp::getAsmResultNames( 836 function_ref<void(Value, StringRef)> setNameFn) { 837 838 auto value = names(); 839 for (size_t i = 0, e = value.size(); i != e; ++i) 840 if (auto str = value[i].dyn_cast<StringAttr>()) 841 if (!str.getValue().empty()) 842 setNameFn(getResult(i), str.getValue()); 843 } 844 845 //===----------------------------------------------------------------------===// 846 // RegionIfOp 847 //===----------------------------------------------------------------------===// 848 849 static void print(OpAsmPrinter &p, RegionIfOp op) { 850 p << RegionIfOp::getOperationName() << " "; 851 p.printOperands(op.getOperands()); 852 p << ": " << op.getOperandTypes(); 853 p.printArrowTypeList(op.getResultTypes()); 854 p << " then"; 855 p.printRegion(op.thenRegion(), 856 /*printEntryBlockArgs=*/true, 857 /*printBlockTerminators=*/true); 858 p << " else"; 859 p.printRegion(op.elseRegion(), 860 /*printEntryBlockArgs=*/true, 861 /*printBlockTerminators=*/true); 862 p << " join"; 863 p.printRegion(op.joinRegion(), 864 /*printEntryBlockArgs=*/true, 865 /*printBlockTerminators=*/true); 866 } 867 868 static ParseResult parseRegionIfOp(OpAsmParser &parser, 869 OperationState &result) { 870 SmallVector<OpAsmParser::OperandType, 2> operandInfos; 871 SmallVector<Type, 2> operandTypes; 872 873 result.regions.reserve(3); 874 Region *thenRegion = result.addRegion(); 875 Region *elseRegion = result.addRegion(); 876 Region *joinRegion = result.addRegion(); 877 878 // Parse operand, type and arrow type lists. 879 if (parser.parseOperandList(operandInfos) || 880 parser.parseColonTypeList(operandTypes) || 881 parser.parseArrowTypeList(result.types)) 882 return failure(); 883 884 // Parse all attached regions. 885 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 886 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 887 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 888 return failure(); 889 890 return parser.resolveOperands(operandInfos, operandTypes, 891 parser.getCurrentLocation(), result.operands); 892 } 893 894 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 895 assert(index < 2 && "invalid region index"); 896 return getOperands(); 897 } 898 899 void RegionIfOp::getSuccessorRegions( 900 Optional<unsigned> index, ArrayRef<Attribute> operands, 901 SmallVectorImpl<RegionSuccessor> ®ions) { 902 // We always branch to the join region. 903 if (index.hasValue()) { 904 if (index.getValue() < 2) 905 regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs())); 906 else 907 regions.push_back(RegionSuccessor(getResults())); 908 return; 909 } 910 911 // The then and else regions are the entry regions of this op. 912 regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs())); 913 regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs())); 914 } 915 916 #include "TestOpEnums.cpp.inc" 917 #include "TestOpInterfaces.cpp.inc" 918 #include "TestOpStructs.cpp.inc" 919 #include "TestTypeInterfaces.cpp.inc" 920 921 #define GET_OP_CLASSES 922 #include "TestOps.cpp.inc" 923