1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestTypes.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/IR/DialectImplementation.h" 13 #include "mlir/IR/Function.h" 14 #include "mlir/IR/Module.h" 15 #include "mlir/IR/PatternMatch.h" 16 #include "mlir/IR/TypeUtilities.h" 17 #include "mlir/Transforms/FoldUtils.h" 18 #include "mlir/Transforms/InliningUtils.h" 19 #include "llvm/ADT/SetVector.h" 20 #include "llvm/ADT/StringSwitch.h" 21 22 using namespace mlir; 23 using namespace mlir::test; 24 25 void mlir::test::registerTestDialect(DialectRegistry ®istry) { 26 registry.insert<TestDialect>(); 27 } 28 29 //===----------------------------------------------------------------------===// 30 // TestDialect Interfaces 31 //===----------------------------------------------------------------------===// 32 33 namespace { 34 35 // Test support for interacting with the AsmPrinter. 36 struct TestOpAsmInterface : public OpAsmDialectInterface { 37 using OpAsmDialectInterface::OpAsmDialectInterface; 38 39 LogicalResult getAlias(Attribute attr, raw_ostream &os) const final { 40 StringAttr strAttr = attr.dyn_cast<StringAttr>(); 41 if (!strAttr) 42 return failure(); 43 44 // Check the contents of the string attribute to see what the test alias 45 // should be named. 46 Optional<StringRef> aliasName = 47 StringSwitch<Optional<StringRef>>(strAttr.getValue()) 48 .Case("alias_test:dot_in_name", StringRef("test.alias")) 49 .Case("alias_test:trailing_digit", StringRef("test_alias0")) 50 .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) 51 .Case("alias_test:sanitize_conflict_a", 52 StringRef("test_alias_conflict0")) 53 .Case("alias_test:sanitize_conflict_b", 54 StringRef("test_alias_conflict0_")) 55 .Default(llvm::None); 56 if (!aliasName) 57 return failure(); 58 59 os << *aliasName; 60 return success(); 61 } 62 63 void getAsmResultNames(Operation *op, 64 OpAsmSetValueNameFn setNameFn) const final { 65 if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) 66 setNameFn(asmOp, "result"); 67 } 68 69 void getAsmBlockArgumentNames(Block *block, 70 OpAsmSetValueNameFn setNameFn) const final { 71 auto op = block->getParentOp(); 72 auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names"); 73 if (!arrayAttr) 74 return; 75 auto args = block->getArguments(); 76 auto e = std::min(arrayAttr.size(), args.size()); 77 for (unsigned i = 0; i < e; ++i) { 78 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 79 setNameFn(args[i], strAttr.getValue()); 80 } 81 } 82 }; 83 84 struct TestDialectFoldInterface : public DialectFoldInterface { 85 using DialectFoldInterface::DialectFoldInterface; 86 87 /// Registered hook to check if the given region, which is attached to an 88 /// operation that is *not* isolated from above, should be used when 89 /// materializing constants. 90 bool shouldMaterializeInto(Region *region) const final { 91 // If this is a one region operation, then insert into it. 92 return isa<OneRegionOp>(region->getParentOp()); 93 } 94 }; 95 96 /// This class defines the interface for handling inlining with standard 97 /// operations. 98 struct TestInlinerInterface : public DialectInlinerInterface { 99 using DialectInlinerInterface::DialectInlinerInterface; 100 101 //===--------------------------------------------------------------------===// 102 // Analysis Hooks 103 //===--------------------------------------------------------------------===// 104 105 bool isLegalToInline(Operation *call, Operation *callable, 106 bool wouldBeCloned) const final { 107 // Don't allow inlining calls that are marked `noinline`. 108 return !call->hasAttr("noinline"); 109 } 110 bool isLegalToInline(Region *, Region *, bool, 111 BlockAndValueMapping &) const final { 112 // Inlining into test dialect regions is legal. 113 return true; 114 } 115 bool isLegalToInline(Operation *, Region *, bool, 116 BlockAndValueMapping &) const final { 117 return true; 118 } 119 120 bool shouldAnalyzeRecursively(Operation *op) const final { 121 // Analyze recursively if this is not a functional region operation, it 122 // froms a separate functional scope. 123 return !isa<FunctionalRegionOp>(op); 124 } 125 126 //===--------------------------------------------------------------------===// 127 // Transformation Hooks 128 //===--------------------------------------------------------------------===// 129 130 /// Handle the given inlined terminator by replacing it with a new operation 131 /// as necessary. 132 void handleTerminator(Operation *op, 133 ArrayRef<Value> valuesToRepl) const final { 134 // Only handle "test.return" here. 135 auto returnOp = dyn_cast<TestReturnOp>(op); 136 if (!returnOp) 137 return; 138 139 // Replace the values directly with the return operands. 140 assert(returnOp.getNumOperands() == valuesToRepl.size()); 141 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 142 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 143 } 144 145 /// Attempt to materialize a conversion for a type mismatch between a call 146 /// from this dialect, and a callable region. This method should generate an 147 /// operation that takes 'input' as the only operand, and produces a single 148 /// result of 'resultType'. If a conversion can not be generated, nullptr 149 /// should be returned. 150 Operation *materializeCallConversion(OpBuilder &builder, Value input, 151 Type resultType, 152 Location conversionLoc) const final { 153 // Only allow conversion for i16/i32 types. 154 if (!(resultType.isSignlessInteger(16) || 155 resultType.isSignlessInteger(32)) || 156 !(input.getType().isSignlessInteger(16) || 157 input.getType().isSignlessInteger(32))) 158 return nullptr; 159 return builder.create<TestCastOp>(conversionLoc, resultType, input); 160 } 161 }; 162 } // end anonymous namespace 163 164 //===----------------------------------------------------------------------===// 165 // TestDialect 166 //===----------------------------------------------------------------------===// 167 168 void TestDialect::initialize() { 169 addOperations< 170 #define GET_OP_LIST 171 #include "TestOps.cpp.inc" 172 >(); 173 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 174 TestInlinerInterface>(); 175 addTypes<TestType, TestRecursiveType, 176 #define GET_TYPEDEF_LIST 177 #include "TestTypeDefs.cpp.inc" 178 >(); 179 allowUnknownOperations(); 180 } 181 182 static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser, 183 llvm::SetVector<Type> &stack) { 184 StringRef typeTag; 185 if (failed(parser.parseKeyword(&typeTag))) 186 return Type(); 187 188 auto genType = generatedTypeParser(ctxt, parser, typeTag); 189 if (genType != Type()) 190 return genType; 191 192 if (typeTag == "test_type") 193 return TestType::get(parser.getBuilder().getContext()); 194 195 if (typeTag != "test_rec") 196 return Type(); 197 198 StringRef name; 199 if (parser.parseLess() || parser.parseKeyword(&name)) 200 return Type(); 201 auto rec = TestRecursiveType::get(parser.getBuilder().getContext(), name); 202 203 // If this type already has been parsed above in the stack, expect just the 204 // name. 205 if (stack.contains(rec)) { 206 if (failed(parser.parseGreater())) 207 return Type(); 208 return rec; 209 } 210 211 // Otherwise, parse the body and update the type. 212 if (failed(parser.parseComma())) 213 return Type(); 214 stack.insert(rec); 215 Type subtype = parseTestType(ctxt, parser, stack); 216 stack.pop_back(); 217 if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype))) 218 return Type(); 219 220 return rec; 221 } 222 223 Type TestDialect::parseType(DialectAsmParser &parser) const { 224 llvm::SetVector<Type> stack; 225 return parseTestType(getContext(), parser, stack); 226 } 227 228 static void printTestType(Type type, DialectAsmPrinter &printer, 229 llvm::SetVector<Type> &stack) { 230 if (succeeded(generatedTypePrinter(type, printer))) 231 return; 232 if (type.isa<TestType>()) { 233 printer << "test_type"; 234 return; 235 } 236 237 auto rec = type.cast<TestRecursiveType>(); 238 printer << "test_rec<" << rec.getName(); 239 if (!stack.contains(rec)) { 240 printer << ", "; 241 stack.insert(rec); 242 printTestType(rec.getBody(), printer, stack); 243 stack.pop_back(); 244 } 245 printer << ">"; 246 } 247 248 void TestDialect::printType(Type type, DialectAsmPrinter &printer) const { 249 llvm::SetVector<Type> stack; 250 printTestType(type, printer, stack); 251 } 252 253 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 254 NamedAttribute namedAttr) { 255 if (namedAttr.first == "test.invalid_attr") 256 return op->emitError() << "invalid to use 'test.invalid_attr'"; 257 return success(); 258 } 259 260 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 261 unsigned regionIndex, 262 unsigned argIndex, 263 NamedAttribute namedAttr) { 264 if (namedAttr.first == "test.invalid_attr") 265 return op->emitError() << "invalid to use 'test.invalid_attr'"; 266 return success(); 267 } 268 269 LogicalResult 270 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 271 unsigned resultIndex, 272 NamedAttribute namedAttr) { 273 if (namedAttr.first == "test.invalid_attr") 274 return op->emitError() << "invalid to use 'test.invalid_attr'"; 275 return success(); 276 } 277 278 //===----------------------------------------------------------------------===// 279 // TestBranchOp 280 //===----------------------------------------------------------------------===// 281 282 Optional<MutableOperandRange> 283 TestBranchOp::getMutableSuccessorOperands(unsigned index) { 284 assert(index == 0 && "invalid successor index"); 285 return targetOperandsMutable(); 286 } 287 288 //===----------------------------------------------------------------------===// 289 // TestFoldToCallOp 290 //===----------------------------------------------------------------------===// 291 292 namespace { 293 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 294 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 295 296 LogicalResult matchAndRewrite(FoldToCallOp op, 297 PatternRewriter &rewriter) const override { 298 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(), 299 ValueRange()); 300 return success(); 301 } 302 }; 303 } // end anonymous namespace 304 305 void FoldToCallOp::getCanonicalizationPatterns( 306 OwningRewritePatternList &results, MLIRContext *context) { 307 results.insert<FoldToCallOpPattern>(context); 308 } 309 310 //===----------------------------------------------------------------------===// 311 // Test Format* operations 312 //===----------------------------------------------------------------------===// 313 314 //===----------------------------------------------------------------------===// 315 // Parsing 316 317 static ParseResult parseCustomDirectiveOperands( 318 OpAsmParser &parser, OpAsmParser::OperandType &operand, 319 Optional<OpAsmParser::OperandType> &optOperand, 320 SmallVectorImpl<OpAsmParser::OperandType> &varOperands) { 321 if (parser.parseOperand(operand)) 322 return failure(); 323 if (succeeded(parser.parseOptionalComma())) { 324 optOperand.emplace(); 325 if (parser.parseOperand(*optOperand)) 326 return failure(); 327 } 328 if (parser.parseArrow() || parser.parseLParen() || 329 parser.parseOperandList(varOperands) || parser.parseRParen()) 330 return failure(); 331 return success(); 332 } 333 static ParseResult 334 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 335 Type &optOperandType, 336 SmallVectorImpl<Type> &varOperandTypes) { 337 if (parser.parseColon()) 338 return failure(); 339 340 if (parser.parseType(operandType)) 341 return failure(); 342 if (succeeded(parser.parseOptionalComma())) { 343 if (parser.parseType(optOperandType)) 344 return failure(); 345 } 346 if (parser.parseArrow() || parser.parseLParen() || 347 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 348 return failure(); 349 return success(); 350 } 351 static ParseResult 352 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 353 Type optOperandType, 354 const SmallVectorImpl<Type> &varOperandTypes) { 355 if (parser.parseKeyword("type_refs_capture")) 356 return failure(); 357 358 Type operandType2, optOperandType2; 359 SmallVector<Type, 1> varOperandTypes2; 360 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 361 varOperandTypes2)) 362 return failure(); 363 364 if (operandType != operandType2 || optOperandType != optOperandType2 || 365 varOperandTypes != varOperandTypes2) 366 return failure(); 367 368 return success(); 369 } 370 static ParseResult parseCustomDirectiveOperandsAndTypes( 371 OpAsmParser &parser, OpAsmParser::OperandType &operand, 372 Optional<OpAsmParser::OperandType> &optOperand, 373 SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType, 374 Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) { 375 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 376 parseCustomDirectiveResults(parser, operandType, optOperandType, 377 varOperandTypes)) 378 return failure(); 379 return success(); 380 } 381 static ParseResult parseCustomDirectiveRegions( 382 OpAsmParser &parser, Region ®ion, 383 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 384 if (parser.parseRegion(region)) 385 return failure(); 386 if (failed(parser.parseOptionalComma())) 387 return success(); 388 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 389 if (parser.parseRegion(*varRegion)) 390 return failure(); 391 varRegions.emplace_back(std::move(varRegion)); 392 return success(); 393 } 394 static ParseResult 395 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 396 SmallVectorImpl<Block *> &varSuccessors) { 397 if (parser.parseSuccessor(successor)) 398 return failure(); 399 if (failed(parser.parseOptionalComma())) 400 return success(); 401 Block *varSuccessor; 402 if (parser.parseSuccessor(varSuccessor)) 403 return failure(); 404 varSuccessors.append(2, varSuccessor); 405 return success(); 406 } 407 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 408 IntegerAttr &attr, 409 IntegerAttr &optAttr) { 410 if (parser.parseAttribute(attr)) 411 return failure(); 412 if (succeeded(parser.parseOptionalComma())) { 413 if (parser.parseAttribute(optAttr)) 414 return failure(); 415 } 416 return success(); 417 } 418 419 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, 420 NamedAttrList &attrs) { 421 return parser.parseOptionalAttrDict(attrs); 422 } 423 424 //===----------------------------------------------------------------------===// 425 // Printing 426 427 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 428 Value operand, Value optOperand, 429 OperandRange varOperands) { 430 printer << operand; 431 if (optOperand) 432 printer << ", " << optOperand; 433 printer << " -> (" << varOperands << ")"; 434 } 435 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 436 Type operandType, Type optOperandType, 437 TypeRange varOperandTypes) { 438 printer << " : " << operandType; 439 if (optOperandType) 440 printer << ", " << optOperandType; 441 printer << " -> (" << varOperandTypes << ")"; 442 } 443 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 444 Operation *op, Type operandType, 445 Type optOperandType, 446 TypeRange varOperandTypes) { 447 printer << " type_refs_capture "; 448 printCustomDirectiveResults(printer, op, operandType, optOperandType, 449 varOperandTypes); 450 } 451 static void printCustomDirectiveOperandsAndTypes( 452 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 453 OperandRange varOperands, Type operandType, Type optOperandType, 454 TypeRange varOperandTypes) { 455 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 456 printCustomDirectiveResults(printer, op, operandType, optOperandType, 457 varOperandTypes); 458 } 459 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 460 Region ®ion, 461 MutableArrayRef<Region> varRegions) { 462 printer.printRegion(region); 463 if (!varRegions.empty()) { 464 printer << ", "; 465 for (Region ®ion : varRegions) 466 printer.printRegion(region); 467 } 468 } 469 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 470 Block *successor, 471 SuccessorRange varSuccessors) { 472 printer << successor; 473 if (!varSuccessors.empty()) 474 printer << ", " << varSuccessors.front(); 475 } 476 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 477 Attribute attribute, 478 Attribute optAttribute) { 479 printer << attribute; 480 if (optAttribute) 481 printer << ", " << optAttribute; 482 } 483 484 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 485 MutableDictionaryAttr attrs) { 486 printer.printOptionalAttrDict(attrs.getAttrs()); 487 } 488 //===----------------------------------------------------------------------===// 489 // Test IsolatedRegionOp - parse passthrough region arguments. 490 //===----------------------------------------------------------------------===// 491 492 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, 493 OperationState &result) { 494 OpAsmParser::OperandType argInfo; 495 Type argType = parser.getBuilder().getIndexType(); 496 497 // Parse the input operand. 498 if (parser.parseOperand(argInfo) || 499 parser.resolveOperand(argInfo, argType, result.operands)) 500 return failure(); 501 502 // Parse the body region, and reuse the operand info as the argument info. 503 Region *body = result.addRegion(); 504 return parser.parseRegion(*body, argInfo, argType, 505 /*enableNameShadowing=*/true); 506 } 507 508 static void print(OpAsmPrinter &p, IsolatedRegionOp op) { 509 p << "test.isolated_region "; 510 p.printOperand(op.getOperand()); 511 p.shadowRegionArgs(op.region(), op.getOperand()); 512 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 513 } 514 515 //===----------------------------------------------------------------------===// 516 // Test SSACFGRegionOp 517 //===----------------------------------------------------------------------===// 518 519 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 520 return RegionKind::SSACFG; 521 } 522 523 //===----------------------------------------------------------------------===// 524 // Test GraphRegionOp 525 //===----------------------------------------------------------------------===// 526 527 static ParseResult parseGraphRegionOp(OpAsmParser &parser, 528 OperationState &result) { 529 // Parse the body region, and reuse the operand info as the argument info. 530 Region *body = result.addRegion(); 531 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 532 } 533 534 static void print(OpAsmPrinter &p, GraphRegionOp op) { 535 p << "test.graph_region "; 536 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 537 } 538 539 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 540 return RegionKind::Graph; 541 } 542 543 //===----------------------------------------------------------------------===// 544 // Test AffineScopeOp 545 //===----------------------------------------------------------------------===// 546 547 static ParseResult parseAffineScopeOp(OpAsmParser &parser, 548 OperationState &result) { 549 // Parse the body region, and reuse the operand info as the argument info. 550 Region *body = result.addRegion(); 551 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 552 } 553 554 static void print(OpAsmPrinter &p, AffineScopeOp op) { 555 p << "test.affine_scope "; 556 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 557 } 558 559 //===----------------------------------------------------------------------===// 560 // Test parser. 561 //===----------------------------------------------------------------------===// 562 563 static ParseResult parseWrappedKeywordOp(OpAsmParser &parser, 564 OperationState &result) { 565 StringRef keyword; 566 if (parser.parseKeyword(&keyword)) 567 return failure(); 568 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 569 return success(); 570 } 571 572 static void print(OpAsmPrinter &p, WrappedKeywordOp op) { 573 p << WrappedKeywordOp::getOperationName() << " " << op.keyword(); 574 } 575 576 //===----------------------------------------------------------------------===// 577 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 578 579 static ParseResult parseWrappingRegionOp(OpAsmParser &parser, 580 OperationState &result) { 581 if (parser.parseKeyword("wraps")) 582 return failure(); 583 584 // Parse the wrapped op in a region 585 Region &body = *result.addRegion(); 586 body.push_back(new Block); 587 Block &block = body.back(); 588 Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); 589 if (!wrapped_op) 590 return failure(); 591 592 // Create a return terminator in the inner region, pass as operand to the 593 // terminator the returned values from the wrapped operation. 594 SmallVector<Value, 8> return_operands(wrapped_op->getResults()); 595 OpBuilder builder(parser.getBuilder().getContext()); 596 builder.setInsertionPointToEnd(&block); 597 builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands); 598 599 // Get the results type for the wrapping op from the terminator operands. 600 Operation &return_op = body.back().back(); 601 result.types.append(return_op.operand_type_begin(), 602 return_op.operand_type_end()); 603 604 // Use the location of the wrapped op for the "test.wrapping_region" op. 605 result.location = wrapped_op->getLoc(); 606 607 return success(); 608 } 609 610 static void print(OpAsmPrinter &p, WrappingRegionOp op) { 611 p << op.getOperationName() << " wraps "; 612 p.printGenericOp(&op.region().front().front()); 613 } 614 615 //===----------------------------------------------------------------------===// 616 // Test PolyForOp - parse list of region arguments. 617 //===----------------------------------------------------------------------===// 618 619 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { 620 SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 621 // Parse list of region arguments without a delimiter. 622 if (parser.parseRegionArgumentList(ivsInfo)) 623 return failure(); 624 625 // Parse the body region. 626 Region *body = result.addRegion(); 627 auto &builder = parser.getBuilder(); 628 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 629 return parser.parseRegion(*body, ivsInfo, argTypes); 630 } 631 632 //===----------------------------------------------------------------------===// 633 // Test removing op with inner ops. 634 //===----------------------------------------------------------------------===// 635 636 namespace { 637 struct TestRemoveOpWithInnerOps 638 : public OpRewritePattern<TestOpWithRegionPattern> { 639 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 640 641 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 642 PatternRewriter &rewriter) const override { 643 rewriter.eraseOp(op); 644 return success(); 645 } 646 }; 647 } // end anonymous namespace 648 649 void TestOpWithRegionPattern::getCanonicalizationPatterns( 650 OwningRewritePatternList &results, MLIRContext *context) { 651 results.insert<TestRemoveOpWithInnerOps>(context); 652 } 653 654 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 655 return operand(); 656 } 657 658 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) { 659 return getValue(); 660 } 661 662 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 663 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 664 for (Value input : this->operands()) { 665 results.push_back(input); 666 } 667 return success(); 668 } 669 670 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 671 assert(operands.size() == 1); 672 if (operands.front()) { 673 setAttr("attr", operands.front()); 674 return getResult(); 675 } 676 return {}; 677 } 678 679 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 680 MLIRContext *, Optional<Location> location, ValueRange operands, 681 DictionaryAttr attributes, RegionRange regions, 682 SmallVectorImpl<Type> &inferredReturnTypes) { 683 if (operands[0].getType() != operands[1].getType()) { 684 return emitOptionalError(location, "operand type mismatch ", 685 operands[0].getType(), " vs ", 686 operands[1].getType()); 687 } 688 inferredReturnTypes.assign({operands[0].getType()}); 689 return success(); 690 } 691 692 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 693 MLIRContext *context, Optional<Location> location, ValueRange operands, 694 DictionaryAttr attributes, RegionRange regions, 695 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 696 // Create return type consisting of the last element of the first operand. 697 auto operandType = *operands.getTypes().begin(); 698 auto sval = operandType.dyn_cast<ShapedType>(); 699 if (!sval) { 700 return emitOptionalError(location, "only shaped type operands allowed"); 701 } 702 int64_t dim = 703 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 704 auto type = IntegerType::get(17, context); 705 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 706 return success(); 707 } 708 709 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 710 OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) { 711 shapes = SmallVector<Value, 1>{ 712 builder.createOrFold<DimOp>(getLoc(), getOperand(0), 0)}; 713 return success(); 714 } 715 716 //===----------------------------------------------------------------------===// 717 // Test SideEffect interfaces 718 //===----------------------------------------------------------------------===// 719 720 namespace { 721 /// A test resource for side effects. 722 struct TestResource : public SideEffects::Resource::Base<TestResource> { 723 StringRef getName() final { return "<Test>"; } 724 }; 725 } // end anonymous namespace 726 727 void SideEffectOp::getEffects( 728 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 729 // Check for an effects attribute on the op instance. 730 ArrayAttr effectsAttr = getAttrOfType<ArrayAttr>("effects"); 731 if (!effectsAttr) 732 return; 733 734 // If there is one, it is an array of dictionary attributes that hold 735 // information on the effects of this operation. 736 for (Attribute element : effectsAttr) { 737 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 738 739 // Get the specific memory effect. 740 MemoryEffects::Effect *effect = 741 StringSwitch<MemoryEffects::Effect *>( 742 effectElement.get("effect").cast<StringAttr>().getValue()) 743 .Case("allocate", MemoryEffects::Allocate::get()) 744 .Case("free", MemoryEffects::Free::get()) 745 .Case("read", MemoryEffects::Read::get()) 746 .Case("write", MemoryEffects::Write::get()); 747 748 // Check for a result to affect. 749 Value value; 750 if (effectElement.get("on_result")) 751 value = getResult(); 752 753 // Check for a non-default resource to use. 754 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 755 if (effectElement.get("test_resource")) 756 resource = TestResource::get(); 757 758 effects.emplace_back(effect, value, resource); 759 } 760 } 761 762 //===----------------------------------------------------------------------===// 763 // StringAttrPrettyNameOp 764 //===----------------------------------------------------------------------===// 765 766 // This op has fancy handling of its SSA result name. 767 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, 768 OperationState &result) { 769 // Add the result types. 770 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 771 result.addTypes(parser.getBuilder().getIntegerType(32)); 772 773 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 774 return failure(); 775 776 // If the attribute dictionary contains no 'names' attribute, infer it from 777 // the SSA name (if specified). 778 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 779 return attr.first == "names"; 780 }); 781 782 // If there was no name specified, check to see if there was a useful name 783 // specified in the asm file. 784 if (hadNames || parser.getNumResults() == 0) 785 return success(); 786 787 SmallVector<StringRef, 4> names; 788 auto *context = result.getContext(); 789 790 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 791 auto resultName = parser.getResultName(i); 792 StringRef nameStr; 793 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 794 nameStr = resultName.first; 795 796 names.push_back(nameStr); 797 } 798 799 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 800 result.attributes.push_back({Identifier::get("names", context), namesAttr}); 801 return success(); 802 } 803 804 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { 805 p << "test.string_attr_pretty_name"; 806 807 // Note that we only need to print the "name" attribute if the asmprinter 808 // result name disagrees with it. This can happen in strange cases, e.g. 809 // when there are conflicts. 810 bool namesDisagree = op.names().size() != op.getNumResults(); 811 812 SmallString<32> resultNameStr; 813 for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { 814 resultNameStr.clear(); 815 llvm::raw_svector_ostream tmpStream(resultNameStr); 816 p.printOperand(op.getResult(i), tmpStream); 817 818 auto expectedName = op.names()[i].dyn_cast<StringAttr>(); 819 if (!expectedName || 820 tmpStream.str().drop_front() != expectedName.getValue()) { 821 namesDisagree = true; 822 } 823 } 824 825 if (namesDisagree) 826 p.printOptionalAttrDictWithKeyword(op.getAttrs()); 827 else 828 p.printOptionalAttrDictWithKeyword(op.getAttrs(), {"names"}); 829 } 830 831 // We set the SSA name in the asm syntax to the contents of the name 832 // attribute. 833 void StringAttrPrettyNameOp::getAsmResultNames( 834 function_ref<void(Value, StringRef)> setNameFn) { 835 836 auto value = names(); 837 for (size_t i = 0, e = value.size(); i != e; ++i) 838 if (auto str = value[i].dyn_cast<StringAttr>()) 839 if (!str.getValue().empty()) 840 setNameFn(getResult(i), str.getValue()); 841 } 842 843 //===----------------------------------------------------------------------===// 844 // RegionIfOp 845 //===----------------------------------------------------------------------===// 846 847 static void print(OpAsmPrinter &p, RegionIfOp op) { 848 p << RegionIfOp::getOperationName() << " "; 849 p.printOperands(op.getOperands()); 850 p << ": " << op.getOperandTypes(); 851 p.printArrowTypeList(op.getResultTypes()); 852 p << " then"; 853 p.printRegion(op.thenRegion(), 854 /*printEntryBlockArgs=*/true, 855 /*printBlockTerminators=*/true); 856 p << " else"; 857 p.printRegion(op.elseRegion(), 858 /*printEntryBlockArgs=*/true, 859 /*printBlockTerminators=*/true); 860 p << " join"; 861 p.printRegion(op.joinRegion(), 862 /*printEntryBlockArgs=*/true, 863 /*printBlockTerminators=*/true); 864 } 865 866 static ParseResult parseRegionIfOp(OpAsmParser &parser, 867 OperationState &result) { 868 SmallVector<OpAsmParser::OperandType, 2> operandInfos; 869 SmallVector<Type, 2> operandTypes; 870 871 result.regions.reserve(3); 872 Region *thenRegion = result.addRegion(); 873 Region *elseRegion = result.addRegion(); 874 Region *joinRegion = result.addRegion(); 875 876 // Parse operand, type and arrow type lists. 877 if (parser.parseOperandList(operandInfos) || 878 parser.parseColonTypeList(operandTypes) || 879 parser.parseArrowTypeList(result.types)) 880 return failure(); 881 882 // Parse all attached regions. 883 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 884 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 885 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 886 return failure(); 887 888 return parser.resolveOperands(operandInfos, operandTypes, 889 parser.getCurrentLocation(), result.operands); 890 } 891 892 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 893 assert(index < 2 && "invalid region index"); 894 return getOperands(); 895 } 896 897 void RegionIfOp::getSuccessorRegions( 898 Optional<unsigned> index, ArrayRef<Attribute> operands, 899 SmallVectorImpl<RegionSuccessor> ®ions) { 900 // We always branch to the join region. 901 if (index.hasValue()) { 902 if (index.getValue() < 2) 903 regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs())); 904 else 905 regions.push_back(RegionSuccessor(getResults())); 906 return; 907 } 908 909 // The then and else regions are the entry regions of this op. 910 regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs())); 911 regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs())); 912 } 913 914 #include "TestOpEnums.cpp.inc" 915 #include "TestOpStructs.cpp.inc" 916 #include "TestTypeInterfaces.cpp.inc" 917 918 #define GET_OP_CLASSES 919 #include "TestOps.cpp.inc" 920