1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestTypes.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/IR/BuiltinOps.h" 13 #include "mlir/IR/DialectImplementation.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/IR/TypeUtilities.h" 16 #include "mlir/Transforms/FoldUtils.h" 17 #include "mlir/Transforms/InliningUtils.h" 18 #include "llvm/ADT/SetVector.h" 19 #include "llvm/ADT/StringSwitch.h" 20 21 using namespace mlir; 22 using namespace mlir::test; 23 24 void mlir::test::registerTestDialect(DialectRegistry ®istry) { 25 registry.insert<TestDialect>(); 26 } 27 28 //===----------------------------------------------------------------------===// 29 // TestDialect Interfaces 30 //===----------------------------------------------------------------------===// 31 32 namespace { 33 34 // Test support for interacting with the AsmPrinter. 35 struct TestOpAsmInterface : public OpAsmDialectInterface { 36 using OpAsmDialectInterface::OpAsmDialectInterface; 37 38 LogicalResult getAlias(Attribute attr, raw_ostream &os) const final { 39 StringAttr strAttr = attr.dyn_cast<StringAttr>(); 40 if (!strAttr) 41 return failure(); 42 43 // Check the contents of the string attribute to see what the test alias 44 // should be named. 45 Optional<StringRef> aliasName = 46 StringSwitch<Optional<StringRef>>(strAttr.getValue()) 47 .Case("alias_test:dot_in_name", StringRef("test.alias")) 48 .Case("alias_test:trailing_digit", StringRef("test_alias0")) 49 .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) 50 .Case("alias_test:sanitize_conflict_a", 51 StringRef("test_alias_conflict0")) 52 .Case("alias_test:sanitize_conflict_b", 53 StringRef("test_alias_conflict0_")) 54 .Default(llvm::None); 55 if (!aliasName) 56 return failure(); 57 58 os << *aliasName; 59 return success(); 60 } 61 62 void getAsmResultNames(Operation *op, 63 OpAsmSetValueNameFn setNameFn) const final { 64 if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) 65 setNameFn(asmOp, "result"); 66 } 67 68 void getAsmBlockArgumentNames(Block *block, 69 OpAsmSetValueNameFn setNameFn) const final { 70 auto op = block->getParentOp(); 71 auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names"); 72 if (!arrayAttr) 73 return; 74 auto args = block->getArguments(); 75 auto e = std::min(arrayAttr.size(), args.size()); 76 for (unsigned i = 0; i < e; ++i) { 77 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 78 setNameFn(args[i], strAttr.getValue()); 79 } 80 } 81 }; 82 83 struct TestDialectFoldInterface : public DialectFoldInterface { 84 using DialectFoldInterface::DialectFoldInterface; 85 86 /// Registered hook to check if the given region, which is attached to an 87 /// operation that is *not* isolated from above, should be used when 88 /// materializing constants. 89 bool shouldMaterializeInto(Region *region) const final { 90 // If this is a one region operation, then insert into it. 91 return isa<OneRegionOp>(region->getParentOp()); 92 } 93 }; 94 95 /// This class defines the interface for handling inlining with standard 96 /// operations. 97 struct TestInlinerInterface : public DialectInlinerInterface { 98 using DialectInlinerInterface::DialectInlinerInterface; 99 100 //===--------------------------------------------------------------------===// 101 // Analysis Hooks 102 //===--------------------------------------------------------------------===// 103 104 bool isLegalToInline(Operation *call, Operation *callable, 105 bool wouldBeCloned) const final { 106 // Don't allow inlining calls that are marked `noinline`. 107 return !call->hasAttr("noinline"); 108 } 109 bool isLegalToInline(Region *, Region *, bool, 110 BlockAndValueMapping &) const final { 111 // Inlining into test dialect regions is legal. 112 return true; 113 } 114 bool isLegalToInline(Operation *, Region *, bool, 115 BlockAndValueMapping &) const final { 116 return true; 117 } 118 119 bool shouldAnalyzeRecursively(Operation *op) const final { 120 // Analyze recursively if this is not a functional region operation, it 121 // froms a separate functional scope. 122 return !isa<FunctionalRegionOp>(op); 123 } 124 125 //===--------------------------------------------------------------------===// 126 // Transformation Hooks 127 //===--------------------------------------------------------------------===// 128 129 /// Handle the given inlined terminator by replacing it with a new operation 130 /// as necessary. 131 void handleTerminator(Operation *op, 132 ArrayRef<Value> valuesToRepl) const final { 133 // Only handle "test.return" here. 134 auto returnOp = dyn_cast<TestReturnOp>(op); 135 if (!returnOp) 136 return; 137 138 // Replace the values directly with the return operands. 139 assert(returnOp.getNumOperands() == valuesToRepl.size()); 140 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 141 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 142 } 143 144 /// Attempt to materialize a conversion for a type mismatch between a call 145 /// from this dialect, and a callable region. This method should generate an 146 /// operation that takes 'input' as the only operand, and produces a single 147 /// result of 'resultType'. If a conversion can not be generated, nullptr 148 /// should be returned. 149 Operation *materializeCallConversion(OpBuilder &builder, Value input, 150 Type resultType, 151 Location conversionLoc) const final { 152 // Only allow conversion for i16/i32 types. 153 if (!(resultType.isSignlessInteger(16) || 154 resultType.isSignlessInteger(32)) || 155 !(input.getType().isSignlessInteger(16) || 156 input.getType().isSignlessInteger(32))) 157 return nullptr; 158 return builder.create<TestCastOp>(conversionLoc, resultType, input); 159 } 160 }; 161 } // end anonymous namespace 162 163 //===----------------------------------------------------------------------===// 164 // TestDialect 165 //===----------------------------------------------------------------------===// 166 167 void TestDialect::initialize() { 168 addOperations< 169 #define GET_OP_LIST 170 #include "TestOps.cpp.inc" 171 >(); 172 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 173 TestInlinerInterface>(); 174 addTypes<TestType, TestRecursiveType, 175 #define GET_TYPEDEF_LIST 176 #include "TestTypeDefs.cpp.inc" 177 >(); 178 allowUnknownOperations(); 179 } 180 181 static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser, 182 llvm::SetVector<Type> &stack) { 183 StringRef typeTag; 184 if (failed(parser.parseKeyword(&typeTag))) 185 return Type(); 186 187 auto genType = generatedTypeParser(ctxt, parser, typeTag); 188 if (genType != Type()) 189 return genType; 190 191 if (typeTag == "test_type") 192 return TestType::get(parser.getBuilder().getContext()); 193 194 if (typeTag != "test_rec") 195 return Type(); 196 197 StringRef name; 198 if (parser.parseLess() || parser.parseKeyword(&name)) 199 return Type(); 200 auto rec = TestRecursiveType::get(parser.getBuilder().getContext(), name); 201 202 // If this type already has been parsed above in the stack, expect just the 203 // name. 204 if (stack.contains(rec)) { 205 if (failed(parser.parseGreater())) 206 return Type(); 207 return rec; 208 } 209 210 // Otherwise, parse the body and update the type. 211 if (failed(parser.parseComma())) 212 return Type(); 213 stack.insert(rec); 214 Type subtype = parseTestType(ctxt, parser, stack); 215 stack.pop_back(); 216 if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype))) 217 return Type(); 218 219 return rec; 220 } 221 222 Type TestDialect::parseType(DialectAsmParser &parser) const { 223 llvm::SetVector<Type> stack; 224 return parseTestType(getContext(), parser, stack); 225 } 226 227 static void printTestType(Type type, DialectAsmPrinter &printer, 228 llvm::SetVector<Type> &stack) { 229 if (succeeded(generatedTypePrinter(type, printer))) 230 return; 231 if (type.isa<TestType>()) { 232 printer << "test_type"; 233 return; 234 } 235 236 auto rec = type.cast<TestRecursiveType>(); 237 printer << "test_rec<" << rec.getName(); 238 if (!stack.contains(rec)) { 239 printer << ", "; 240 stack.insert(rec); 241 printTestType(rec.getBody(), printer, stack); 242 stack.pop_back(); 243 } 244 printer << ">"; 245 } 246 247 void TestDialect::printType(Type type, DialectAsmPrinter &printer) const { 248 llvm::SetVector<Type> stack; 249 printTestType(type, printer, stack); 250 } 251 252 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 253 NamedAttribute namedAttr) { 254 if (namedAttr.first == "test.invalid_attr") 255 return op->emitError() << "invalid to use 'test.invalid_attr'"; 256 return success(); 257 } 258 259 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 260 unsigned regionIndex, 261 unsigned argIndex, 262 NamedAttribute namedAttr) { 263 if (namedAttr.first == "test.invalid_attr") 264 return op->emitError() << "invalid to use 'test.invalid_attr'"; 265 return success(); 266 } 267 268 LogicalResult 269 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 270 unsigned resultIndex, 271 NamedAttribute namedAttr) { 272 if (namedAttr.first == "test.invalid_attr") 273 return op->emitError() << "invalid to use 'test.invalid_attr'"; 274 return success(); 275 } 276 277 //===----------------------------------------------------------------------===// 278 // TestBranchOp 279 //===----------------------------------------------------------------------===// 280 281 Optional<MutableOperandRange> 282 TestBranchOp::getMutableSuccessorOperands(unsigned index) { 283 assert(index == 0 && "invalid successor index"); 284 return targetOperandsMutable(); 285 } 286 287 //===----------------------------------------------------------------------===// 288 // TestFoldToCallOp 289 //===----------------------------------------------------------------------===// 290 291 namespace { 292 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 293 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 294 295 LogicalResult matchAndRewrite(FoldToCallOp op, 296 PatternRewriter &rewriter) const override { 297 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(), 298 ValueRange()); 299 return success(); 300 } 301 }; 302 } // end anonymous namespace 303 304 void FoldToCallOp::getCanonicalizationPatterns( 305 OwningRewritePatternList &results, MLIRContext *context) { 306 results.insert<FoldToCallOpPattern>(context); 307 } 308 309 //===----------------------------------------------------------------------===// 310 // Test Format* operations 311 //===----------------------------------------------------------------------===// 312 313 //===----------------------------------------------------------------------===// 314 // Parsing 315 316 static ParseResult parseCustomDirectiveOperands( 317 OpAsmParser &parser, OpAsmParser::OperandType &operand, 318 Optional<OpAsmParser::OperandType> &optOperand, 319 SmallVectorImpl<OpAsmParser::OperandType> &varOperands) { 320 if (parser.parseOperand(operand)) 321 return failure(); 322 if (succeeded(parser.parseOptionalComma())) { 323 optOperand.emplace(); 324 if (parser.parseOperand(*optOperand)) 325 return failure(); 326 } 327 if (parser.parseArrow() || parser.parseLParen() || 328 parser.parseOperandList(varOperands) || parser.parseRParen()) 329 return failure(); 330 return success(); 331 } 332 static ParseResult 333 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 334 Type &optOperandType, 335 SmallVectorImpl<Type> &varOperandTypes) { 336 if (parser.parseColon()) 337 return failure(); 338 339 if (parser.parseType(operandType)) 340 return failure(); 341 if (succeeded(parser.parseOptionalComma())) { 342 if (parser.parseType(optOperandType)) 343 return failure(); 344 } 345 if (parser.parseArrow() || parser.parseLParen() || 346 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 347 return failure(); 348 return success(); 349 } 350 static ParseResult 351 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 352 Type optOperandType, 353 const SmallVectorImpl<Type> &varOperandTypes) { 354 if (parser.parseKeyword("type_refs_capture")) 355 return failure(); 356 357 Type operandType2, optOperandType2; 358 SmallVector<Type, 1> varOperandTypes2; 359 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 360 varOperandTypes2)) 361 return failure(); 362 363 if (operandType != operandType2 || optOperandType != optOperandType2 || 364 varOperandTypes != varOperandTypes2) 365 return failure(); 366 367 return success(); 368 } 369 static ParseResult parseCustomDirectiveOperandsAndTypes( 370 OpAsmParser &parser, OpAsmParser::OperandType &operand, 371 Optional<OpAsmParser::OperandType> &optOperand, 372 SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType, 373 Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) { 374 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 375 parseCustomDirectiveResults(parser, operandType, optOperandType, 376 varOperandTypes)) 377 return failure(); 378 return success(); 379 } 380 static ParseResult parseCustomDirectiveRegions( 381 OpAsmParser &parser, Region ®ion, 382 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 383 if (parser.parseRegion(region)) 384 return failure(); 385 if (failed(parser.parseOptionalComma())) 386 return success(); 387 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 388 if (parser.parseRegion(*varRegion)) 389 return failure(); 390 varRegions.emplace_back(std::move(varRegion)); 391 return success(); 392 } 393 static ParseResult 394 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 395 SmallVectorImpl<Block *> &varSuccessors) { 396 if (parser.parseSuccessor(successor)) 397 return failure(); 398 if (failed(parser.parseOptionalComma())) 399 return success(); 400 Block *varSuccessor; 401 if (parser.parseSuccessor(varSuccessor)) 402 return failure(); 403 varSuccessors.append(2, varSuccessor); 404 return success(); 405 } 406 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 407 IntegerAttr &attr, 408 IntegerAttr &optAttr) { 409 if (parser.parseAttribute(attr)) 410 return failure(); 411 if (succeeded(parser.parseOptionalComma())) { 412 if (parser.parseAttribute(optAttr)) 413 return failure(); 414 } 415 return success(); 416 } 417 418 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, 419 NamedAttrList &attrs) { 420 return parser.parseOptionalAttrDict(attrs); 421 } 422 423 //===----------------------------------------------------------------------===// 424 // Printing 425 426 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 427 Value operand, Value optOperand, 428 OperandRange varOperands) { 429 printer << operand; 430 if (optOperand) 431 printer << ", " << optOperand; 432 printer << " -> (" << varOperands << ")"; 433 } 434 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 435 Type operandType, Type optOperandType, 436 TypeRange varOperandTypes) { 437 printer << " : " << operandType; 438 if (optOperandType) 439 printer << ", " << optOperandType; 440 printer << " -> (" << varOperandTypes << ")"; 441 } 442 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 443 Operation *op, Type operandType, 444 Type optOperandType, 445 TypeRange varOperandTypes) { 446 printer << " type_refs_capture "; 447 printCustomDirectiveResults(printer, op, operandType, optOperandType, 448 varOperandTypes); 449 } 450 static void printCustomDirectiveOperandsAndTypes( 451 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 452 OperandRange varOperands, Type operandType, Type optOperandType, 453 TypeRange varOperandTypes) { 454 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 455 printCustomDirectiveResults(printer, op, operandType, optOperandType, 456 varOperandTypes); 457 } 458 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 459 Region ®ion, 460 MutableArrayRef<Region> varRegions) { 461 printer.printRegion(region); 462 if (!varRegions.empty()) { 463 printer << ", "; 464 for (Region ®ion : varRegions) 465 printer.printRegion(region); 466 } 467 } 468 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 469 Block *successor, 470 SuccessorRange varSuccessors) { 471 printer << successor; 472 if (!varSuccessors.empty()) 473 printer << ", " << varSuccessors.front(); 474 } 475 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 476 Attribute attribute, 477 Attribute optAttribute) { 478 printer << attribute; 479 if (optAttribute) 480 printer << ", " << optAttribute; 481 } 482 483 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 484 MutableDictionaryAttr attrs) { 485 printer.printOptionalAttrDict(attrs.getAttrs()); 486 } 487 //===----------------------------------------------------------------------===// 488 // Test IsolatedRegionOp - parse passthrough region arguments. 489 //===----------------------------------------------------------------------===// 490 491 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, 492 OperationState &result) { 493 OpAsmParser::OperandType argInfo; 494 Type argType = parser.getBuilder().getIndexType(); 495 496 // Parse the input operand. 497 if (parser.parseOperand(argInfo) || 498 parser.resolveOperand(argInfo, argType, result.operands)) 499 return failure(); 500 501 // Parse the body region, and reuse the operand info as the argument info. 502 Region *body = result.addRegion(); 503 return parser.parseRegion(*body, argInfo, argType, 504 /*enableNameShadowing=*/true); 505 } 506 507 static void print(OpAsmPrinter &p, IsolatedRegionOp op) { 508 p << "test.isolated_region "; 509 p.printOperand(op.getOperand()); 510 p.shadowRegionArgs(op.region(), op.getOperand()); 511 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 512 } 513 514 //===----------------------------------------------------------------------===// 515 // Test SSACFGRegionOp 516 //===----------------------------------------------------------------------===// 517 518 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 519 return RegionKind::SSACFG; 520 } 521 522 //===----------------------------------------------------------------------===// 523 // Test GraphRegionOp 524 //===----------------------------------------------------------------------===// 525 526 static ParseResult parseGraphRegionOp(OpAsmParser &parser, 527 OperationState &result) { 528 // Parse the body region, and reuse the operand info as the argument info. 529 Region *body = result.addRegion(); 530 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 531 } 532 533 static void print(OpAsmPrinter &p, GraphRegionOp op) { 534 p << "test.graph_region "; 535 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 536 } 537 538 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 539 return RegionKind::Graph; 540 } 541 542 //===----------------------------------------------------------------------===// 543 // Test AffineScopeOp 544 //===----------------------------------------------------------------------===// 545 546 static ParseResult parseAffineScopeOp(OpAsmParser &parser, 547 OperationState &result) { 548 // Parse the body region, and reuse the operand info as the argument info. 549 Region *body = result.addRegion(); 550 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 551 } 552 553 static void print(OpAsmPrinter &p, AffineScopeOp op) { 554 p << "test.affine_scope "; 555 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 556 } 557 558 //===----------------------------------------------------------------------===// 559 // Test parser. 560 //===----------------------------------------------------------------------===// 561 562 static ParseResult parseWrappedKeywordOp(OpAsmParser &parser, 563 OperationState &result) { 564 StringRef keyword; 565 if (parser.parseKeyword(&keyword)) 566 return failure(); 567 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 568 return success(); 569 } 570 571 static void print(OpAsmPrinter &p, WrappedKeywordOp op) { 572 p << WrappedKeywordOp::getOperationName() << " " << op.keyword(); 573 } 574 575 //===----------------------------------------------------------------------===// 576 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 577 578 static ParseResult parseWrappingRegionOp(OpAsmParser &parser, 579 OperationState &result) { 580 if (parser.parseKeyword("wraps")) 581 return failure(); 582 583 // Parse the wrapped op in a region 584 Region &body = *result.addRegion(); 585 body.push_back(new Block); 586 Block &block = body.back(); 587 Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); 588 if (!wrapped_op) 589 return failure(); 590 591 // Create a return terminator in the inner region, pass as operand to the 592 // terminator the returned values from the wrapped operation. 593 SmallVector<Value, 8> return_operands(wrapped_op->getResults()); 594 OpBuilder builder(parser.getBuilder().getContext()); 595 builder.setInsertionPointToEnd(&block); 596 builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands); 597 598 // Get the results type for the wrapping op from the terminator operands. 599 Operation &return_op = body.back().back(); 600 result.types.append(return_op.operand_type_begin(), 601 return_op.operand_type_end()); 602 603 // Use the location of the wrapped op for the "test.wrapping_region" op. 604 result.location = wrapped_op->getLoc(); 605 606 return success(); 607 } 608 609 static void print(OpAsmPrinter &p, WrappingRegionOp op) { 610 p << op.getOperationName() << " wraps "; 611 p.printGenericOp(&op.region().front().front()); 612 } 613 614 //===----------------------------------------------------------------------===// 615 // Test PolyForOp - parse list of region arguments. 616 //===----------------------------------------------------------------------===// 617 618 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { 619 SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 620 // Parse list of region arguments without a delimiter. 621 if (parser.parseRegionArgumentList(ivsInfo)) 622 return failure(); 623 624 // Parse the body region. 625 Region *body = result.addRegion(); 626 auto &builder = parser.getBuilder(); 627 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 628 return parser.parseRegion(*body, ivsInfo, argTypes); 629 } 630 631 //===----------------------------------------------------------------------===// 632 // Test removing op with inner ops. 633 //===----------------------------------------------------------------------===// 634 635 namespace { 636 struct TestRemoveOpWithInnerOps 637 : public OpRewritePattern<TestOpWithRegionPattern> { 638 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 639 640 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 641 PatternRewriter &rewriter) const override { 642 rewriter.eraseOp(op); 643 return success(); 644 } 645 }; 646 } // end anonymous namespace 647 648 void TestOpWithRegionPattern::getCanonicalizationPatterns( 649 OwningRewritePatternList &results, MLIRContext *context) { 650 results.insert<TestRemoveOpWithInnerOps>(context); 651 } 652 653 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 654 return operand(); 655 } 656 657 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) { 658 return getValue(); 659 } 660 661 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 662 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 663 for (Value input : this->operands()) { 664 results.push_back(input); 665 } 666 return success(); 667 } 668 669 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 670 assert(operands.size() == 1); 671 if (operands.front()) { 672 setAttr("attr", operands.front()); 673 return getResult(); 674 } 675 return {}; 676 } 677 678 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 679 MLIRContext *, Optional<Location> location, ValueRange operands, 680 DictionaryAttr attributes, RegionRange regions, 681 SmallVectorImpl<Type> &inferredReturnTypes) { 682 if (operands[0].getType() != operands[1].getType()) { 683 return emitOptionalError(location, "operand type mismatch ", 684 operands[0].getType(), " vs ", 685 operands[1].getType()); 686 } 687 inferredReturnTypes.assign({operands[0].getType()}); 688 return success(); 689 } 690 691 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 692 MLIRContext *context, Optional<Location> location, ValueRange operands, 693 DictionaryAttr attributes, RegionRange regions, 694 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 695 // Create return type consisting of the last element of the first operand. 696 auto operandType = *operands.getTypes().begin(); 697 auto sval = operandType.dyn_cast<ShapedType>(); 698 if (!sval) { 699 return emitOptionalError(location, "only shaped type operands allowed"); 700 } 701 int64_t dim = 702 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 703 auto type = IntegerType::get(17, context); 704 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 705 return success(); 706 } 707 708 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 709 OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) { 710 shapes = SmallVector<Value, 1>{ 711 builder.createOrFold<DimOp>(getLoc(), getOperand(0), 0)}; 712 return success(); 713 } 714 715 //===----------------------------------------------------------------------===// 716 // Test SideEffect interfaces 717 //===----------------------------------------------------------------------===// 718 719 namespace { 720 /// A test resource for side effects. 721 struct TestResource : public SideEffects::Resource::Base<TestResource> { 722 StringRef getName() final { return "<Test>"; } 723 }; 724 } // end anonymous namespace 725 726 void SideEffectOp::getEffects( 727 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 728 // Check for an effects attribute on the op instance. 729 ArrayAttr effectsAttr = getAttrOfType<ArrayAttr>("effects"); 730 if (!effectsAttr) 731 return; 732 733 // If there is one, it is an array of dictionary attributes that hold 734 // information on the effects of this operation. 735 for (Attribute element : effectsAttr) { 736 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 737 738 // Get the specific memory effect. 739 MemoryEffects::Effect *effect = 740 StringSwitch<MemoryEffects::Effect *>( 741 effectElement.get("effect").cast<StringAttr>().getValue()) 742 .Case("allocate", MemoryEffects::Allocate::get()) 743 .Case("free", MemoryEffects::Free::get()) 744 .Case("read", MemoryEffects::Read::get()) 745 .Case("write", MemoryEffects::Write::get()); 746 747 // Check for a non-default resource to use. 748 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 749 if (effectElement.get("test_resource")) 750 resource = TestResource::get(); 751 752 // Check for a result to affect. 753 if (effectElement.get("on_result")) 754 effects.emplace_back(effect, getResult(), resource); 755 else if (Attribute ref = effectElement.get("on_reference")) 756 effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource); 757 else 758 effects.emplace_back(effect, resource); 759 } 760 } 761 762 void SideEffectOp::getEffects( 763 SmallVectorImpl<TestEffects::EffectInstance> &effects) { 764 auto effectsAttr = getAttrOfType<AffineMapAttr>("effect_parameter"); 765 if (!effectsAttr) 766 return; 767 768 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); 769 } 770 771 //===----------------------------------------------------------------------===// 772 // StringAttrPrettyNameOp 773 //===----------------------------------------------------------------------===// 774 775 // This op has fancy handling of its SSA result name. 776 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, 777 OperationState &result) { 778 // Add the result types. 779 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 780 result.addTypes(parser.getBuilder().getIntegerType(32)); 781 782 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 783 return failure(); 784 785 // If the attribute dictionary contains no 'names' attribute, infer it from 786 // the SSA name (if specified). 787 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 788 return attr.first == "names"; 789 }); 790 791 // If there was no name specified, check to see if there was a useful name 792 // specified in the asm file. 793 if (hadNames || parser.getNumResults() == 0) 794 return success(); 795 796 SmallVector<StringRef, 4> names; 797 auto *context = result.getContext(); 798 799 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 800 auto resultName = parser.getResultName(i); 801 StringRef nameStr; 802 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 803 nameStr = resultName.first; 804 805 names.push_back(nameStr); 806 } 807 808 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 809 result.attributes.push_back({Identifier::get("names", context), namesAttr}); 810 return success(); 811 } 812 813 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { 814 p << "test.string_attr_pretty_name"; 815 816 // Note that we only need to print the "name" attribute if the asmprinter 817 // result name disagrees with it. This can happen in strange cases, e.g. 818 // when there are conflicts. 819 bool namesDisagree = op.names().size() != op.getNumResults(); 820 821 SmallString<32> resultNameStr; 822 for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { 823 resultNameStr.clear(); 824 llvm::raw_svector_ostream tmpStream(resultNameStr); 825 p.printOperand(op.getResult(i), tmpStream); 826 827 auto expectedName = op.names()[i].dyn_cast<StringAttr>(); 828 if (!expectedName || 829 tmpStream.str().drop_front() != expectedName.getValue()) { 830 namesDisagree = true; 831 } 832 } 833 834 if (namesDisagree) 835 p.printOptionalAttrDictWithKeyword(op.getAttrs()); 836 else 837 p.printOptionalAttrDictWithKeyword(op.getAttrs(), {"names"}); 838 } 839 840 // We set the SSA name in the asm syntax to the contents of the name 841 // attribute. 842 void StringAttrPrettyNameOp::getAsmResultNames( 843 function_ref<void(Value, StringRef)> setNameFn) { 844 845 auto value = names(); 846 for (size_t i = 0, e = value.size(); i != e; ++i) 847 if (auto str = value[i].dyn_cast<StringAttr>()) 848 if (!str.getValue().empty()) 849 setNameFn(getResult(i), str.getValue()); 850 } 851 852 //===----------------------------------------------------------------------===// 853 // RegionIfOp 854 //===----------------------------------------------------------------------===// 855 856 static void print(OpAsmPrinter &p, RegionIfOp op) { 857 p << RegionIfOp::getOperationName() << " "; 858 p.printOperands(op.getOperands()); 859 p << ": " << op.getOperandTypes(); 860 p.printArrowTypeList(op.getResultTypes()); 861 p << " then"; 862 p.printRegion(op.thenRegion(), 863 /*printEntryBlockArgs=*/true, 864 /*printBlockTerminators=*/true); 865 p << " else"; 866 p.printRegion(op.elseRegion(), 867 /*printEntryBlockArgs=*/true, 868 /*printBlockTerminators=*/true); 869 p << " join"; 870 p.printRegion(op.joinRegion(), 871 /*printEntryBlockArgs=*/true, 872 /*printBlockTerminators=*/true); 873 } 874 875 static ParseResult parseRegionIfOp(OpAsmParser &parser, 876 OperationState &result) { 877 SmallVector<OpAsmParser::OperandType, 2> operandInfos; 878 SmallVector<Type, 2> operandTypes; 879 880 result.regions.reserve(3); 881 Region *thenRegion = result.addRegion(); 882 Region *elseRegion = result.addRegion(); 883 Region *joinRegion = result.addRegion(); 884 885 // Parse operand, type and arrow type lists. 886 if (parser.parseOperandList(operandInfos) || 887 parser.parseColonTypeList(operandTypes) || 888 parser.parseArrowTypeList(result.types)) 889 return failure(); 890 891 // Parse all attached regions. 892 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 893 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 894 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 895 return failure(); 896 897 return parser.resolveOperands(operandInfos, operandTypes, 898 parser.getCurrentLocation(), result.operands); 899 } 900 901 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 902 assert(index < 2 && "invalid region index"); 903 return getOperands(); 904 } 905 906 void RegionIfOp::getSuccessorRegions( 907 Optional<unsigned> index, ArrayRef<Attribute> operands, 908 SmallVectorImpl<RegionSuccessor> ®ions) { 909 // We always branch to the join region. 910 if (index.hasValue()) { 911 if (index.getValue() < 2) 912 regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs())); 913 else 914 regions.push_back(RegionSuccessor(getResults())); 915 return; 916 } 917 918 // The then and else regions are the entry regions of this op. 919 regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs())); 920 regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs())); 921 } 922 923 #include "TestOpEnums.cpp.inc" 924 #include "TestOpInterfaces.cpp.inc" 925 #include "TestOpStructs.cpp.inc" 926 #include "TestTypeInterfaces.cpp.inc" 927 928 #define GET_OP_CLASSES 929 #include "TestOps.cpp.inc" 930