1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestTypes.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/IR/BuiltinOps.h" 13 #include "mlir/IR/DialectImplementation.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/IR/TypeUtilities.h" 16 #include "mlir/Transforms/FoldUtils.h" 17 #include "mlir/Transforms/InliningUtils.h" 18 #include "llvm/ADT/SetVector.h" 19 #include "llvm/ADT/StringSwitch.h" 20 21 using namespace mlir; 22 using namespace mlir::test; 23 24 void mlir::test::registerTestDialect(DialectRegistry ®istry) { 25 registry.insert<TestDialect>(); 26 } 27 28 //===----------------------------------------------------------------------===// 29 // TestDialect Interfaces 30 //===----------------------------------------------------------------------===// 31 32 namespace { 33 34 // Test support for interacting with the AsmPrinter. 35 struct TestOpAsmInterface : public OpAsmDialectInterface { 36 using OpAsmDialectInterface::OpAsmDialectInterface; 37 38 LogicalResult getAlias(Attribute attr, raw_ostream &os) const final { 39 StringAttr strAttr = attr.dyn_cast<StringAttr>(); 40 if (!strAttr) 41 return failure(); 42 43 // Check the contents of the string attribute to see what the test alias 44 // should be named. 45 Optional<StringRef> aliasName = 46 StringSwitch<Optional<StringRef>>(strAttr.getValue()) 47 .Case("alias_test:dot_in_name", StringRef("test.alias")) 48 .Case("alias_test:trailing_digit", StringRef("test_alias0")) 49 .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) 50 .Case("alias_test:sanitize_conflict_a", 51 StringRef("test_alias_conflict0")) 52 .Case("alias_test:sanitize_conflict_b", 53 StringRef("test_alias_conflict0_")) 54 .Default(llvm::None); 55 if (!aliasName) 56 return failure(); 57 58 os << *aliasName; 59 return success(); 60 } 61 62 void getAsmResultNames(Operation *op, 63 OpAsmSetValueNameFn setNameFn) const final { 64 if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) 65 setNameFn(asmOp, "result"); 66 } 67 68 void getAsmBlockArgumentNames(Block *block, 69 OpAsmSetValueNameFn setNameFn) const final { 70 auto op = block->getParentOp(); 71 auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names"); 72 if (!arrayAttr) 73 return; 74 auto args = block->getArguments(); 75 auto e = std::min(arrayAttr.size(), args.size()); 76 for (unsigned i = 0; i < e; ++i) { 77 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 78 setNameFn(args[i], strAttr.getValue()); 79 } 80 } 81 }; 82 83 struct TestDialectFoldInterface : public DialectFoldInterface { 84 using DialectFoldInterface::DialectFoldInterface; 85 86 /// Registered hook to check if the given region, which is attached to an 87 /// operation that is *not* isolated from above, should be used when 88 /// materializing constants. 89 bool shouldMaterializeInto(Region *region) const final { 90 // If this is a one region operation, then insert into it. 91 return isa<OneRegionOp>(region->getParentOp()); 92 } 93 }; 94 95 /// This class defines the interface for handling inlining with standard 96 /// operations. 97 struct TestInlinerInterface : public DialectInlinerInterface { 98 using DialectInlinerInterface::DialectInlinerInterface; 99 100 //===--------------------------------------------------------------------===// 101 // Analysis Hooks 102 //===--------------------------------------------------------------------===// 103 104 bool isLegalToInline(Operation *call, Operation *callable, 105 bool wouldBeCloned) const final { 106 // Don't allow inlining calls that are marked `noinline`. 107 return !call->hasAttr("noinline"); 108 } 109 bool isLegalToInline(Region *, Region *, bool, 110 BlockAndValueMapping &) const final { 111 // Inlining into test dialect regions is legal. 112 return true; 113 } 114 bool isLegalToInline(Operation *, Region *, bool, 115 BlockAndValueMapping &) const final { 116 return true; 117 } 118 119 bool shouldAnalyzeRecursively(Operation *op) const final { 120 // Analyze recursively if this is not a functional region operation, it 121 // froms a separate functional scope. 122 return !isa<FunctionalRegionOp>(op); 123 } 124 125 //===--------------------------------------------------------------------===// 126 // Transformation Hooks 127 //===--------------------------------------------------------------------===// 128 129 /// Handle the given inlined terminator by replacing it with a new operation 130 /// as necessary. 131 void handleTerminator(Operation *op, 132 ArrayRef<Value> valuesToRepl) const final { 133 // Only handle "test.return" here. 134 auto returnOp = dyn_cast<TestReturnOp>(op); 135 if (!returnOp) 136 return; 137 138 // Replace the values directly with the return operands. 139 assert(returnOp.getNumOperands() == valuesToRepl.size()); 140 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 141 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 142 } 143 144 /// Attempt to materialize a conversion for a type mismatch between a call 145 /// from this dialect, and a callable region. This method should generate an 146 /// operation that takes 'input' as the only operand, and produces a single 147 /// result of 'resultType'. If a conversion can not be generated, nullptr 148 /// should be returned. 149 Operation *materializeCallConversion(OpBuilder &builder, Value input, 150 Type resultType, 151 Location conversionLoc) const final { 152 // Only allow conversion for i16/i32 types. 153 if (!(resultType.isSignlessInteger(16) || 154 resultType.isSignlessInteger(32)) || 155 !(input.getType().isSignlessInteger(16) || 156 input.getType().isSignlessInteger(32))) 157 return nullptr; 158 return builder.create<TestCastOp>(conversionLoc, resultType, input); 159 } 160 }; 161 } // end anonymous namespace 162 163 //===----------------------------------------------------------------------===// 164 // TestDialect 165 //===----------------------------------------------------------------------===// 166 167 void TestDialect::initialize() { 168 addOperations< 169 #define GET_OP_LIST 170 #include "TestOps.cpp.inc" 171 >(); 172 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 173 TestInlinerInterface>(); 174 addTypes<TestType, TestRecursiveType, 175 #define GET_TYPEDEF_LIST 176 #include "TestTypeDefs.cpp.inc" 177 >(); 178 allowUnknownOperations(); 179 } 180 181 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, 182 Type type, Location loc) { 183 return builder.create<TestOpConstant>(loc, type, value); 184 } 185 186 static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser, 187 llvm::SetVector<Type> &stack) { 188 StringRef typeTag; 189 if (failed(parser.parseKeyword(&typeTag))) 190 return Type(); 191 192 auto genType = generatedTypeParser(ctxt, parser, typeTag); 193 if (genType != Type()) 194 return genType; 195 196 if (typeTag == "test_type") 197 return TestType::get(parser.getBuilder().getContext()); 198 199 if (typeTag != "test_rec") 200 return Type(); 201 202 StringRef name; 203 if (parser.parseLess() || parser.parseKeyword(&name)) 204 return Type(); 205 auto rec = TestRecursiveType::get(parser.getBuilder().getContext(), name); 206 207 // If this type already has been parsed above in the stack, expect just the 208 // name. 209 if (stack.contains(rec)) { 210 if (failed(parser.parseGreater())) 211 return Type(); 212 return rec; 213 } 214 215 // Otherwise, parse the body and update the type. 216 if (failed(parser.parseComma())) 217 return Type(); 218 stack.insert(rec); 219 Type subtype = parseTestType(ctxt, parser, stack); 220 stack.pop_back(); 221 if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype))) 222 return Type(); 223 224 return rec; 225 } 226 227 Type TestDialect::parseType(DialectAsmParser &parser) const { 228 llvm::SetVector<Type> stack; 229 return parseTestType(getContext(), parser, stack); 230 } 231 232 static void printTestType(Type type, DialectAsmPrinter &printer, 233 llvm::SetVector<Type> &stack) { 234 if (succeeded(generatedTypePrinter(type, printer))) 235 return; 236 if (type.isa<TestType>()) { 237 printer << "test_type"; 238 return; 239 } 240 241 auto rec = type.cast<TestRecursiveType>(); 242 printer << "test_rec<" << rec.getName(); 243 if (!stack.contains(rec)) { 244 printer << ", "; 245 stack.insert(rec); 246 printTestType(rec.getBody(), printer, stack); 247 stack.pop_back(); 248 } 249 printer << ">"; 250 } 251 252 void TestDialect::printType(Type type, DialectAsmPrinter &printer) const { 253 llvm::SetVector<Type> stack; 254 printTestType(type, printer, stack); 255 } 256 257 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 258 NamedAttribute namedAttr) { 259 if (namedAttr.first == "test.invalid_attr") 260 return op->emitError() << "invalid to use 'test.invalid_attr'"; 261 return success(); 262 } 263 264 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 265 unsigned regionIndex, 266 unsigned argIndex, 267 NamedAttribute namedAttr) { 268 if (namedAttr.first == "test.invalid_attr") 269 return op->emitError() << "invalid to use 'test.invalid_attr'"; 270 return success(); 271 } 272 273 LogicalResult 274 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 275 unsigned resultIndex, 276 NamedAttribute namedAttr) { 277 if (namedAttr.first == "test.invalid_attr") 278 return op->emitError() << "invalid to use 'test.invalid_attr'"; 279 return success(); 280 } 281 282 //===----------------------------------------------------------------------===// 283 // TestBranchOp 284 //===----------------------------------------------------------------------===// 285 286 Optional<MutableOperandRange> 287 TestBranchOp::getMutableSuccessorOperands(unsigned index) { 288 assert(index == 0 && "invalid successor index"); 289 return targetOperandsMutable(); 290 } 291 292 //===----------------------------------------------------------------------===// 293 // TestFoldToCallOp 294 //===----------------------------------------------------------------------===// 295 296 namespace { 297 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 298 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 299 300 LogicalResult matchAndRewrite(FoldToCallOp op, 301 PatternRewriter &rewriter) const override { 302 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(), 303 ValueRange()); 304 return success(); 305 } 306 }; 307 } // end anonymous namespace 308 309 void FoldToCallOp::getCanonicalizationPatterns( 310 OwningRewritePatternList &results, MLIRContext *context) { 311 results.insert<FoldToCallOpPattern>(context); 312 } 313 314 //===----------------------------------------------------------------------===// 315 // Test Format* operations 316 //===----------------------------------------------------------------------===// 317 318 //===----------------------------------------------------------------------===// 319 // Parsing 320 321 static ParseResult parseCustomDirectiveOperands( 322 OpAsmParser &parser, OpAsmParser::OperandType &operand, 323 Optional<OpAsmParser::OperandType> &optOperand, 324 SmallVectorImpl<OpAsmParser::OperandType> &varOperands) { 325 if (parser.parseOperand(operand)) 326 return failure(); 327 if (succeeded(parser.parseOptionalComma())) { 328 optOperand.emplace(); 329 if (parser.parseOperand(*optOperand)) 330 return failure(); 331 } 332 if (parser.parseArrow() || parser.parseLParen() || 333 parser.parseOperandList(varOperands) || parser.parseRParen()) 334 return failure(); 335 return success(); 336 } 337 static ParseResult 338 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 339 Type &optOperandType, 340 SmallVectorImpl<Type> &varOperandTypes) { 341 if (parser.parseColon()) 342 return failure(); 343 344 if (parser.parseType(operandType)) 345 return failure(); 346 if (succeeded(parser.parseOptionalComma())) { 347 if (parser.parseType(optOperandType)) 348 return failure(); 349 } 350 if (parser.parseArrow() || parser.parseLParen() || 351 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 352 return failure(); 353 return success(); 354 } 355 static ParseResult 356 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 357 Type optOperandType, 358 const SmallVectorImpl<Type> &varOperandTypes) { 359 if (parser.parseKeyword("type_refs_capture")) 360 return failure(); 361 362 Type operandType2, optOperandType2; 363 SmallVector<Type, 1> varOperandTypes2; 364 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 365 varOperandTypes2)) 366 return failure(); 367 368 if (operandType != operandType2 || optOperandType != optOperandType2 || 369 varOperandTypes != varOperandTypes2) 370 return failure(); 371 372 return success(); 373 } 374 static ParseResult parseCustomDirectiveOperandsAndTypes( 375 OpAsmParser &parser, OpAsmParser::OperandType &operand, 376 Optional<OpAsmParser::OperandType> &optOperand, 377 SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType, 378 Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) { 379 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 380 parseCustomDirectiveResults(parser, operandType, optOperandType, 381 varOperandTypes)) 382 return failure(); 383 return success(); 384 } 385 static ParseResult parseCustomDirectiveRegions( 386 OpAsmParser &parser, Region ®ion, 387 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 388 if (parser.parseRegion(region)) 389 return failure(); 390 if (failed(parser.parseOptionalComma())) 391 return success(); 392 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 393 if (parser.parseRegion(*varRegion)) 394 return failure(); 395 varRegions.emplace_back(std::move(varRegion)); 396 return success(); 397 } 398 static ParseResult 399 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 400 SmallVectorImpl<Block *> &varSuccessors) { 401 if (parser.parseSuccessor(successor)) 402 return failure(); 403 if (failed(parser.parseOptionalComma())) 404 return success(); 405 Block *varSuccessor; 406 if (parser.parseSuccessor(varSuccessor)) 407 return failure(); 408 varSuccessors.append(2, varSuccessor); 409 return success(); 410 } 411 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 412 IntegerAttr &attr, 413 IntegerAttr &optAttr) { 414 if (parser.parseAttribute(attr)) 415 return failure(); 416 if (succeeded(parser.parseOptionalComma())) { 417 if (parser.parseAttribute(optAttr)) 418 return failure(); 419 } 420 return success(); 421 } 422 423 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, 424 NamedAttrList &attrs) { 425 return parser.parseOptionalAttrDict(attrs); 426 } 427 428 //===----------------------------------------------------------------------===// 429 // Printing 430 431 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 432 Value operand, Value optOperand, 433 OperandRange varOperands) { 434 printer << operand; 435 if (optOperand) 436 printer << ", " << optOperand; 437 printer << " -> (" << varOperands << ")"; 438 } 439 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 440 Type operandType, Type optOperandType, 441 TypeRange varOperandTypes) { 442 printer << " : " << operandType; 443 if (optOperandType) 444 printer << ", " << optOperandType; 445 printer << " -> (" << varOperandTypes << ")"; 446 } 447 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 448 Operation *op, Type operandType, 449 Type optOperandType, 450 TypeRange varOperandTypes) { 451 printer << " type_refs_capture "; 452 printCustomDirectiveResults(printer, op, operandType, optOperandType, 453 varOperandTypes); 454 } 455 static void printCustomDirectiveOperandsAndTypes( 456 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 457 OperandRange varOperands, Type operandType, Type optOperandType, 458 TypeRange varOperandTypes) { 459 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 460 printCustomDirectiveResults(printer, op, operandType, optOperandType, 461 varOperandTypes); 462 } 463 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 464 Region ®ion, 465 MutableArrayRef<Region> varRegions) { 466 printer.printRegion(region); 467 if (!varRegions.empty()) { 468 printer << ", "; 469 for (Region ®ion : varRegions) 470 printer.printRegion(region); 471 } 472 } 473 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 474 Block *successor, 475 SuccessorRange varSuccessors) { 476 printer << successor; 477 if (!varSuccessors.empty()) 478 printer << ", " << varSuccessors.front(); 479 } 480 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 481 Attribute attribute, 482 Attribute optAttribute) { 483 printer << attribute; 484 if (optAttribute) 485 printer << ", " << optAttribute; 486 } 487 488 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 489 MutableDictionaryAttr attrs) { 490 printer.printOptionalAttrDict(attrs.getAttrs()); 491 } 492 //===----------------------------------------------------------------------===// 493 // Test IsolatedRegionOp - parse passthrough region arguments. 494 //===----------------------------------------------------------------------===// 495 496 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, 497 OperationState &result) { 498 OpAsmParser::OperandType argInfo; 499 Type argType = parser.getBuilder().getIndexType(); 500 501 // Parse the input operand. 502 if (parser.parseOperand(argInfo) || 503 parser.resolveOperand(argInfo, argType, result.operands)) 504 return failure(); 505 506 // Parse the body region, and reuse the operand info as the argument info. 507 Region *body = result.addRegion(); 508 return parser.parseRegion(*body, argInfo, argType, 509 /*enableNameShadowing=*/true); 510 } 511 512 static void print(OpAsmPrinter &p, IsolatedRegionOp op) { 513 p << "test.isolated_region "; 514 p.printOperand(op.getOperand()); 515 p.shadowRegionArgs(op.region(), op.getOperand()); 516 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 517 } 518 519 //===----------------------------------------------------------------------===// 520 // Test SSACFGRegionOp 521 //===----------------------------------------------------------------------===// 522 523 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 524 return RegionKind::SSACFG; 525 } 526 527 //===----------------------------------------------------------------------===// 528 // Test GraphRegionOp 529 //===----------------------------------------------------------------------===// 530 531 static ParseResult parseGraphRegionOp(OpAsmParser &parser, 532 OperationState &result) { 533 // Parse the body region, and reuse the operand info as the argument info. 534 Region *body = result.addRegion(); 535 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 536 } 537 538 static void print(OpAsmPrinter &p, GraphRegionOp op) { 539 p << "test.graph_region "; 540 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 541 } 542 543 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 544 return RegionKind::Graph; 545 } 546 547 //===----------------------------------------------------------------------===// 548 // Test AffineScopeOp 549 //===----------------------------------------------------------------------===// 550 551 static ParseResult parseAffineScopeOp(OpAsmParser &parser, 552 OperationState &result) { 553 // Parse the body region, and reuse the operand info as the argument info. 554 Region *body = result.addRegion(); 555 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 556 } 557 558 static void print(OpAsmPrinter &p, AffineScopeOp op) { 559 p << "test.affine_scope "; 560 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 561 } 562 563 //===----------------------------------------------------------------------===// 564 // Test parser. 565 //===----------------------------------------------------------------------===// 566 567 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser, 568 OperationState &result) { 569 if (parser.parseOptionalColon()) 570 return success(); 571 uint64_t numResults; 572 if (parser.parseInteger(numResults)) 573 return failure(); 574 575 IndexType type = parser.getBuilder().getIndexType(); 576 for (unsigned i = 0; i < numResults; ++i) 577 result.addTypes(type); 578 return success(); 579 } 580 581 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) { 582 p << ParseIntegerLiteralOp::getOperationName(); 583 if (unsigned numResults = op->getNumResults()) 584 p << " : " << numResults; 585 } 586 587 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser, 588 OperationState &result) { 589 StringRef keyword; 590 if (parser.parseKeyword(&keyword)) 591 return failure(); 592 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 593 return success(); 594 } 595 596 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) { 597 p << ParseWrappedKeywordOp::getOperationName() << " " << op.keyword(); 598 } 599 600 //===----------------------------------------------------------------------===// 601 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 602 603 static ParseResult parseWrappingRegionOp(OpAsmParser &parser, 604 OperationState &result) { 605 if (parser.parseKeyword("wraps")) 606 return failure(); 607 608 // Parse the wrapped op in a region 609 Region &body = *result.addRegion(); 610 body.push_back(new Block); 611 Block &block = body.back(); 612 Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); 613 if (!wrapped_op) 614 return failure(); 615 616 // Create a return terminator in the inner region, pass as operand to the 617 // terminator the returned values from the wrapped operation. 618 SmallVector<Value, 8> return_operands(wrapped_op->getResults()); 619 OpBuilder builder(parser.getBuilder().getContext()); 620 builder.setInsertionPointToEnd(&block); 621 builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands); 622 623 // Get the results type for the wrapping op from the terminator operands. 624 Operation &return_op = body.back().back(); 625 result.types.append(return_op.operand_type_begin(), 626 return_op.operand_type_end()); 627 628 // Use the location of the wrapped op for the "test.wrapping_region" op. 629 result.location = wrapped_op->getLoc(); 630 631 return success(); 632 } 633 634 static void print(OpAsmPrinter &p, WrappingRegionOp op) { 635 p << op.getOperationName() << " wraps "; 636 p.printGenericOp(&op.region().front().front()); 637 } 638 639 //===----------------------------------------------------------------------===// 640 // Test PolyForOp - parse list of region arguments. 641 //===----------------------------------------------------------------------===// 642 643 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { 644 SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 645 // Parse list of region arguments without a delimiter. 646 if (parser.parseRegionArgumentList(ivsInfo)) 647 return failure(); 648 649 // Parse the body region. 650 Region *body = result.addRegion(); 651 auto &builder = parser.getBuilder(); 652 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 653 return parser.parseRegion(*body, ivsInfo, argTypes); 654 } 655 656 //===----------------------------------------------------------------------===// 657 // Test removing op with inner ops. 658 //===----------------------------------------------------------------------===// 659 660 namespace { 661 struct TestRemoveOpWithInnerOps 662 : public OpRewritePattern<TestOpWithRegionPattern> { 663 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 664 665 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 666 PatternRewriter &rewriter) const override { 667 rewriter.eraseOp(op); 668 return success(); 669 } 670 }; 671 } // end anonymous namespace 672 673 void TestOpWithRegionPattern::getCanonicalizationPatterns( 674 OwningRewritePatternList &results, MLIRContext *context) { 675 results.insert<TestRemoveOpWithInnerOps>(context); 676 } 677 678 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 679 return operand(); 680 } 681 682 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) { 683 return getValue(); 684 } 685 686 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 687 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 688 for (Value input : this->operands()) { 689 results.push_back(input); 690 } 691 return success(); 692 } 693 694 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 695 assert(operands.size() == 1); 696 if (operands.front()) { 697 (*this)->setAttr("attr", operands.front()); 698 return getResult(); 699 } 700 return {}; 701 } 702 703 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 704 MLIRContext *, Optional<Location> location, ValueRange operands, 705 DictionaryAttr attributes, RegionRange regions, 706 SmallVectorImpl<Type> &inferredReturnTypes) { 707 if (operands[0].getType() != operands[1].getType()) { 708 return emitOptionalError(location, "operand type mismatch ", 709 operands[0].getType(), " vs ", 710 operands[1].getType()); 711 } 712 inferredReturnTypes.assign({operands[0].getType()}); 713 return success(); 714 } 715 716 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 717 MLIRContext *context, Optional<Location> location, ValueRange operands, 718 DictionaryAttr attributes, RegionRange regions, 719 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 720 // Create return type consisting of the last element of the first operand. 721 auto operandType = *operands.getTypes().begin(); 722 auto sval = operandType.dyn_cast<ShapedType>(); 723 if (!sval) { 724 return emitOptionalError(location, "only shaped type operands allowed"); 725 } 726 int64_t dim = 727 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 728 auto type = IntegerType::get(17, context); 729 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 730 return success(); 731 } 732 733 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 734 OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) { 735 shapes = SmallVector<Value, 1>{ 736 builder.createOrFold<DimOp>(getLoc(), getOperand(0), 0)}; 737 return success(); 738 } 739 740 //===----------------------------------------------------------------------===// 741 // Test SideEffect interfaces 742 //===----------------------------------------------------------------------===// 743 744 namespace { 745 /// A test resource for side effects. 746 struct TestResource : public SideEffects::Resource::Base<TestResource> { 747 StringRef getName() final { return "<Test>"; } 748 }; 749 } // end anonymous namespace 750 751 void SideEffectOp::getEffects( 752 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 753 // Check for an effects attribute on the op instance. 754 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects"); 755 if (!effectsAttr) 756 return; 757 758 // If there is one, it is an array of dictionary attributes that hold 759 // information on the effects of this operation. 760 for (Attribute element : effectsAttr) { 761 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 762 763 // Get the specific memory effect. 764 MemoryEffects::Effect *effect = 765 StringSwitch<MemoryEffects::Effect *>( 766 effectElement.get("effect").cast<StringAttr>().getValue()) 767 .Case("allocate", MemoryEffects::Allocate::get()) 768 .Case("free", MemoryEffects::Free::get()) 769 .Case("read", MemoryEffects::Read::get()) 770 .Case("write", MemoryEffects::Write::get()); 771 772 // Check for a non-default resource to use. 773 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 774 if (effectElement.get("test_resource")) 775 resource = TestResource::get(); 776 777 // Check for a result to affect. 778 if (effectElement.get("on_result")) 779 effects.emplace_back(effect, getResult(), resource); 780 else if (Attribute ref = effectElement.get("on_reference")) 781 effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource); 782 else 783 effects.emplace_back(effect, resource); 784 } 785 } 786 787 void SideEffectOp::getEffects( 788 SmallVectorImpl<TestEffects::EffectInstance> &effects) { 789 auto effectsAttr = (*this)->getAttrOfType<AffineMapAttr>("effect_parameter"); 790 if (!effectsAttr) 791 return; 792 793 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); 794 } 795 796 //===----------------------------------------------------------------------===// 797 // StringAttrPrettyNameOp 798 //===----------------------------------------------------------------------===// 799 800 // This op has fancy handling of its SSA result name. 801 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, 802 OperationState &result) { 803 // Add the result types. 804 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 805 result.addTypes(parser.getBuilder().getIntegerType(32)); 806 807 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 808 return failure(); 809 810 // If the attribute dictionary contains no 'names' attribute, infer it from 811 // the SSA name (if specified). 812 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 813 return attr.first == "names"; 814 }); 815 816 // If there was no name specified, check to see if there was a useful name 817 // specified in the asm file. 818 if (hadNames || parser.getNumResults() == 0) 819 return success(); 820 821 SmallVector<StringRef, 4> names; 822 auto *context = result.getContext(); 823 824 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 825 auto resultName = parser.getResultName(i); 826 StringRef nameStr; 827 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 828 nameStr = resultName.first; 829 830 names.push_back(nameStr); 831 } 832 833 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 834 result.attributes.push_back({Identifier::get("names", context), namesAttr}); 835 return success(); 836 } 837 838 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { 839 p << "test.string_attr_pretty_name"; 840 841 // Note that we only need to print the "name" attribute if the asmprinter 842 // result name disagrees with it. This can happen in strange cases, e.g. 843 // when there are conflicts. 844 bool namesDisagree = op.names().size() != op.getNumResults(); 845 846 SmallString<32> resultNameStr; 847 for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { 848 resultNameStr.clear(); 849 llvm::raw_svector_ostream tmpStream(resultNameStr); 850 p.printOperand(op.getResult(i), tmpStream); 851 852 auto expectedName = op.names()[i].dyn_cast<StringAttr>(); 853 if (!expectedName || 854 tmpStream.str().drop_front() != expectedName.getValue()) { 855 namesDisagree = true; 856 } 857 } 858 859 if (namesDisagree) 860 p.printOptionalAttrDictWithKeyword(op.getAttrs()); 861 else 862 p.printOptionalAttrDictWithKeyword(op.getAttrs(), {"names"}); 863 } 864 865 // We set the SSA name in the asm syntax to the contents of the name 866 // attribute. 867 void StringAttrPrettyNameOp::getAsmResultNames( 868 function_ref<void(Value, StringRef)> setNameFn) { 869 870 auto value = names(); 871 for (size_t i = 0, e = value.size(); i != e; ++i) 872 if (auto str = value[i].dyn_cast<StringAttr>()) 873 if (!str.getValue().empty()) 874 setNameFn(getResult(i), str.getValue()); 875 } 876 877 //===----------------------------------------------------------------------===// 878 // RegionIfOp 879 //===----------------------------------------------------------------------===// 880 881 static void print(OpAsmPrinter &p, RegionIfOp op) { 882 p << RegionIfOp::getOperationName() << " "; 883 p.printOperands(op.getOperands()); 884 p << ": " << op.getOperandTypes(); 885 p.printArrowTypeList(op.getResultTypes()); 886 p << " then"; 887 p.printRegion(op.thenRegion(), 888 /*printEntryBlockArgs=*/true, 889 /*printBlockTerminators=*/true); 890 p << " else"; 891 p.printRegion(op.elseRegion(), 892 /*printEntryBlockArgs=*/true, 893 /*printBlockTerminators=*/true); 894 p << " join"; 895 p.printRegion(op.joinRegion(), 896 /*printEntryBlockArgs=*/true, 897 /*printBlockTerminators=*/true); 898 } 899 900 static ParseResult parseRegionIfOp(OpAsmParser &parser, 901 OperationState &result) { 902 SmallVector<OpAsmParser::OperandType, 2> operandInfos; 903 SmallVector<Type, 2> operandTypes; 904 905 result.regions.reserve(3); 906 Region *thenRegion = result.addRegion(); 907 Region *elseRegion = result.addRegion(); 908 Region *joinRegion = result.addRegion(); 909 910 // Parse operand, type and arrow type lists. 911 if (parser.parseOperandList(operandInfos) || 912 parser.parseColonTypeList(operandTypes) || 913 parser.parseArrowTypeList(result.types)) 914 return failure(); 915 916 // Parse all attached regions. 917 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 918 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 919 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 920 return failure(); 921 922 return parser.resolveOperands(operandInfos, operandTypes, 923 parser.getCurrentLocation(), result.operands); 924 } 925 926 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 927 assert(index < 2 && "invalid region index"); 928 return getOperands(); 929 } 930 931 void RegionIfOp::getSuccessorRegions( 932 Optional<unsigned> index, ArrayRef<Attribute> operands, 933 SmallVectorImpl<RegionSuccessor> ®ions) { 934 // We always branch to the join region. 935 if (index.hasValue()) { 936 if (index.getValue() < 2) 937 regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs())); 938 else 939 regions.push_back(RegionSuccessor(getResults())); 940 return; 941 } 942 943 // The then and else regions are the entry regions of this op. 944 regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs())); 945 regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs())); 946 } 947 948 #include "TestOpEnums.cpp.inc" 949 #include "TestOpInterfaces.cpp.inc" 950 #include "TestOpStructs.cpp.inc" 951 #include "TestTypeInterfaces.cpp.inc" 952 953 #define GET_OP_CLASSES 954 #include "TestOps.cpp.inc" 955