1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestAttributes.h" 11 #include "TestInterfaces.h" 12 #include "TestTypes.h" 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/DLTI/DLTI.h" 15 #include "mlir/Dialect/Func/IR/FuncOps.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/IR/AsmState.h" 18 #include "mlir/IR/BuiltinAttributes.h" 19 #include "mlir/IR/BuiltinOps.h" 20 #include "mlir/IR/Diagnostics.h" 21 #include "mlir/IR/DialectImplementation.h" 22 #include "mlir/IR/ExtensibleDialect.h" 23 #include "mlir/IR/MLIRContext.h" 24 #include "mlir/IR/OperationSupport.h" 25 #include "mlir/IR/PatternMatch.h" 26 #include "mlir/IR/TypeUtilities.h" 27 #include "mlir/IR/Verifier.h" 28 #include "mlir/Interfaces/InferIntRangeInterface.h" 29 #include "mlir/Reducer/ReductionPatternInterface.h" 30 #include "mlir/Transforms/FoldUtils.h" 31 #include "mlir/Transforms/InliningUtils.h" 32 #include "llvm/ADT/SmallString.h" 33 #include "llvm/ADT/StringExtras.h" 34 #include "llvm/ADT/StringSwitch.h" 35 36 // Include this before the using namespace lines below to 37 // test that we don't have namespace dependencies. 38 #include "TestOpsDialect.cpp.inc" 39 40 using namespace mlir; 41 using namespace test; 42 43 void test::registerTestDialect(DialectRegistry ®istry) { 44 registry.insert<TestDialect>(); 45 } 46 47 //===----------------------------------------------------------------------===// 48 // External Elements Data 49 //===----------------------------------------------------------------------===// 50 51 ArrayRef<uint64_t> TestExternalElementsData::getData() const { 52 ArrayRef<char> data = AsmResourceBlob::getData(); 53 return ArrayRef<uint64_t>((const uint64_t *)data.data(), 54 data.size() / sizeof(uint64_t)); 55 } 56 57 TestExternalElementsData 58 TestExternalElementsData::allocate(size_t numElements) { 59 return TestExternalElementsData( 60 llvm::ArrayRef<uint64_t>(new uint64_t[numElements], numElements), 61 [](const uint64_t *data, size_t) { delete[] data; }, 62 /*dataIsMutable=*/true); 63 } 64 65 const TestExternalElementsData * 66 TestExternalElementsDataManager::getData(StringRef name) const { 67 auto it = dataMap.find(name); 68 return it != dataMap.end() ? &*it->second : nullptr; 69 } 70 71 std::pair<TestExternalElementsDataManager::DataMap::iterator, bool> 72 TestExternalElementsDataManager::insert(StringRef name) { 73 auto it = dataMap.try_emplace(name, nullptr); 74 if (it.second) 75 return it; 76 77 llvm::SmallString<32> nameStorage(name); 78 nameStorage.push_back('_'); 79 size_t nameCounter = 1; 80 do { 81 nameStorage += std::to_string(nameCounter++); 82 auto it = dataMap.try_emplace(nameStorage, nullptr); 83 if (it.second) 84 return it; 85 nameStorage.resize(name.size() + 1); 86 } while (true); 87 } 88 89 void TestExternalElementsDataManager::setData(StringRef name, 90 TestExternalElementsData &&data) { 91 auto it = dataMap.find(name); 92 assert(it != dataMap.end() && "data not registered"); 93 it->second = std::make_unique<TestExternalElementsData>(std::move(data)); 94 } 95 96 //===----------------------------------------------------------------------===// 97 // TestDialect Interfaces 98 //===----------------------------------------------------------------------===// 99 100 namespace { 101 102 /// Testing the correctness of some traits. 103 static_assert( 104 llvm::is_detected<OpTrait::has_implicit_terminator_t, 105 SingleBlockImplicitTerminatorOp>::value, 106 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp"); 107 static_assert(OpTrait::hasSingleBlockImplicitTerminator< 108 SingleBlockImplicitTerminatorOp>::value, 109 "hasSingleBlockImplicitTerminator does not match " 110 "SingleBlockImplicitTerminatorOp"); 111 112 // Test support for interacting with the AsmPrinter. 113 struct TestOpAsmInterface : public OpAsmDialectInterface { 114 using OpAsmDialectInterface::OpAsmDialectInterface; 115 116 //===------------------------------------------------------------------===// 117 // Aliases 118 //===------------------------------------------------------------------===// 119 120 AliasResult getAlias(Attribute attr, raw_ostream &os) const final { 121 StringAttr strAttr = attr.dyn_cast<StringAttr>(); 122 if (!strAttr) 123 return AliasResult::NoAlias; 124 125 // Check the contents of the string attribute to see what the test alias 126 // should be named. 127 Optional<StringRef> aliasName = 128 StringSwitch<Optional<StringRef>>(strAttr.getValue()) 129 .Case("alias_test:dot_in_name", StringRef("test.alias")) 130 .Case("alias_test:trailing_digit", StringRef("test_alias0")) 131 .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) 132 .Case("alias_test:sanitize_conflict_a", 133 StringRef("test_alias_conflict0")) 134 .Case("alias_test:sanitize_conflict_b", 135 StringRef("test_alias_conflict0_")) 136 .Case("alias_test:tensor_encoding", StringRef("test_encoding")) 137 .Default(llvm::None); 138 if (!aliasName) 139 return AliasResult::NoAlias; 140 141 os << *aliasName; 142 return AliasResult::FinalAlias; 143 } 144 145 AliasResult getAlias(Type type, raw_ostream &os) const final { 146 if (auto tupleType = type.dyn_cast<TupleType>()) { 147 if (tupleType.size() > 0 && 148 llvm::all_of(tupleType.getTypes(), [](Type elemType) { 149 return elemType.isa<SimpleAType>(); 150 })) { 151 os << "test_tuple"; 152 return AliasResult::FinalAlias; 153 } 154 } 155 if (auto intType = type.dyn_cast<TestIntegerType>()) { 156 if (intType.getSignedness() == 157 TestIntegerType::SignednessSemantics::Unsigned && 158 intType.getWidth() == 8) { 159 os << "test_ui8"; 160 return AliasResult::FinalAlias; 161 } 162 } 163 if (auto recType = type.dyn_cast<TestRecursiveType>()) { 164 if (recType.getName() == "type_to_alias") { 165 // We only make alias for a specific recursive type. 166 os << "testrec"; 167 return AliasResult::FinalAlias; 168 } 169 } 170 return AliasResult::NoAlias; 171 } 172 173 //===------------------------------------------------------------------===// 174 // Resources 175 //===------------------------------------------------------------------===// 176 177 std::string 178 getResourceKey(const AsmDialectResourceHandle &handle) const override { 179 return cast<TestExternalElementsDataHandle>(handle).getKey().str(); 180 } 181 182 FailureOr<AsmDialectResourceHandle> 183 declareResource(StringRef key) const final { 184 TestDialect *dialect = cast<TestDialect>(getDialect()); 185 TestExternalElementsDataManager &mgr = dialect->getExternalDataManager(); 186 187 // Resolve the reference by inserting a new entry into the manager. 188 auto it = mgr.insert(key).first; 189 return TestExternalElementsDataHandle(&*it, dialect); 190 } 191 192 LogicalResult parseResource(AsmParsedResourceEntry &entry) const final { 193 TestDialect *dialect = cast<TestDialect>(getDialect()); 194 TestExternalElementsDataManager &mgr = dialect->getExternalDataManager(); 195 196 // The resource entries are external constant data. 197 auto blobAllocFn = [](unsigned size, unsigned align) { 198 assert(align == alignof(uint64_t) && "unexpected data alignment"); 199 return TestExternalElementsData::allocate(size / sizeof(uint64_t)); 200 }; 201 FailureOr<AsmResourceBlob> blob = entry.parseAsBlob(blobAllocFn); 202 if (failed(blob)) 203 return failure(); 204 205 mgr.setData(entry.getKey(), std::move(*blob)); 206 return success(); 207 } 208 209 void 210 buildResources(Operation *op, 211 const SetVector<AsmDialectResourceHandle> &referencedResources, 212 AsmResourceBuilder &provider) const final { 213 for (const AsmDialectResourceHandle &handle : referencedResources) { 214 const auto &testHandle = cast<TestExternalElementsDataHandle>(handle); 215 provider.buildBlob(testHandle.getKey(), testHandle.getData()->getData()); 216 } 217 } 218 }; 219 220 struct TestDialectFoldInterface : public DialectFoldInterface { 221 using DialectFoldInterface::DialectFoldInterface; 222 223 /// Registered hook to check if the given region, which is attached to an 224 /// operation that is *not* isolated from above, should be used when 225 /// materializing constants. 226 bool shouldMaterializeInto(Region *region) const final { 227 // If this is a one region operation, then insert into it. 228 return isa<OneRegionOp>(region->getParentOp()); 229 } 230 }; 231 232 /// This class defines the interface for handling inlining with standard 233 /// operations. 234 struct TestInlinerInterface : public DialectInlinerInterface { 235 using DialectInlinerInterface::DialectInlinerInterface; 236 237 //===--------------------------------------------------------------------===// 238 // Analysis Hooks 239 //===--------------------------------------------------------------------===// 240 241 bool isLegalToInline(Operation *call, Operation *callable, 242 bool wouldBeCloned) const final { 243 // Don't allow inlining calls that are marked `noinline`. 244 return !call->hasAttr("noinline"); 245 } 246 bool isLegalToInline(Region *, Region *, bool, 247 BlockAndValueMapping &) const final { 248 // Inlining into test dialect regions is legal. 249 return true; 250 } 251 bool isLegalToInline(Operation *, Region *, bool, 252 BlockAndValueMapping &) const final { 253 return true; 254 } 255 256 bool shouldAnalyzeRecursively(Operation *op) const final { 257 // Analyze recursively if this is not a functional region operation, it 258 // froms a separate functional scope. 259 return !isa<FunctionalRegionOp>(op); 260 } 261 262 //===--------------------------------------------------------------------===// 263 // Transformation Hooks 264 //===--------------------------------------------------------------------===// 265 266 /// Handle the given inlined terminator by replacing it with a new operation 267 /// as necessary. 268 void handleTerminator(Operation *op, 269 ArrayRef<Value> valuesToRepl) const final { 270 // Only handle "test.return" here. 271 auto returnOp = dyn_cast<TestReturnOp>(op); 272 if (!returnOp) 273 return; 274 275 // Replace the values directly with the return operands. 276 assert(returnOp.getNumOperands() == valuesToRepl.size()); 277 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 278 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 279 } 280 281 /// Attempt to materialize a conversion for a type mismatch between a call 282 /// from this dialect, and a callable region. This method should generate an 283 /// operation that takes 'input' as the only operand, and produces a single 284 /// result of 'resultType'. If a conversion can not be generated, nullptr 285 /// should be returned. 286 Operation *materializeCallConversion(OpBuilder &builder, Value input, 287 Type resultType, 288 Location conversionLoc) const final { 289 // Only allow conversion for i16/i32 types. 290 if (!(resultType.isSignlessInteger(16) || 291 resultType.isSignlessInteger(32)) || 292 !(input.getType().isSignlessInteger(16) || 293 input.getType().isSignlessInteger(32))) 294 return nullptr; 295 return builder.create<TestCastOp>(conversionLoc, resultType, input); 296 } 297 298 void processInlinedCallBlocks( 299 Operation *call, 300 iterator_range<Region::iterator> inlinedBlocks) const final { 301 if (!isa<ConversionCallOp>(call)) 302 return; 303 304 // Set attributed on all ops in the inlined blocks. 305 for (Block &block : inlinedBlocks) { 306 block.walk([&](Operation *op) { 307 op->setAttr("inlined_conversion", UnitAttr::get(call->getContext())); 308 }); 309 } 310 } 311 }; 312 313 struct TestReductionPatternInterface : public DialectReductionPatternInterface { 314 public: 315 TestReductionPatternInterface(Dialect *dialect) 316 : DialectReductionPatternInterface(dialect) {} 317 318 void populateReductionPatterns(RewritePatternSet &patterns) const final { 319 populateTestReductionPatterns(patterns); 320 } 321 }; 322 323 } // namespace 324 325 //===----------------------------------------------------------------------===// 326 // Dynamic operations 327 //===----------------------------------------------------------------------===// 328 329 std::unique_ptr<DynamicOpDefinition> getDynamicGenericOp(TestDialect *dialect) { 330 return DynamicOpDefinition::get( 331 "dynamic_generic", dialect, [](Operation *op) { return success(); }, 332 [](Operation *op) { return success(); }); 333 } 334 335 std::unique_ptr<DynamicOpDefinition> 336 getDynamicOneOperandTwoResultsOp(TestDialect *dialect) { 337 return DynamicOpDefinition::get( 338 "dynamic_one_operand_two_results", dialect, 339 [](Operation *op) { 340 if (op->getNumOperands() != 1) { 341 op->emitOpError() 342 << "expected 1 operand, but had " << op->getNumOperands(); 343 return failure(); 344 } 345 if (op->getNumResults() != 2) { 346 op->emitOpError() 347 << "expected 2 results, but had " << op->getNumResults(); 348 return failure(); 349 } 350 return success(); 351 }, 352 [](Operation *op) { return success(); }); 353 } 354 355 std::unique_ptr<DynamicOpDefinition> 356 getDynamicCustomParserPrinterOp(TestDialect *dialect) { 357 auto verifier = [](Operation *op) { 358 if (op->getNumOperands() == 0 && op->getNumResults() == 0) 359 return success(); 360 op->emitError() << "operation should have no operands and no results"; 361 return failure(); 362 }; 363 auto regionVerifier = [](Operation *op) { return success(); }; 364 365 auto parser = [](OpAsmParser &parser, OperationState &state) { 366 return parser.parseKeyword("custom_keyword"); 367 }; 368 369 auto printer = [](Operation *op, OpAsmPrinter &printer, llvm::StringRef) { 370 printer << op->getName() << " custom_keyword"; 371 }; 372 373 return DynamicOpDefinition::get("dynamic_custom_parser_printer", dialect, 374 verifier, regionVerifier, parser, printer); 375 } 376 377 //===----------------------------------------------------------------------===// 378 // TestDialect 379 //===----------------------------------------------------------------------===// 380 381 static void testSideEffectOpGetEffect( 382 Operation *op, 383 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects); 384 385 // This is the implementation of a dialect fallback for `TestEffectOpInterface`. 386 struct TestOpEffectInterfaceFallback 387 : public TestEffectOpInterface::FallbackModel< 388 TestOpEffectInterfaceFallback> { 389 static bool classof(Operation *op) { 390 bool isSupportedOp = 391 op->getName().getStringRef() == "test.unregistered_side_effect_op"; 392 assert(isSupportedOp && "Unexpected dispatch"); 393 return isSupportedOp; 394 } 395 396 void 397 getEffects(Operation *op, 398 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 399 &effects) const { 400 testSideEffectOpGetEffect(op, effects); 401 } 402 }; 403 404 void TestDialect::initialize() { 405 registerAttributes(); 406 registerTypes(); 407 addOperations< 408 #define GET_OP_LIST 409 #include "TestOps.cpp.inc" 410 >(); 411 registerDynamicOp(getDynamicGenericOp(this)); 412 registerDynamicOp(getDynamicOneOperandTwoResultsOp(this)); 413 registerDynamicOp(getDynamicCustomParserPrinterOp(this)); 414 415 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 416 TestInlinerInterface, TestReductionPatternInterface>(); 417 allowUnknownOperations(); 418 419 // Instantiate our fallback op interface that we'll use on specific 420 // unregistered op. 421 fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback; 422 } 423 TestDialect::~TestDialect() { 424 delete static_cast<TestOpEffectInterfaceFallback *>( 425 fallbackEffectOpInterfaces); 426 } 427 428 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, 429 Type type, Location loc) { 430 return builder.create<TestOpConstant>(loc, type, value); 431 } 432 433 ::mlir::LogicalResult FormatInferType2Op::inferReturnTypes( 434 ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, 435 ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, 436 ::mlir::RegionRange regions, 437 ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { 438 inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)}); 439 return ::mlir::success(); 440 } 441 442 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, 443 OperationName opName) { 444 if (opName.getIdentifier() == "test.unregistered_side_effect_op" && 445 typeID == TypeID::get<TestEffectOpInterface>()) 446 return fallbackEffectOpInterfaces; 447 return nullptr; 448 } 449 450 LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 451 NamedAttribute namedAttr) { 452 if (namedAttr.getName() == "test.invalid_attr") 453 return op->emitError() << "invalid to use 'test.invalid_attr'"; 454 return success(); 455 } 456 457 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 458 unsigned regionIndex, 459 unsigned argIndex, 460 NamedAttribute namedAttr) { 461 if (namedAttr.getName() == "test.invalid_attr") 462 return op->emitError() << "invalid to use 'test.invalid_attr'"; 463 return success(); 464 } 465 466 LogicalResult 467 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 468 unsigned resultIndex, 469 NamedAttribute namedAttr) { 470 if (namedAttr.getName() == "test.invalid_attr") 471 return op->emitError() << "invalid to use 'test.invalid_attr'"; 472 return success(); 473 } 474 475 Optional<Dialect::ParseOpHook> 476 TestDialect::getParseOperationHook(StringRef opName) const { 477 if (opName == "test.dialect_custom_printer") { 478 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 479 return parser.parseKeyword("custom_format"); 480 }}; 481 } 482 if (opName == "test.dialect_custom_format_fallback") { 483 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 484 return parser.parseKeyword("custom_format_fallback"); 485 }}; 486 } 487 if (opName == "test.dialect_custom_printer.with.dot") { 488 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { 489 return ParseResult::success(); 490 }}; 491 } 492 return None; 493 } 494 495 llvm::unique_function<void(Operation *, OpAsmPrinter &)> 496 TestDialect::getOperationPrinter(Operation *op) const { 497 StringRef opName = op->getName().getStringRef(); 498 if (opName == "test.dialect_custom_printer") { 499 return [](Operation *op, OpAsmPrinter &printer) { 500 printer.getStream() << " custom_format"; 501 }; 502 } 503 if (opName == "test.dialect_custom_format_fallback") { 504 return [](Operation *op, OpAsmPrinter &printer) { 505 printer.getStream() << " custom_format_fallback"; 506 }; 507 } 508 return {}; 509 } 510 511 //===----------------------------------------------------------------------===// 512 // TestBranchOp 513 //===----------------------------------------------------------------------===// 514 515 SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) { 516 assert(index == 0 && "invalid successor index"); 517 return SuccessorOperands(getTargetOperandsMutable()); 518 } 519 520 //===----------------------------------------------------------------------===// 521 // TestProducingBranchOp 522 //===----------------------------------------------------------------------===// 523 524 SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) { 525 assert(index <= 1 && "invalid successor index"); 526 if (index == 1) 527 return SuccessorOperands(getFirstOperandsMutable()); 528 return SuccessorOperands(getSecondOperandsMutable()); 529 } 530 531 //===----------------------------------------------------------------------===// 532 // TestProducingBranchOp 533 //===----------------------------------------------------------------------===// 534 535 SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) { 536 assert(index <= 1 && "invalid successor index"); 537 if (index == 0) 538 return SuccessorOperands(0, getSuccessOperandsMutable()); 539 return SuccessorOperands(1, getErrorOperandsMutable()); 540 } 541 542 //===----------------------------------------------------------------------===// 543 // TestDialectCanonicalizerOp 544 //===----------------------------------------------------------------------===// 545 546 static LogicalResult 547 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op, 548 PatternRewriter &rewriter) { 549 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 550 op, rewriter.getI32IntegerAttr(42)); 551 return success(); 552 } 553 554 void TestDialect::getCanonicalizationPatterns( 555 RewritePatternSet &results) const { 556 results.add(&dialectCanonicalizationPattern); 557 } 558 559 //===----------------------------------------------------------------------===// 560 // TestCallOp 561 //===----------------------------------------------------------------------===// 562 563 LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 564 // Check that the callee attribute was specified. 565 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee"); 566 if (!fnAttr) 567 return emitOpError("requires a 'callee' symbol reference attribute"); 568 if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr)) 569 return emitOpError() << "'" << fnAttr.getValue() 570 << "' does not reference a valid function"; 571 return success(); 572 } 573 574 //===----------------------------------------------------------------------===// 575 // TestFoldToCallOp 576 //===----------------------------------------------------------------------===// 577 578 namespace { 579 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 580 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 581 582 LogicalResult matchAndRewrite(FoldToCallOp op, 583 PatternRewriter &rewriter) const override { 584 rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(), 585 op.getCalleeAttr(), ValueRange()); 586 return success(); 587 } 588 }; 589 } // namespace 590 591 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results, 592 MLIRContext *context) { 593 results.add<FoldToCallOpPattern>(context); 594 } 595 596 //===----------------------------------------------------------------------===// 597 // Test Format* operations 598 //===----------------------------------------------------------------------===// 599 600 //===----------------------------------------------------------------------===// 601 // Parsing 602 603 static ParseResult parseCustomOptionalOperand( 604 OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &optOperand) { 605 if (succeeded(parser.parseOptionalLParen())) { 606 optOperand.emplace(); 607 if (parser.parseOperand(*optOperand) || parser.parseRParen()) 608 return failure(); 609 } 610 return success(); 611 } 612 613 static ParseResult parseCustomDirectiveOperands( 614 OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, 615 Optional<OpAsmParser::UnresolvedOperand> &optOperand, 616 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands) { 617 if (parser.parseOperand(operand)) 618 return failure(); 619 if (succeeded(parser.parseOptionalComma())) { 620 optOperand.emplace(); 621 if (parser.parseOperand(*optOperand)) 622 return failure(); 623 } 624 if (parser.parseArrow() || parser.parseLParen() || 625 parser.parseOperandList(varOperands) || parser.parseRParen()) 626 return failure(); 627 return success(); 628 } 629 static ParseResult 630 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 631 Type &optOperandType, 632 SmallVectorImpl<Type> &varOperandTypes) { 633 if (parser.parseColon()) 634 return failure(); 635 636 if (parser.parseType(operandType)) 637 return failure(); 638 if (succeeded(parser.parseOptionalComma())) { 639 if (parser.parseType(optOperandType)) 640 return failure(); 641 } 642 if (parser.parseArrow() || parser.parseLParen() || 643 parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 644 return failure(); 645 return success(); 646 } 647 static ParseResult 648 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 649 Type optOperandType, 650 const SmallVectorImpl<Type> &varOperandTypes) { 651 if (parser.parseKeyword("type_refs_capture")) 652 return failure(); 653 654 Type operandType2, optOperandType2; 655 SmallVector<Type, 1> varOperandTypes2; 656 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 657 varOperandTypes2)) 658 return failure(); 659 660 if (operandType != operandType2 || optOperandType != optOperandType2 || 661 varOperandTypes != varOperandTypes2) 662 return failure(); 663 664 return success(); 665 } 666 static ParseResult parseCustomDirectiveOperandsAndTypes( 667 OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, 668 Optional<OpAsmParser::UnresolvedOperand> &optOperand, 669 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands, 670 Type &operandType, Type &optOperandType, 671 SmallVectorImpl<Type> &varOperandTypes) { 672 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 673 parseCustomDirectiveResults(parser, operandType, optOperandType, 674 varOperandTypes)) 675 return failure(); 676 return success(); 677 } 678 static ParseResult parseCustomDirectiveRegions( 679 OpAsmParser &parser, Region ®ion, 680 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 681 if (parser.parseRegion(region)) 682 return failure(); 683 if (failed(parser.parseOptionalComma())) 684 return success(); 685 std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 686 if (parser.parseRegion(*varRegion)) 687 return failure(); 688 varRegions.emplace_back(std::move(varRegion)); 689 return success(); 690 } 691 static ParseResult 692 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 693 SmallVectorImpl<Block *> &varSuccessors) { 694 if (parser.parseSuccessor(successor)) 695 return failure(); 696 if (failed(parser.parseOptionalComma())) 697 return success(); 698 Block *varSuccessor; 699 if (parser.parseSuccessor(varSuccessor)) 700 return failure(); 701 varSuccessors.append(2, varSuccessor); 702 return success(); 703 } 704 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 705 IntegerAttr &attr, 706 IntegerAttr &optAttr) { 707 if (parser.parseAttribute(attr)) 708 return failure(); 709 if (succeeded(parser.parseOptionalComma())) { 710 if (parser.parseAttribute(optAttr)) 711 return failure(); 712 } 713 return success(); 714 } 715 716 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, 717 NamedAttrList &attrs) { 718 return parser.parseOptionalAttrDict(attrs); 719 } 720 static ParseResult parseCustomDirectiveOptionalOperandRef( 721 OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &optOperand) { 722 int64_t operandCount = 0; 723 if (parser.parseInteger(operandCount)) 724 return failure(); 725 bool expectedOptionalOperand = operandCount == 0; 726 return success(expectedOptionalOperand != optOperand.has_value()); 727 } 728 729 //===----------------------------------------------------------------------===// 730 // Printing 731 732 static void printCustomOptionalOperand(OpAsmPrinter &printer, Operation *, 733 Value optOperand) { 734 if (optOperand) 735 printer << "(" << optOperand << ") "; 736 } 737 738 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 739 Value operand, Value optOperand, 740 OperandRange varOperands) { 741 printer << operand; 742 if (optOperand) 743 printer << ", " << optOperand; 744 printer << " -> (" << varOperands << ")"; 745 } 746 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 747 Type operandType, Type optOperandType, 748 TypeRange varOperandTypes) { 749 printer << " : " << operandType; 750 if (optOperandType) 751 printer << ", " << optOperandType; 752 printer << " -> (" << varOperandTypes << ")"; 753 } 754 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 755 Operation *op, Type operandType, 756 Type optOperandType, 757 TypeRange varOperandTypes) { 758 printer << " type_refs_capture "; 759 printCustomDirectiveResults(printer, op, operandType, optOperandType, 760 varOperandTypes); 761 } 762 static void printCustomDirectiveOperandsAndTypes( 763 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 764 OperandRange varOperands, Type operandType, Type optOperandType, 765 TypeRange varOperandTypes) { 766 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 767 printCustomDirectiveResults(printer, op, operandType, optOperandType, 768 varOperandTypes); 769 } 770 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 771 Region ®ion, 772 MutableArrayRef<Region> varRegions) { 773 printer.printRegion(region); 774 if (!varRegions.empty()) { 775 printer << ", "; 776 for (Region ®ion : varRegions) 777 printer.printRegion(region); 778 } 779 } 780 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 781 Block *successor, 782 SuccessorRange varSuccessors) { 783 printer << successor; 784 if (!varSuccessors.empty()) 785 printer << ", " << varSuccessors.front(); 786 } 787 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 788 Attribute attribute, 789 Attribute optAttribute) { 790 printer << attribute; 791 if (optAttribute) 792 printer << ", " << optAttribute; 793 } 794 795 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 796 DictionaryAttr attrs) { 797 printer.printOptionalAttrDict(attrs.getValue()); 798 } 799 800 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, 801 Operation *op, 802 Value optOperand) { 803 printer << (optOperand ? "1" : "0"); 804 } 805 806 //===----------------------------------------------------------------------===// 807 // Test IsolatedRegionOp - parse passthrough region arguments. 808 //===----------------------------------------------------------------------===// 809 810 ParseResult IsolatedRegionOp::parse(OpAsmParser &parser, 811 OperationState &result) { 812 // Parse the input operand. 813 OpAsmParser::Argument argInfo; 814 argInfo.type = parser.getBuilder().getIndexType(); 815 if (parser.parseOperand(argInfo.ssaName) || 816 parser.resolveOperand(argInfo.ssaName, argInfo.type, result.operands)) 817 return failure(); 818 819 // Parse the body region, and reuse the operand info as the argument info. 820 Region *body = result.addRegion(); 821 return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/true); 822 } 823 824 void IsolatedRegionOp::print(OpAsmPrinter &p) { 825 p << "test.isolated_region "; 826 p.printOperand(getOperand()); 827 p.shadowRegionArgs(getRegion(), getOperand()); 828 p << ' '; 829 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); 830 } 831 832 //===----------------------------------------------------------------------===// 833 // Test SSACFGRegionOp 834 //===----------------------------------------------------------------------===// 835 836 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 837 return RegionKind::SSACFG; 838 } 839 840 //===----------------------------------------------------------------------===// 841 // Test GraphRegionOp 842 //===----------------------------------------------------------------------===// 843 844 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 845 return RegionKind::Graph; 846 } 847 848 //===----------------------------------------------------------------------===// 849 // Test AffineScopeOp 850 //===----------------------------------------------------------------------===// 851 852 ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) { 853 // Parse the body region, and reuse the operand info as the argument info. 854 Region *body = result.addRegion(); 855 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 856 } 857 858 void AffineScopeOp::print(OpAsmPrinter &p) { 859 p << "test.affine_scope "; 860 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); 861 } 862 863 //===----------------------------------------------------------------------===// 864 // Test parser. 865 //===----------------------------------------------------------------------===// 866 867 ParseResult ParseIntegerLiteralOp::parse(OpAsmParser &parser, 868 OperationState &result) { 869 if (parser.parseOptionalColon()) 870 return success(); 871 uint64_t numResults; 872 if (parser.parseInteger(numResults)) 873 return failure(); 874 875 IndexType type = parser.getBuilder().getIndexType(); 876 for (unsigned i = 0; i < numResults; ++i) 877 result.addTypes(type); 878 return success(); 879 } 880 881 void ParseIntegerLiteralOp::print(OpAsmPrinter &p) { 882 if (unsigned numResults = getNumResults()) 883 p << " : " << numResults; 884 } 885 886 ParseResult ParseWrappedKeywordOp::parse(OpAsmParser &parser, 887 OperationState &result) { 888 StringRef keyword; 889 if (parser.parseKeyword(&keyword)) 890 return failure(); 891 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 892 return success(); 893 } 894 895 void ParseWrappedKeywordOp::print(OpAsmPrinter &p) { p << " " << getKeyword(); } 896 897 //===----------------------------------------------------------------------===// 898 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 899 900 ParseResult WrappingRegionOp::parse(OpAsmParser &parser, 901 OperationState &result) { 902 if (parser.parseKeyword("wraps")) 903 return failure(); 904 905 // Parse the wrapped op in a region 906 Region &body = *result.addRegion(); 907 body.push_back(new Block); 908 Block &block = body.back(); 909 Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin()); 910 if (!wrappedOp) 911 return failure(); 912 913 // Create a return terminator in the inner region, pass as operand to the 914 // terminator the returned values from the wrapped operation. 915 SmallVector<Value, 8> returnOperands(wrappedOp->getResults()); 916 OpBuilder builder(parser.getContext()); 917 builder.setInsertionPointToEnd(&block); 918 builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands); 919 920 // Get the results type for the wrapping op from the terminator operands. 921 Operation &returnOp = body.back().back(); 922 result.types.append(returnOp.operand_type_begin(), 923 returnOp.operand_type_end()); 924 925 // Use the location of the wrapped op for the "test.wrapping_region" op. 926 result.location = wrappedOp->getLoc(); 927 928 return success(); 929 } 930 931 void WrappingRegionOp::print(OpAsmPrinter &p) { 932 p << " wraps "; 933 p.printGenericOp(&getRegion().front().front()); 934 } 935 936 //===----------------------------------------------------------------------===// 937 // Test PrettyPrintedRegionOp - exercising the following parser APIs 938 // parseGenericOperationAfterOpName 939 // parseCustomOperationName 940 //===----------------------------------------------------------------------===// 941 942 ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser, 943 OperationState &result) { 944 945 SMLoc loc = parser.getCurrentLocation(); 946 Location currLocation = parser.getEncodedSourceLoc(loc); 947 948 // Parse the operands. 949 SmallVector<OpAsmParser::UnresolvedOperand, 2> operands; 950 if (parser.parseOperandList(operands)) 951 return failure(); 952 953 // Check if we are parsing the pretty-printed version 954 // test.pretty_printed_region start <inner-op> end : <functional-type> 955 // Else fallback to parsing the "non pretty-printed" version. 956 if (!succeeded(parser.parseOptionalKeyword("start"))) 957 return parser.parseGenericOperationAfterOpName( 958 result, llvm::makeArrayRef(operands)); 959 960 FailureOr<OperationName> parseOpNameInfo = parser.parseCustomOperationName(); 961 if (failed(parseOpNameInfo)) 962 return failure(); 963 964 StringAttr innerOpName = parseOpNameInfo->getIdentifier(); 965 966 FunctionType opFntype; 967 Optional<Location> explicitLoc; 968 if (parser.parseKeyword("end") || parser.parseColon() || 969 parser.parseType(opFntype) || 970 parser.parseOptionalLocationSpecifier(explicitLoc)) 971 return failure(); 972 973 // If location of the op is explicitly provided, then use it; Else use 974 // the parser's current location. 975 Location opLoc = explicitLoc.value_or(currLocation); 976 977 // Derive the SSA-values for op's operands. 978 if (parser.resolveOperands(operands, opFntype.getInputs(), loc, 979 result.operands)) 980 return failure(); 981 982 // Add a region for op. 983 Region ®ion = *result.addRegion(); 984 985 // Create a basic-block inside op's region. 986 Block &block = region.emplaceBlock(); 987 988 // Create and insert an "inner-op" operation in the block. 989 // Just for testing purposes, we can assume that inner op is a binary op with 990 // result and operand types all same as the test-op's first operand. 991 Type innerOpType = opFntype.getInput(0); 992 Value lhs = block.addArgument(innerOpType, opLoc); 993 Value rhs = block.addArgument(innerOpType, opLoc); 994 995 OpBuilder builder(parser.getBuilder().getContext()); 996 builder.setInsertionPointToStart(&block); 997 998 Operation *innerOp = 999 builder.create(opLoc, innerOpName, /*operands=*/{lhs, rhs}, innerOpType); 1000 1001 // Insert a return statement in the block returning the inner-op's result. 1002 builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults()); 1003 1004 // Populate the op operation-state with result-type and location. 1005 result.addTypes(opFntype.getResults()); 1006 result.location = innerOp->getLoc(); 1007 1008 return success(); 1009 } 1010 1011 void PrettyPrintedRegionOp::print(OpAsmPrinter &p) { 1012 p << ' '; 1013 p.printOperands(getOperands()); 1014 1015 Operation &innerOp = getRegion().front().front(); 1016 // Assuming that region has a single non-terminator inner-op, if the inner-op 1017 // meets some criteria (which in this case is a simple one based on the name 1018 // of inner-op), then we can print the entire region in a succinct way. 1019 // Here we assume that the prototype of "special.op" can be trivially derived 1020 // while parsing it back. 1021 if (innerOp.getName().getStringRef().equals("special.op")) { 1022 p << " start special.op end"; 1023 } else { 1024 p << " ("; 1025 p.printRegion(getRegion()); 1026 p << ")"; 1027 } 1028 1029 p << " : "; 1030 p.printFunctionalType(*this); 1031 } 1032 1033 //===----------------------------------------------------------------------===// 1034 // Test PolyForOp - parse list of region arguments. 1035 //===----------------------------------------------------------------------===// 1036 1037 ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) { 1038 SmallVector<OpAsmParser::Argument, 4> ivsInfo; 1039 // Parse list of region arguments without a delimiter. 1040 if (parser.parseArgumentList(ivsInfo, OpAsmParser::Delimiter::None)) 1041 return failure(); 1042 1043 // Parse the body region. 1044 Region *body = result.addRegion(); 1045 for (auto &iv : ivsInfo) 1046 iv.type = parser.getBuilder().getIndexType(); 1047 return parser.parseRegion(*body, ivsInfo); 1048 } 1049 1050 void PolyForOp::print(OpAsmPrinter &p) { p.printGenericOp(*this); } 1051 1052 void PolyForOp::getAsmBlockArgumentNames(Region ®ion, 1053 OpAsmSetValueNameFn setNameFn) { 1054 auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names"); 1055 if (!arrayAttr) 1056 return; 1057 auto args = getRegion().front().getArguments(); 1058 auto e = std::min(arrayAttr.size(), args.size()); 1059 for (unsigned i = 0; i < e; ++i) { 1060 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 1061 setNameFn(args[i], strAttr.getValue()); 1062 } 1063 } 1064 1065 //===----------------------------------------------------------------------===// 1066 // Test removing op with inner ops. 1067 //===----------------------------------------------------------------------===// 1068 1069 namespace { 1070 struct TestRemoveOpWithInnerOps 1071 : public OpRewritePattern<TestOpWithRegionPattern> { 1072 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 1073 1074 void initialize() { setDebugName("TestRemoveOpWithInnerOps"); } 1075 1076 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 1077 PatternRewriter &rewriter) const override { 1078 rewriter.eraseOp(op); 1079 return success(); 1080 } 1081 }; 1082 } // namespace 1083 1084 void TestOpWithRegionPattern::getCanonicalizationPatterns( 1085 RewritePatternSet &results, MLIRContext *context) { 1086 results.add<TestRemoveOpWithInnerOps>(context); 1087 } 1088 1089 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 1090 return getOperand(); 1091 } 1092 1093 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) { 1094 return getValue(); 1095 } 1096 1097 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 1098 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 1099 for (Value input : this->getOperands()) { 1100 results.push_back(input); 1101 } 1102 return success(); 1103 } 1104 1105 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 1106 assert(operands.size() == 1); 1107 if (operands.front()) { 1108 (*this)->setAttr("attr", operands.front()); 1109 return getResult(); 1110 } 1111 return {}; 1112 } 1113 1114 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) { 1115 return getOperand(); 1116 } 1117 1118 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 1119 MLIRContext *, Optional<Location> location, ValueRange operands, 1120 DictionaryAttr attributes, RegionRange regions, 1121 SmallVectorImpl<Type> &inferredReturnTypes) { 1122 if (operands[0].getType() != operands[1].getType()) { 1123 return emitOptionalError(location, "operand type mismatch ", 1124 operands[0].getType(), " vs ", 1125 operands[1].getType()); 1126 } 1127 inferredReturnTypes.assign({operands[0].getType()}); 1128 return success(); 1129 } 1130 1131 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 1132 MLIRContext *context, Optional<Location> location, ValueShapeRange operands, 1133 DictionaryAttr attributes, RegionRange regions, 1134 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 1135 // Create return type consisting of the last element of the first operand. 1136 auto operandType = operands.front().getType(); 1137 auto sval = operandType.dyn_cast<ShapedType>(); 1138 if (!sval) { 1139 return emitOptionalError(location, "only shaped type operands allowed"); 1140 } 1141 int64_t dim = 1142 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 1143 auto type = IntegerType::get(context, 17); 1144 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 1145 return success(); 1146 } 1147 1148 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 1149 OpBuilder &builder, ValueRange operands, 1150 llvm::SmallVectorImpl<Value> &shapes) { 1151 shapes = SmallVector<Value, 1>{ 1152 builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)}; 1153 return success(); 1154 } 1155 1156 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( 1157 OpBuilder &builder, ValueRange operands, 1158 llvm::SmallVectorImpl<Value> &shapes) { 1159 Location loc = getLoc(); 1160 shapes.reserve(operands.size()); 1161 for (Value operand : llvm::reverse(operands)) { 1162 auto rank = operand.getType().cast<RankedTensorType>().getRank(); 1163 auto currShape = llvm::to_vector<4>( 1164 llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value { 1165 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 1166 })); 1167 shapes.push_back(builder.create<tensor::FromElementsOp>( 1168 getLoc(), RankedTensorType::get({rank}, builder.getIndexType()), 1169 currShape)); 1170 } 1171 return success(); 1172 } 1173 1174 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( 1175 OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { 1176 Location loc = getLoc(); 1177 shapes.reserve(getNumOperands()); 1178 for (Value operand : llvm::reverse(getOperands())) { 1179 auto currShape = llvm::to_vector<4>(llvm::map_range( 1180 llvm::seq<int64_t>( 1181 0, operand.getType().cast<RankedTensorType>().getRank()), 1182 [&](int64_t dim) -> Value { 1183 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 1184 })); 1185 shapes.emplace_back(std::move(currShape)); 1186 } 1187 return success(); 1188 } 1189 1190 //===----------------------------------------------------------------------===// 1191 // Test SideEffect interfaces 1192 //===----------------------------------------------------------------------===// 1193 1194 namespace { 1195 /// A test resource for side effects. 1196 struct TestResource : public SideEffects::Resource::Base<TestResource> { 1197 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource) 1198 1199 StringRef getName() final { return "<Test>"; } 1200 }; 1201 } // namespace 1202 1203 static void testSideEffectOpGetEffect( 1204 Operation *op, 1205 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> 1206 &effects) { 1207 auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter"); 1208 if (!effectsAttr) 1209 return; 1210 1211 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); 1212 } 1213 1214 void SideEffectOp::getEffects( 1215 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1216 // Check for an effects attribute on the op instance. 1217 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects"); 1218 if (!effectsAttr) 1219 return; 1220 1221 // If there is one, it is an array of dictionary attributes that hold 1222 // information on the effects of this operation. 1223 for (Attribute element : effectsAttr) { 1224 DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 1225 1226 // Get the specific memory effect. 1227 MemoryEffects::Effect *effect = 1228 StringSwitch<MemoryEffects::Effect *>( 1229 effectElement.get("effect").cast<StringAttr>().getValue()) 1230 .Case("allocate", MemoryEffects::Allocate::get()) 1231 .Case("free", MemoryEffects::Free::get()) 1232 .Case("read", MemoryEffects::Read::get()) 1233 .Case("write", MemoryEffects::Write::get()); 1234 1235 // Check for a non-default resource to use. 1236 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 1237 if (effectElement.get("test_resource")) 1238 resource = TestResource::get(); 1239 1240 // Check for a result to affect. 1241 if (effectElement.get("on_result")) 1242 effects.emplace_back(effect, getResult(), resource); 1243 else if (Attribute ref = effectElement.get("on_reference")) 1244 effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource); 1245 else 1246 effects.emplace_back(effect, resource); 1247 } 1248 } 1249 1250 void SideEffectOp::getEffects( 1251 SmallVectorImpl<TestEffects::EffectInstance> &effects) { 1252 testSideEffectOpGetEffect(getOperation(), effects); 1253 } 1254 1255 //===----------------------------------------------------------------------===// 1256 // StringAttrPrettyNameOp 1257 //===----------------------------------------------------------------------===// 1258 1259 // This op has fancy handling of its SSA result name. 1260 ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser, 1261 OperationState &result) { 1262 // Add the result types. 1263 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 1264 result.addTypes(parser.getBuilder().getIntegerType(32)); 1265 1266 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 1267 return failure(); 1268 1269 // If the attribute dictionary contains no 'names' attribute, infer it from 1270 // the SSA name (if specified). 1271 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 1272 return attr.getName() == "names"; 1273 }); 1274 1275 // If there was no name specified, check to see if there was a useful name 1276 // specified in the asm file. 1277 if (hadNames || parser.getNumResults() == 0) 1278 return success(); 1279 1280 SmallVector<StringRef, 4> names; 1281 auto *context = result.getContext(); 1282 1283 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 1284 auto resultName = parser.getResultName(i); 1285 StringRef nameStr; 1286 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 1287 nameStr = resultName.first; 1288 1289 names.push_back(nameStr); 1290 } 1291 1292 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 1293 result.attributes.push_back({StringAttr::get(context, "names"), namesAttr}); 1294 return success(); 1295 } 1296 1297 void StringAttrPrettyNameOp::print(OpAsmPrinter &p) { 1298 // Note that we only need to print the "name" attribute if the asmprinter 1299 // result name disagrees with it. This can happen in strange cases, e.g. 1300 // when there are conflicts. 1301 bool namesDisagree = getNames().size() != getNumResults(); 1302 1303 SmallString<32> resultNameStr; 1304 for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) { 1305 resultNameStr.clear(); 1306 llvm::raw_svector_ostream tmpStream(resultNameStr); 1307 p.printOperand(getResult(i), tmpStream); 1308 1309 auto expectedName = getNames()[i].dyn_cast<StringAttr>(); 1310 if (!expectedName || 1311 tmpStream.str().drop_front() != expectedName.getValue()) { 1312 namesDisagree = true; 1313 } 1314 } 1315 1316 if (namesDisagree) 1317 p.printOptionalAttrDictWithKeyword((*this)->getAttrs()); 1318 else 1319 p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"}); 1320 } 1321 1322 // We set the SSA name in the asm syntax to the contents of the name 1323 // attribute. 1324 void StringAttrPrettyNameOp::getAsmResultNames( 1325 function_ref<void(Value, StringRef)> setNameFn) { 1326 1327 auto value = getNames(); 1328 for (size_t i = 0, e = value.size(); i != e; ++i) 1329 if (auto str = value[i].dyn_cast<StringAttr>()) 1330 if (!str.getValue().empty()) 1331 setNameFn(getResult(i), str.getValue()); 1332 } 1333 1334 void CustomResultsNameOp::getAsmResultNames( 1335 function_ref<void(Value, StringRef)> setNameFn) { 1336 ArrayAttr value = getNames(); 1337 for (size_t i = 0, e = value.size(); i != e; ++i) 1338 if (auto str = value[i].dyn_cast<StringAttr>()) 1339 if (!str.getValue().empty()) 1340 setNameFn(getResult(i), str.getValue()); 1341 } 1342 1343 //===----------------------------------------------------------------------===// 1344 // ResultTypeWithTraitOp 1345 //===----------------------------------------------------------------------===// 1346 1347 LogicalResult ResultTypeWithTraitOp::verify() { 1348 if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>()) 1349 return success(); 1350 return emitError("result type should have trait 'TestTypeTrait'"); 1351 } 1352 1353 //===----------------------------------------------------------------------===// 1354 // AttrWithTraitOp 1355 //===----------------------------------------------------------------------===// 1356 1357 LogicalResult AttrWithTraitOp::verify() { 1358 if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>()) 1359 return success(); 1360 return emitError("'attr' attribute should have trait 'TestAttrTrait'"); 1361 } 1362 1363 //===----------------------------------------------------------------------===// 1364 // RegionIfOp 1365 //===----------------------------------------------------------------------===// 1366 1367 void RegionIfOp::print(OpAsmPrinter &p) { 1368 p << " "; 1369 p.printOperands(getOperands()); 1370 p << ": " << getOperandTypes(); 1371 p.printArrowTypeList(getResultTypes()); 1372 p << " then "; 1373 p.printRegion(getThenRegion(), 1374 /*printEntryBlockArgs=*/true, 1375 /*printBlockTerminators=*/true); 1376 p << " else "; 1377 p.printRegion(getElseRegion(), 1378 /*printEntryBlockArgs=*/true, 1379 /*printBlockTerminators=*/true); 1380 p << " join "; 1381 p.printRegion(getJoinRegion(), 1382 /*printEntryBlockArgs=*/true, 1383 /*printBlockTerminators=*/true); 1384 } 1385 1386 ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) { 1387 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos; 1388 SmallVector<Type, 2> operandTypes; 1389 1390 result.regions.reserve(3); 1391 Region *thenRegion = result.addRegion(); 1392 Region *elseRegion = result.addRegion(); 1393 Region *joinRegion = result.addRegion(); 1394 1395 // Parse operand, type and arrow type lists. 1396 if (parser.parseOperandList(operandInfos) || 1397 parser.parseColonTypeList(operandTypes) || 1398 parser.parseArrowTypeList(result.types)) 1399 return failure(); 1400 1401 // Parse all attached regions. 1402 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 1403 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 1404 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 1405 return failure(); 1406 1407 return parser.resolveOperands(operandInfos, operandTypes, 1408 parser.getCurrentLocation(), result.operands); 1409 } 1410 1411 OperandRange RegionIfOp::getSuccessorEntryOperands(Optional<unsigned> index) { 1412 assert(index && *index < 2 && "invalid region index"); 1413 return getOperands(); 1414 } 1415 1416 void RegionIfOp::getSuccessorRegions( 1417 Optional<unsigned> index, ArrayRef<Attribute> operands, 1418 SmallVectorImpl<RegionSuccessor> ®ions) { 1419 // We always branch to the join region. 1420 if (index.has_value()) { 1421 if (index.value() < 2) 1422 regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs())); 1423 else 1424 regions.push_back(RegionSuccessor(getResults())); 1425 return; 1426 } 1427 1428 // The then and else regions are the entry regions of this op. 1429 regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs())); 1430 regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs())); 1431 } 1432 1433 void RegionIfOp::getRegionInvocationBounds( 1434 ArrayRef<Attribute> operands, 1435 SmallVectorImpl<InvocationBounds> &invocationBounds) { 1436 // Each region is invoked at most once. 1437 invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1}); 1438 } 1439 1440 //===----------------------------------------------------------------------===// 1441 // AnyCondOp 1442 //===----------------------------------------------------------------------===// 1443 1444 void AnyCondOp::getSuccessorRegions(Optional<unsigned> index, 1445 ArrayRef<Attribute> operands, 1446 SmallVectorImpl<RegionSuccessor> ®ions) { 1447 // The parent op branches into the only region, and the region branches back 1448 // to the parent op. 1449 if (!index) 1450 regions.emplace_back(&getRegion()); 1451 else 1452 regions.emplace_back(getResults()); 1453 } 1454 1455 void AnyCondOp::getRegionInvocationBounds( 1456 ArrayRef<Attribute> operands, 1457 SmallVectorImpl<InvocationBounds> &invocationBounds) { 1458 invocationBounds.emplace_back(1, 1); 1459 } 1460 1461 //===----------------------------------------------------------------------===// 1462 // SingleNoTerminatorCustomAsmOp 1463 //===----------------------------------------------------------------------===// 1464 1465 ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser, 1466 OperationState &state) { 1467 Region *body = state.addRegion(); 1468 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) 1469 return failure(); 1470 return success(); 1471 } 1472 1473 void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) { 1474 printer.printRegion( 1475 getRegion(), /*printEntryBlockArgs=*/false, 1476 // This op has a single block without terminators. But explicitly mark 1477 // as not printing block terminators for testing. 1478 /*printBlockTerminators=*/false); 1479 } 1480 1481 //===----------------------------------------------------------------------===// 1482 // TestVerifiersOp 1483 //===----------------------------------------------------------------------===// 1484 1485 LogicalResult TestVerifiersOp::verify() { 1486 if (!getRegion().hasOneBlock()) 1487 return emitOpError("`hasOneBlock` trait hasn't been verified"); 1488 1489 Operation *definingOp = getInput().getDefiningOp(); 1490 if (definingOp && failed(mlir::verify(definingOp))) 1491 return emitOpError("operand hasn't been verified"); 1492 1493 emitRemark("success run of verifier"); 1494 1495 return success(); 1496 } 1497 1498 LogicalResult TestVerifiersOp::verifyRegions() { 1499 if (!getRegion().hasOneBlock()) 1500 return emitOpError("`hasOneBlock` trait hasn't been verified"); 1501 1502 for (Block &block : getRegion()) 1503 for (Operation &op : block) 1504 if (failed(mlir::verify(&op))) 1505 return emitOpError("nested op hasn't been verified"); 1506 1507 emitRemark("success run of region verifier"); 1508 1509 return success(); 1510 } 1511 1512 //===----------------------------------------------------------------------===// 1513 // Test InferIntRangeInterface 1514 //===----------------------------------------------------------------------===// 1515 1516 void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 1517 SetIntRangeFn setResultRanges) { 1518 setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()}); 1519 } 1520 1521 ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser, 1522 OperationState &result) { 1523 if (parser.parseOptionalAttrDict(result.attributes)) 1524 return failure(); 1525 1526 // Parse the input argument 1527 OpAsmParser::Argument argInfo; 1528 argInfo.type = parser.getBuilder().getIndexType(); 1529 if (failed(parser.parseArgument(argInfo))) 1530 return failure(); 1531 1532 // Parse the body region, and reuse the operand info as the argument info. 1533 Region *body = result.addRegion(); 1534 return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false); 1535 } 1536 1537 void TestWithBoundsRegionOp::print(OpAsmPrinter &p) { 1538 p.printOptionalAttrDict((*this)->getAttrs()); 1539 p << ' '; 1540 p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{}, 1541 /*omitType=*/true); 1542 p << ' '; 1543 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); 1544 } 1545 1546 void TestWithBoundsRegionOp::inferResultRanges( 1547 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) { 1548 Value arg = getRegion().getArgument(0); 1549 setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()}); 1550 } 1551 1552 void TestIncrementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 1553 SetIntRangeFn setResultRanges) { 1554 const ConstantIntRanges &range = argRanges[0]; 1555 APInt one(range.umin().getBitWidth(), 1); 1556 setResultRanges(getResult(), 1557 {range.umin().uadd_sat(one), range.umax().uadd_sat(one), 1558 range.smin().sadd_sat(one), range.smax().sadd_sat(one)}); 1559 } 1560 1561 void TestReflectBoundsOp::inferResultRanges( 1562 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) { 1563 const ConstantIntRanges &range = argRanges[0]; 1564 MLIRContext *ctx = getContext(); 1565 Builder b(ctx); 1566 setUminAttr(b.getIndexAttr(range.umin().getZExtValue())); 1567 setUmaxAttr(b.getIndexAttr(range.umax().getZExtValue())); 1568 setSminAttr(b.getIndexAttr(range.smin().getSExtValue())); 1569 setSmaxAttr(b.getIndexAttr(range.smax().getSExtValue())); 1570 setResultRanges(getResult(), range); 1571 } 1572 1573 #include "TestOpEnums.cpp.inc" 1574 #include "TestOpInterfaces.cpp.inc" 1575 #include "TestTypeInterfaces.cpp.inc" 1576 1577 #define GET_OP_CLASSES 1578 #include "TestOps.cpp.inc" 1579