1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestTypes.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/IR/DialectImplementation.h" 13 #include "mlir/IR/Function.h" 14 #include "mlir/IR/Module.h" 15 #include "mlir/IR/PatternMatch.h" 16 #include "mlir/IR/TypeUtilities.h" 17 #include "mlir/Transforms/FoldUtils.h" 18 #include "mlir/Transforms/InliningUtils.h" 19 #include "llvm/ADT/SetVector.h" 20 #include "llvm/ADT/StringSwitch.h" 21 22 using namespace mlir; 23 24 void mlir::registerTestDialect(DialectRegistry ®istry) { 25 registry.insert<TestDialect>(); 26 } 27 28 //===----------------------------------------------------------------------===// 29 // TestDialect Interfaces 30 //===----------------------------------------------------------------------===// 31 32 namespace { 33 34 // Test support for interacting with the AsmPrinter. 35 struct TestOpAsmInterface : public OpAsmDialectInterface { 36 using OpAsmDialectInterface::OpAsmDialectInterface; 37 38 void getAsmResultNames(Operation *op, 39 OpAsmSetValueNameFn setNameFn) const final { 40 if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) 41 setNameFn(asmOp, "result"); 42 } 43 44 void getAsmBlockArgumentNames(Block *block, 45 OpAsmSetValueNameFn setNameFn) const final { 46 auto op = block->getParentOp(); 47 auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names"); 48 if (!arrayAttr) 49 return; 50 auto args = block->getArguments(); 51 auto e = std::min(arrayAttr.size(), args.size()); 52 for (unsigned i = 0; i < e; ++i) { 53 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 54 setNameFn(args[i], strAttr.getValue()); 55 } 56 } 57 }; 58 59 struct TestDialectFoldInterface : public DialectFoldInterface { 60 using DialectFoldInterface::DialectFoldInterface; 61 62 /// Registered hook to check if the given region, which is attached to an 63 /// operation that is *not* isolated from above, should be used when 64 /// materializing constants. 65 bool shouldMaterializeInto(Region *region) const final { 66 // If this is a one region operation, then insert into it. 67 return isa<OneRegionOp>(region->getParentOp()); 68 } 69 }; 70 71 /// This class defines the interface for handling inlining with standard 72 /// operations. 73 struct TestInlinerInterface : public DialectInlinerInterface { 74 using DialectInlinerInterface::DialectInlinerInterface; 75 76 //===--------------------------------------------------------------------===// 77 // Analysis Hooks 78 //===--------------------------------------------------------------------===// 79 80 bool isLegalToInline(Region *, Region *, BlockAndValueMapping &) const final { 81 // Inlining into test dialect regions is legal. 82 return true; 83 } 84 bool isLegalToInline(Operation *, Region *, 85 BlockAndValueMapping &) const final { 86 return true; 87 } 88 89 bool shouldAnalyzeRecursively(Operation *op) const final { 90 // Analyze recursively if this is not a functional region operation, it 91 // froms a separate functional scope. 92 return !isa<FunctionalRegionOp>(op); 93 } 94 95 //===--------------------------------------------------------------------===// 96 // Transformation Hooks 97 //===--------------------------------------------------------------------===// 98 99 /// Handle the given inlined terminator by replacing it with a new operation 100 /// as necessary. 101 void handleTerminator(Operation *op, 102 ArrayRef<Value> valuesToRepl) const final { 103 // Only handle "test.return" here. 104 auto returnOp = dyn_cast<TestReturnOp>(op); 105 if (!returnOp) 106 return; 107 108 // Replace the values directly with the return operands. 109 assert(returnOp.getNumOperands() == valuesToRepl.size()); 110 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 111 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 112 } 113 114 /// Attempt to materialize a conversion for a type mismatch between a call 115 /// from this dialect, and a callable region. This method should generate an 116 /// operation that takes 'input' as the only operand, and produces a single 117 /// result of 'resultType'. If a conversion can not be generated, nullptr 118 /// should be returned. 119 Operation *materializeCallConversion(OpBuilder &builder, Value input, 120 Type resultType, 121 Location conversionLoc) const final { 122 // Only allow conversion for i16/i32 types. 123 if (!(resultType.isSignlessInteger(16) || 124 resultType.isSignlessInteger(32)) || 125 !(input.getType().isSignlessInteger(16) || 126 input.getType().isSignlessInteger(32))) 127 return nullptr; 128 return builder.create<TestCastOp>(conversionLoc, resultType, input); 129 } 130 }; 131 } // end anonymous namespace 132 133 //===----------------------------------------------------------------------===// 134 // TestDialect 135 //===----------------------------------------------------------------------===// 136 137 void TestDialect::initialize() { 138 addOperations< 139 #define GET_OP_LIST 140 #include "TestOps.cpp.inc" 141 >(); 142 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 143 TestInlinerInterface>(); 144 addTypes<TestType, TestRecursiveType>(); 145 allowUnknownOperations(); 146 } 147 148 static Type parseTestType(DialectAsmParser &parser, 149 llvm::SetVector<Type> &stack) { 150 StringRef typeTag; 151 if (failed(parser.parseKeyword(&typeTag))) 152 return Type(); 153 154 if (typeTag == "test_type") 155 return TestType::get(parser.getBuilder().getContext()); 156 157 if (typeTag != "test_rec") 158 return Type(); 159 160 StringRef name; 161 if (parser.parseLess() || parser.parseKeyword(&name)) 162 return Type(); 163 auto rec = TestRecursiveType::get(parser.getBuilder().getContext(), name); 164 165 // If this type already has been parsed above in the stack, expect just the 166 // name. 167 if (stack.contains(rec)) { 168 if (failed(parser.parseGreater())) 169 return Type(); 170 return rec; 171 } 172 173 // Otherwise, parse the body and update the type. 174 if (failed(parser.parseComma())) 175 return Type(); 176 stack.insert(rec); 177 Type subtype = parseTestType(parser, stack); 178 stack.pop_back(); 179 if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype))) 180 return Type(); 181 182 return rec; 183 } 184 185 Type TestDialect::parseType(DialectAsmParser &parser) const { 186 llvm::SetVector<Type> stack; 187 return parseTestType(parser, stack); 188 } 189 190 static void printTestType(Type type, DialectAsmPrinter &printer, 191 llvm::SetVector<Type> &stack) { 192 if (type.isa<TestType>()) { 193 printer << "test_type"; 194 return; 195 } 196 197 auto rec = type.cast<TestRecursiveType>(); 198 printer << "test_rec<" << rec.getName(); 199 if (!stack.contains(rec)) { 200 printer << ", "; 201 stack.insert(rec); 202 printTestType(rec.getBody(), printer, stack); 203 stack.pop_back(); 204 } 205 printer << ">"; 206 } 207 208 void TestDialect::printType(Type type, DialectAsmPrinter &printer) const { 209 llvm::SetVector<Type> stack; 210 printTestType(type, printer, stack); 211 } 212 213 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 214 NamedAttribute namedAttr) { 215 if (namedAttr.first == "test.invalid_attr") 216 return op->emitError() << "invalid to use 'test.invalid_attr'"; 217 return success(); 218 } 219 220 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 221 unsigned regionIndex, 222 unsigned argIndex, 223 NamedAttribute namedAttr) { 224 if (namedAttr.first == "test.invalid_attr") 225 return op->emitError() << "invalid to use 'test.invalid_attr'"; 226 return success(); 227 } 228 229 LogicalResult 230 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 231 unsigned resultIndex, 232 NamedAttribute namedAttr) { 233 if (namedAttr.first == "test.invalid_attr") 234 return op->emitError() << "invalid to use 'test.invalid_attr'"; 235 return success(); 236 } 237 238 //===----------------------------------------------------------------------===// 239 // TestBranchOp 240 //===----------------------------------------------------------------------===// 241 242 Optional<MutableOperandRange> 243 TestBranchOp::getMutableSuccessorOperands(unsigned index) { 244 assert(index == 0 && "invalid successor index"); 245 return targetOperandsMutable(); 246 } 247 248 //===----------------------------------------------------------------------===// 249 // TestFoldToCallOp 250 //===----------------------------------------------------------------------===// 251 252 namespace { 253 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 254 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 255 256 LogicalResult matchAndRewrite(FoldToCallOp op, 257 PatternRewriter &rewriter) const override { 258 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(), 259 ValueRange()); 260 return success(); 261 } 262 }; 263 } // end anonymous namespace 264 265 void FoldToCallOp::getCanonicalizationPatterns( 266 OwningRewritePatternList &results, MLIRContext *context) { 267 results.insert<FoldToCallOpPattern>(context); 268 } 269 270 //===----------------------------------------------------------------------===// 271 // Test Format* operations 272 //===----------------------------------------------------------------------===// 273 274 //===----------------------------------------------------------------------===// 275 // Parsing 276 277 static ParseResult parseCustomDirectiveOperands( 278 OpAsmParser &parser, OpAsmParser::OperandType &operand, 279 Optional<OpAsmParser::OperandType> &optOperand, 280 SmallVectorImpl<OpAsmParser::OperandType> &varOperands) { 281 if (parser.parseOperand(operand)) 282 return failure(); 283 if (succeeded(parser.parseOptionalComma())) { 284 optOperand.emplace(); 285 if (parser.parseOperand(*optOperand)) 286 return failure(); 287 } 288 if (parser.parseArrow() || parser.parseLParen() || 289 parser.parseOperandList(varOperands) || parser.parseRParen()) 290 return failure(); 291 return success(); 292 } 293 static ParseResult 294 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 295 Type &optOperandType, 296 SmallVectorImpl<Type> &varOperandTypes) { 297 if (parser.parseColon()) 298 return failure(); 299 300 if (parser.parseType(operandType)) 301 return failure(); 302 if (succeeded(parser.parseOptionalComma())) { 303 if (parser.parseType(optOperandType)) 304 return failure(); 305 } 306 if (parser.parseArrow() || parser.parseLParen() || 307 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 308 return failure(); 309 return success(); 310 } 311 static ParseResult 312 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 313 Type optOperandType, 314 const SmallVectorImpl<Type> &varOperandTypes) { 315 if (parser.parseKeyword("type_refs_capture")) 316 return failure(); 317 318 Type operandType2, optOperandType2; 319 SmallVector<Type, 1> varOperandTypes2; 320 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 321 varOperandTypes2)) 322 return failure(); 323 324 if (operandType != operandType2 || optOperandType != optOperandType2 || 325 varOperandTypes != varOperandTypes2) 326 return failure(); 327 328 return success(); 329 } 330 static ParseResult parseCustomDirectiveOperandsAndTypes( 331 OpAsmParser &parser, OpAsmParser::OperandType &operand, 332 Optional<OpAsmParser::OperandType> &optOperand, 333 SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType, 334 Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) { 335 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 336 parseCustomDirectiveResults(parser, operandType, optOperandType, 337 varOperandTypes)) 338 return failure(); 339 return success(); 340 } 341 static ParseResult parseCustomDirectiveRegions( 342 OpAsmParser &parser, Region ®ion, 343 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 344 if (parser.parseRegion(region)) 345 return failure(); 346 if (failed(parser.parseOptionalComma())) 347 return success(); 348 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 349 if (parser.parseRegion(*varRegion)) 350 return failure(); 351 varRegions.emplace_back(std::move(varRegion)); 352 return success(); 353 } 354 static ParseResult 355 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 356 SmallVectorImpl<Block *> &varSuccessors) { 357 if (parser.parseSuccessor(successor)) 358 return failure(); 359 if (failed(parser.parseOptionalComma())) 360 return success(); 361 Block *varSuccessor; 362 if (parser.parseSuccessor(varSuccessor)) 363 return failure(); 364 varSuccessors.append(2, varSuccessor); 365 return success(); 366 } 367 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 368 IntegerAttr &attr, 369 IntegerAttr &optAttr) { 370 if (parser.parseAttribute(attr)) 371 return failure(); 372 if (succeeded(parser.parseOptionalComma())) { 373 if (parser.parseAttribute(optAttr)) 374 return failure(); 375 } 376 return success(); 377 } 378 379 //===----------------------------------------------------------------------===// 380 // Printing 381 382 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Value operand, 383 Value optOperand, 384 OperandRange varOperands) { 385 printer << operand; 386 if (optOperand) 387 printer << ", " << optOperand; 388 printer << " -> (" << varOperands << ")"; 389 } 390 static void printCustomDirectiveResults(OpAsmPrinter &printer, Type operandType, 391 Type optOperandType, 392 TypeRange varOperandTypes) { 393 printer << " : " << operandType; 394 if (optOperandType) 395 printer << ", " << optOperandType; 396 printer << " -> (" << varOperandTypes << ")"; 397 } 398 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 399 Type operandType, 400 Type optOperandType, 401 TypeRange varOperandTypes) { 402 printer << " type_refs_capture "; 403 printCustomDirectiveResults(printer, operandType, optOperandType, 404 varOperandTypes); 405 } 406 static void 407 printCustomDirectiveOperandsAndTypes(OpAsmPrinter &printer, Value operand, 408 Value optOperand, OperandRange varOperands, 409 Type operandType, Type optOperandType, 410 TypeRange varOperandTypes) { 411 printCustomDirectiveOperands(printer, operand, optOperand, varOperands); 412 printCustomDirectiveResults(printer, operandType, optOperandType, 413 varOperandTypes); 414 } 415 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Region ®ion, 416 MutableArrayRef<Region> varRegions) { 417 printer.printRegion(region); 418 if (!varRegions.empty()) { 419 printer << ", "; 420 for (Region ®ion : varRegions) 421 printer.printRegion(region); 422 } 423 } 424 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, 425 Block *successor, 426 SuccessorRange varSuccessors) { 427 printer << successor; 428 if (!varSuccessors.empty()) 429 printer << ", " << varSuccessors.front(); 430 } 431 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, 432 Attribute attribute, 433 Attribute optAttribute) { 434 printer << attribute; 435 if (optAttribute) 436 printer << ", " << optAttribute; 437 } 438 439 //===----------------------------------------------------------------------===// 440 // Test IsolatedRegionOp - parse passthrough region arguments. 441 //===----------------------------------------------------------------------===// 442 443 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, 444 OperationState &result) { 445 OpAsmParser::OperandType argInfo; 446 Type argType = parser.getBuilder().getIndexType(); 447 448 // Parse the input operand. 449 if (parser.parseOperand(argInfo) || 450 parser.resolveOperand(argInfo, argType, result.operands)) 451 return failure(); 452 453 // Parse the body region, and reuse the operand info as the argument info. 454 Region *body = result.addRegion(); 455 return parser.parseRegion(*body, argInfo, argType, 456 /*enableNameShadowing=*/true); 457 } 458 459 static void print(OpAsmPrinter &p, IsolatedRegionOp op) { 460 p << "test.isolated_region "; 461 p.printOperand(op.getOperand()); 462 p.shadowRegionArgs(op.region(), op.getOperand()); 463 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 464 } 465 466 //===----------------------------------------------------------------------===// 467 // Test SSACFGRegionOp 468 //===----------------------------------------------------------------------===// 469 470 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 471 return RegionKind::SSACFG; 472 } 473 474 //===----------------------------------------------------------------------===// 475 // Test GraphRegionOp 476 //===----------------------------------------------------------------------===// 477 478 static ParseResult parseGraphRegionOp(OpAsmParser &parser, 479 OperationState &result) { 480 // Parse the body region, and reuse the operand info as the argument info. 481 Region *body = result.addRegion(); 482 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 483 } 484 485 static void print(OpAsmPrinter &p, GraphRegionOp op) { 486 p << "test.graph_region "; 487 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 488 } 489 490 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 491 return RegionKind::Graph; 492 } 493 494 //===----------------------------------------------------------------------===// 495 // Test AffineScopeOp 496 //===----------------------------------------------------------------------===// 497 498 static ParseResult parseAffineScopeOp(OpAsmParser &parser, 499 OperationState &result) { 500 // Parse the body region, and reuse the operand info as the argument info. 501 Region *body = result.addRegion(); 502 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 503 } 504 505 static void print(OpAsmPrinter &p, AffineScopeOp op) { 506 p << "test.affine_scope "; 507 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 508 } 509 510 //===----------------------------------------------------------------------===// 511 // Test parser. 512 //===----------------------------------------------------------------------===// 513 514 static ParseResult parseWrappedKeywordOp(OpAsmParser &parser, 515 OperationState &result) { 516 StringRef keyword; 517 if (parser.parseKeyword(&keyword)) 518 return failure(); 519 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 520 return success(); 521 } 522 523 static void print(OpAsmPrinter &p, WrappedKeywordOp op) { 524 p << WrappedKeywordOp::getOperationName() << " " << op.keyword(); 525 } 526 527 //===----------------------------------------------------------------------===// 528 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 529 530 static ParseResult parseWrappingRegionOp(OpAsmParser &parser, 531 OperationState &result) { 532 if (parser.parseKeyword("wraps")) 533 return failure(); 534 535 // Parse the wrapped op in a region 536 Region &body = *result.addRegion(); 537 body.push_back(new Block); 538 Block &block = body.back(); 539 Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); 540 if (!wrapped_op) 541 return failure(); 542 543 // Create a return terminator in the inner region, pass as operand to the 544 // terminator the returned values from the wrapped operation. 545 SmallVector<Value, 8> return_operands(wrapped_op->getResults()); 546 OpBuilder builder(parser.getBuilder().getContext()); 547 builder.setInsertionPointToEnd(&block); 548 builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands); 549 550 // Get the results type for the wrapping op from the terminator operands. 551 Operation &return_op = body.back().back(); 552 result.types.append(return_op.operand_type_begin(), 553 return_op.operand_type_end()); 554 555 // Use the location of the wrapped op for the "test.wrapping_region" op. 556 result.location = wrapped_op->getLoc(); 557 558 return success(); 559 } 560 561 static void print(OpAsmPrinter &p, WrappingRegionOp op) { 562 p << op.getOperationName() << " wraps "; 563 p.printGenericOp(&op.region().front().front()); 564 } 565 566 //===----------------------------------------------------------------------===// 567 // Test PolyForOp - parse list of region arguments. 568 //===----------------------------------------------------------------------===// 569 570 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { 571 SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 572 // Parse list of region arguments without a delimiter. 573 if (parser.parseRegionArgumentList(ivsInfo)) 574 return failure(); 575 576 // Parse the body region. 577 Region *body = result.addRegion(); 578 auto &builder = parser.getBuilder(); 579 SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 580 return parser.parseRegion(*body, ivsInfo, argTypes); 581 } 582 583 //===----------------------------------------------------------------------===// 584 // Test removing op with inner ops. 585 //===----------------------------------------------------------------------===// 586 587 namespace { 588 struct TestRemoveOpWithInnerOps 589 : public OpRewritePattern<TestOpWithRegionPattern> { 590 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 591 592 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 593 PatternRewriter &rewriter) const override { 594 rewriter.eraseOp(op); 595 return success(); 596 } 597 }; 598 } // end anonymous namespace 599 600 void TestOpWithRegionPattern::getCanonicalizationPatterns( 601 OwningRewritePatternList &results, MLIRContext *context) { 602 results.insert<TestRemoveOpWithInnerOps>(context); 603 } 604 605 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 606 return operand(); 607 } 608 609 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 610 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 611 for (Value input : this->operands()) { 612 results.push_back(input); 613 } 614 return success(); 615 } 616 617 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 618 assert(operands.size() == 1); 619 if (operands.front()) { 620 setAttr("attr", operands.front()); 621 return getResult(); 622 } 623 return {}; 624 } 625 626 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 627 MLIRContext *, Optional<Location> location, ValueRange operands, 628 DictionaryAttr attributes, RegionRange regions, 629 SmallVectorImpl<Type> &inferredReturnTypes) { 630 if (operands[0].getType() != operands[1].getType()) { 631 return emitOptionalError(location, "operand type mismatch ", 632 operands[0].getType(), " vs ", 633 operands[1].getType()); 634 } 635 inferredReturnTypes.assign({operands[0].getType()}); 636 return success(); 637 } 638 639 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 640 MLIRContext *context, Optional<Location> location, ValueRange operands, 641 DictionaryAttr attributes, RegionRange regions, 642 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 643 // Create return type consisting of the last element of the first operand. 644 auto operandType = *operands.getTypes().begin(); 645 auto sval = operandType.dyn_cast<ShapedType>(); 646 if (!sval) { 647 return emitOptionalError(location, "only shaped type operands allowed"); 648 } 649 int64_t dim = 650 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 651 auto type = IntegerType::get(17, context); 652 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 653 return success(); 654 } 655 656 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 657 OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) { 658 shapes = SmallVector<Value, 1>{ 659 builder.createOrFold<DimOp>(getLoc(), getOperand(0), 0)}; 660 return success(); 661 } 662 663 //===----------------------------------------------------------------------===// 664 // Test SideEffect interfaces 665 //===----------------------------------------------------------------------===// 666 667 namespace { 668 /// A test resource for side effects. 669 struct TestResource : public SideEffects::Resource::Base<TestResource> { 670 StringRef getName() final { return "<Test>"; } 671 }; 672 } // end anonymous namespace 673 674 void SideEffectOp::getEffects( 675 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 676 // Check for an effects attribute on the op instance. 677 ArrayAttr effectsAttr = getAttrOfType<ArrayAttr>("effects"); 678 if (!effectsAttr) 679 return; 680 681 // If there is one, it is an array of dictionary attributes that hold 682 // information on the effects of this operation. 683 for (Attribute element : effectsAttr) { 684 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 685 686 // Get the specific memory effect. 687 MemoryEffects::Effect *effect = 688 llvm::StringSwitch<MemoryEffects::Effect *>( 689 effectElement.get("effect").cast<StringAttr>().getValue()) 690 .Case("allocate", MemoryEffects::Allocate::get()) 691 .Case("free", MemoryEffects::Free::get()) 692 .Case("read", MemoryEffects::Read::get()) 693 .Case("write", MemoryEffects::Write::get()); 694 695 // Check for a result to affect. 696 Value value; 697 if (effectElement.get("on_result")) 698 value = getResult(); 699 700 // Check for a non-default resource to use. 701 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 702 if (effectElement.get("test_resource")) 703 resource = TestResource::get(); 704 705 effects.emplace_back(effect, value, resource); 706 } 707 } 708 709 //===----------------------------------------------------------------------===// 710 // StringAttrPrettyNameOp 711 //===----------------------------------------------------------------------===// 712 713 // This op has fancy handling of its SSA result name. 714 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, 715 OperationState &result) { 716 // Add the result types. 717 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 718 result.addTypes(parser.getBuilder().getIntegerType(32)); 719 720 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 721 return failure(); 722 723 // If the attribute dictionary contains no 'names' attribute, infer it from 724 // the SSA name (if specified). 725 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 726 return attr.first == "names"; 727 }); 728 729 // If there was no name specified, check to see if there was a useful name 730 // specified in the asm file. 731 if (hadNames || parser.getNumResults() == 0) 732 return success(); 733 734 SmallVector<StringRef, 4> names; 735 auto *context = result.getContext(); 736 737 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 738 auto resultName = parser.getResultName(i); 739 StringRef nameStr; 740 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 741 nameStr = resultName.first; 742 743 names.push_back(nameStr); 744 } 745 746 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 747 result.attributes.push_back({Identifier::get("names", context), namesAttr}); 748 return success(); 749 } 750 751 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { 752 p << "test.string_attr_pretty_name"; 753 754 // Note that we only need to print the "name" attribute if the asmprinter 755 // result name disagrees with it. This can happen in strange cases, e.g. 756 // when there are conflicts. 757 bool namesDisagree = op.names().size() != op.getNumResults(); 758 759 SmallString<32> resultNameStr; 760 for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { 761 resultNameStr.clear(); 762 llvm::raw_svector_ostream tmpStream(resultNameStr); 763 p.printOperand(op.getResult(i), tmpStream); 764 765 auto expectedName = op.names()[i].dyn_cast<StringAttr>(); 766 if (!expectedName || 767 tmpStream.str().drop_front() != expectedName.getValue()) { 768 namesDisagree = true; 769 } 770 } 771 772 if (namesDisagree) 773 p.printOptionalAttrDictWithKeyword(op.getAttrs()); 774 else 775 p.printOptionalAttrDictWithKeyword(op.getAttrs(), {"names"}); 776 } 777 778 // We set the SSA name in the asm syntax to the contents of the name 779 // attribute. 780 void StringAttrPrettyNameOp::getAsmResultNames( 781 function_ref<void(Value, StringRef)> setNameFn) { 782 783 auto value = names(); 784 for (size_t i = 0, e = value.size(); i != e; ++i) 785 if (auto str = value[i].dyn_cast<StringAttr>()) 786 if (!str.getValue().empty()) 787 setNameFn(getResult(i), str.getValue()); 788 } 789 790 //===----------------------------------------------------------------------===// 791 // RegionIfOp 792 //===----------------------------------------------------------------------===// 793 794 static void print(OpAsmPrinter &p, RegionIfOp op) { 795 p << RegionIfOp::getOperationName() << " "; 796 p.printOperands(op.getOperands()); 797 p << ": " << op.getOperandTypes(); 798 p.printArrowTypeList(op.getResultTypes()); 799 p << " then"; 800 p.printRegion(op.thenRegion(), 801 /*printEntryBlockArgs=*/true, 802 /*printBlockTerminators=*/true); 803 p << " else"; 804 p.printRegion(op.elseRegion(), 805 /*printEntryBlockArgs=*/true, 806 /*printBlockTerminators=*/true); 807 p << " join"; 808 p.printRegion(op.joinRegion(), 809 /*printEntryBlockArgs=*/true, 810 /*printBlockTerminators=*/true); 811 } 812 813 static ParseResult parseRegionIfOp(OpAsmParser &parser, 814 OperationState &result) { 815 SmallVector<OpAsmParser::OperandType, 2> operandInfos; 816 SmallVector<Type, 2> operandTypes; 817 818 result.regions.reserve(3); 819 Region *thenRegion = result.addRegion(); 820 Region *elseRegion = result.addRegion(); 821 Region *joinRegion = result.addRegion(); 822 823 // Parse operand, type and arrow type lists. 824 if (parser.parseOperandList(operandInfos) || 825 parser.parseColonTypeList(operandTypes) || 826 parser.parseArrowTypeList(result.types)) 827 return failure(); 828 829 // Parse all attached regions. 830 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 831 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 832 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 833 return failure(); 834 835 return parser.resolveOperands(operandInfos, operandTypes, 836 parser.getCurrentLocation(), result.operands); 837 } 838 839 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 840 assert(index < 2 && "invalid region index"); 841 return getOperands(); 842 } 843 844 void RegionIfOp::getSuccessorRegions( 845 Optional<unsigned> index, ArrayRef<Attribute> operands, 846 SmallVectorImpl<RegionSuccessor> ®ions) { 847 // We always branch to the join region. 848 if (index.hasValue()) { 849 if (index.getValue() < 2) 850 regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs())); 851 else 852 regions.push_back(RegionSuccessor(getResults())); 853 return; 854 } 855 856 // The then and else regions are the entry regions of this op. 857 regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs())); 858 regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs())); 859 } 860 861 #include "TestOpEnums.cpp.inc" 862 #include "TestOpStructs.cpp.inc" 863 #include "TestTypeInterfaces.cpp.inc" 864 865 #define GET_OP_CLASSES 866 #include "TestOps.cpp.inc" 867