1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestTypes.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/IR/DialectImplementation.h" 13 #include "mlir/IR/Function.h" 14 #include "mlir/IR/Module.h" 15 #include "mlir/IR/PatternMatch.h" 16 #include "mlir/IR/TypeUtilities.h" 17 #include "mlir/Transforms/FoldUtils.h" 18 #include "mlir/Transforms/InliningUtils.h" 19 #include "llvm/ADT/SetVector.h" 20 #include "llvm/ADT/StringSwitch.h" 21 22 using namespace mlir; 23 24 void mlir::registerTestDialect(DialectRegistry ®istry) { 25 registry.insert<TestDialect>(); 26 } 27 28 //===----------------------------------------------------------------------===// 29 // TestDialect Interfaces 30 //===----------------------------------------------------------------------===// 31 32 namespace { 33 34 // Test support for interacting with the AsmPrinter. 35 struct TestOpAsmInterface : public OpAsmDialectInterface { 36 using OpAsmDialectInterface::OpAsmDialectInterface; 37 38 void getAsmResultNames(Operation *op, 39 OpAsmSetValueNameFn setNameFn) const final { 40 if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) 41 setNameFn(asmOp, "result"); 42 } 43 44 void getAsmBlockArgumentNames(Block *block, 45 OpAsmSetValueNameFn setNameFn) const final { 46 auto op = block->getParentOp(); 47 auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names"); 48 if (!arrayAttr) 49 return; 50 auto args = block->getArguments(); 51 auto e = std::min(arrayAttr.size(), args.size()); 52 for (unsigned i = 0; i < e; ++i) { 53 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 54 setNameFn(args[i], strAttr.getValue()); 55 } 56 } 57 }; 58 59 struct TestDialectFoldInterface : public DialectFoldInterface { 60 using DialectFoldInterface::DialectFoldInterface; 61 62 /// Registered hook to check if the given region, which is attached to an 63 /// operation that is *not* isolated from above, should be used when 64 /// materializing constants. 65 bool shouldMaterializeInto(Region *region) const final { 66 // If this is a one region operation, then insert into it. 67 return isa<OneRegionOp>(region->getParentOp()); 68 } 69 }; 70 71 /// This class defines the interface for handling inlining with standard 72 /// operations. 73 struct TestInlinerInterface : public DialectInlinerInterface { 74 using DialectInlinerInterface::DialectInlinerInterface; 75 76 //===--------------------------------------------------------------------===// 77 // Analysis Hooks 78 //===--------------------------------------------------------------------===// 79 80 bool isLegalToInline(Region *, Region *, BlockAndValueMapping &) const final { 81 // Inlining into test dialect regions is legal. 82 return true; 83 } 84 bool isLegalToInline(Operation *, Region *, 85 BlockAndValueMapping &) const final { 86 return true; 87 } 88 89 bool shouldAnalyzeRecursively(Operation *op) const final { 90 // Analyze recursively if this is not a functional region operation, it 91 // froms a separate functional scope. 92 return !isa<FunctionalRegionOp>(op); 93 } 94 95 //===--------------------------------------------------------------------===// 96 // Transformation Hooks 97 //===--------------------------------------------------------------------===// 98 99 /// Handle the given inlined terminator by replacing it with a new operation 100 /// as necessary. 101 void handleTerminator(Operation *op, 102 ArrayRef<Value> valuesToRepl) const final { 103 // Only handle "test.return" here. 104 auto returnOp = dyn_cast<TestReturnOp>(op); 105 if (!returnOp) 106 return; 107 108 // Replace the values directly with the return operands. 109 assert(returnOp.getNumOperands() == valuesToRepl.size()); 110 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 111 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 112 } 113 114 /// Attempt to materialize a conversion for a type mismatch between a call 115 /// from this dialect, and a callable region. This method should generate an 116 /// operation that takes 'input' as the only operand, and produces a single 117 /// result of 'resultType'. If a conversion can not be generated, nullptr 118 /// should be returned. 119 Operation *materializeCallConversion(OpBuilder &builder, Value input, 120 Type resultType, 121 Location conversionLoc) const final { 122 // Only allow conversion for i16/i32 types. 123 if (!(resultType.isSignlessInteger(16) || 124 resultType.isSignlessInteger(32)) || 125 !(input.getType().isSignlessInteger(16) || 126 input.getType().isSignlessInteger(32))) 127 return nullptr; 128 return builder.create<TestCastOp>(conversionLoc, resultType, input); 129 } 130 }; 131 } // end anonymous namespace 132 133 //===----------------------------------------------------------------------===// 134 // TestDialect 135 //===----------------------------------------------------------------------===// 136 137 void TestDialect::initialize() { 138 addOperations< 139 #define GET_OP_LIST 140 #include "TestOps.cpp.inc" 141 >(); 142 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 143 TestInlinerInterface>(); 144 addTypes<TestType, TestRecursiveType, 145 #define GET_TYPEDEF_LIST 146 #include "TestTypeDefs.cpp.inc" 147 >(); 148 allowUnknownOperations(); 149 } 150 151 static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser, 152 llvm::SetVector<Type> &stack) { 153 StringRef typeTag; 154 if (failed(parser.parseKeyword(&typeTag))) 155 return Type(); 156 157 auto genType = generatedTypeParser(ctxt, parser, typeTag); 158 if (genType != Type()) 159 return genType; 160 161 if (typeTag == "test_type") 162 return TestType::get(parser.getBuilder().getContext()); 163 164 if (typeTag != "test_rec") 165 return Type(); 166 167 StringRef name; 168 if (parser.parseLess() || parser.parseKeyword(&name)) 169 return Type(); 170 auto rec = TestRecursiveType::get(parser.getBuilder().getContext(), name); 171 172 // If this type already has been parsed above in the stack, expect just the 173 // name. 174 if (stack.contains(rec)) { 175 if (failed(parser.parseGreater())) 176 return Type(); 177 return rec; 178 } 179 180 // Otherwise, parse the body and update the type. 181 if (failed(parser.parseComma())) 182 return Type(); 183 stack.insert(rec); 184 Type subtype = parseTestType(ctxt, parser, stack); 185 stack.pop_back(); 186 if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype))) 187 return Type(); 188 189 return rec; 190 } 191 192 Type TestDialect::parseType(DialectAsmParser &parser) const { 193 llvm::SetVector<Type> stack; 194 return parseTestType(getContext(), parser, stack); 195 } 196 197 static void printTestType(Type type, DialectAsmPrinter &printer, 198 llvm::SetVector<Type> &stack) { 199 if (succeeded(generatedTypePrinter(type, printer))) 200 return; 201 if (type.isa<TestType>()) { 202 printer << "test_type"; 203 return; 204 } 205 206 auto rec = type.cast<TestRecursiveType>(); 207 printer << "test_rec<" << rec.getName(); 208 if (!stack.contains(rec)) { 209 printer << ", "; 210 stack.insert(rec); 211 printTestType(rec.getBody(), printer, stack); 212 stack.pop_back(); 213 } 214 printer << ">"; 215 } 216 217 void TestDialect::printType(Type type, DialectAsmPrinter &printer) const { 218 llvm::SetVector<Type> stack; 219 printTestType(type, printer, stack); 220 } 221 222 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 223 NamedAttribute namedAttr) { 224 if (namedAttr.first == "test.invalid_attr") 225 return op->emitError() << "invalid to use 'test.invalid_attr'"; 226 return success(); 227 } 228 229 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 230 unsigned regionIndex, 231 unsigned argIndex, 232 NamedAttribute namedAttr) { 233 if (namedAttr.first == "test.invalid_attr") 234 return op->emitError() << "invalid to use 'test.invalid_attr'"; 235 return success(); 236 } 237 238 LogicalResult 239 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 240 unsigned resultIndex, 241 NamedAttribute namedAttr) { 242 if (namedAttr.first == "test.invalid_attr") 243 return op->emitError() << "invalid to use 'test.invalid_attr'"; 244 return success(); 245 } 246 247 //===----------------------------------------------------------------------===// 248 // TestBranchOp 249 //===----------------------------------------------------------------------===// 250 251 Optional<MutableOperandRange> 252 TestBranchOp::getMutableSuccessorOperands(unsigned index) { 253 assert(index == 0 && "invalid successor index"); 254 return targetOperandsMutable(); 255 } 256 257 //===----------------------------------------------------------------------===// 258 // TestFoldToCallOp 259 //===----------------------------------------------------------------------===// 260 261 namespace { 262 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 263 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 264 265 LogicalResult matchAndRewrite(FoldToCallOp op, 266 PatternRewriter &rewriter) const override { 267 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(), 268 ValueRange()); 269 return success(); 270 } 271 }; 272 } // end anonymous namespace 273 274 void FoldToCallOp::getCanonicalizationPatterns( 275 OwningRewritePatternList &results, MLIRContext *context) { 276 results.insert<FoldToCallOpPattern>(context); 277 } 278 279 //===----------------------------------------------------------------------===// 280 // Test Format* operations 281 //===----------------------------------------------------------------------===// 282 283 //===----------------------------------------------------------------------===// 284 // Parsing 285 286 static ParseResult parseCustomDirectiveOperands( 287 OpAsmParser &parser, OpAsmParser::OperandType &operand, 288 Optional<OpAsmParser::OperandType> &optOperand, 289 SmallVectorImpl<OpAsmParser::OperandType> &varOperands) { 290 if (parser.parseOperand(operand)) 291 return failure(); 292 if (succeeded(parser.parseOptionalComma())) { 293 optOperand.emplace(); 294 if (parser.parseOperand(*optOperand)) 295 return failure(); 296 } 297 if (parser.parseArrow() || parser.parseLParen() || 298 parser.parseOperandList(varOperands) || parser.parseRParen()) 299 return failure(); 300 return success(); 301 } 302 static ParseResult 303 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 304 Type &optOperandType, 305 SmallVectorImpl<Type> &varOperandTypes) { 306 if (parser.parseColon()) 307 return failure(); 308 309 if (parser.parseType(operandType)) 310 return failure(); 311 if (succeeded(parser.parseOptionalComma())) { 312 if (parser.parseType(optOperandType)) 313 return failure(); 314 } 315 if (parser.parseArrow() || parser.parseLParen() || 316 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 317 return failure(); 318 return success(); 319 } 320 static ParseResult 321 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 322 Type optOperandType, 323 const SmallVectorImpl<Type> &varOperandTypes) { 324 if (parser.parseKeyword("type_refs_capture")) 325 return failure(); 326 327 Type operandType2, optOperandType2; 328 SmallVector<Type, 1> varOperandTypes2; 329 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 330 varOperandTypes2)) 331 return failure(); 332 333 if (operandType != operandType2 || optOperandType != optOperandType2 || 334 varOperandTypes != varOperandTypes2) 335 return failure(); 336 337 return success(); 338 } 339 static ParseResult parseCustomDirectiveOperandsAndTypes( 340 OpAsmParser &parser, OpAsmParser::OperandType &operand, 341 Optional<OpAsmParser::OperandType> &optOperand, 342 SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType, 343 Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) { 344 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 345 parseCustomDirectiveResults(parser, operandType, optOperandType, 346 varOperandTypes)) 347 return failure(); 348 return success(); 349 } 350 static ParseResult parseCustomDirectiveRegions( 351 OpAsmParser &parser, Region ®ion, 352 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 353 if (parser.parseRegion(region)) 354 return failure(); 355 if (failed(parser.parseOptionalComma())) 356 return success(); 357 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 358 if (parser.parseRegion(*varRegion)) 359 return failure(); 360 varRegions.emplace_back(std::move(varRegion)); 361 return success(); 362 } 363 static ParseResult 364 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 365 SmallVectorImpl<Block *> &varSuccessors) { 366 if (parser.parseSuccessor(successor)) 367 return failure(); 368 if (failed(parser.parseOptionalComma())) 369 return success(); 370 Block *varSuccessor; 371 if (parser.parseSuccessor(varSuccessor)) 372 return failure(); 373 varSuccessors.append(2, varSuccessor); 374 return success(); 375 } 376 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 377 IntegerAttr &attr, 378 IntegerAttr &optAttr) { 379 if (parser.parseAttribute(attr)) 380 return failure(); 381 if (succeeded(parser.parseOptionalComma())) { 382 if (parser.parseAttribute(optAttr)) 383 return failure(); 384 } 385 return success(); 386 } 387 388 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, 389 NamedAttrList &attrs) { 390 return parser.parseOptionalAttrDict(attrs); 391 } 392 393 //===----------------------------------------------------------------------===// 394 // Printing 395 396 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 397 Value operand, Value optOperand, 398 OperandRange varOperands) { 399 printer << operand; 400 if (optOperand) 401 printer << ", " << optOperand; 402 printer << " -> (" << varOperands << ")"; 403 } 404 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 405 Type operandType, Type optOperandType, 406 TypeRange varOperandTypes) { 407 printer << " : " << operandType; 408 if (optOperandType) 409 printer << ", " << optOperandType; 410 printer << " -> (" << varOperandTypes << ")"; 411 } 412 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 413 Operation *op, Type operandType, 414 Type optOperandType, 415 TypeRange varOperandTypes) { 416 printer << " type_refs_capture "; 417 printCustomDirectiveResults(printer, op, operandType, optOperandType, 418 varOperandTypes); 419 } 420 static void printCustomDirectiveOperandsAndTypes( 421 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 422 OperandRange varOperands, Type operandType, Type optOperandType, 423 TypeRange varOperandTypes) { 424 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 425 printCustomDirectiveResults(printer, op, operandType, optOperandType, 426 varOperandTypes); 427 } 428 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 429 Region ®ion, 430 MutableArrayRef<Region> varRegions) { 431 printer.printRegion(region); 432 if (!varRegions.empty()) { 433 printer << ", "; 434 for (Region ®ion : varRegions) 435 printer.printRegion(region); 436 } 437 } 438 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 439 Block *successor, 440 SuccessorRange varSuccessors) { 441 printer << successor; 442 if (!varSuccessors.empty()) 443 printer << ", " << varSuccessors.front(); 444 } 445 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 446 Attribute attribute, 447 Attribute optAttribute) { 448 printer << attribute; 449 if (optAttribute) 450 printer << ", " << optAttribute; 451 } 452 453 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 454 MutableDictionaryAttr attrs) { 455 printer.printOptionalAttrDict(attrs.getAttrs()); 456 } 457 //===----------------------------------------------------------------------===// 458 // Test IsolatedRegionOp - parse passthrough region arguments. 459 //===----------------------------------------------------------------------===// 460 461 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, 462 OperationState &result) { 463 OpAsmParser::OperandType argInfo; 464 Type argType = parser.getBuilder().getIndexType(); 465 466 // Parse the input operand. 467 if (parser.parseOperand(argInfo) || 468 parser.resolveOperand(argInfo, argType, result.operands)) 469 return failure(); 470 471 // Parse the body region, and reuse the operand info as the argument info. 472 Region *body = result.addRegion(); 473 return parser.parseRegion(*body, argInfo, argType, 474 /*enableNameShadowing=*/true); 475 } 476 477 static void print(OpAsmPrinter &p, IsolatedRegionOp op) { 478 p << "test.isolated_region "; 479 p.printOperand(op.getOperand()); 480 p.shadowRegionArgs(op.region(), op.getOperand()); 481 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 482 } 483 484 //===----------------------------------------------------------------------===// 485 // Test SSACFGRegionOp 486 //===----------------------------------------------------------------------===// 487 488 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 489 return RegionKind::SSACFG; 490 } 491 492 //===----------------------------------------------------------------------===// 493 // Test GraphRegionOp 494 //===----------------------------------------------------------------------===// 495 496 static ParseResult parseGraphRegionOp(OpAsmParser &parser, 497 OperationState &result) { 498 // Parse the body region, and reuse the operand info as the argument info. 499 Region *body = result.addRegion(); 500 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 501 } 502 503 static void print(OpAsmPrinter &p, GraphRegionOp op) { 504 p << "test.graph_region "; 505 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 506 } 507 508 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 509 return RegionKind::Graph; 510 } 511 512 //===----------------------------------------------------------------------===// 513 // Test AffineScopeOp 514 //===----------------------------------------------------------------------===// 515 516 static ParseResult parseAffineScopeOp(OpAsmParser &parser, 517 OperationState &result) { 518 // Parse the body region, and reuse the operand info as the argument info. 519 Region *body = result.addRegion(); 520 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 521 } 522 523 static void print(OpAsmPrinter &p, AffineScopeOp op) { 524 p << "test.affine_scope "; 525 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 526 } 527 528 //===----------------------------------------------------------------------===// 529 // Test parser. 530 //===----------------------------------------------------------------------===// 531 532 static ParseResult parseWrappedKeywordOp(OpAsmParser &parser, 533 OperationState &result) { 534 StringRef keyword; 535 if (parser.parseKeyword(&keyword)) 536 return failure(); 537 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 538 return success(); 539 } 540 541 static void print(OpAsmPrinter &p, WrappedKeywordOp op) { 542 p << WrappedKeywordOp::getOperationName() << " " << op.keyword(); 543 } 544 545 //===----------------------------------------------------------------------===// 546 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 547 548 static ParseResult parseWrappingRegionOp(OpAsmParser &parser, 549 OperationState &result) { 550 if (parser.parseKeyword("wraps")) 551 return failure(); 552 553 // Parse the wrapped op in a region 554 Region &body = *result.addRegion(); 555 body.push_back(new Block); 556 Block &block = body.back(); 557 Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); 558 if (!wrapped_op) 559 return failure(); 560 561 // Create a return terminator in the inner region, pass as operand to the 562 // terminator the returned values from the wrapped operation. 563 SmallVector<Value, 8> return_operands(wrapped_op->getResults()); 564 OpBuilder builder(parser.getBuilder().getContext()); 565 builder.setInsertionPointToEnd(&block); 566 builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands); 567 568 // Get the results type for the wrapping op from the terminator operands. 569 Operation &return_op = body.back().back(); 570 result.types.append(return_op.operand_type_begin(), 571 return_op.operand_type_end()); 572 573 // Use the location of the wrapped op for the "test.wrapping_region" op. 574 result.location = wrapped_op->getLoc(); 575 576 return success(); 577 } 578 579 static void print(OpAsmPrinter &p, WrappingRegionOp op) { 580 p << op.getOperationName() << " wraps "; 581 p.printGenericOp(&op.region().front().front()); 582 } 583 584 //===----------------------------------------------------------------------===// 585 // Test PolyForOp - parse list of region arguments. 586 //===----------------------------------------------------------------------===// 587 588 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { 589 SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 590 // Parse list of region arguments without a delimiter. 591 if (parser.parseRegionArgumentList(ivsInfo)) 592 return failure(); 593 594 // Parse the body region. 595 Region *body = result.addRegion(); 596 auto &builder = parser.getBuilder(); 597 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 598 return parser.parseRegion(*body, ivsInfo, argTypes); 599 } 600 601 //===----------------------------------------------------------------------===// 602 // Test removing op with inner ops. 603 //===----------------------------------------------------------------------===// 604 605 namespace { 606 struct TestRemoveOpWithInnerOps 607 : public OpRewritePattern<TestOpWithRegionPattern> { 608 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 609 610 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 611 PatternRewriter &rewriter) const override { 612 rewriter.eraseOp(op); 613 return success(); 614 } 615 }; 616 } // end anonymous namespace 617 618 void TestOpWithRegionPattern::getCanonicalizationPatterns( 619 OwningRewritePatternList &results, MLIRContext *context) { 620 results.insert<TestRemoveOpWithInnerOps>(context); 621 } 622 623 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 624 return operand(); 625 } 626 627 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) { 628 return getValue(); 629 } 630 631 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 632 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 633 for (Value input : this->operands()) { 634 results.push_back(input); 635 } 636 return success(); 637 } 638 639 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 640 assert(operands.size() == 1); 641 if (operands.front()) { 642 setAttr("attr", operands.front()); 643 return getResult(); 644 } 645 return {}; 646 } 647 648 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 649 MLIRContext *, Optional<Location> location, ValueRange operands, 650 DictionaryAttr attributes, RegionRange regions, 651 SmallVectorImpl<Type> &inferredReturnTypes) { 652 if (operands[0].getType() != operands[1].getType()) { 653 return emitOptionalError(location, "operand type mismatch ", 654 operands[0].getType(), " vs ", 655 operands[1].getType()); 656 } 657 inferredReturnTypes.assign({operands[0].getType()}); 658 return success(); 659 } 660 661 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 662 MLIRContext *context, Optional<Location> location, ValueRange operands, 663 DictionaryAttr attributes, RegionRange regions, 664 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 665 // Create return type consisting of the last element of the first operand. 666 auto operandType = *operands.getTypes().begin(); 667 auto sval = operandType.dyn_cast<ShapedType>(); 668 if (!sval) { 669 return emitOptionalError(location, "only shaped type operands allowed"); 670 } 671 int64_t dim = 672 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 673 auto type = IntegerType::get(17, context); 674 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 675 return success(); 676 } 677 678 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 679 OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) { 680 shapes = SmallVector<Value, 1>{ 681 builder.createOrFold<DimOp>(getLoc(), getOperand(0), 0)}; 682 return success(); 683 } 684 685 //===----------------------------------------------------------------------===// 686 // Test SideEffect interfaces 687 //===----------------------------------------------------------------------===// 688 689 namespace { 690 /// A test resource for side effects. 691 struct TestResource : public SideEffects::Resource::Base<TestResource> { 692 StringRef getName() final { return "<Test>"; } 693 }; 694 } // end anonymous namespace 695 696 void SideEffectOp::getEffects( 697 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 698 // Check for an effects attribute on the op instance. 699 ArrayAttr effectsAttr = getAttrOfType<ArrayAttr>("effects"); 700 if (!effectsAttr) 701 return; 702 703 // If there is one, it is an array of dictionary attributes that hold 704 // information on the effects of this operation. 705 for (Attribute element : effectsAttr) { 706 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 707 708 // Get the specific memory effect. 709 MemoryEffects::Effect *effect = 710 StringSwitch<MemoryEffects::Effect *>( 711 effectElement.get("effect").cast<StringAttr>().getValue()) 712 .Case("allocate", MemoryEffects::Allocate::get()) 713 .Case("free", MemoryEffects::Free::get()) 714 .Case("read", MemoryEffects::Read::get()) 715 .Case("write", MemoryEffects::Write::get()); 716 717 // Check for a result to affect. 718 Value value; 719 if (effectElement.get("on_result")) 720 value = getResult(); 721 722 // Check for a non-default resource to use. 723 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 724 if (effectElement.get("test_resource")) 725 resource = TestResource::get(); 726 727 effects.emplace_back(effect, value, resource); 728 } 729 } 730 731 //===----------------------------------------------------------------------===// 732 // StringAttrPrettyNameOp 733 //===----------------------------------------------------------------------===// 734 735 // This op has fancy handling of its SSA result name. 736 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, 737 OperationState &result) { 738 // Add the result types. 739 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 740 result.addTypes(parser.getBuilder().getIntegerType(32)); 741 742 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 743 return failure(); 744 745 // If the attribute dictionary contains no 'names' attribute, infer it from 746 // the SSA name (if specified). 747 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 748 return attr.first == "names"; 749 }); 750 751 // If there was no name specified, check to see if there was a useful name 752 // specified in the asm file. 753 if (hadNames || parser.getNumResults() == 0) 754 return success(); 755 756 SmallVector<StringRef, 4> names; 757 auto *context = result.getContext(); 758 759 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 760 auto resultName = parser.getResultName(i); 761 StringRef nameStr; 762 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 763 nameStr = resultName.first; 764 765 names.push_back(nameStr); 766 } 767 768 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 769 result.attributes.push_back({Identifier::get("names", context), namesAttr}); 770 return success(); 771 } 772 773 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { 774 p << "test.string_attr_pretty_name"; 775 776 // Note that we only need to print the "name" attribute if the asmprinter 777 // result name disagrees with it. This can happen in strange cases, e.g. 778 // when there are conflicts. 779 bool namesDisagree = op.names().size() != op.getNumResults(); 780 781 SmallString<32> resultNameStr; 782 for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { 783 resultNameStr.clear(); 784 llvm::raw_svector_ostream tmpStream(resultNameStr); 785 p.printOperand(op.getResult(i), tmpStream); 786 787 auto expectedName = op.names()[i].dyn_cast<StringAttr>(); 788 if (!expectedName || 789 tmpStream.str().drop_front() != expectedName.getValue()) { 790 namesDisagree = true; 791 } 792 } 793 794 if (namesDisagree) 795 p.printOptionalAttrDictWithKeyword(op.getAttrs()); 796 else 797 p.printOptionalAttrDictWithKeyword(op.getAttrs(), {"names"}); 798 } 799 800 // We set the SSA name in the asm syntax to the contents of the name 801 // attribute. 802 void StringAttrPrettyNameOp::getAsmResultNames( 803 function_ref<void(Value, StringRef)> setNameFn) { 804 805 auto value = names(); 806 for (size_t i = 0, e = value.size(); i != e; ++i) 807 if (auto str = value[i].dyn_cast<StringAttr>()) 808 if (!str.getValue().empty()) 809 setNameFn(getResult(i), str.getValue()); 810 } 811 812 //===----------------------------------------------------------------------===// 813 // RegionIfOp 814 //===----------------------------------------------------------------------===// 815 816 static void print(OpAsmPrinter &p, RegionIfOp op) { 817 p << RegionIfOp::getOperationName() << " "; 818 p.printOperands(op.getOperands()); 819 p << ": " << op.getOperandTypes(); 820 p.printArrowTypeList(op.getResultTypes()); 821 p << " then"; 822 p.printRegion(op.thenRegion(), 823 /*printEntryBlockArgs=*/true, 824 /*printBlockTerminators=*/true); 825 p << " else"; 826 p.printRegion(op.elseRegion(), 827 /*printEntryBlockArgs=*/true, 828 /*printBlockTerminators=*/true); 829 p << " join"; 830 p.printRegion(op.joinRegion(), 831 /*printEntryBlockArgs=*/true, 832 /*printBlockTerminators=*/true); 833 } 834 835 static ParseResult parseRegionIfOp(OpAsmParser &parser, 836 OperationState &result) { 837 SmallVector<OpAsmParser::OperandType, 2> operandInfos; 838 SmallVector<Type, 2> operandTypes; 839 840 result.regions.reserve(3); 841 Region *thenRegion = result.addRegion(); 842 Region *elseRegion = result.addRegion(); 843 Region *joinRegion = result.addRegion(); 844 845 // Parse operand, type and arrow type lists. 846 if (parser.parseOperandList(operandInfos) || 847 parser.parseColonTypeList(operandTypes) || 848 parser.parseArrowTypeList(result.types)) 849 return failure(); 850 851 // Parse all attached regions. 852 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 853 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 854 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 855 return failure(); 856 857 return parser.resolveOperands(operandInfos, operandTypes, 858 parser.getCurrentLocation(), result.operands); 859 } 860 861 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 862 assert(index < 2 && "invalid region index"); 863 return getOperands(); 864 } 865 866 void RegionIfOp::getSuccessorRegions( 867 Optional<unsigned> index, ArrayRef<Attribute> operands, 868 SmallVectorImpl<RegionSuccessor> ®ions) { 869 // We always branch to the join region. 870 if (index.hasValue()) { 871 if (index.getValue() < 2) 872 regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs())); 873 else 874 regions.push_back(RegionSuccessor(getResults())); 875 return; 876 } 877 878 // The then and else regions are the entry regions of this op. 879 regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs())); 880 regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs())); 881 } 882 883 #include "TestOpEnums.cpp.inc" 884 #include "TestOpStructs.cpp.inc" 885 #include "TestTypeInterfaces.cpp.inc" 886 887 #define GET_OP_CLASSES 888 #include "TestOps.cpp.inc" 889