1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestTypes.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/IR/DialectImplementation.h" 13 #include "mlir/IR/Function.h" 14 #include "mlir/IR/Module.h" 15 #include "mlir/IR/PatternMatch.h" 16 #include "mlir/IR/TypeUtilities.h" 17 #include "mlir/Transforms/FoldUtils.h" 18 #include "mlir/Transforms/InliningUtils.h" 19 #include "llvm/ADT/SetVector.h" 20 #include "llvm/ADT/StringSwitch.h" 21 22 using namespace mlir; 23 24 void mlir::registerTestDialect(DialectRegistry ®istry) { 25 registry.insert<TestDialect>(); 26 } 27 28 //===----------------------------------------------------------------------===// 29 // TestDialect Interfaces 30 //===----------------------------------------------------------------------===// 31 32 namespace { 33 34 // Test support for interacting with the AsmPrinter. 35 struct TestOpAsmInterface : public OpAsmDialectInterface { 36 using OpAsmDialectInterface::OpAsmDialectInterface; 37 38 void getAsmResultNames(Operation *op, 39 OpAsmSetValueNameFn setNameFn) const final { 40 if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) 41 setNameFn(asmOp, "result"); 42 } 43 44 void getAsmBlockArgumentNames(Block *block, 45 OpAsmSetValueNameFn setNameFn) const final { 46 auto op = block->getParentOp(); 47 auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names"); 48 if (!arrayAttr) 49 return; 50 auto args = block->getArguments(); 51 auto e = std::min(arrayAttr.size(), args.size()); 52 for (unsigned i = 0; i < e; ++i) { 53 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 54 setNameFn(args[i], strAttr.getValue()); 55 } 56 } 57 }; 58 59 struct TestDialectFoldInterface : public DialectFoldInterface { 60 using DialectFoldInterface::DialectFoldInterface; 61 62 /// Registered hook to check if the given region, which is attached to an 63 /// operation that is *not* isolated from above, should be used when 64 /// materializing constants. 65 bool shouldMaterializeInto(Region *region) const final { 66 // If this is a one region operation, then insert into it. 67 return isa<OneRegionOp>(region->getParentOp()); 68 } 69 }; 70 71 /// This class defines the interface for handling inlining with standard 72 /// operations. 73 struct TestInlinerInterface : public DialectInlinerInterface { 74 using DialectInlinerInterface::DialectInlinerInterface; 75 76 //===--------------------------------------------------------------------===// 77 // Analysis Hooks 78 //===--------------------------------------------------------------------===// 79 80 bool isLegalToInline(Operation *call, Operation *callable, 81 bool wouldBeCloned) const final { 82 // Don't allow inlining calls that are marked `noinline`. 83 return !call->hasAttr("noinline"); 84 } 85 bool isLegalToInline(Region *, Region *, bool, 86 BlockAndValueMapping &) const final { 87 // Inlining into test dialect regions is legal. 88 return true; 89 } 90 bool isLegalToInline(Operation *, Region *, bool, 91 BlockAndValueMapping &) const final { 92 return true; 93 } 94 95 bool shouldAnalyzeRecursively(Operation *op) const final { 96 // Analyze recursively if this is not a functional region operation, it 97 // froms a separate functional scope. 98 return !isa<FunctionalRegionOp>(op); 99 } 100 101 //===--------------------------------------------------------------------===// 102 // Transformation Hooks 103 //===--------------------------------------------------------------------===// 104 105 /// Handle the given inlined terminator by replacing it with a new operation 106 /// as necessary. 107 void handleTerminator(Operation *op, 108 ArrayRef<Value> valuesToRepl) const final { 109 // Only handle "test.return" here. 110 auto returnOp = dyn_cast<TestReturnOp>(op); 111 if (!returnOp) 112 return; 113 114 // Replace the values directly with the return operands. 115 assert(returnOp.getNumOperands() == valuesToRepl.size()); 116 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 117 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 118 } 119 120 /// Attempt to materialize a conversion for a type mismatch between a call 121 /// from this dialect, and a callable region. This method should generate an 122 /// operation that takes 'input' as the only operand, and produces a single 123 /// result of 'resultType'. If a conversion can not be generated, nullptr 124 /// should be returned. 125 Operation *materializeCallConversion(OpBuilder &builder, Value input, 126 Type resultType, 127 Location conversionLoc) const final { 128 // Only allow conversion for i16/i32 types. 129 if (!(resultType.isSignlessInteger(16) || 130 resultType.isSignlessInteger(32)) || 131 !(input.getType().isSignlessInteger(16) || 132 input.getType().isSignlessInteger(32))) 133 return nullptr; 134 return builder.create<TestCastOp>(conversionLoc, resultType, input); 135 } 136 }; 137 } // end anonymous namespace 138 139 //===----------------------------------------------------------------------===// 140 // TestDialect 141 //===----------------------------------------------------------------------===// 142 143 void TestDialect::initialize() { 144 addOperations< 145 #define GET_OP_LIST 146 #include "TestOps.cpp.inc" 147 >(); 148 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 149 TestInlinerInterface>(); 150 addTypes<TestType, TestRecursiveType, 151 #define GET_TYPEDEF_LIST 152 #include "TestTypeDefs.cpp.inc" 153 >(); 154 allowUnknownOperations(); 155 } 156 157 static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser, 158 llvm::SetVector<Type> &stack) { 159 StringRef typeTag; 160 if (failed(parser.parseKeyword(&typeTag))) 161 return Type(); 162 163 auto genType = generatedTypeParser(ctxt, parser, typeTag); 164 if (genType != Type()) 165 return genType; 166 167 if (typeTag == "test_type") 168 return TestType::get(parser.getBuilder().getContext()); 169 170 if (typeTag != "test_rec") 171 return Type(); 172 173 StringRef name; 174 if (parser.parseLess() || parser.parseKeyword(&name)) 175 return Type(); 176 auto rec = TestRecursiveType::get(parser.getBuilder().getContext(), name); 177 178 // If this type already has been parsed above in the stack, expect just the 179 // name. 180 if (stack.contains(rec)) { 181 if (failed(parser.parseGreater())) 182 return Type(); 183 return rec; 184 } 185 186 // Otherwise, parse the body and update the type. 187 if (failed(parser.parseComma())) 188 return Type(); 189 stack.insert(rec); 190 Type subtype = parseTestType(ctxt, parser, stack); 191 stack.pop_back(); 192 if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype))) 193 return Type(); 194 195 return rec; 196 } 197 198 Type TestDialect::parseType(DialectAsmParser &parser) const { 199 llvm::SetVector<Type> stack; 200 return parseTestType(getContext(), parser, stack); 201 } 202 203 static void printTestType(Type type, DialectAsmPrinter &printer, 204 llvm::SetVector<Type> &stack) { 205 if (succeeded(generatedTypePrinter(type, printer))) 206 return; 207 if (type.isa<TestType>()) { 208 printer << "test_type"; 209 return; 210 } 211 212 auto rec = type.cast<TestRecursiveType>(); 213 printer << "test_rec<" << rec.getName(); 214 if (!stack.contains(rec)) { 215 printer << ", "; 216 stack.insert(rec); 217 printTestType(rec.getBody(), printer, stack); 218 stack.pop_back(); 219 } 220 printer << ">"; 221 } 222 223 void TestDialect::printType(Type type, DialectAsmPrinter &printer) const { 224 llvm::SetVector<Type> stack; 225 printTestType(type, printer, stack); 226 } 227 228 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 229 NamedAttribute namedAttr) { 230 if (namedAttr.first == "test.invalid_attr") 231 return op->emitError() << "invalid to use 'test.invalid_attr'"; 232 return success(); 233 } 234 235 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 236 unsigned regionIndex, 237 unsigned argIndex, 238 NamedAttribute namedAttr) { 239 if (namedAttr.first == "test.invalid_attr") 240 return op->emitError() << "invalid to use 'test.invalid_attr'"; 241 return success(); 242 } 243 244 LogicalResult 245 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 246 unsigned resultIndex, 247 NamedAttribute namedAttr) { 248 if (namedAttr.first == "test.invalid_attr") 249 return op->emitError() << "invalid to use 'test.invalid_attr'"; 250 return success(); 251 } 252 253 //===----------------------------------------------------------------------===// 254 // TestBranchOp 255 //===----------------------------------------------------------------------===// 256 257 Optional<MutableOperandRange> 258 TestBranchOp::getMutableSuccessorOperands(unsigned index) { 259 assert(index == 0 && "invalid successor index"); 260 return targetOperandsMutable(); 261 } 262 263 //===----------------------------------------------------------------------===// 264 // TestFoldToCallOp 265 //===----------------------------------------------------------------------===// 266 267 namespace { 268 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 269 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 270 271 LogicalResult matchAndRewrite(FoldToCallOp op, 272 PatternRewriter &rewriter) const override { 273 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(), 274 ValueRange()); 275 return success(); 276 } 277 }; 278 } // end anonymous namespace 279 280 void FoldToCallOp::getCanonicalizationPatterns( 281 OwningRewritePatternList &results, MLIRContext *context) { 282 results.insert<FoldToCallOpPattern>(context); 283 } 284 285 //===----------------------------------------------------------------------===// 286 // Test Format* operations 287 //===----------------------------------------------------------------------===// 288 289 //===----------------------------------------------------------------------===// 290 // Parsing 291 292 static ParseResult parseCustomDirectiveOperands( 293 OpAsmParser &parser, OpAsmParser::OperandType &operand, 294 Optional<OpAsmParser::OperandType> &optOperand, 295 SmallVectorImpl<OpAsmParser::OperandType> &varOperands) { 296 if (parser.parseOperand(operand)) 297 return failure(); 298 if (succeeded(parser.parseOptionalComma())) { 299 optOperand.emplace(); 300 if (parser.parseOperand(*optOperand)) 301 return failure(); 302 } 303 if (parser.parseArrow() || parser.parseLParen() || 304 parser.parseOperandList(varOperands) || parser.parseRParen()) 305 return failure(); 306 return success(); 307 } 308 static ParseResult 309 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 310 Type &optOperandType, 311 SmallVectorImpl<Type> &varOperandTypes) { 312 if (parser.parseColon()) 313 return failure(); 314 315 if (parser.parseType(operandType)) 316 return failure(); 317 if (succeeded(parser.parseOptionalComma())) { 318 if (parser.parseType(optOperandType)) 319 return failure(); 320 } 321 if (parser.parseArrow() || parser.parseLParen() || 322 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 323 return failure(); 324 return success(); 325 } 326 static ParseResult 327 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 328 Type optOperandType, 329 const SmallVectorImpl<Type> &varOperandTypes) { 330 if (parser.parseKeyword("type_refs_capture")) 331 return failure(); 332 333 Type operandType2, optOperandType2; 334 SmallVector<Type, 1> varOperandTypes2; 335 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 336 varOperandTypes2)) 337 return failure(); 338 339 if (operandType != operandType2 || optOperandType != optOperandType2 || 340 varOperandTypes != varOperandTypes2) 341 return failure(); 342 343 return success(); 344 } 345 static ParseResult parseCustomDirectiveOperandsAndTypes( 346 OpAsmParser &parser, OpAsmParser::OperandType &operand, 347 Optional<OpAsmParser::OperandType> &optOperand, 348 SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType, 349 Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) { 350 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 351 parseCustomDirectiveResults(parser, operandType, optOperandType, 352 varOperandTypes)) 353 return failure(); 354 return success(); 355 } 356 static ParseResult parseCustomDirectiveRegions( 357 OpAsmParser &parser, Region ®ion, 358 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 359 if (parser.parseRegion(region)) 360 return failure(); 361 if (failed(parser.parseOptionalComma())) 362 return success(); 363 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 364 if (parser.parseRegion(*varRegion)) 365 return failure(); 366 varRegions.emplace_back(std::move(varRegion)); 367 return success(); 368 } 369 static ParseResult 370 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 371 SmallVectorImpl<Block *> &varSuccessors) { 372 if (parser.parseSuccessor(successor)) 373 return failure(); 374 if (failed(parser.parseOptionalComma())) 375 return success(); 376 Block *varSuccessor; 377 if (parser.parseSuccessor(varSuccessor)) 378 return failure(); 379 varSuccessors.append(2, varSuccessor); 380 return success(); 381 } 382 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 383 IntegerAttr &attr, 384 IntegerAttr &optAttr) { 385 if (parser.parseAttribute(attr)) 386 return failure(); 387 if (succeeded(parser.parseOptionalComma())) { 388 if (parser.parseAttribute(optAttr)) 389 return failure(); 390 } 391 return success(); 392 } 393 394 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, 395 NamedAttrList &attrs) { 396 return parser.parseOptionalAttrDict(attrs); 397 } 398 399 //===----------------------------------------------------------------------===// 400 // Printing 401 402 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 403 Value operand, Value optOperand, 404 OperandRange varOperands) { 405 printer << operand; 406 if (optOperand) 407 printer << ", " << optOperand; 408 printer << " -> (" << varOperands << ")"; 409 } 410 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 411 Type operandType, Type optOperandType, 412 TypeRange varOperandTypes) { 413 printer << " : " << operandType; 414 if (optOperandType) 415 printer << ", " << optOperandType; 416 printer << " -> (" << varOperandTypes << ")"; 417 } 418 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 419 Operation *op, Type operandType, 420 Type optOperandType, 421 TypeRange varOperandTypes) { 422 printer << " type_refs_capture "; 423 printCustomDirectiveResults(printer, op, operandType, optOperandType, 424 varOperandTypes); 425 } 426 static void printCustomDirectiveOperandsAndTypes( 427 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 428 OperandRange varOperands, Type operandType, Type optOperandType, 429 TypeRange varOperandTypes) { 430 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 431 printCustomDirectiveResults(printer, op, operandType, optOperandType, 432 varOperandTypes); 433 } 434 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 435 Region ®ion, 436 MutableArrayRef<Region> varRegions) { 437 printer.printRegion(region); 438 if (!varRegions.empty()) { 439 printer << ", "; 440 for (Region ®ion : varRegions) 441 printer.printRegion(region); 442 } 443 } 444 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 445 Block *successor, 446 SuccessorRange varSuccessors) { 447 printer << successor; 448 if (!varSuccessors.empty()) 449 printer << ", " << varSuccessors.front(); 450 } 451 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 452 Attribute attribute, 453 Attribute optAttribute) { 454 printer << attribute; 455 if (optAttribute) 456 printer << ", " << optAttribute; 457 } 458 459 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 460 MutableDictionaryAttr attrs) { 461 printer.printOptionalAttrDict(attrs.getAttrs()); 462 } 463 //===----------------------------------------------------------------------===// 464 // Test IsolatedRegionOp - parse passthrough region arguments. 465 //===----------------------------------------------------------------------===// 466 467 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, 468 OperationState &result) { 469 OpAsmParser::OperandType argInfo; 470 Type argType = parser.getBuilder().getIndexType(); 471 472 // Parse the input operand. 473 if (parser.parseOperand(argInfo) || 474 parser.resolveOperand(argInfo, argType, result.operands)) 475 return failure(); 476 477 // Parse the body region, and reuse the operand info as the argument info. 478 Region *body = result.addRegion(); 479 return parser.parseRegion(*body, argInfo, argType, 480 /*enableNameShadowing=*/true); 481 } 482 483 static void print(OpAsmPrinter &p, IsolatedRegionOp op) { 484 p << "test.isolated_region "; 485 p.printOperand(op.getOperand()); 486 p.shadowRegionArgs(op.region(), op.getOperand()); 487 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 488 } 489 490 //===----------------------------------------------------------------------===// 491 // Test SSACFGRegionOp 492 //===----------------------------------------------------------------------===// 493 494 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 495 return RegionKind::SSACFG; 496 } 497 498 //===----------------------------------------------------------------------===// 499 // Test GraphRegionOp 500 //===----------------------------------------------------------------------===// 501 502 static ParseResult parseGraphRegionOp(OpAsmParser &parser, 503 OperationState &result) { 504 // Parse the body region, and reuse the operand info as the argument info. 505 Region *body = result.addRegion(); 506 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 507 } 508 509 static void print(OpAsmPrinter &p, GraphRegionOp op) { 510 p << "test.graph_region "; 511 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 512 } 513 514 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 515 return RegionKind::Graph; 516 } 517 518 //===----------------------------------------------------------------------===// 519 // Test AffineScopeOp 520 //===----------------------------------------------------------------------===// 521 522 static ParseResult parseAffineScopeOp(OpAsmParser &parser, 523 OperationState &result) { 524 // Parse the body region, and reuse the operand info as the argument info. 525 Region *body = result.addRegion(); 526 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 527 } 528 529 static void print(OpAsmPrinter &p, AffineScopeOp op) { 530 p << "test.affine_scope "; 531 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 532 } 533 534 //===----------------------------------------------------------------------===// 535 // Test parser. 536 //===----------------------------------------------------------------------===// 537 538 static ParseResult parseWrappedKeywordOp(OpAsmParser &parser, 539 OperationState &result) { 540 StringRef keyword; 541 if (parser.parseKeyword(&keyword)) 542 return failure(); 543 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 544 return success(); 545 } 546 547 static void print(OpAsmPrinter &p, WrappedKeywordOp op) { 548 p << WrappedKeywordOp::getOperationName() << " " << op.keyword(); 549 } 550 551 //===----------------------------------------------------------------------===// 552 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 553 554 static ParseResult parseWrappingRegionOp(OpAsmParser &parser, 555 OperationState &result) { 556 if (parser.parseKeyword("wraps")) 557 return failure(); 558 559 // Parse the wrapped op in a region 560 Region &body = *result.addRegion(); 561 body.push_back(new Block); 562 Block &block = body.back(); 563 Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); 564 if (!wrapped_op) 565 return failure(); 566 567 // Create a return terminator in the inner region, pass as operand to the 568 // terminator the returned values from the wrapped operation. 569 SmallVector<Value, 8> return_operands(wrapped_op->getResults()); 570 OpBuilder builder(parser.getBuilder().getContext()); 571 builder.setInsertionPointToEnd(&block); 572 builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands); 573 574 // Get the results type for the wrapping op from the terminator operands. 575 Operation &return_op = body.back().back(); 576 result.types.append(return_op.operand_type_begin(), 577 return_op.operand_type_end()); 578 579 // Use the location of the wrapped op for the "test.wrapping_region" op. 580 result.location = wrapped_op->getLoc(); 581 582 return success(); 583 } 584 585 static void print(OpAsmPrinter &p, WrappingRegionOp op) { 586 p << op.getOperationName() << " wraps "; 587 p.printGenericOp(&op.region().front().front()); 588 } 589 590 //===----------------------------------------------------------------------===// 591 // Test PolyForOp - parse list of region arguments. 592 //===----------------------------------------------------------------------===// 593 594 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { 595 SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 596 // Parse list of region arguments without a delimiter. 597 if (parser.parseRegionArgumentList(ivsInfo)) 598 return failure(); 599 600 // Parse the body region. 601 Region *body = result.addRegion(); 602 auto &builder = parser.getBuilder(); 603 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 604 return parser.parseRegion(*body, ivsInfo, argTypes); 605 } 606 607 //===----------------------------------------------------------------------===// 608 // Test removing op with inner ops. 609 //===----------------------------------------------------------------------===// 610 611 namespace { 612 struct TestRemoveOpWithInnerOps 613 : public OpRewritePattern<TestOpWithRegionPattern> { 614 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 615 616 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 617 PatternRewriter &rewriter) const override { 618 rewriter.eraseOp(op); 619 return success(); 620 } 621 }; 622 } // end anonymous namespace 623 624 void TestOpWithRegionPattern::getCanonicalizationPatterns( 625 OwningRewritePatternList &results, MLIRContext *context) { 626 results.insert<TestRemoveOpWithInnerOps>(context); 627 } 628 629 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 630 return operand(); 631 } 632 633 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) { 634 return getValue(); 635 } 636 637 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 638 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 639 for (Value input : this->operands()) { 640 results.push_back(input); 641 } 642 return success(); 643 } 644 645 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 646 assert(operands.size() == 1); 647 if (operands.front()) { 648 setAttr("attr", operands.front()); 649 return getResult(); 650 } 651 return {}; 652 } 653 654 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 655 MLIRContext *, Optional<Location> location, ValueRange operands, 656 DictionaryAttr attributes, RegionRange regions, 657 SmallVectorImpl<Type> &inferredReturnTypes) { 658 if (operands[0].getType() != operands[1].getType()) { 659 return emitOptionalError(location, "operand type mismatch ", 660 operands[0].getType(), " vs ", 661 operands[1].getType()); 662 } 663 inferredReturnTypes.assign({operands[0].getType()}); 664 return success(); 665 } 666 667 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 668 MLIRContext *context, Optional<Location> location, ValueRange operands, 669 DictionaryAttr attributes, RegionRange regions, 670 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 671 // Create return type consisting of the last element of the first operand. 672 auto operandType = *operands.getTypes().begin(); 673 auto sval = operandType.dyn_cast<ShapedType>(); 674 if (!sval) { 675 return emitOptionalError(location, "only shaped type operands allowed"); 676 } 677 int64_t dim = 678 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 679 auto type = IntegerType::get(17, context); 680 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 681 return success(); 682 } 683 684 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 685 OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) { 686 shapes = SmallVector<Value, 1>{ 687 builder.createOrFold<DimOp>(getLoc(), getOperand(0), 0)}; 688 return success(); 689 } 690 691 //===----------------------------------------------------------------------===// 692 // Test SideEffect interfaces 693 //===----------------------------------------------------------------------===// 694 695 namespace { 696 /// A test resource for side effects. 697 struct TestResource : public SideEffects::Resource::Base<TestResource> { 698 StringRef getName() final { return "<Test>"; } 699 }; 700 } // end anonymous namespace 701 702 void SideEffectOp::getEffects( 703 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 704 // Check for an effects attribute on the op instance. 705 ArrayAttr effectsAttr = getAttrOfType<ArrayAttr>("effects"); 706 if (!effectsAttr) 707 return; 708 709 // If there is one, it is an array of dictionary attributes that hold 710 // information on the effects of this operation. 711 for (Attribute element : effectsAttr) { 712 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 713 714 // Get the specific memory effect. 715 MemoryEffects::Effect *effect = 716 StringSwitch<MemoryEffects::Effect *>( 717 effectElement.get("effect").cast<StringAttr>().getValue()) 718 .Case("allocate", MemoryEffects::Allocate::get()) 719 .Case("free", MemoryEffects::Free::get()) 720 .Case("read", MemoryEffects::Read::get()) 721 .Case("write", MemoryEffects::Write::get()); 722 723 // Check for a result to affect. 724 Value value; 725 if (effectElement.get("on_result")) 726 value = getResult(); 727 728 // Check for a non-default resource to use. 729 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 730 if (effectElement.get("test_resource")) 731 resource = TestResource::get(); 732 733 effects.emplace_back(effect, value, resource); 734 } 735 } 736 737 //===----------------------------------------------------------------------===// 738 // StringAttrPrettyNameOp 739 //===----------------------------------------------------------------------===// 740 741 // This op has fancy handling of its SSA result name. 742 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, 743 OperationState &result) { 744 // Add the result types. 745 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 746 result.addTypes(parser.getBuilder().getIntegerType(32)); 747 748 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 749 return failure(); 750 751 // If the attribute dictionary contains no 'names' attribute, infer it from 752 // the SSA name (if specified). 753 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 754 return attr.first == "names"; 755 }); 756 757 // If there was no name specified, check to see if there was a useful name 758 // specified in the asm file. 759 if (hadNames || parser.getNumResults() == 0) 760 return success(); 761 762 SmallVector<StringRef, 4> names; 763 auto *context = result.getContext(); 764 765 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 766 auto resultName = parser.getResultName(i); 767 StringRef nameStr; 768 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 769 nameStr = resultName.first; 770 771 names.push_back(nameStr); 772 } 773 774 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 775 result.attributes.push_back({Identifier::get("names", context), namesAttr}); 776 return success(); 777 } 778 779 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { 780 p << "test.string_attr_pretty_name"; 781 782 // Note that we only need to print the "name" attribute if the asmprinter 783 // result name disagrees with it. This can happen in strange cases, e.g. 784 // when there are conflicts. 785 bool namesDisagree = op.names().size() != op.getNumResults(); 786 787 SmallString<32> resultNameStr; 788 for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { 789 resultNameStr.clear(); 790 llvm::raw_svector_ostream tmpStream(resultNameStr); 791 p.printOperand(op.getResult(i), tmpStream); 792 793 auto expectedName = op.names()[i].dyn_cast<StringAttr>(); 794 if (!expectedName || 795 tmpStream.str().drop_front() != expectedName.getValue()) { 796 namesDisagree = true; 797 } 798 } 799 800 if (namesDisagree) 801 p.printOptionalAttrDictWithKeyword(op.getAttrs()); 802 else 803 p.printOptionalAttrDictWithKeyword(op.getAttrs(), {"names"}); 804 } 805 806 // We set the SSA name in the asm syntax to the contents of the name 807 // attribute. 808 void StringAttrPrettyNameOp::getAsmResultNames( 809 function_ref<void(Value, StringRef)> setNameFn) { 810 811 auto value = names(); 812 for (size_t i = 0, e = value.size(); i != e; ++i) 813 if (auto str = value[i].dyn_cast<StringAttr>()) 814 if (!str.getValue().empty()) 815 setNameFn(getResult(i), str.getValue()); 816 } 817 818 //===----------------------------------------------------------------------===// 819 // RegionIfOp 820 //===----------------------------------------------------------------------===// 821 822 static void print(OpAsmPrinter &p, RegionIfOp op) { 823 p << RegionIfOp::getOperationName() << " "; 824 p.printOperands(op.getOperands()); 825 p << ": " << op.getOperandTypes(); 826 p.printArrowTypeList(op.getResultTypes()); 827 p << " then"; 828 p.printRegion(op.thenRegion(), 829 /*printEntryBlockArgs=*/true, 830 /*printBlockTerminators=*/true); 831 p << " else"; 832 p.printRegion(op.elseRegion(), 833 /*printEntryBlockArgs=*/true, 834 /*printBlockTerminators=*/true); 835 p << " join"; 836 p.printRegion(op.joinRegion(), 837 /*printEntryBlockArgs=*/true, 838 /*printBlockTerminators=*/true); 839 } 840 841 static ParseResult parseRegionIfOp(OpAsmParser &parser, 842 OperationState &result) { 843 SmallVector<OpAsmParser::OperandType, 2> operandInfos; 844 SmallVector<Type, 2> operandTypes; 845 846 result.regions.reserve(3); 847 Region *thenRegion = result.addRegion(); 848 Region *elseRegion = result.addRegion(); 849 Region *joinRegion = result.addRegion(); 850 851 // Parse operand, type and arrow type lists. 852 if (parser.parseOperandList(operandInfos) || 853 parser.parseColonTypeList(operandTypes) || 854 parser.parseArrowTypeList(result.types)) 855 return failure(); 856 857 // Parse all attached regions. 858 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 859 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 860 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 861 return failure(); 862 863 return parser.resolveOperands(operandInfos, operandTypes, 864 parser.getCurrentLocation(), result.operands); 865 } 866 867 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 868 assert(index < 2 && "invalid region index"); 869 return getOperands(); 870 } 871 872 void RegionIfOp::getSuccessorRegions( 873 Optional<unsigned> index, ArrayRef<Attribute> operands, 874 SmallVectorImpl<RegionSuccessor> ®ions) { 875 // We always branch to the join region. 876 if (index.hasValue()) { 877 if (index.getValue() < 2) 878 regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs())); 879 else 880 regions.push_back(RegionSuccessor(getResults())); 881 return; 882 } 883 884 // The then and else regions are the entry regions of this op. 885 regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs())); 886 regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs())); 887 } 888 889 #include "TestOpEnums.cpp.inc" 890 #include "TestOpStructs.cpp.inc" 891 #include "TestTypeInterfaces.cpp.inc" 892 893 #define GET_OP_CLASSES 894 #include "TestOps.cpp.inc" 895