1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestAttributes.h" 11 #include "TestInterfaces.h" 12 #include "TestTypes.h" 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/DLTI/DLTI.h" 15 #include "mlir/Dialect/Func/IR/FuncOps.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/IR/BuiltinAttributes.h" 18 #include "mlir/IR/BuiltinOps.h" 19 #include "mlir/IR/Diagnostics.h" 20 #include "mlir/IR/DialectImplementation.h" 21 #include "mlir/IR/ExtensibleDialect.h" 22 #include "mlir/IR/MLIRContext.h" 23 #include "mlir/IR/OperationSupport.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "mlir/IR/TypeUtilities.h" 26 #include "mlir/IR/Verifier.h" 27 #include "mlir/Interfaces/InferIntRangeInterface.h" 28 #include "mlir/Reducer/ReductionPatternInterface.h" 29 #include "mlir/Transforms/FoldUtils.h" 30 #include "mlir/Transforms/InliningUtils.h" 31 #include "llvm/ADT/SmallString.h" 32 #include "llvm/ADT/StringExtras.h" 33 #include "llvm/ADT/StringSwitch.h" 34 35 // Include this before the using namespace lines below to 36 // test that we don't have namespace dependencies. 37 #include "TestOpsDialect.cpp.inc" 38 39 using namespace mlir; 40 using namespace test; 41 42 void test::registerTestDialect(DialectRegistry ®istry) { 43 registry.insert<TestDialect>(); 44 } 45 46 //===----------------------------------------------------------------------===// 47 // TestDialect Interfaces 48 //===----------------------------------------------------------------------===// 49 50 namespace { 51 52 /// Testing the correctness of some traits. 53 static_assert( 54 llvm::is_detected<OpTrait::has_implicit_terminator_t, 55 SingleBlockImplicitTerminatorOp>::value, 56 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp"); 57 static_assert(OpTrait::hasSingleBlockImplicitTerminator< 58 SingleBlockImplicitTerminatorOp>::value, 59 "hasSingleBlockImplicitTerminator does not match " 60 "SingleBlockImplicitTerminatorOp"); 61 62 // Test support for interacting with the AsmPrinter. 63 struct TestOpAsmInterface : public OpAsmDialectInterface { 64 using OpAsmDialectInterface::OpAsmDialectInterface; 65 66 AliasResult getAlias(Attribute attr, raw_ostream &os) const final { 67 StringAttr strAttr = attr.dyn_cast<StringAttr>(); 68 if (!strAttr) 69 return AliasResult::NoAlias; 70 71 // Check the contents of the string attribute to see what the test alias 72 // should be named. 73 Optional<StringRef> aliasName = 74 StringSwitch<Optional<StringRef>>(strAttr.getValue()) 75 .Case("alias_test:dot_in_name", StringRef("test.alias")) 76 .Case("alias_test:trailing_digit", StringRef("test_alias0")) 77 .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) 78 .Case("alias_test:sanitize_conflict_a", 79 StringRef("test_alias_conflict0")) 80 .Case("alias_test:sanitize_conflict_b", 81 StringRef("test_alias_conflict0_")) 82 .Case("alias_test:tensor_encoding", StringRef("test_encoding")) 83 .Default(llvm::None); 84 if (!aliasName) 85 return AliasResult::NoAlias; 86 87 os << *aliasName; 88 return AliasResult::FinalAlias; 89 } 90 91 AliasResult getAlias(Type type, raw_ostream &os) const final { 92 if (auto tupleType = type.dyn_cast<TupleType>()) { 93 if (tupleType.size() > 0 && 94 llvm::all_of(tupleType.getTypes(), [](Type elemType) { 95 return elemType.isa<SimpleAType>(); 96 })) { 97 os << "test_tuple"; 98 return AliasResult::FinalAlias; 99 } 100 } 101 if (auto intType = type.dyn_cast<TestIntegerType>()) { 102 if (intType.getSignedness() == 103 TestIntegerType::SignednessSemantics::Unsigned && 104 intType.getWidth() == 8) { 105 os << "test_ui8"; 106 return AliasResult::FinalAlias; 107 } 108 } 109 return AliasResult::NoAlias; 110 } 111 }; 112 113 struct TestDialectFoldInterface : public DialectFoldInterface { 114 using DialectFoldInterface::DialectFoldInterface; 115 116 /// Registered hook to check if the given region, which is attached to an 117 /// operation that is *not* isolated from above, should be used when 118 /// materializing constants. 119 bool shouldMaterializeInto(Region *region) const final { 120 // If this is a one region operation, then insert into it. 121 return isa<OneRegionOp>(region->getParentOp()); 122 } 123 }; 124 125 /// This class defines the interface for handling inlining with standard 126 /// operations. 127 struct TestInlinerInterface : public DialectInlinerInterface { 128 using DialectInlinerInterface::DialectInlinerInterface; 129 130 //===--------------------------------------------------------------------===// 131 // Analysis Hooks 132 //===--------------------------------------------------------------------===// 133 134 bool isLegalToInline(Operation *call, Operation *callable, 135 bool wouldBeCloned) const final { 136 // Don't allow inlining calls that are marked `noinline`. 137 return !call->hasAttr("noinline"); 138 } 139 bool isLegalToInline(Region *, Region *, bool, 140 BlockAndValueMapping &) const final { 141 // Inlining into test dialect regions is legal. 142 return true; 143 } 144 bool isLegalToInline(Operation *, Region *, bool, 145 BlockAndValueMapping &) const final { 146 return true; 147 } 148 149 bool shouldAnalyzeRecursively(Operation *op) const final { 150 // Analyze recursively if this is not a functional region operation, it 151 // froms a separate functional scope. 152 return !isa<FunctionalRegionOp>(op); 153 } 154 155 //===--------------------------------------------------------------------===// 156 // Transformation Hooks 157 //===--------------------------------------------------------------------===// 158 159 /// Handle the given inlined terminator by replacing it with a new operation 160 /// as necessary. 161 void handleTerminator(Operation *op, 162 ArrayRef<Value> valuesToRepl) const final { 163 // Only handle "test.return" here. 164 auto returnOp = dyn_cast<TestReturnOp>(op); 165 if (!returnOp) 166 return; 167 168 // Replace the values directly with the return operands. 169 assert(returnOp.getNumOperands() == valuesToRepl.size()); 170 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 171 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 172 } 173 174 /// Attempt to materialize a conversion for a type mismatch between a call 175 /// from this dialect, and a callable region. This method should generate an 176 /// operation that takes 'input' as the only operand, and produces a single 177 /// result of 'resultType'. If a conversion can not be generated, nullptr 178 /// should be returned. 179 Operation *materializeCallConversion(OpBuilder &builder, Value input, 180 Type resultType, 181 Location conversionLoc) const final { 182 // Only allow conversion for i16/i32 types. 183 if (!(resultType.isSignlessInteger(16) || 184 resultType.isSignlessInteger(32)) || 185 !(input.getType().isSignlessInteger(16) || 186 input.getType().isSignlessInteger(32))) 187 return nullptr; 188 return builder.create<TestCastOp>(conversionLoc, resultType, input); 189 } 190 191 void processInlinedCallBlocks( 192 Operation *call, 193 iterator_range<Region::iterator> inlinedBlocks) const final { 194 if (!isa<ConversionCallOp>(call)) 195 return; 196 197 // Set attributed on all ops in the inlined blocks. 198 for (Block &block : inlinedBlocks) { 199 block.walk([&](Operation *op) { 200 op->setAttr("inlined_conversion", UnitAttr::get(call->getContext())); 201 }); 202 } 203 } 204 }; 205 206 struct TestReductionPatternInterface : public DialectReductionPatternInterface { 207 public: 208 TestReductionPatternInterface(Dialect *dialect) 209 : DialectReductionPatternInterface(dialect) {} 210 211 void populateReductionPatterns(RewritePatternSet &patterns) const final { 212 populateTestReductionPatterns(patterns); 213 } 214 }; 215 216 } // namespace 217 218 //===----------------------------------------------------------------------===// 219 // Dynamic operations 220 //===----------------------------------------------------------------------===// 221 222 std::unique_ptr<DynamicOpDefinition> getDynamicGenericOp(TestDialect *dialect) { 223 return DynamicOpDefinition::get( 224 "dynamic_generic", dialect, [](Operation *op) { return success(); }, 225 [](Operation *op) { return success(); }); 226 } 227 228 std::unique_ptr<DynamicOpDefinition> 229 getDynamicOneOperandTwoResultsOp(TestDialect *dialect) { 230 return DynamicOpDefinition::get( 231 "dynamic_one_operand_two_results", dialect, 232 [](Operation *op) { 233 if (op->getNumOperands() != 1) { 234 op->emitOpError() 235 << "expected 1 operand, but had " << op->getNumOperands(); 236 return failure(); 237 } 238 if (op->getNumResults() != 2) { 239 op->emitOpError() 240 << "expected 2 results, but had " << op->getNumResults(); 241 return failure(); 242 } 243 return success(); 244 }, 245 [](Operation *op) { return success(); }); 246 } 247 248 std::unique_ptr<DynamicOpDefinition> 249 getDynamicCustomParserPrinterOp(TestDialect *dialect) { 250 auto verifier = [](Operation *op) { 251 if (op->getNumOperands() == 0 && op->getNumResults() == 0) 252 return success(); 253 op->emitError() << "operation should have no operands and no results"; 254 return failure(); 255 }; 256 auto regionVerifier = [](Operation *op) { return success(); }; 257 258 auto parser = [](OpAsmParser &parser, OperationState &state) { 259 return parser.parseKeyword("custom_keyword"); 260 }; 261 262 auto printer = [](Operation *op, OpAsmPrinter &printer, llvm::StringRef) { 263 printer << op->getName() << " custom_keyword"; 264 }; 265 266 return DynamicOpDefinition::get("dynamic_custom_parser_printer", dialect, 267 verifier, regionVerifier, parser, printer); 268 } 269 270 //===----------------------------------------------------------------------===// 271 // TestDialect 272 //===----------------------------------------------------------------------===// 273 274 static void testSideEffectOpGetEffect( 275 Operation *op, 276 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects); 277 278 // This is the implementation of a dialect fallback for `TestEffectOpInterface`. 279 struct TestOpEffectInterfaceFallback 280 : public TestEffectOpInterface::FallbackModel< 281 TestOpEffectInterfaceFallback> { 282 static bool classof(Operation *op) { 283 bool isSupportedOp = 284 op->getName().getStringRef() == "test.unregistered_side_effect_op"; 285 assert(isSupportedOp && "Unexpected dispatch"); 286 return isSupportedOp; 287 } 288 289 void 290 getEffects(Operation *op, 291 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 292 &effects) const { 293 testSideEffectOpGetEffect(op, effects); 294 } 295 }; 296 297 void TestDialect::initialize() { 298 registerAttributes(); 299 registerTypes(); 300 addOperations< 301 #define GET_OP_LIST 302 #include "TestOps.cpp.inc" 303 >(); 304 registerDynamicOp(getDynamicGenericOp(this)); 305 registerDynamicOp(getDynamicOneOperandTwoResultsOp(this)); 306 registerDynamicOp(getDynamicCustomParserPrinterOp(this)); 307 308 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 309 TestInlinerInterface, TestReductionPatternInterface>(); 310 allowUnknownOperations(); 311 312 // Instantiate our fallback op interface that we'll use on specific 313 // unregistered op. 314 fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback; 315 } 316 TestDialect::~TestDialect() { 317 delete static_cast<TestOpEffectInterfaceFallback *>( 318 fallbackEffectOpInterfaces); 319 } 320 321 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, 322 Type type, Location loc) { 323 return builder.create<TestOpConstant>(loc, type, value); 324 } 325 326 ::mlir::LogicalResult FormatInferType2Op::inferReturnTypes( 327 ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, 328 ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, 329 ::mlir::RegionRange regions, 330 ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { 331 inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)}); 332 return ::mlir::success(); 333 } 334 335 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, 336 OperationName opName) { 337 if (opName.getIdentifier() == "test.unregistered_side_effect_op" && 338 typeID == TypeID::get<TestEffectOpInterface>()) 339 return fallbackEffectOpInterfaces; 340 return nullptr; 341 } 342 343 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 344 NamedAttribute namedAttr) { 345 if (namedAttr.getName() == "test.invalid_attr") 346 return op->emitError() << "invalid to use 'test.invalid_attr'"; 347 return success(); 348 } 349 350 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 351 unsigned regionIndex, 352 unsigned argIndex, 353 NamedAttribute namedAttr) { 354 if (namedAttr.getName() == "test.invalid_attr") 355 return op->emitError() << "invalid to use 'test.invalid_attr'"; 356 return success(); 357 } 358 359 LogicalResult 360 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 361 unsigned resultIndex, 362 NamedAttribute namedAttr) { 363 if (namedAttr.getName() == "test.invalid_attr") 364 return op->emitError() << "invalid to use 'test.invalid_attr'"; 365 return success(); 366 } 367 368 Optional<Dialect::ParseOpHook> 369 TestDialect::getParseOperationHook(StringRef opName) const { 370 if (opName == "test.dialect_custom_printer") { 371 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 372 return parser.parseKeyword("custom_format"); 373 }}; 374 } 375 if (opName == "test.dialect_custom_format_fallback") { 376 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 377 return parser.parseKeyword("custom_format_fallback"); 378 }}; 379 } 380 if (opName == "test.dialect_custom_printer.with.dot") { 381 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 382 return ParseResult::success(); 383 }}; 384 } 385 return None; 386 } 387 388 llvm::unique_function<void(Operation *, OpAsmPrinter &)> 389 TestDialect::getOperationPrinter(Operation *op) const { 390 StringRef opName = op->getName().getStringRef(); 391 if (opName == "test.dialect_custom_printer") { 392 return [](Operation *op, OpAsmPrinter &printer) { 393 printer.getStream() << " custom_format"; 394 }; 395 } 396 if (opName == "test.dialect_custom_format_fallback") { 397 return [](Operation *op, OpAsmPrinter &printer) { 398 printer.getStream() << " custom_format_fallback"; 399 }; 400 } 401 return {}; 402 } 403 404 //===----------------------------------------------------------------------===// 405 // TestBranchOp 406 //===----------------------------------------------------------------------===// 407 408 SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) { 409 assert(index == 0 && "invalid successor index"); 410 return SuccessorOperands(getTargetOperandsMutable()); 411 } 412 413 //===----------------------------------------------------------------------===// 414 // TestProducingBranchOp 415 //===----------------------------------------------------------------------===// 416 417 SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) { 418 assert(index <= 1 && "invalid successor index"); 419 if (index == 1) 420 return SuccessorOperands(getFirstOperandsMutable()); 421 return SuccessorOperands(getSecondOperandsMutable()); 422 } 423 424 //===----------------------------------------------------------------------===// 425 // TestProducingBranchOp 426 //===----------------------------------------------------------------------===// 427 428 SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) { 429 assert(index <= 1 && "invalid successor index"); 430 if (index == 0) 431 return SuccessorOperands(0, getSuccessOperandsMutable()); 432 return SuccessorOperands(1, getErrorOperandsMutable()); 433 } 434 435 //===----------------------------------------------------------------------===// 436 // TestDialectCanonicalizerOp 437 //===----------------------------------------------------------------------===// 438 439 static LogicalResult 440 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op, 441 PatternRewriter &rewriter) { 442 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 443 op, rewriter.getI32IntegerAttr(42)); 444 return success(); 445 } 446 447 void TestDialect::getCanonicalizationPatterns( 448 RewritePatternSet &results) const { 449 results.add(&dialectCanonicalizationPattern); 450 } 451 452 //===----------------------------------------------------------------------===// 453 // TestCallOp 454 //===----------------------------------------------------------------------===// 455 456 LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 457 // Check that the callee attribute was specified. 458 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee"); 459 if (!fnAttr) 460 return emitOpError("requires a 'callee' symbol reference attribute"); 461 if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr)) 462 return emitOpError() << "'" << fnAttr.getValue() 463 << "' does not reference a valid function"; 464 return success(); 465 } 466 467 //===----------------------------------------------------------------------===// 468 // TestFoldToCallOp 469 //===----------------------------------------------------------------------===// 470 471 namespace { 472 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 473 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 474 475 LogicalResult matchAndRewrite(FoldToCallOp op, 476 PatternRewriter &rewriter) const override { 477 rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(), 478 op.getCalleeAttr(), ValueRange()); 479 return success(); 480 } 481 }; 482 } // namespace 483 484 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results, 485 MLIRContext *context) { 486 results.add<FoldToCallOpPattern>(context); 487 } 488 489 //===----------------------------------------------------------------------===// 490 // Test Format* operations 491 //===----------------------------------------------------------------------===// 492 493 //===----------------------------------------------------------------------===// 494 // Parsing 495 496 static ParseResult parseCustomOptionalOperand( 497 OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &optOperand) { 498 if (succeeded(parser.parseOptionalLParen())) { 499 optOperand.emplace(); 500 if (parser.parseOperand(*optOperand) || parser.parseRParen()) 501 return failure(); 502 } 503 return success(); 504 } 505 506 static ParseResult parseCustomDirectiveOperands( 507 OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, 508 Optional<OpAsmParser::UnresolvedOperand> &optOperand, 509 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands) { 510 if (parser.parseOperand(operand)) 511 return failure(); 512 if (succeeded(parser.parseOptionalComma())) { 513 optOperand.emplace(); 514 if (parser.parseOperand(*optOperand)) 515 return failure(); 516 } 517 if (parser.parseArrow() || parser.parseLParen() || 518 parser.parseOperandList(varOperands) || parser.parseRParen()) 519 return failure(); 520 return success(); 521 } 522 static ParseResult 523 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 524 Type &optOperandType, 525 SmallVectorImpl<Type> &varOperandTypes) { 526 if (parser.parseColon()) 527 return failure(); 528 529 if (parser.parseType(operandType)) 530 return failure(); 531 if (succeeded(parser.parseOptionalComma())) { 532 if (parser.parseType(optOperandType)) 533 return failure(); 534 } 535 if (parser.parseArrow() || parser.parseLParen() || 536 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 537 return failure(); 538 return success(); 539 } 540 static ParseResult 541 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 542 Type optOperandType, 543 const SmallVectorImpl<Type> &varOperandTypes) { 544 if (parser.parseKeyword("type_refs_capture")) 545 return failure(); 546 547 Type operandType2, optOperandType2; 548 SmallVector<Type, 1> varOperandTypes2; 549 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 550 varOperandTypes2)) 551 return failure(); 552 553 if (operandType != operandType2 || optOperandType != optOperandType2 || 554 varOperandTypes != varOperandTypes2) 555 return failure(); 556 557 return success(); 558 } 559 static ParseResult parseCustomDirectiveOperandsAndTypes( 560 OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, 561 Optional<OpAsmParser::UnresolvedOperand> &optOperand, 562 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands, 563 Type &operandType, Type &optOperandType, 564 SmallVectorImpl<Type> &varOperandTypes) { 565 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 566 parseCustomDirectiveResults(parser, operandType, optOperandType, 567 varOperandTypes)) 568 return failure(); 569 return success(); 570 } 571 static ParseResult parseCustomDirectiveRegions( 572 OpAsmParser &parser, Region ®ion, 573 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 574 if (parser.parseRegion(region)) 575 return failure(); 576 if (failed(parser.parseOptionalComma())) 577 return success(); 578 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 579 if (parser.parseRegion(*varRegion)) 580 return failure(); 581 varRegions.emplace_back(std::move(varRegion)); 582 return success(); 583 } 584 static ParseResult 585 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 586 SmallVectorImpl<Block *> &varSuccessors) { 587 if (parser.parseSuccessor(successor)) 588 return failure(); 589 if (failed(parser.parseOptionalComma())) 590 return success(); 591 Block *varSuccessor; 592 if (parser.parseSuccessor(varSuccessor)) 593 return failure(); 594 varSuccessors.append(2, varSuccessor); 595 return success(); 596 } 597 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 598 IntegerAttr &attr, 599 IntegerAttr &optAttr) { 600 if (parser.parseAttribute(attr)) 601 return failure(); 602 if (succeeded(parser.parseOptionalComma())) { 603 if (parser.parseAttribute(optAttr)) 604 return failure(); 605 } 606 return success(); 607 } 608 609 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, 610 NamedAttrList &attrs) { 611 return parser.parseOptionalAttrDict(attrs); 612 } 613 static ParseResult parseCustomDirectiveOptionalOperandRef( 614 OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &optOperand) { 615 int64_t operandCount = 0; 616 if (parser.parseInteger(operandCount)) 617 return failure(); 618 bool expectedOptionalOperand = operandCount == 0; 619 return success(expectedOptionalOperand != optOperand.has_value()); 620 } 621 622 //===----------------------------------------------------------------------===// 623 // Printing 624 625 static void printCustomOptionalOperand(OpAsmPrinter &printer, Operation *, 626 Value optOperand) { 627 if (optOperand) 628 printer << "(" << optOperand << ") "; 629 } 630 631 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 632 Value operand, Value optOperand, 633 OperandRange varOperands) { 634 printer << operand; 635 if (optOperand) 636 printer << ", " << optOperand; 637 printer << " -> (" << varOperands << ")"; 638 } 639 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 640 Type operandType, Type optOperandType, 641 TypeRange varOperandTypes) { 642 printer << " : " << operandType; 643 if (optOperandType) 644 printer << ", " << optOperandType; 645 printer << " -> (" << varOperandTypes << ")"; 646 } 647 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 648 Operation *op, Type operandType, 649 Type optOperandType, 650 TypeRange varOperandTypes) { 651 printer << " type_refs_capture "; 652 printCustomDirectiveResults(printer, op, operandType, optOperandType, 653 varOperandTypes); 654 } 655 static void printCustomDirectiveOperandsAndTypes( 656 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 657 OperandRange varOperands, Type operandType, Type optOperandType, 658 TypeRange varOperandTypes) { 659 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 660 printCustomDirectiveResults(printer, op, operandType, optOperandType, 661 varOperandTypes); 662 } 663 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 664 Region ®ion, 665 MutableArrayRef<Region> varRegions) { 666 printer.printRegion(region); 667 if (!varRegions.empty()) { 668 printer << ", "; 669 for (Region ®ion : varRegions) 670 printer.printRegion(region); 671 } 672 } 673 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 674 Block *successor, 675 SuccessorRange varSuccessors) { 676 printer << successor; 677 if (!varSuccessors.empty()) 678 printer << ", " << varSuccessors.front(); 679 } 680 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 681 Attribute attribute, 682 Attribute optAttribute) { 683 printer << attribute; 684 if (optAttribute) 685 printer << ", " << optAttribute; 686 } 687 688 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 689 DictionaryAttr attrs) { 690 printer.printOptionalAttrDict(attrs.getValue()); 691 } 692 693 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, 694 Operation *op, 695 Value optOperand) { 696 printer << (optOperand ? "1" : "0"); 697 } 698 699 //===----------------------------------------------------------------------===// 700 // Test IsolatedRegionOp - parse passthrough region arguments. 701 //===----------------------------------------------------------------------===// 702 703 ParseResult IsolatedRegionOp::parse(OpAsmParser &parser, 704 OperationState &result) { 705 // Parse the input operand. 706 OpAsmParser::Argument argInfo; 707 argInfo.type = parser.getBuilder().getIndexType(); 708 if (parser.parseOperand(argInfo.ssaName) || 709 parser.resolveOperand(argInfo.ssaName, argInfo.type, result.operands)) 710 return failure(); 711 712 // Parse the body region, and reuse the operand info as the argument info. 713 Region *body = result.addRegion(); 714 return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/true); 715 } 716 717 void IsolatedRegionOp::print(OpAsmPrinter &p) { 718 p << "test.isolated_region "; 719 p.printOperand(getOperand()); 720 p.shadowRegionArgs(getRegion(), getOperand()); 721 p << ' '; 722 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); 723 } 724 725 //===----------------------------------------------------------------------===// 726 // Test SSACFGRegionOp 727 //===----------------------------------------------------------------------===// 728 729 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 730 return RegionKind::SSACFG; 731 } 732 733 //===----------------------------------------------------------------------===// 734 // Test GraphRegionOp 735 //===----------------------------------------------------------------------===// 736 737 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 738 return RegionKind::Graph; 739 } 740 741 //===----------------------------------------------------------------------===// 742 // Test AffineScopeOp 743 //===----------------------------------------------------------------------===// 744 745 ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) { 746 // Parse the body region, and reuse the operand info as the argument info. 747 Region *body = result.addRegion(); 748 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 749 } 750 751 void AffineScopeOp::print(OpAsmPrinter &p) { 752 p << "test.affine_scope "; 753 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); 754 } 755 756 //===----------------------------------------------------------------------===// 757 // Test parser. 758 //===----------------------------------------------------------------------===// 759 760 ParseResult ParseIntegerLiteralOp::parse(OpAsmParser &parser, 761 OperationState &result) { 762 if (parser.parseOptionalColon()) 763 return success(); 764 uint64_t numResults; 765 if (parser.parseInteger(numResults)) 766 return failure(); 767 768 IndexType type = parser.getBuilder().getIndexType(); 769 for (unsigned i = 0; i < numResults; ++i) 770 result.addTypes(type); 771 return success(); 772 } 773 774 void ParseIntegerLiteralOp::print(OpAsmPrinter &p) { 775 if (unsigned numResults = getNumResults()) 776 p << " : " << numResults; 777 } 778 779 ParseResult ParseWrappedKeywordOp::parse(OpAsmParser &parser, 780 OperationState &result) { 781 StringRef keyword; 782 if (parser.parseKeyword(&keyword)) 783 return failure(); 784 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 785 return success(); 786 } 787 788 void ParseWrappedKeywordOp::print(OpAsmPrinter &p) { p << " " << getKeyword(); } 789 790 //===----------------------------------------------------------------------===// 791 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 792 793 ParseResult WrappingRegionOp::parse(OpAsmParser &parser, 794 OperationState &result) { 795 if (parser.parseKeyword("wraps")) 796 return failure(); 797 798 // Parse the wrapped op in a region 799 Region &body = *result.addRegion(); 800 body.push_back(new Block); 801 Block &block = body.back(); 802 Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin()); 803 if (!wrappedOp) 804 return failure(); 805 806 // Create a return terminator in the inner region, pass as operand to the 807 // terminator the returned values from the wrapped operation. 808 SmallVector<Value, 8> returnOperands(wrappedOp->getResults()); 809 OpBuilder builder(parser.getContext()); 810 builder.setInsertionPointToEnd(&block); 811 builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands); 812 813 // Get the results type for the wrapping op from the terminator operands. 814 Operation &returnOp = body.back().back(); 815 result.types.append(returnOp.operand_type_begin(), 816 returnOp.operand_type_end()); 817 818 // Use the location of the wrapped op for the "test.wrapping_region" op. 819 result.location = wrappedOp->getLoc(); 820 821 return success(); 822 } 823 824 void WrappingRegionOp::print(OpAsmPrinter &p) { 825 p << " wraps "; 826 p.printGenericOp(&getRegion().front().front()); 827 } 828 829 //===----------------------------------------------------------------------===// 830 // Test PrettyPrintedRegionOp - exercising the following parser APIs 831 // parseGenericOperationAfterOpName 832 // parseCustomOperationName 833 //===----------------------------------------------------------------------===// 834 835 ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser, 836 OperationState &result) { 837 838 SMLoc loc = parser.getCurrentLocation(); 839 Location currLocation = parser.getEncodedSourceLoc(loc); 840 841 // Parse the operands. 842 SmallVector<OpAsmParser::UnresolvedOperand, 2> operands; 843 if (parser.parseOperandList(operands)) 844 return failure(); 845 846 // Check if we are parsing the pretty-printed version 847 // test.pretty_printed_region start <inner-op> end : <functional-type> 848 // Else fallback to parsing the "non pretty-printed" version. 849 if (!succeeded(parser.parseOptionalKeyword("start"))) 850 return parser.parseGenericOperationAfterOpName( 851 result, llvm::makeArrayRef(operands)); 852 853 FailureOr<OperationName> parseOpNameInfo = parser.parseCustomOperationName(); 854 if (failed(parseOpNameInfo)) 855 return failure(); 856 857 StringAttr innerOpName = parseOpNameInfo->getIdentifier(); 858 859 FunctionType opFntype; 860 Optional<Location> explicitLoc; 861 if (parser.parseKeyword("end") || parser.parseColon() || 862 parser.parseType(opFntype) || 863 parser.parseOptionalLocationSpecifier(explicitLoc)) 864 return failure(); 865 866 // If location of the op is explicitly provided, then use it; Else use 867 // the parser's current location. 868 Location opLoc = explicitLoc.value_or(currLocation); 869 870 // Derive the SSA-values for op's operands. 871 if (parser.resolveOperands(operands, opFntype.getInputs(), loc, 872 result.operands)) 873 return failure(); 874 875 // Add a region for op. 876 Region ®ion = *result.addRegion(); 877 878 // Create a basic-block inside op's region. 879 Block &block = region.emplaceBlock(); 880 881 // Create and insert an "inner-op" operation in the block. 882 // Just for testing purposes, we can assume that inner op is a binary op with 883 // result and operand types all same as the test-op's first operand. 884 Type innerOpType = opFntype.getInput(0); 885 Value lhs = block.addArgument(innerOpType, opLoc); 886 Value rhs = block.addArgument(innerOpType, opLoc); 887 888 OpBuilder builder(parser.getBuilder().getContext()); 889 builder.setInsertionPointToStart(&block); 890 891 Operation *innerOp = 892 builder.create(opLoc, innerOpName, /*operands=*/{lhs, rhs}, innerOpType); 893 894 // Insert a return statement in the block returning the inner-op's result. 895 builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults()); 896 897 // Populate the op operation-state with result-type and location. 898 result.addTypes(opFntype.getResults()); 899 result.location = innerOp->getLoc(); 900 901 return success(); 902 } 903 904 void PrettyPrintedRegionOp::print(OpAsmPrinter &p) { 905 p << ' '; 906 p.printOperands(getOperands()); 907 908 Operation &innerOp = getRegion().front().front(); 909 // Assuming that region has a single non-terminator inner-op, if the inner-op 910 // meets some criteria (which in this case is a simple one based on the name 911 // of inner-op), then we can print the entire region in a succinct way. 912 // Here we assume that the prototype of "special.op" can be trivially derived 913 // while parsing it back. 914 if (innerOp.getName().getStringRef().equals("special.op")) { 915 p << " start special.op end"; 916 } else { 917 p << " ("; 918 p.printRegion(getRegion()); 919 p << ")"; 920 } 921 922 p << " : "; 923 p.printFunctionalType(*this); 924 } 925 926 //===----------------------------------------------------------------------===// 927 // Test PolyForOp - parse list of region arguments. 928 //===----------------------------------------------------------------------===// 929 930 ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) { 931 SmallVector<OpAsmParser::Argument, 4> ivsInfo; 932 // Parse list of region arguments without a delimiter. 933 if (parser.parseArgumentList(ivsInfo, OpAsmParser::Delimiter::None)) 934 return failure(); 935 936 // Parse the body region. 937 Region *body = result.addRegion(); 938 for (auto &iv : ivsInfo) 939 iv.type = parser.getBuilder().getIndexType(); 940 return parser.parseRegion(*body, ivsInfo); 941 } 942 943 void PolyForOp::print(OpAsmPrinter &p) { p.printGenericOp(*this); } 944 945 void PolyForOp::getAsmBlockArgumentNames(Region ®ion, 946 OpAsmSetValueNameFn setNameFn) { 947 auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names"); 948 if (!arrayAttr) 949 return; 950 auto args = getRegion().front().getArguments(); 951 auto e = std::min(arrayAttr.size(), args.size()); 952 for (unsigned i = 0; i < e; ++i) { 953 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 954 setNameFn(args[i], strAttr.getValue()); 955 } 956 } 957 958 //===----------------------------------------------------------------------===// 959 // Test removing op with inner ops. 960 //===----------------------------------------------------------------------===// 961 962 namespace { 963 struct TestRemoveOpWithInnerOps 964 : public OpRewritePattern<TestOpWithRegionPattern> { 965 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 966 967 void initialize() { setDebugName("TestRemoveOpWithInnerOps"); } 968 969 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 970 PatternRewriter &rewriter) const override { 971 rewriter.eraseOp(op); 972 return success(); 973 } 974 }; 975 } // namespace 976 977 void TestOpWithRegionPattern::getCanonicalizationPatterns( 978 RewritePatternSet &results, MLIRContext *context) { 979 results.add<TestRemoveOpWithInnerOps>(context); 980 } 981 982 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 983 return getOperand(); 984 } 985 986 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) { 987 return getValue(); 988 } 989 990 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 991 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 992 for (Value input : this->getOperands()) { 993 results.push_back(input); 994 } 995 return success(); 996 } 997 998 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 999 assert(operands.size() == 1); 1000 if (operands.front()) { 1001 (*this)->setAttr("attr", operands.front()); 1002 return getResult(); 1003 } 1004 return {}; 1005 } 1006 1007 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) { 1008 return getOperand(); 1009 } 1010 1011 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 1012 MLIRContext *, Optional<Location> location, ValueRange operands, 1013 DictionaryAttr attributes, RegionRange regions, 1014 SmallVectorImpl<Type> &inferredReturnTypes) { 1015 if (operands[0].getType() != operands[1].getType()) { 1016 return emitOptionalError(location, "operand type mismatch ", 1017 operands[0].getType(), " vs ", 1018 operands[1].getType()); 1019 } 1020 inferredReturnTypes.assign({operands[0].getType()}); 1021 return success(); 1022 } 1023 1024 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 1025 MLIRContext *context, Optional<Location> location, ValueShapeRange operands, 1026 DictionaryAttr attributes, RegionRange regions, 1027 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 1028 // Create return type consisting of the last element of the first operand. 1029 auto operandType = operands.front().getType(); 1030 auto sval = operandType.dyn_cast<ShapedType>(); 1031 if (!sval) { 1032 return emitOptionalError(location, "only shaped type operands allowed"); 1033 } 1034 int64_t dim = 1035 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 1036 auto type = IntegerType::get(context, 17); 1037 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 1038 return success(); 1039 } 1040 1041 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 1042 OpBuilder &builder, ValueRange operands, 1043 llvm::SmallVectorImpl<Value> &shapes) { 1044 shapes = SmallVector<Value, 1>{ 1045 builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)}; 1046 return success(); 1047 } 1048 1049 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( 1050 OpBuilder &builder, ValueRange operands, 1051 llvm::SmallVectorImpl<Value> &shapes) { 1052 Location loc = getLoc(); 1053 shapes.reserve(operands.size()); 1054 for (Value operand : llvm::reverse(operands)) { 1055 auto rank = operand.getType().cast<RankedTensorType>().getRank(); 1056 auto currShape = llvm::to_vector<4>( 1057 llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value { 1058 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 1059 })); 1060 shapes.push_back(builder.create<tensor::FromElementsOp>( 1061 getLoc(), RankedTensorType::get({rank}, builder.getIndexType()), 1062 currShape)); 1063 } 1064 return success(); 1065 } 1066 1067 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( 1068 OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { 1069 Location loc = getLoc(); 1070 shapes.reserve(getNumOperands()); 1071 for (Value operand : llvm::reverse(getOperands())) { 1072 auto currShape = llvm::to_vector<4>(llvm::map_range( 1073 llvm::seq<int64_t>( 1074 0, operand.getType().cast<RankedTensorType>().getRank()), 1075 [&](int64_t dim) -> Value { 1076 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 1077 })); 1078 shapes.emplace_back(std::move(currShape)); 1079 } 1080 return success(); 1081 } 1082 1083 //===----------------------------------------------------------------------===// 1084 // Test SideEffect interfaces 1085 //===----------------------------------------------------------------------===// 1086 1087 namespace { 1088 /// A test resource for side effects. 1089 struct TestResource : public SideEffects::Resource::Base<TestResource> { 1090 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource) 1091 1092 StringRef getName() final { return "<Test>"; } 1093 }; 1094 } // namespace 1095 1096 static void testSideEffectOpGetEffect( 1097 Operation *op, 1098 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 1099 &effects) { 1100 auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter"); 1101 if (!effectsAttr) 1102 return; 1103 1104 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); 1105 } 1106 1107 void SideEffectOp::getEffects( 1108 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1109 // Check for an effects attribute on the op instance. 1110 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects"); 1111 if (!effectsAttr) 1112 return; 1113 1114 // If there is one, it is an array of dictionary attributes that hold 1115 // information on the effects of this operation. 1116 for (Attribute element : effectsAttr) { 1117 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 1118 1119 // Get the specific memory effect. 1120 MemoryEffects::Effect *effect = 1121 StringSwitch<MemoryEffects::Effect *>( 1122 effectElement.get("effect").cast<StringAttr>().getValue()) 1123 .Case("allocate", MemoryEffects::Allocate::get()) 1124 .Case("free", MemoryEffects::Free::get()) 1125 .Case("read", MemoryEffects::Read::get()) 1126 .Case("write", MemoryEffects::Write::get()); 1127 1128 // Check for a non-default resource to use. 1129 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 1130 if (effectElement.get("test_resource")) 1131 resource = TestResource::get(); 1132 1133 // Check for a result to affect. 1134 if (effectElement.get("on_result")) 1135 effects.emplace_back(effect, getResult(), resource); 1136 else if (Attribute ref = effectElement.get("on_reference")) 1137 effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource); 1138 else 1139 effects.emplace_back(effect, resource); 1140 } 1141 } 1142 1143 void SideEffectOp::getEffects( 1144 SmallVectorImpl<TestEffects::EffectInstance> &effects) { 1145 testSideEffectOpGetEffect(getOperation(), effects); 1146 } 1147 1148 //===----------------------------------------------------------------------===// 1149 // StringAttrPrettyNameOp 1150 //===----------------------------------------------------------------------===// 1151 1152 // This op has fancy handling of its SSA result name. 1153 ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser, 1154 OperationState &result) { 1155 // Add the result types. 1156 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 1157 result.addTypes(parser.getBuilder().getIntegerType(32)); 1158 1159 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 1160 return failure(); 1161 1162 // If the attribute dictionary contains no 'names' attribute, infer it from 1163 // the SSA name (if specified). 1164 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 1165 return attr.getName() == "names"; 1166 }); 1167 1168 // If there was no name specified, check to see if there was a useful name 1169 // specified in the asm file. 1170 if (hadNames || parser.getNumResults() == 0) 1171 return success(); 1172 1173 SmallVector<StringRef, 4> names; 1174 auto *context = result.getContext(); 1175 1176 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 1177 auto resultName = parser.getResultName(i); 1178 StringRef nameStr; 1179 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 1180 nameStr = resultName.first; 1181 1182 names.push_back(nameStr); 1183 } 1184 1185 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 1186 result.attributes.push_back({StringAttr::get(context, "names"), namesAttr}); 1187 return success(); 1188 } 1189 1190 void StringAttrPrettyNameOp::print(OpAsmPrinter &p) { 1191 // Note that we only need to print the "name" attribute if the asmprinter 1192 // result name disagrees with it. This can happen in strange cases, e.g. 1193 // when there are conflicts. 1194 bool namesDisagree = getNames().size() != getNumResults(); 1195 1196 SmallString<32> resultNameStr; 1197 for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) { 1198 resultNameStr.clear(); 1199 llvm::raw_svector_ostream tmpStream(resultNameStr); 1200 p.printOperand(getResult(i), tmpStream); 1201 1202 auto expectedName = getNames()[i].dyn_cast<StringAttr>(); 1203 if (!expectedName || 1204 tmpStream.str().drop_front() != expectedName.getValue()) { 1205 namesDisagree = true; 1206 } 1207 } 1208 1209 if (namesDisagree) 1210 p.printOptionalAttrDictWithKeyword((*this)->getAttrs()); 1211 else 1212 p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"}); 1213 } 1214 1215 // We set the SSA name in the asm syntax to the contents of the name 1216 // attribute. 1217 void StringAttrPrettyNameOp::getAsmResultNames( 1218 function_ref<void(Value, StringRef)> setNameFn) { 1219 1220 auto value = getNames(); 1221 for (size_t i = 0, e = value.size(); i != e; ++i) 1222 if (auto str = value[i].dyn_cast<StringAttr>()) 1223 if (!str.getValue().empty()) 1224 setNameFn(getResult(i), str.getValue()); 1225 } 1226 1227 void CustomResultsNameOp::getAsmResultNames( 1228 function_ref<void(Value, StringRef)> setNameFn) { 1229 ArrayAttr value = getNames(); 1230 for (size_t i = 0, e = value.size(); i != e; ++i) 1231 if (auto str = value[i].dyn_cast<StringAttr>()) 1232 if (!str.getValue().empty()) 1233 setNameFn(getResult(i), str.getValue()); 1234 } 1235 1236 //===----------------------------------------------------------------------===// 1237 // ResultTypeWithTraitOp 1238 //===----------------------------------------------------------------------===// 1239 1240 LogicalResult ResultTypeWithTraitOp::verify() { 1241 if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>()) 1242 return success(); 1243 return emitError("result type should have trait 'TestTypeTrait'"); 1244 } 1245 1246 //===----------------------------------------------------------------------===// 1247 // AttrWithTraitOp 1248 //===----------------------------------------------------------------------===// 1249 1250 LogicalResult AttrWithTraitOp::verify() { 1251 if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>()) 1252 return success(); 1253 return emitError("'attr' attribute should have trait 'TestAttrTrait'"); 1254 } 1255 1256 //===----------------------------------------------------------------------===// 1257 // RegionIfOp 1258 //===----------------------------------------------------------------------===// 1259 1260 void RegionIfOp::print(OpAsmPrinter &p) { 1261 p << " "; 1262 p.printOperands(getOperands()); 1263 p << ": " << getOperandTypes(); 1264 p.printArrowTypeList(getResultTypes()); 1265 p << " then "; 1266 p.printRegion(getThenRegion(), 1267 /*printEntryBlockArgs=*/true, 1268 /*printBlockTerminators=*/true); 1269 p << " else "; 1270 p.printRegion(getElseRegion(), 1271 /*printEntryBlockArgs=*/true, 1272 /*printBlockTerminators=*/true); 1273 p << " join "; 1274 p.printRegion(getJoinRegion(), 1275 /*printEntryBlockArgs=*/true, 1276 /*printBlockTerminators=*/true); 1277 } 1278 1279 ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) { 1280 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos; 1281 SmallVector<Type, 2> operandTypes; 1282 1283 result.regions.reserve(3); 1284 Region *thenRegion = result.addRegion(); 1285 Region *elseRegion = result.addRegion(); 1286 Region *joinRegion = result.addRegion(); 1287 1288 // Parse operand, type and arrow type lists. 1289 if (parser.parseOperandList(operandInfos) || 1290 parser.parseColonTypeList(operandTypes) || 1291 parser.parseArrowTypeList(result.types)) 1292 return failure(); 1293 1294 // Parse all attached regions. 1295 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 1296 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 1297 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 1298 return failure(); 1299 1300 return parser.resolveOperands(operandInfos, operandTypes, 1301 parser.getCurrentLocation(), result.operands); 1302 } 1303 1304 OperandRange RegionIfOp::getSuccessorEntryOperands(Optional<unsigned> index) { 1305 assert(index && *index < 2 && "invalid region index"); 1306 return getOperands(); 1307 } 1308 1309 void RegionIfOp::getSuccessorRegions( 1310 Optional<unsigned> index, ArrayRef<Attribute> operands, 1311 SmallVectorImpl<RegionSuccessor> ®ions) { 1312 // We always branch to the join region. 1313 if (index) { 1314 if (index.value() < 2) 1315 regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs())); 1316 else 1317 regions.push_back(RegionSuccessor(getResults())); 1318 return; 1319 } 1320 1321 // The then and else regions are the entry regions of this op. 1322 regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs())); 1323 regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs())); 1324 } 1325 1326 void RegionIfOp::getRegionInvocationBounds( 1327 ArrayRef<Attribute> operands, 1328 SmallVectorImpl<InvocationBounds> &invocationBounds) { 1329 // Each region is invoked at most once. 1330 invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1}); 1331 } 1332 1333 //===----------------------------------------------------------------------===// 1334 // AnyCondOp 1335 //===----------------------------------------------------------------------===// 1336 1337 void AnyCondOp::getSuccessorRegions(Optional<unsigned> index, 1338 ArrayRef<Attribute> operands, 1339 SmallVectorImpl<RegionSuccessor> ®ions) { 1340 // The parent op branches into the only region, and the region branches back 1341 // to the parent op. 1342 if (!index) 1343 regions.emplace_back(&getRegion()); 1344 else 1345 regions.emplace_back(getResults()); 1346 } 1347 1348 void AnyCondOp::getRegionInvocationBounds( 1349 ArrayRef<Attribute> operands, 1350 SmallVectorImpl<InvocationBounds> &invocationBounds) { 1351 invocationBounds.emplace_back(1, 1); 1352 } 1353 1354 //===----------------------------------------------------------------------===// 1355 // SingleNoTerminatorCustomAsmOp 1356 //===----------------------------------------------------------------------===// 1357 1358 ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser, 1359 OperationState &state) { 1360 Region *body = state.addRegion(); 1361 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) 1362 return failure(); 1363 return success(); 1364 } 1365 1366 void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) { 1367 printer.printRegion( 1368 getRegion(), /*printEntryBlockArgs=*/false, 1369 // This op has a single block without terminators. But explicitly mark 1370 // as not printing block terminators for testing. 1371 /*printBlockTerminators=*/false); 1372 } 1373 1374 //===----------------------------------------------------------------------===// 1375 // TestVerifiersOp 1376 //===----------------------------------------------------------------------===// 1377 1378 LogicalResult TestVerifiersOp::verify() { 1379 if (!getRegion().hasOneBlock()) 1380 return emitOpError("`hasOneBlock` trait hasn't been verified"); 1381 1382 Operation *definingOp = getInput().getDefiningOp(); 1383 if (definingOp && failed(mlir::verify(definingOp))) 1384 return emitOpError("operand hasn't been verified"); 1385 1386 emitRemark("success run of verifier"); 1387 1388 return success(); 1389 } 1390 1391 LogicalResult TestVerifiersOp::verifyRegions() { 1392 if (!getRegion().hasOneBlock()) 1393 return emitOpError("`hasOneBlock` trait hasn't been verified"); 1394 1395 for (Block &block : getRegion()) 1396 for (Operation &op : block) 1397 if (failed(mlir::verify(&op))) 1398 return emitOpError("nested op hasn't been verified"); 1399 1400 emitRemark("success run of region verifier"); 1401 1402 return success(); 1403 } 1404 1405 //===----------------------------------------------------------------------===// 1406 // Test InferIntRangeInterface 1407 //===----------------------------------------------------------------------===// 1408 1409 void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 1410 SetIntRangeFn setResultRanges) { 1411 setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()}); 1412 } 1413 1414 ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser, 1415 OperationState &result) { 1416 if (parser.parseOptionalAttrDict(result.attributes)) 1417 return failure(); 1418 1419 // Parse the input argument 1420 OpAsmParser::Argument argInfo; 1421 argInfo.type = parser.getBuilder().getIndexType(); 1422 if (failed(parser.parseArgument(argInfo))) 1423 return failure(); 1424 1425 // Parse the body region, and reuse the operand info as the argument info. 1426 Region *body = result.addRegion(); 1427 return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false); 1428 } 1429 1430 void TestWithBoundsRegionOp::print(OpAsmPrinter &p) { 1431 p.printOptionalAttrDict((*this)->getAttrs()); 1432 p << ' '; 1433 p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{}, 1434 /*omitType=*/true); 1435 p << ' '; 1436 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); 1437 } 1438 1439 void TestWithBoundsRegionOp::inferResultRanges( 1440 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) { 1441 Value arg = getRegion().getArgument(0); 1442 setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()}); 1443 } 1444 1445 void TestIncrementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 1446 SetIntRangeFn setResultRanges) { 1447 const ConstantIntRanges &range = argRanges[0]; 1448 APInt one(range.umin().getBitWidth(), 1); 1449 setResultRanges(getResult(), 1450 {range.umin().uadd_sat(one), range.umax().uadd_sat(one), 1451 range.smin().sadd_sat(one), range.smax().sadd_sat(one)}); 1452 } 1453 1454 void TestReflectBoundsOp::inferResultRanges( 1455 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) { 1456 const ConstantIntRanges &range = argRanges[0]; 1457 MLIRContext *ctx = getContext(); 1458 Builder b(ctx); 1459 setUminAttr(b.getIndexAttr(range.umin().getZExtValue())); 1460 setUmaxAttr(b.getIndexAttr(range.umax().getZExtValue())); 1461 setSminAttr(b.getIndexAttr(range.smin().getSExtValue())); 1462 setSmaxAttr(b.getIndexAttr(range.smax().getSExtValue())); 1463 setResultRanges(getResult(), range); 1464 } 1465 1466 #include "TestOpEnums.cpp.inc" 1467 #include "TestOpInterfaces.cpp.inc" 1468 #include "TestTypeInterfaces.cpp.inc" 1469 1470 #define GET_OP_CLASSES 1471 #include "TestOps.cpp.inc" 1472