1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestAttributes.h" 11 #include "TestInterfaces.h" 12 #include "TestTypes.h" 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/DLTI/DLTI.h" 15 #include "mlir/Dialect/Func/IR/FuncOps.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/IR/AsmState.h" 18 #include "mlir/IR/BuiltinAttributes.h" 19 #include "mlir/IR/BuiltinOps.h" 20 #include "mlir/IR/Diagnostics.h" 21 #include "mlir/IR/DialectImplementation.h" 22 #include "mlir/IR/ExtensibleDialect.h" 23 #include "mlir/IR/MLIRContext.h" 24 #include "mlir/IR/OperationSupport.h" 25 #include "mlir/IR/PatternMatch.h" 26 #include "mlir/IR/TypeUtilities.h" 27 #include "mlir/IR/Verifier.h" 28 #include "mlir/Interfaces/InferIntRangeInterface.h" 29 #include "mlir/Reducer/ReductionPatternInterface.h" 30 #include "mlir/Transforms/FoldUtils.h" 31 #include "mlir/Transforms/InliningUtils.h" 32 #include "llvm/ADT/SmallString.h" 33 #include "llvm/ADT/StringExtras.h" 34 #include "llvm/ADT/StringSwitch.h" 35 36 // Include this before the using namespace lines below to 37 // test that we don't have namespace dependencies. 38 #include "TestOpsDialect.cpp.inc" 39 40 using namespace mlir; 41 using namespace test; 42 43 void test::registerTestDialect(DialectRegistry ®istry) { 44 registry.insert<TestDialect>(); 45 } 46 47 //===----------------------------------------------------------------------===// 48 // External Elements Data 49 //===----------------------------------------------------------------------===// 50 51 ArrayRef<uint64_t> TestExternalElementsData::getData() const { 52 ArrayRef<char> data = AsmResourceBlob::getData(); 53 return ArrayRef<uint64_t>((const uint64_t *)data.data(), 54 data.size() / sizeof(uint64_t)); 55 } 56 57 TestExternalElementsData 58 TestExternalElementsData::allocate(size_t numElements) { 59 return TestExternalElementsData( 60 llvm::ArrayRef<uint64_t>(new uint64_t[numElements], numElements), 61 [](const uint64_t *data, size_t) { delete[] data; }, 62 /*dataIsMutable=*/true); 63 } 64 65 const TestExternalElementsData * 66 TestExternalElementsDataManager::getData(StringRef name) const { 67 auto it = dataMap.find(name); 68 return it != dataMap.end() ? &*it->second : nullptr; 69 } 70 71 std::pair<TestExternalElementsDataManager::DataMap::iterator, bool> 72 TestExternalElementsDataManager::insert(StringRef name) { 73 auto it = dataMap.try_emplace(name, nullptr); 74 if (it.second) 75 return it; 76 77 llvm::SmallString<32> nameStorage(name); 78 nameStorage.push_back('_'); 79 size_t nameCounter = 1; 80 do { 81 nameStorage += std::to_string(nameCounter++); 82 auto it = dataMap.try_emplace(nameStorage, nullptr); 83 if (it.second) 84 return it; 85 nameStorage.resize(name.size() + 1); 86 } while (true); 87 } 88 89 void TestExternalElementsDataManager::setData(StringRef name, 90 TestExternalElementsData &&data) { 91 auto it = dataMap.find(name); 92 assert(it != dataMap.end() && "data not registered"); 93 it->second = std::make_unique<TestExternalElementsData>(std::move(data)); 94 } 95 96 //===----------------------------------------------------------------------===// 97 // TestDialect Interfaces 98 //===----------------------------------------------------------------------===// 99 100 namespace { 101 102 /// Testing the correctness of some traits. 103 static_assert( 104 llvm::is_detected<OpTrait::has_implicit_terminator_t, 105 SingleBlockImplicitTerminatorOp>::value, 106 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp"); 107 static_assert(OpTrait::hasSingleBlockImplicitTerminator< 108 SingleBlockImplicitTerminatorOp>::value, 109 "hasSingleBlockImplicitTerminator does not match " 110 "SingleBlockImplicitTerminatorOp"); 111 112 // Test support for interacting with the AsmPrinter. 113 struct TestOpAsmInterface : public OpAsmDialectInterface { 114 using OpAsmDialectInterface::OpAsmDialectInterface; 115 116 //===------------------------------------------------------------------===// 117 // Aliases 118 //===------------------------------------------------------------------===// 119 120 AliasResult getAlias(Attribute attr, raw_ostream &os) const final { 121 StringAttr strAttr = attr.dyn_cast<StringAttr>(); 122 if (!strAttr) 123 return AliasResult::NoAlias; 124 125 // Check the contents of the string attribute to see what the test alias 126 // should be named. 127 Optional<StringRef> aliasName = 128 StringSwitch<Optional<StringRef>>(strAttr.getValue()) 129 .Case("alias_test:dot_in_name", StringRef("test.alias")) 130 .Case("alias_test:trailing_digit", StringRef("test_alias0")) 131 .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) 132 .Case("alias_test:sanitize_conflict_a", 133 StringRef("test_alias_conflict0")) 134 .Case("alias_test:sanitize_conflict_b", 135 StringRef("test_alias_conflict0_")) 136 .Case("alias_test:tensor_encoding", StringRef("test_encoding")) 137 .Default(llvm::None); 138 if (!aliasName) 139 return AliasResult::NoAlias; 140 141 os << *aliasName; 142 return AliasResult::FinalAlias; 143 } 144 145 AliasResult getAlias(Type type, raw_ostream &os) const final { 146 if (auto tupleType = type.dyn_cast<TupleType>()) { 147 if (tupleType.size() > 0 && 148 llvm::all_of(tupleType.getTypes(), [](Type elemType) { 149 return elemType.isa<SimpleAType>(); 150 })) { 151 os << "test_tuple"; 152 return AliasResult::FinalAlias; 153 } 154 } 155 if (auto intType = type.dyn_cast<TestIntegerType>()) { 156 if (intType.getSignedness() == 157 TestIntegerType::SignednessSemantics::Unsigned && 158 intType.getWidth() == 8) { 159 os << "test_ui8"; 160 return AliasResult::FinalAlias; 161 } 162 } 163 return AliasResult::NoAlias; 164 } 165 166 //===------------------------------------------------------------------===// 167 // Resources 168 //===------------------------------------------------------------------===// 169 170 std::string 171 getResourceKey(const AsmDialectResourceHandle &handle) const override { 172 return cast<TestExternalElementsDataHandle>(handle).getKey().str(); 173 } 174 175 FailureOr<AsmDialectResourceHandle> 176 declareResource(StringRef key) const final { 177 TestDialect *dialect = cast<TestDialect>(getDialect()); 178 TestExternalElementsDataManager &mgr = dialect->getExternalDataManager(); 179 180 // Resolve the reference by inserting a new entry into the manager. 181 auto it = mgr.insert(key).first; 182 return TestExternalElementsDataHandle(&*it, dialect); 183 } 184 185 LogicalResult parseResource(AsmParsedResourceEntry &entry) const final { 186 TestDialect *dialect = cast<TestDialect>(getDialect()); 187 TestExternalElementsDataManager &mgr = dialect->getExternalDataManager(); 188 189 // The resource entries are external constant data. 190 auto blobAllocFn = [](unsigned size, unsigned align) { 191 assert(align == alignof(uint64_t) && "unexpected data alignment"); 192 return TestExternalElementsData::allocate(size / sizeof(uint64_t)); 193 }; 194 FailureOr<AsmResourceBlob> blob = entry.parseAsBlob(blobAllocFn); 195 if (failed(blob)) 196 return failure(); 197 198 mgr.setData(entry.getKey(), std::move(*blob)); 199 return success(); 200 } 201 202 void 203 buildResources(Operation *op, 204 const SetVector<AsmDialectResourceHandle> &referencedResources, 205 AsmResourceBuilder &provider) const final { 206 for (const AsmDialectResourceHandle &handle : referencedResources) { 207 const auto &testHandle = cast<TestExternalElementsDataHandle>(handle); 208 provider.buildBlob(testHandle.getKey(), testHandle.getData()->getData()); 209 } 210 } 211 }; 212 213 struct TestDialectFoldInterface : public DialectFoldInterface { 214 using DialectFoldInterface::DialectFoldInterface; 215 216 /// Registered hook to check if the given region, which is attached to an 217 /// operation that is *not* isolated from above, should be used when 218 /// materializing constants. 219 bool shouldMaterializeInto(Region *region) const final { 220 // If this is a one region operation, then insert into it. 221 return isa<OneRegionOp>(region->getParentOp()); 222 } 223 }; 224 225 /// This class defines the interface for handling inlining with standard 226 /// operations. 227 struct TestInlinerInterface : public DialectInlinerInterface { 228 using DialectInlinerInterface::DialectInlinerInterface; 229 230 //===--------------------------------------------------------------------===// 231 // Analysis Hooks 232 //===--------------------------------------------------------------------===// 233 234 bool isLegalToInline(Operation *call, Operation *callable, 235 bool wouldBeCloned) const final { 236 // Don't allow inlining calls that are marked `noinline`. 237 return !call->hasAttr("noinline"); 238 } 239 bool isLegalToInline(Region *, Region *, bool, 240 BlockAndValueMapping &) const final { 241 // Inlining into test dialect regions is legal. 242 return true; 243 } 244 bool isLegalToInline(Operation *, Region *, bool, 245 BlockAndValueMapping &) const final { 246 return true; 247 } 248 249 bool shouldAnalyzeRecursively(Operation *op) const final { 250 // Analyze recursively if this is not a functional region operation, it 251 // froms a separate functional scope. 252 return !isa<FunctionalRegionOp>(op); 253 } 254 255 //===--------------------------------------------------------------------===// 256 // Transformation Hooks 257 //===--------------------------------------------------------------------===// 258 259 /// Handle the given inlined terminator by replacing it with a new operation 260 /// as necessary. 261 void handleTerminator(Operation *op, 262 ArrayRef<Value> valuesToRepl) const final { 263 // Only handle "test.return" here. 264 auto returnOp = dyn_cast<TestReturnOp>(op); 265 if (!returnOp) 266 return; 267 268 // Replace the values directly with the return operands. 269 assert(returnOp.getNumOperands() == valuesToRepl.size()); 270 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 271 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 272 } 273 274 /// Attempt to materialize a conversion for a type mismatch between a call 275 /// from this dialect, and a callable region. This method should generate an 276 /// operation that takes 'input' as the only operand, and produces a single 277 /// result of 'resultType'. If a conversion can not be generated, nullptr 278 /// should be returned. 279 Operation *materializeCallConversion(OpBuilder &builder, Value input, 280 Type resultType, 281 Location conversionLoc) const final { 282 // Only allow conversion for i16/i32 types. 283 if (!(resultType.isSignlessInteger(16) || 284 resultType.isSignlessInteger(32)) || 285 !(input.getType().isSignlessInteger(16) || 286 input.getType().isSignlessInteger(32))) 287 return nullptr; 288 return builder.create<TestCastOp>(conversionLoc, resultType, input); 289 } 290 291 void processInlinedCallBlocks( 292 Operation *call, 293 iterator_range<Region::iterator> inlinedBlocks) const final { 294 if (!isa<ConversionCallOp>(call)) 295 return; 296 297 // Set attributed on all ops in the inlined blocks. 298 for (Block &block : inlinedBlocks) { 299 block.walk([&](Operation *op) { 300 op->setAttr("inlined_conversion", UnitAttr::get(call->getContext())); 301 }); 302 } 303 } 304 }; 305 306 struct TestReductionPatternInterface : public DialectReductionPatternInterface { 307 public: 308 TestReductionPatternInterface(Dialect *dialect) 309 : DialectReductionPatternInterface(dialect) {} 310 311 void populateReductionPatterns(RewritePatternSet &patterns) const final { 312 populateTestReductionPatterns(patterns); 313 } 314 }; 315 316 } // namespace 317 318 //===----------------------------------------------------------------------===// 319 // Dynamic operations 320 //===----------------------------------------------------------------------===// 321 322 std::unique_ptr<DynamicOpDefinition> getDynamicGenericOp(TestDialect *dialect) { 323 return DynamicOpDefinition::get( 324 "dynamic_generic", dialect, [](Operation *op) { return success(); }, 325 [](Operation *op) { return success(); }); 326 } 327 328 std::unique_ptr<DynamicOpDefinition> 329 getDynamicOneOperandTwoResultsOp(TestDialect *dialect) { 330 return DynamicOpDefinition::get( 331 "dynamic_one_operand_two_results", dialect, 332 [](Operation *op) { 333 if (op->getNumOperands() != 1) { 334 op->emitOpError() 335 << "expected 1 operand, but had " << op->getNumOperands(); 336 return failure(); 337 } 338 if (op->getNumResults() != 2) { 339 op->emitOpError() 340 << "expected 2 results, but had " << op->getNumResults(); 341 return failure(); 342 } 343 return success(); 344 }, 345 [](Operation *op) { return success(); }); 346 } 347 348 std::unique_ptr<DynamicOpDefinition> 349 getDynamicCustomParserPrinterOp(TestDialect *dialect) { 350 auto verifier = [](Operation *op) { 351 if (op->getNumOperands() == 0 && op->getNumResults() == 0) 352 return success(); 353 op->emitError() << "operation should have no operands and no results"; 354 return failure(); 355 }; 356 auto regionVerifier = [](Operation *op) { return success(); }; 357 358 auto parser = [](OpAsmParser &parser, OperationState &state) { 359 return parser.parseKeyword("custom_keyword"); 360 }; 361 362 auto printer = [](Operation *op, OpAsmPrinter &printer, llvm::StringRef) { 363 printer << op->getName() << " custom_keyword"; 364 }; 365 366 return DynamicOpDefinition::get("dynamic_custom_parser_printer", dialect, 367 verifier, regionVerifier, parser, printer); 368 } 369 370 //===----------------------------------------------------------------------===// 371 // TestDialect 372 //===----------------------------------------------------------------------===// 373 374 static void testSideEffectOpGetEffect( 375 Operation *op, 376 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects); 377 378 // This is the implementation of a dialect fallback for `TestEffectOpInterface`. 379 struct TestOpEffectInterfaceFallback 380 : public TestEffectOpInterface::FallbackModel< 381 TestOpEffectInterfaceFallback> { 382 static bool classof(Operation *op) { 383 bool isSupportedOp = 384 op->getName().getStringRef() == "test.unregistered_side_effect_op"; 385 assert(isSupportedOp && "Unexpected dispatch"); 386 return isSupportedOp; 387 } 388 389 void 390 getEffects(Operation *op, 391 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 392 &effects) const { 393 testSideEffectOpGetEffect(op, effects); 394 } 395 }; 396 397 void TestDialect::initialize() { 398 registerAttributes(); 399 registerTypes(); 400 addOperations< 401 #define GET_OP_LIST 402 #include "TestOps.cpp.inc" 403 >(); 404 registerDynamicOp(getDynamicGenericOp(this)); 405 registerDynamicOp(getDynamicOneOperandTwoResultsOp(this)); 406 registerDynamicOp(getDynamicCustomParserPrinterOp(this)); 407 408 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 409 TestInlinerInterface, TestReductionPatternInterface>(); 410 allowUnknownOperations(); 411 412 // Instantiate our fallback op interface that we'll use on specific 413 // unregistered op. 414 fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback; 415 } 416 TestDialect::~TestDialect() { 417 delete static_cast<TestOpEffectInterfaceFallback *>( 418 fallbackEffectOpInterfaces); 419 } 420 421 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, 422 Type type, Location loc) { 423 return builder.create<TestOpConstant>(loc, type, value); 424 } 425 426 ::mlir::LogicalResult FormatInferType2Op::inferReturnTypes( 427 ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, 428 ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, 429 ::mlir::RegionRange regions, 430 ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { 431 inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)}); 432 return ::mlir::success(); 433 } 434 435 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, 436 OperationName opName) { 437 if (opName.getIdentifier() == "test.unregistered_side_effect_op" && 438 typeID == TypeID::get<TestEffectOpInterface>()) 439 return fallbackEffectOpInterfaces; 440 return nullptr; 441 } 442 443 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 444 NamedAttribute namedAttr) { 445 if (namedAttr.getName() == "test.invalid_attr") 446 return op->emitError() << "invalid to use 'test.invalid_attr'"; 447 return success(); 448 } 449 450 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 451 unsigned regionIndex, 452 unsigned argIndex, 453 NamedAttribute namedAttr) { 454 if (namedAttr.getName() == "test.invalid_attr") 455 return op->emitError() << "invalid to use 'test.invalid_attr'"; 456 return success(); 457 } 458 459 LogicalResult 460 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 461 unsigned resultIndex, 462 NamedAttribute namedAttr) { 463 if (namedAttr.getName() == "test.invalid_attr") 464 return op->emitError() << "invalid to use 'test.invalid_attr'"; 465 return success(); 466 } 467 468 Optional<Dialect::ParseOpHook> 469 TestDialect::getParseOperationHook(StringRef opName) const { 470 if (opName == "test.dialect_custom_printer") { 471 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 472 return parser.parseKeyword("custom_format"); 473 }}; 474 } 475 if (opName == "test.dialect_custom_format_fallback") { 476 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 477 return parser.parseKeyword("custom_format_fallback"); 478 }}; 479 } 480 if (opName == "test.dialect_custom_printer.with.dot") { 481 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 482 return ParseResult::success(); 483 }}; 484 } 485 return None; 486 } 487 488 llvm::unique_function<void(Operation *, OpAsmPrinter &)> 489 TestDialect::getOperationPrinter(Operation *op) const { 490 StringRef opName = op->getName().getStringRef(); 491 if (opName == "test.dialect_custom_printer") { 492 return [](Operation *op, OpAsmPrinter &printer) { 493 printer.getStream() << " custom_format"; 494 }; 495 } 496 if (opName == "test.dialect_custom_format_fallback") { 497 return [](Operation *op, OpAsmPrinter &printer) { 498 printer.getStream() << " custom_format_fallback"; 499 }; 500 } 501 return {}; 502 } 503 504 //===----------------------------------------------------------------------===// 505 // TestBranchOp 506 //===----------------------------------------------------------------------===// 507 508 SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) { 509 assert(index == 0 && "invalid successor index"); 510 return SuccessorOperands(getTargetOperandsMutable()); 511 } 512 513 //===----------------------------------------------------------------------===// 514 // TestProducingBranchOp 515 //===----------------------------------------------------------------------===// 516 517 SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) { 518 assert(index <= 1 && "invalid successor index"); 519 if (index == 1) 520 return SuccessorOperands(getFirstOperandsMutable()); 521 return SuccessorOperands(getSecondOperandsMutable()); 522 } 523 524 //===----------------------------------------------------------------------===// 525 // TestProducingBranchOp 526 //===----------------------------------------------------------------------===// 527 528 SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) { 529 assert(index <= 1 && "invalid successor index"); 530 if (index == 0) 531 return SuccessorOperands(0, getSuccessOperandsMutable()); 532 return SuccessorOperands(1, getErrorOperandsMutable()); 533 } 534 535 //===----------------------------------------------------------------------===// 536 // TestDialectCanonicalizerOp 537 //===----------------------------------------------------------------------===// 538 539 static LogicalResult 540 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op, 541 PatternRewriter &rewriter) { 542 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 543 op, rewriter.getI32IntegerAttr(42)); 544 return success(); 545 } 546 547 void TestDialect::getCanonicalizationPatterns( 548 RewritePatternSet &results) const { 549 results.add(&dialectCanonicalizationPattern); 550 } 551 552 //===----------------------------------------------------------------------===// 553 // TestCallOp 554 //===----------------------------------------------------------------------===// 555 556 LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 557 // Check that the callee attribute was specified. 558 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee"); 559 if (!fnAttr) 560 return emitOpError("requires a 'callee' symbol reference attribute"); 561 if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr)) 562 return emitOpError() << "'" << fnAttr.getValue() 563 << "' does not reference a valid function"; 564 return success(); 565 } 566 567 //===----------------------------------------------------------------------===// 568 // TestFoldToCallOp 569 //===----------------------------------------------------------------------===// 570 571 namespace { 572 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 573 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 574 575 LogicalResult matchAndRewrite(FoldToCallOp op, 576 PatternRewriter &rewriter) const override { 577 rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(), 578 op.getCalleeAttr(), ValueRange()); 579 return success(); 580 } 581 }; 582 } // namespace 583 584 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results, 585 MLIRContext *context) { 586 results.add<FoldToCallOpPattern>(context); 587 } 588 589 //===----------------------------------------------------------------------===// 590 // Test Format* operations 591 //===----------------------------------------------------------------------===// 592 593 //===----------------------------------------------------------------------===// 594 // Parsing 595 596 static ParseResult parseCustomOptionalOperand( 597 OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &optOperand) { 598 if (succeeded(parser.parseOptionalLParen())) { 599 optOperand.emplace(); 600 if (parser.parseOperand(*optOperand) || parser.parseRParen()) 601 return failure(); 602 } 603 return success(); 604 } 605 606 static ParseResult parseCustomDirectiveOperands( 607 OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, 608 Optional<OpAsmParser::UnresolvedOperand> &optOperand, 609 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands) { 610 if (parser.parseOperand(operand)) 611 return failure(); 612 if (succeeded(parser.parseOptionalComma())) { 613 optOperand.emplace(); 614 if (parser.parseOperand(*optOperand)) 615 return failure(); 616 } 617 if (parser.parseArrow() || parser.parseLParen() || 618 parser.parseOperandList(varOperands) || parser.parseRParen()) 619 return failure(); 620 return success(); 621 } 622 static ParseResult 623 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 624 Type &optOperandType, 625 SmallVectorImpl<Type> &varOperandTypes) { 626 if (parser.parseColon()) 627 return failure(); 628 629 if (parser.parseType(operandType)) 630 return failure(); 631 if (succeeded(parser.parseOptionalComma())) { 632 if (parser.parseType(optOperandType)) 633 return failure(); 634 } 635 if (parser.parseArrow() || parser.parseLParen() || 636 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 637 return failure(); 638 return success(); 639 } 640 static ParseResult 641 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 642 Type optOperandType, 643 const SmallVectorImpl<Type> &varOperandTypes) { 644 if (parser.parseKeyword("type_refs_capture")) 645 return failure(); 646 647 Type operandType2, optOperandType2; 648 SmallVector<Type, 1> varOperandTypes2; 649 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 650 varOperandTypes2)) 651 return failure(); 652 653 if (operandType != operandType2 || optOperandType != optOperandType2 || 654 varOperandTypes != varOperandTypes2) 655 return failure(); 656 657 return success(); 658 } 659 static ParseResult parseCustomDirectiveOperandsAndTypes( 660 OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, 661 Optional<OpAsmParser::UnresolvedOperand> &optOperand, 662 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands, 663 Type &operandType, Type &optOperandType, 664 SmallVectorImpl<Type> &varOperandTypes) { 665 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 666 parseCustomDirectiveResults(parser, operandType, optOperandType, 667 varOperandTypes)) 668 return failure(); 669 return success(); 670 } 671 static ParseResult parseCustomDirectiveRegions( 672 OpAsmParser &parser, Region ®ion, 673 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 674 if (parser.parseRegion(region)) 675 return failure(); 676 if (failed(parser.parseOptionalComma())) 677 return success(); 678 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 679 if (parser.parseRegion(*varRegion)) 680 return failure(); 681 varRegions.emplace_back(std::move(varRegion)); 682 return success(); 683 } 684 static ParseResult 685 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 686 SmallVectorImpl<Block *> &varSuccessors) { 687 if (parser.parseSuccessor(successor)) 688 return failure(); 689 if (failed(parser.parseOptionalComma())) 690 return success(); 691 Block *varSuccessor; 692 if (parser.parseSuccessor(varSuccessor)) 693 return failure(); 694 varSuccessors.append(2, varSuccessor); 695 return success(); 696 } 697 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 698 IntegerAttr &attr, 699 IntegerAttr &optAttr) { 700 if (parser.parseAttribute(attr)) 701 return failure(); 702 if (succeeded(parser.parseOptionalComma())) { 703 if (parser.parseAttribute(optAttr)) 704 return failure(); 705 } 706 return success(); 707 } 708 709 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, 710 NamedAttrList &attrs) { 711 return parser.parseOptionalAttrDict(attrs); 712 } 713 static ParseResult parseCustomDirectiveOptionalOperandRef( 714 OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &optOperand) { 715 int64_t operandCount = 0; 716 if (parser.parseInteger(operandCount)) 717 return failure(); 718 bool expectedOptionalOperand = operandCount == 0; 719 return success(expectedOptionalOperand != optOperand.has_value()); 720 } 721 722 //===----------------------------------------------------------------------===// 723 // Printing 724 725 static void printCustomOptionalOperand(OpAsmPrinter &printer, Operation *, 726 Value optOperand) { 727 if (optOperand) 728 printer << "(" << optOperand << ") "; 729 } 730 731 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 732 Value operand, Value optOperand, 733 OperandRange varOperands) { 734 printer << operand; 735 if (optOperand) 736 printer << ", " << optOperand; 737 printer << " -> (" << varOperands << ")"; 738 } 739 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 740 Type operandType, Type optOperandType, 741 TypeRange varOperandTypes) { 742 printer << " : " << operandType; 743 if (optOperandType) 744 printer << ", " << optOperandType; 745 printer << " -> (" << varOperandTypes << ")"; 746 } 747 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 748 Operation *op, Type operandType, 749 Type optOperandType, 750 TypeRange varOperandTypes) { 751 printer << " type_refs_capture "; 752 printCustomDirectiveResults(printer, op, operandType, optOperandType, 753 varOperandTypes); 754 } 755 static void printCustomDirectiveOperandsAndTypes( 756 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 757 OperandRange varOperands, Type operandType, Type optOperandType, 758 TypeRange varOperandTypes) { 759 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 760 printCustomDirectiveResults(printer, op, operandType, optOperandType, 761 varOperandTypes); 762 } 763 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 764 Region ®ion, 765 MutableArrayRef<Region> varRegions) { 766 printer.printRegion(region); 767 if (!varRegions.empty()) { 768 printer << ", "; 769 for (Region ®ion : varRegions) 770 printer.printRegion(region); 771 } 772 } 773 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 774 Block *successor, 775 SuccessorRange varSuccessors) { 776 printer << successor; 777 if (!varSuccessors.empty()) 778 printer << ", " << varSuccessors.front(); 779 } 780 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 781 Attribute attribute, 782 Attribute optAttribute) { 783 printer << attribute; 784 if (optAttribute) 785 printer << ", " << optAttribute; 786 } 787 788 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 789 DictionaryAttr attrs) { 790 printer.printOptionalAttrDict(attrs.getValue()); 791 } 792 793 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, 794 Operation *op, 795 Value optOperand) { 796 printer << (optOperand ? "1" : "0"); 797 } 798 799 //===----------------------------------------------------------------------===// 800 // Test IsolatedRegionOp - parse passthrough region arguments. 801 //===----------------------------------------------------------------------===// 802 803 ParseResult IsolatedRegionOp::parse(OpAsmParser &parser, 804 OperationState &result) { 805 // Parse the input operand. 806 OpAsmParser::Argument argInfo; 807 argInfo.type = parser.getBuilder().getIndexType(); 808 if (parser.parseOperand(argInfo.ssaName) || 809 parser.resolveOperand(argInfo.ssaName, argInfo.type, result.operands)) 810 return failure(); 811 812 // Parse the body region, and reuse the operand info as the argument info. 813 Region *body = result.addRegion(); 814 return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/true); 815 } 816 817 void IsolatedRegionOp::print(OpAsmPrinter &p) { 818 p << "test.isolated_region "; 819 p.printOperand(getOperand()); 820 p.shadowRegionArgs(getRegion(), getOperand()); 821 p << ' '; 822 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); 823 } 824 825 //===----------------------------------------------------------------------===// 826 // Test SSACFGRegionOp 827 //===----------------------------------------------------------------------===// 828 829 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 830 return RegionKind::SSACFG; 831 } 832 833 //===----------------------------------------------------------------------===// 834 // Test GraphRegionOp 835 //===----------------------------------------------------------------------===// 836 837 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 838 return RegionKind::Graph; 839 } 840 841 //===----------------------------------------------------------------------===// 842 // Test AffineScopeOp 843 //===----------------------------------------------------------------------===// 844 845 ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) { 846 // Parse the body region, and reuse the operand info as the argument info. 847 Region *body = result.addRegion(); 848 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 849 } 850 851 void AffineScopeOp::print(OpAsmPrinter &p) { 852 p << "test.affine_scope "; 853 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); 854 } 855 856 //===----------------------------------------------------------------------===// 857 // Test parser. 858 //===----------------------------------------------------------------------===// 859 860 ParseResult ParseIntegerLiteralOp::parse(OpAsmParser &parser, 861 OperationState &result) { 862 if (parser.parseOptionalColon()) 863 return success(); 864 uint64_t numResults; 865 if (parser.parseInteger(numResults)) 866 return failure(); 867 868 IndexType type = parser.getBuilder().getIndexType(); 869 for (unsigned i = 0; i < numResults; ++i) 870 result.addTypes(type); 871 return success(); 872 } 873 874 void ParseIntegerLiteralOp::print(OpAsmPrinter &p) { 875 if (unsigned numResults = getNumResults()) 876 p << " : " << numResults; 877 } 878 879 ParseResult ParseWrappedKeywordOp::parse(OpAsmParser &parser, 880 OperationState &result) { 881 StringRef keyword; 882 if (parser.parseKeyword(&keyword)) 883 return failure(); 884 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 885 return success(); 886 } 887 888 void ParseWrappedKeywordOp::print(OpAsmPrinter &p) { p << " " << getKeyword(); } 889 890 //===----------------------------------------------------------------------===// 891 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 892 893 ParseResult WrappingRegionOp::parse(OpAsmParser &parser, 894 OperationState &result) { 895 if (parser.parseKeyword("wraps")) 896 return failure(); 897 898 // Parse the wrapped op in a region 899 Region &body = *result.addRegion(); 900 body.push_back(new Block); 901 Block &block = body.back(); 902 Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin()); 903 if (!wrappedOp) 904 return failure(); 905 906 // Create a return terminator in the inner region, pass as operand to the 907 // terminator the returned values from the wrapped operation. 908 SmallVector<Value, 8> returnOperands(wrappedOp->getResults()); 909 OpBuilder builder(parser.getContext()); 910 builder.setInsertionPointToEnd(&block); 911 builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands); 912 913 // Get the results type for the wrapping op from the terminator operands. 914 Operation &returnOp = body.back().back(); 915 result.types.append(returnOp.operand_type_begin(), 916 returnOp.operand_type_end()); 917 918 // Use the location of the wrapped op for the "test.wrapping_region" op. 919 result.location = wrappedOp->getLoc(); 920 921 return success(); 922 } 923 924 void WrappingRegionOp::print(OpAsmPrinter &p) { 925 p << " wraps "; 926 p.printGenericOp(&getRegion().front().front()); 927 } 928 929 //===----------------------------------------------------------------------===// 930 // Test PrettyPrintedRegionOp - exercising the following parser APIs 931 // parseGenericOperationAfterOpName 932 // parseCustomOperationName 933 //===----------------------------------------------------------------------===// 934 935 ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser, 936 OperationState &result) { 937 938 SMLoc loc = parser.getCurrentLocation(); 939 Location currLocation = parser.getEncodedSourceLoc(loc); 940 941 // Parse the operands. 942 SmallVector<OpAsmParser::UnresolvedOperand, 2> operands; 943 if (parser.parseOperandList(operands)) 944 return failure(); 945 946 // Check if we are parsing the pretty-printed version 947 // test.pretty_printed_region start <inner-op> end : <functional-type> 948 // Else fallback to parsing the "non pretty-printed" version. 949 if (!succeeded(parser.parseOptionalKeyword("start"))) 950 return parser.parseGenericOperationAfterOpName( 951 result, llvm::makeArrayRef(operands)); 952 953 FailureOr<OperationName> parseOpNameInfo = parser.parseCustomOperationName(); 954 if (failed(parseOpNameInfo)) 955 return failure(); 956 957 StringAttr innerOpName = parseOpNameInfo->getIdentifier(); 958 959 FunctionType opFntype; 960 Optional<Location> explicitLoc; 961 if (parser.parseKeyword("end") || parser.parseColon() || 962 parser.parseType(opFntype) || 963 parser.parseOptionalLocationSpecifier(explicitLoc)) 964 return failure(); 965 966 // If location of the op is explicitly provided, then use it; Else use 967 // the parser's current location. 968 Location opLoc = explicitLoc.value_or(currLocation); 969 970 // Derive the SSA-values for op's operands. 971 if (parser.resolveOperands(operands, opFntype.getInputs(), loc, 972 result.operands)) 973 return failure(); 974 975 // Add a region for op. 976 Region ®ion = *result.addRegion(); 977 978 // Create a basic-block inside op's region. 979 Block &block = region.emplaceBlock(); 980 981 // Create and insert an "inner-op" operation in the block. 982 // Just for testing purposes, we can assume that inner op is a binary op with 983 // result and operand types all same as the test-op's first operand. 984 Type innerOpType = opFntype.getInput(0); 985 Value lhs = block.addArgument(innerOpType, opLoc); 986 Value rhs = block.addArgument(innerOpType, opLoc); 987 988 OpBuilder builder(parser.getBuilder().getContext()); 989 builder.setInsertionPointToStart(&block); 990 991 Operation *innerOp = 992 builder.create(opLoc, innerOpName, /*operands=*/{lhs, rhs}, innerOpType); 993 994 // Insert a return statement in the block returning the inner-op's result. 995 builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults()); 996 997 // Populate the op operation-state with result-type and location. 998 result.addTypes(opFntype.getResults()); 999 result.location = innerOp->getLoc(); 1000 1001 return success(); 1002 } 1003 1004 void PrettyPrintedRegionOp::print(OpAsmPrinter &p) { 1005 p << ' '; 1006 p.printOperands(getOperands()); 1007 1008 Operation &innerOp = getRegion().front().front(); 1009 // Assuming that region has a single non-terminator inner-op, if the inner-op 1010 // meets some criteria (which in this case is a simple one based on the name 1011 // of inner-op), then we can print the entire region in a succinct way. 1012 // Here we assume that the prototype of "special.op" can be trivially derived 1013 // while parsing it back. 1014 if (innerOp.getName().getStringRef().equals("special.op")) { 1015 p << " start special.op end"; 1016 } else { 1017 p << " ("; 1018 p.printRegion(getRegion()); 1019 p << ")"; 1020 } 1021 1022 p << " : "; 1023 p.printFunctionalType(*this); 1024 } 1025 1026 //===----------------------------------------------------------------------===// 1027 // Test PolyForOp - parse list of region arguments. 1028 //===----------------------------------------------------------------------===// 1029 1030 ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) { 1031 SmallVector<OpAsmParser::Argument, 4> ivsInfo; 1032 // Parse list of region arguments without a delimiter. 1033 if (parser.parseArgumentList(ivsInfo, OpAsmParser::Delimiter::None)) 1034 return failure(); 1035 1036 // Parse the body region. 1037 Region *body = result.addRegion(); 1038 for (auto &iv : ivsInfo) 1039 iv.type = parser.getBuilder().getIndexType(); 1040 return parser.parseRegion(*body, ivsInfo); 1041 } 1042 1043 void PolyForOp::print(OpAsmPrinter &p) { p.printGenericOp(*this); } 1044 1045 void PolyForOp::getAsmBlockArgumentNames(Region ®ion, 1046 OpAsmSetValueNameFn setNameFn) { 1047 auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names"); 1048 if (!arrayAttr) 1049 return; 1050 auto args = getRegion().front().getArguments(); 1051 auto e = std::min(arrayAttr.size(), args.size()); 1052 for (unsigned i = 0; i < e; ++i) { 1053 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 1054 setNameFn(args[i], strAttr.getValue()); 1055 } 1056 } 1057 1058 //===----------------------------------------------------------------------===// 1059 // Test removing op with inner ops. 1060 //===----------------------------------------------------------------------===// 1061 1062 namespace { 1063 struct TestRemoveOpWithInnerOps 1064 : public OpRewritePattern<TestOpWithRegionPattern> { 1065 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 1066 1067 void initialize() { setDebugName("TestRemoveOpWithInnerOps"); } 1068 1069 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 1070 PatternRewriter &rewriter) const override { 1071 rewriter.eraseOp(op); 1072 return success(); 1073 } 1074 }; 1075 } // namespace 1076 1077 void TestOpWithRegionPattern::getCanonicalizationPatterns( 1078 RewritePatternSet &results, MLIRContext *context) { 1079 results.add<TestRemoveOpWithInnerOps>(context); 1080 } 1081 1082 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 1083 return getOperand(); 1084 } 1085 1086 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) { 1087 return getValue(); 1088 } 1089 1090 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 1091 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 1092 for (Value input : this->getOperands()) { 1093 results.push_back(input); 1094 } 1095 return success(); 1096 } 1097 1098 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 1099 assert(operands.size() == 1); 1100 if (operands.front()) { 1101 (*this)->setAttr("attr", operands.front()); 1102 return getResult(); 1103 } 1104 return {}; 1105 } 1106 1107 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) { 1108 return getOperand(); 1109 } 1110 1111 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 1112 MLIRContext *, Optional<Location> location, ValueRange operands, 1113 DictionaryAttr attributes, RegionRange regions, 1114 SmallVectorImpl<Type> &inferredReturnTypes) { 1115 if (operands[0].getType() != operands[1].getType()) { 1116 return emitOptionalError(location, "operand type mismatch ", 1117 operands[0].getType(), " vs ", 1118 operands[1].getType()); 1119 } 1120 inferredReturnTypes.assign({operands[0].getType()}); 1121 return success(); 1122 } 1123 1124 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 1125 MLIRContext *context, Optional<Location> location, ValueShapeRange operands, 1126 DictionaryAttr attributes, RegionRange regions, 1127 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 1128 // Create return type consisting of the last element of the first operand. 1129 auto operandType = operands.front().getType(); 1130 auto sval = operandType.dyn_cast<ShapedType>(); 1131 if (!sval) { 1132 return emitOptionalError(location, "only shaped type operands allowed"); 1133 } 1134 int64_t dim = 1135 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 1136 auto type = IntegerType::get(context, 17); 1137 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 1138 return success(); 1139 } 1140 1141 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 1142 OpBuilder &builder, ValueRange operands, 1143 llvm::SmallVectorImpl<Value> &shapes) { 1144 shapes = SmallVector<Value, 1>{ 1145 builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)}; 1146 return success(); 1147 } 1148 1149 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( 1150 OpBuilder &builder, ValueRange operands, 1151 llvm::SmallVectorImpl<Value> &shapes) { 1152 Location loc = getLoc(); 1153 shapes.reserve(operands.size()); 1154 for (Value operand : llvm::reverse(operands)) { 1155 auto rank = operand.getType().cast<RankedTensorType>().getRank(); 1156 auto currShape = llvm::to_vector<4>( 1157 llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value { 1158 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 1159 })); 1160 shapes.push_back(builder.create<tensor::FromElementsOp>( 1161 getLoc(), RankedTensorType::get({rank}, builder.getIndexType()), 1162 currShape)); 1163 } 1164 return success(); 1165 } 1166 1167 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( 1168 OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { 1169 Location loc = getLoc(); 1170 shapes.reserve(getNumOperands()); 1171 for (Value operand : llvm::reverse(getOperands())) { 1172 auto currShape = llvm::to_vector<4>(llvm::map_range( 1173 llvm::seq<int64_t>( 1174 0, operand.getType().cast<RankedTensorType>().getRank()), 1175 [&](int64_t dim) -> Value { 1176 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 1177 })); 1178 shapes.emplace_back(std::move(currShape)); 1179 } 1180 return success(); 1181 } 1182 1183 //===----------------------------------------------------------------------===// 1184 // Test SideEffect interfaces 1185 //===----------------------------------------------------------------------===// 1186 1187 namespace { 1188 /// A test resource for side effects. 1189 struct TestResource : public SideEffects::Resource::Base<TestResource> { 1190 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource) 1191 1192 StringRef getName() final { return "<Test>"; } 1193 }; 1194 } // namespace 1195 1196 static void testSideEffectOpGetEffect( 1197 Operation *op, 1198 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 1199 &effects) { 1200 auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter"); 1201 if (!effectsAttr) 1202 return; 1203 1204 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); 1205 } 1206 1207 void SideEffectOp::getEffects( 1208 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1209 // Check for an effects attribute on the op instance. 1210 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects"); 1211 if (!effectsAttr) 1212 return; 1213 1214 // If there is one, it is an array of dictionary attributes that hold 1215 // information on the effects of this operation. 1216 for (Attribute element : effectsAttr) { 1217 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 1218 1219 // Get the specific memory effect. 1220 MemoryEffects::Effect *effect = 1221 StringSwitch<MemoryEffects::Effect *>( 1222 effectElement.get("effect").cast<StringAttr>().getValue()) 1223 .Case("allocate", MemoryEffects::Allocate::get()) 1224 .Case("free", MemoryEffects::Free::get()) 1225 .Case("read", MemoryEffects::Read::get()) 1226 .Case("write", MemoryEffects::Write::get()); 1227 1228 // Check for a non-default resource to use. 1229 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 1230 if (effectElement.get("test_resource")) 1231 resource = TestResource::get(); 1232 1233 // Check for a result to affect. 1234 if (effectElement.get("on_result")) 1235 effects.emplace_back(effect, getResult(), resource); 1236 else if (Attribute ref = effectElement.get("on_reference")) 1237 effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource); 1238 else 1239 effects.emplace_back(effect, resource); 1240 } 1241 } 1242 1243 void SideEffectOp::getEffects( 1244 SmallVectorImpl<TestEffects::EffectInstance> &effects) { 1245 testSideEffectOpGetEffect(getOperation(), effects); 1246 } 1247 1248 //===----------------------------------------------------------------------===// 1249 // StringAttrPrettyNameOp 1250 //===----------------------------------------------------------------------===// 1251 1252 // This op has fancy handling of its SSA result name. 1253 ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser, 1254 OperationState &result) { 1255 // Add the result types. 1256 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 1257 result.addTypes(parser.getBuilder().getIntegerType(32)); 1258 1259 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 1260 return failure(); 1261 1262 // If the attribute dictionary contains no 'names' attribute, infer it from 1263 // the SSA name (if specified). 1264 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 1265 return attr.getName() == "names"; 1266 }); 1267 1268 // If there was no name specified, check to see if there was a useful name 1269 // specified in the asm file. 1270 if (hadNames || parser.getNumResults() == 0) 1271 return success(); 1272 1273 SmallVector<StringRef, 4> names; 1274 auto *context = result.getContext(); 1275 1276 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 1277 auto resultName = parser.getResultName(i); 1278 StringRef nameStr; 1279 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 1280 nameStr = resultName.first; 1281 1282 names.push_back(nameStr); 1283 } 1284 1285 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 1286 result.attributes.push_back({StringAttr::get(context, "names"), namesAttr}); 1287 return success(); 1288 } 1289 1290 void StringAttrPrettyNameOp::print(OpAsmPrinter &p) { 1291 // Note that we only need to print the "name" attribute if the asmprinter 1292 // result name disagrees with it. This can happen in strange cases, e.g. 1293 // when there are conflicts. 1294 bool namesDisagree = getNames().size() != getNumResults(); 1295 1296 SmallString<32> resultNameStr; 1297 for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) { 1298 resultNameStr.clear(); 1299 llvm::raw_svector_ostream tmpStream(resultNameStr); 1300 p.printOperand(getResult(i), tmpStream); 1301 1302 auto expectedName = getNames()[i].dyn_cast<StringAttr>(); 1303 if (!expectedName || 1304 tmpStream.str().drop_front() != expectedName.getValue()) { 1305 namesDisagree = true; 1306 } 1307 } 1308 1309 if (namesDisagree) 1310 p.printOptionalAttrDictWithKeyword((*this)->getAttrs()); 1311 else 1312 p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"}); 1313 } 1314 1315 // We set the SSA name in the asm syntax to the contents of the name 1316 // attribute. 1317 void StringAttrPrettyNameOp::getAsmResultNames( 1318 function_ref<void(Value, StringRef)> setNameFn) { 1319 1320 auto value = getNames(); 1321 for (size_t i = 0, e = value.size(); i != e; ++i) 1322 if (auto str = value[i].dyn_cast<StringAttr>()) 1323 if (!str.getValue().empty()) 1324 setNameFn(getResult(i), str.getValue()); 1325 } 1326 1327 void CustomResultsNameOp::getAsmResultNames( 1328 function_ref<void(Value, StringRef)> setNameFn) { 1329 ArrayAttr value = getNames(); 1330 for (size_t i = 0, e = value.size(); i != e; ++i) 1331 if (auto str = value[i].dyn_cast<StringAttr>()) 1332 if (!str.getValue().empty()) 1333 setNameFn(getResult(i), str.getValue()); 1334 } 1335 1336 //===----------------------------------------------------------------------===// 1337 // ResultTypeWithTraitOp 1338 //===----------------------------------------------------------------------===// 1339 1340 LogicalResult ResultTypeWithTraitOp::verify() { 1341 if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>()) 1342 return success(); 1343 return emitError("result type should have trait 'TestTypeTrait'"); 1344 } 1345 1346 //===----------------------------------------------------------------------===// 1347 // AttrWithTraitOp 1348 //===----------------------------------------------------------------------===// 1349 1350 LogicalResult AttrWithTraitOp::verify() { 1351 if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>()) 1352 return success(); 1353 return emitError("'attr' attribute should have trait 'TestAttrTrait'"); 1354 } 1355 1356 //===----------------------------------------------------------------------===// 1357 // RegionIfOp 1358 //===----------------------------------------------------------------------===// 1359 1360 void RegionIfOp::print(OpAsmPrinter &p) { 1361 p << " "; 1362 p.printOperands(getOperands()); 1363 p << ": " << getOperandTypes(); 1364 p.printArrowTypeList(getResultTypes()); 1365 p << " then "; 1366 p.printRegion(getThenRegion(), 1367 /*printEntryBlockArgs=*/true, 1368 /*printBlockTerminators=*/true); 1369 p << " else "; 1370 p.printRegion(getElseRegion(), 1371 /*printEntryBlockArgs=*/true, 1372 /*printBlockTerminators=*/true); 1373 p << " join "; 1374 p.printRegion(getJoinRegion(), 1375 /*printEntryBlockArgs=*/true, 1376 /*printBlockTerminators=*/true); 1377 } 1378 1379 ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) { 1380 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos; 1381 SmallVector<Type, 2> operandTypes; 1382 1383 result.regions.reserve(3); 1384 Region *thenRegion = result.addRegion(); 1385 Region *elseRegion = result.addRegion(); 1386 Region *joinRegion = result.addRegion(); 1387 1388 // Parse operand, type and arrow type lists. 1389 if (parser.parseOperandList(operandInfos) || 1390 parser.parseColonTypeList(operandTypes) || 1391 parser.parseArrowTypeList(result.types)) 1392 return failure(); 1393 1394 // Parse all attached regions. 1395 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 1396 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 1397 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 1398 return failure(); 1399 1400 return parser.resolveOperands(operandInfos, operandTypes, 1401 parser.getCurrentLocation(), result.operands); 1402 } 1403 1404 OperandRange RegionIfOp::getSuccessorEntryOperands(Optional<unsigned> index) { 1405 assert(index && *index < 2 && "invalid region index"); 1406 return getOperands(); 1407 } 1408 1409 void RegionIfOp::getSuccessorRegions( 1410 Optional<unsigned> index, ArrayRef<Attribute> operands, 1411 SmallVectorImpl<RegionSuccessor> ®ions) { 1412 // We always branch to the join region. 1413 if (index.hasValue()) { 1414 if (index.getValue() < 2) 1415 regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs())); 1416 else 1417 regions.push_back(RegionSuccessor(getResults())); 1418 return; 1419 } 1420 1421 // The then and else regions are the entry regions of this op. 1422 regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs())); 1423 regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs())); 1424 } 1425 1426 void RegionIfOp::getRegionInvocationBounds( 1427 ArrayRef<Attribute> operands, 1428 SmallVectorImpl<InvocationBounds> &invocationBounds) { 1429 // Each region is invoked at most once. 1430 invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1}); 1431 } 1432 1433 //===----------------------------------------------------------------------===// 1434 // AnyCondOp 1435 //===----------------------------------------------------------------------===// 1436 1437 void AnyCondOp::getSuccessorRegions(Optional<unsigned> index, 1438 ArrayRef<Attribute> operands, 1439 SmallVectorImpl<RegionSuccessor> ®ions) { 1440 // The parent op branches into the only region, and the region branches back 1441 // to the parent op. 1442 if (!index) 1443 regions.emplace_back(&getRegion()); 1444 else 1445 regions.emplace_back(getResults()); 1446 } 1447 1448 void AnyCondOp::getRegionInvocationBounds( 1449 ArrayRef<Attribute> operands, 1450 SmallVectorImpl<InvocationBounds> &invocationBounds) { 1451 invocationBounds.emplace_back(1, 1); 1452 } 1453 1454 //===----------------------------------------------------------------------===// 1455 // SingleNoTerminatorCustomAsmOp 1456 //===----------------------------------------------------------------------===// 1457 1458 ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser, 1459 OperationState &state) { 1460 Region *body = state.addRegion(); 1461 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) 1462 return failure(); 1463 return success(); 1464 } 1465 1466 void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) { 1467 printer.printRegion( 1468 getRegion(), /*printEntryBlockArgs=*/false, 1469 // This op has a single block without terminators. But explicitly mark 1470 // as not printing block terminators for testing. 1471 /*printBlockTerminators=*/false); 1472 } 1473 1474 //===----------------------------------------------------------------------===// 1475 // TestVerifiersOp 1476 //===----------------------------------------------------------------------===// 1477 1478 LogicalResult TestVerifiersOp::verify() { 1479 if (!getRegion().hasOneBlock()) 1480 return emitOpError("`hasOneBlock` trait hasn't been verified"); 1481 1482 Operation *definingOp = getInput().getDefiningOp(); 1483 if (definingOp && failed(mlir::verify(definingOp))) 1484 return emitOpError("operand hasn't been verified"); 1485 1486 emitRemark("success run of verifier"); 1487 1488 return success(); 1489 } 1490 1491 LogicalResult TestVerifiersOp::verifyRegions() { 1492 if (!getRegion().hasOneBlock()) 1493 return emitOpError("`hasOneBlock` trait hasn't been verified"); 1494 1495 for (Block &block : getRegion()) 1496 for (Operation &op : block) 1497 if (failed(mlir::verify(&op))) 1498 return emitOpError("nested op hasn't been verified"); 1499 1500 emitRemark("success run of region verifier"); 1501 1502 return success(); 1503 } 1504 1505 //===----------------------------------------------------------------------===// 1506 // Test InferIntRangeInterface 1507 //===----------------------------------------------------------------------===// 1508 1509 void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 1510 SetIntRangeFn setResultRanges) { 1511 setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()}); 1512 } 1513 1514 ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser, 1515 OperationState &result) { 1516 if (parser.parseOptionalAttrDict(result.attributes)) 1517 return failure(); 1518 1519 // Parse the input argument 1520 OpAsmParser::Argument argInfo; 1521 argInfo.type = parser.getBuilder().getIndexType(); 1522 if (failed(parser.parseArgument(argInfo))) 1523 return failure(); 1524 1525 // Parse the body region, and reuse the operand info as the argument info. 1526 Region *body = result.addRegion(); 1527 return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false); 1528 } 1529 1530 void TestWithBoundsRegionOp::print(OpAsmPrinter &p) { 1531 p.printOptionalAttrDict((*this)->getAttrs()); 1532 p << ' '; 1533 p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{}, 1534 /*omitType=*/true); 1535 p << ' '; 1536 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); 1537 } 1538 1539 void TestWithBoundsRegionOp::inferResultRanges( 1540 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) { 1541 Value arg = getRegion().getArgument(0); 1542 setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()}); 1543 } 1544 1545 void TestIncrementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 1546 SetIntRangeFn setResultRanges) { 1547 const ConstantIntRanges &range = argRanges[0]; 1548 APInt one(range.umin().getBitWidth(), 1); 1549 setResultRanges(getResult(), 1550 {range.umin().uadd_sat(one), range.umax().uadd_sat(one), 1551 range.smin().sadd_sat(one), range.smax().sadd_sat(one)}); 1552 } 1553 1554 void TestReflectBoundsOp::inferResultRanges( 1555 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) { 1556 const ConstantIntRanges &range = argRanges[0]; 1557 MLIRContext *ctx = getContext(); 1558 Builder b(ctx); 1559 setUminAttr(b.getIndexAttr(range.umin().getZExtValue())); 1560 setUmaxAttr(b.getIndexAttr(range.umax().getZExtValue())); 1561 setSminAttr(b.getIndexAttr(range.smin().getSExtValue())); 1562 setSmaxAttr(b.getIndexAttr(range.smax().getSExtValue())); 1563 setResultRanges(getResult(), range); 1564 } 1565 1566 #include "TestOpEnums.cpp.inc" 1567 #include "TestOpInterfaces.cpp.inc" 1568 #include "TestTypeInterfaces.cpp.inc" 1569 1570 #define GET_OP_CLASSES 1571 #include "TestOps.cpp.inc" 1572