1 //===- PassRegistry.cpp - Pass Registration Utilities ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Pass/PassRegistry.h" 10 #include "mlir/Pass/Pass.h" 11 #include "mlir/Pass/PassManager.h" 12 #include "llvm/ADT/DenseMap.h" 13 #include "llvm/Support/ManagedStatic.h" 14 #include "llvm/Support/MemoryBuffer.h" 15 #include "llvm/Support/SourceMgr.h" 16 17 using namespace mlir; 18 using namespace detail; 19 20 /// Static mapping of all of the registered passes. 21 static llvm::ManagedStatic<DenseMap<TypeID, PassInfo>> passRegistry; 22 23 /// Static mapping of all of the registered pass pipelines. 24 static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>> 25 passPipelineRegistry; 26 27 /// Utility to create a default registry function from a pass instance. 28 static PassRegistryFunction 29 buildDefaultRegistryFn(const PassAllocatorFunction &allocator) { 30 return [=](OpPassManager &pm, StringRef options, 31 function_ref<LogicalResult(const Twine &)> errorHandler) { 32 std::unique_ptr<Pass> pass = allocator(); 33 LogicalResult result = pass->initializeOptions(options); 34 if ((pm.getNesting() == OpPassManager::Nesting::Explicit) && 35 pass->getOpName() && *pass->getOpName() != pm.getOpName()) 36 return errorHandler(llvm::Twine("Can't add pass '") + pass->getName() + 37 "' restricted to '" + *pass->getOpName() + 38 "' on a PassManager intended to run on '" + 39 pm.getOpName() + "', did you intend to nest?"); 40 pm.addPass(std::move(pass)); 41 return result; 42 }; 43 } 44 45 /// Utility to print the help string for a specific option. 46 static void printOptionHelp(StringRef arg, StringRef desc, size_t indent, 47 size_t descIndent, bool isTopLevel) { 48 size_t numSpaces = descIndent - indent - 4; 49 llvm::outs().indent(indent) 50 << "--" << llvm::left_justify(arg, numSpaces) << "- " << desc << '\n'; 51 } 52 53 //===----------------------------------------------------------------------===// 54 // PassRegistry 55 //===----------------------------------------------------------------------===// 56 57 /// Print the help information for this pass. This includes the argument, 58 /// description, and any pass options. `descIndent` is the indent that the 59 /// descriptions should be aligned. 60 void PassRegistryEntry::printHelpStr(size_t indent, size_t descIndent) const { 61 printOptionHelp(getPassArgument(), getPassDescription(), indent, descIndent, 62 /*isTopLevel=*/true); 63 // If this entry has options, print the help for those as well. 64 optHandler([=](const PassOptions &options) { 65 options.printHelp(indent, descIndent); 66 }); 67 } 68 69 /// Return the maximum width required when printing the options of this 70 /// entry. 71 size_t PassRegistryEntry::getOptionWidth() const { 72 size_t maxLen = 0; 73 optHandler([&](const PassOptions &options) mutable { 74 maxLen = options.getOptionWidth() + 2; 75 }); 76 return maxLen; 77 } 78 79 //===----------------------------------------------------------------------===// 80 // PassPipelineInfo 81 //===----------------------------------------------------------------------===// 82 83 void mlir::registerPassPipeline( 84 StringRef arg, StringRef description, const PassRegistryFunction &function, 85 std::function<void(function_ref<void(const PassOptions &)>)> optHandler) { 86 PassPipelineInfo pipelineInfo(arg, description, function, optHandler); 87 bool inserted = passPipelineRegistry->try_emplace(arg, pipelineInfo).second; 88 assert(inserted && "Pass pipeline registered multiple times"); 89 (void)inserted; 90 } 91 92 //===----------------------------------------------------------------------===// 93 // PassInfo 94 //===----------------------------------------------------------------------===// 95 96 PassInfo::PassInfo(StringRef arg, StringRef description, TypeID passID, 97 const PassAllocatorFunction &allocator) 98 : PassRegistryEntry( 99 arg, description, buildDefaultRegistryFn(allocator), 100 // Use a temporary pass to provide an options instance. 101 [=](function_ref<void(const PassOptions &)> optHandler) { 102 optHandler(allocator()->passOptions); 103 }) {} 104 105 void mlir::registerPass(StringRef arg, StringRef description, 106 const PassAllocatorFunction &function) { 107 // TODO: We should use the 'arg' as the lookup key instead of the pass id. 108 TypeID passID = function()->getTypeID(); 109 PassInfo passInfo(arg, description, passID, function); 110 passRegistry->try_emplace(passID, passInfo); 111 } 112 113 /// Returns the pass info for the specified pass class or null if unknown. 114 const PassInfo *mlir::Pass::lookupPassInfo(TypeID passID) { 115 auto it = passRegistry->find(passID); 116 if (it == passRegistry->end()) 117 return nullptr; 118 return &it->getSecond(); 119 } 120 121 //===----------------------------------------------------------------------===// 122 // PassOptions 123 //===----------------------------------------------------------------------===// 124 125 /// Out of line virtual function to provide home for the class. 126 void detail::PassOptions::OptionBase::anchor() {} 127 128 /// Copy the option values from 'other'. 129 void detail::PassOptions::copyOptionValuesFrom(const PassOptions &other) { 130 assert(options.size() == other.options.size()); 131 if (options.empty()) 132 return; 133 for (auto optionsIt : llvm::zip(options, other.options)) 134 std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt)); 135 } 136 137 LogicalResult detail::PassOptions::parseFromString(StringRef options) { 138 // TODO: Handle escaping strings. 139 // NOTE: `options` is modified in place to always refer to the unprocessed 140 // part of the string. 141 while (!options.empty()) { 142 size_t spacePos = options.find(' '); 143 StringRef arg = options; 144 if (spacePos != StringRef::npos) { 145 arg = options.substr(0, spacePos); 146 options = options.substr(spacePos + 1); 147 } else { 148 options = StringRef(); 149 } 150 if (arg.empty()) 151 continue; 152 153 // At this point, arg refers to everything that is non-space in options 154 // upto the next space, and options refers to the rest of the string after 155 // that point. 156 157 // Split the individual option on '=' to form key and value. If there is no 158 // '=', then value is `StringRef()`. 159 size_t equalPos = arg.find('='); 160 StringRef key = arg; 161 StringRef value; 162 if (equalPos != StringRef::npos) { 163 key = arg.substr(0, equalPos); 164 value = arg.substr(equalPos + 1); 165 } 166 auto it = OptionsMap.find(key); 167 if (it == OptionsMap.end()) { 168 llvm::errs() << "<Pass-Options-Parser>: no such option " << key << "\n"; 169 return failure(); 170 } 171 if (llvm::cl::ProvidePositionalOption(it->second, value, 0)) 172 return failure(); 173 } 174 175 return success(); 176 } 177 178 /// Print the options held by this struct in a form that can be parsed via 179 /// 'parseFromString'. 180 void detail::PassOptions::print(raw_ostream &os) { 181 // If there are no options, there is nothing left to do. 182 if (OptionsMap.empty()) 183 return; 184 185 // Sort the options to make the ordering deterministic. 186 SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end()); 187 auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) { 188 return (*lhs)->getArgStr().compare((*rhs)->getArgStr()); 189 }; 190 llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs); 191 192 // Interleave the options with ' '. 193 os << '{'; 194 llvm::interleave( 195 orderedOps, os, [&](OptionBase *option) { option->print(os); }, " "); 196 os << '}'; 197 } 198 199 /// Print the help string for the options held by this struct. `descIndent` is 200 /// the indent within the stream that the descriptions should be aligned. 201 void detail::PassOptions::printHelp(size_t indent, size_t descIndent) const { 202 // Sort the options to make the ordering deterministic. 203 SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end()); 204 auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) { 205 return (*lhs)->getArgStr().compare((*rhs)->getArgStr()); 206 }; 207 llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs); 208 for (OptionBase *option : orderedOps) { 209 // TODO: printOptionInfo assumes a specific indent and will 210 // print options with values with incorrect indentation. We should add 211 // support to llvm::cl::Option for passing in a base indent to use when 212 // printing. 213 llvm::outs().indent(indent); 214 option->getOption()->printOptionInfo(descIndent - indent); 215 } 216 } 217 218 /// Return the maximum width required when printing the help string. 219 size_t detail::PassOptions::getOptionWidth() const { 220 size_t max = 0; 221 for (auto *option : options) 222 max = std::max(max, option->getOption()->getOptionWidth()); 223 return max; 224 } 225 226 //===----------------------------------------------------------------------===// 227 // TextualPassPipeline Parser 228 //===----------------------------------------------------------------------===// 229 230 namespace { 231 /// This class represents a textual description of a pass pipeline. 232 class TextualPipeline { 233 public: 234 /// Try to initialize this pipeline with the given pipeline text. 235 /// `errorStream` is the output stream to emit errors to. 236 LogicalResult initialize(StringRef text, raw_ostream &errorStream); 237 238 /// Add the internal pipeline elements to the provided pass manager. 239 LogicalResult 240 addToPipeline(OpPassManager &pm, 241 function_ref<LogicalResult(const Twine &)> errorHandler) const; 242 243 private: 244 /// A functor used to emit errors found during pipeline handling. The first 245 /// parameter corresponds to the raw location within the pipeline string. This 246 /// should always return failure. 247 using ErrorHandlerT = function_ref<LogicalResult(const char *, Twine)>; 248 249 /// A struct to capture parsed pass pipeline names. 250 /// 251 /// A pipeline is defined as a series of names, each of which may in itself 252 /// recursively contain a nested pipeline. A name is either the name of a pass 253 /// (e.g. "cse") or the name of an operation type (e.g. "func"). If the name 254 /// is the name of a pass, the InnerPipeline is empty, since passes cannot 255 /// contain inner pipelines. 256 struct PipelineElement { 257 PipelineElement(StringRef name) : name(name), registryEntry(nullptr) {} 258 259 StringRef name; 260 StringRef options; 261 const PassRegistryEntry *registryEntry; 262 std::vector<PipelineElement> innerPipeline; 263 }; 264 265 /// Parse the given pipeline text into the internal pipeline vector. This 266 /// function only parses the structure of the pipeline, and does not resolve 267 /// its elements. 268 LogicalResult parsePipelineText(StringRef text, ErrorHandlerT errorHandler); 269 270 /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to 271 /// the corresponding registry entry. 272 LogicalResult 273 resolvePipelineElements(MutableArrayRef<PipelineElement> elements, 274 ErrorHandlerT errorHandler); 275 276 /// Resolve a single element of the pipeline. 277 LogicalResult resolvePipelineElement(PipelineElement &element, 278 ErrorHandlerT errorHandler); 279 280 /// Add the given pipeline elements to the provided pass manager. 281 LogicalResult 282 addToPipeline(ArrayRef<PipelineElement> elements, OpPassManager &pm, 283 function_ref<LogicalResult(const Twine &)> errorHandler) const; 284 285 std::vector<PipelineElement> pipeline; 286 }; 287 288 } // end anonymous namespace 289 290 /// Try to initialize this pipeline with the given pipeline text. An option is 291 /// given to enable accurate error reporting. 292 LogicalResult TextualPipeline::initialize(StringRef text, 293 raw_ostream &errorStream) { 294 if (text.empty()) 295 return success(); 296 297 // Build a source manager to use for error reporting. 298 llvm::SourceMgr pipelineMgr; 299 pipelineMgr.AddNewSourceBuffer( 300 llvm::MemoryBuffer::getMemBuffer(text, "MLIR Textual PassPipeline Parser", 301 /*RequiresNullTerminator=*/false), 302 llvm::SMLoc()); 303 auto errorHandler = [&](const char *rawLoc, Twine msg) { 304 pipelineMgr.PrintMessage(errorStream, llvm::SMLoc::getFromPointer(rawLoc), 305 llvm::SourceMgr::DK_Error, msg); 306 return failure(); 307 }; 308 309 // Parse the provided pipeline string. 310 if (failed(parsePipelineText(text, errorHandler))) 311 return failure(); 312 return resolvePipelineElements(pipeline, errorHandler); 313 } 314 315 /// Add the internal pipeline elements to the provided pass manager. 316 LogicalResult TextualPipeline::addToPipeline( 317 OpPassManager &pm, 318 function_ref<LogicalResult(const Twine &)> errorHandler) const { 319 return addToPipeline(pipeline, pm, errorHandler); 320 } 321 322 /// Parse the given pipeline text into the internal pipeline vector. This 323 /// function only parses the structure of the pipeline, and does not resolve 324 /// its elements. 325 LogicalResult TextualPipeline::parsePipelineText(StringRef text, 326 ErrorHandlerT errorHandler) { 327 SmallVector<std::vector<PipelineElement> *, 4> pipelineStack = {&pipeline}; 328 for (;;) { 329 std::vector<PipelineElement> &pipeline = *pipelineStack.back(); 330 size_t pos = text.find_first_of(",(){"); 331 pipeline.emplace_back(/*name=*/text.substr(0, pos).trim()); 332 333 // If we have a single terminating name, we're done. 334 if (pos == StringRef::npos) 335 break; 336 337 text = text.substr(pos); 338 char sep = text[0]; 339 340 // Handle pulling ... from 'pass{...}' out as PipelineElement.options. 341 if (sep == '{') { 342 text = text.substr(1); 343 344 // Skip over everything until the closing '}' and store as options. 345 size_t close = StringRef::npos; 346 for (unsigned i = 0, e = text.size(), braceCount = 1; i < e; ++i) { 347 if (text[i] == '{') { 348 ++braceCount; 349 continue; 350 } 351 if (text[i] == '}' && --braceCount == 0) { 352 close = i; 353 break; 354 } 355 } 356 357 // Check to see if a closing options brace was found. 358 if (close == StringRef::npos) { 359 return errorHandler( 360 /*rawLoc=*/text.data() - 1, 361 "missing closing '}' while processing pass options"); 362 } 363 pipeline.back().options = text.substr(0, close); 364 text = text.substr(close + 1); 365 366 // Skip checking for '(' because nested pipelines cannot have options. 367 } else if (sep == '(') { 368 text = text.substr(1); 369 370 // Push the inner pipeline onto the stack to continue processing. 371 pipelineStack.push_back(&pipeline.back().innerPipeline); 372 continue; 373 } 374 375 // When handling the close parenthesis, we greedily consume them to avoid 376 // empty strings in the pipeline. 377 while (text.consume_front(")")) { 378 // If we try to pop the outer pipeline we have unbalanced parentheses. 379 if (pipelineStack.size() == 1) 380 return errorHandler(/*rawLoc=*/text.data() - 1, 381 "encountered extra closing ')' creating unbalanced " 382 "parentheses while parsing pipeline"); 383 384 pipelineStack.pop_back(); 385 } 386 387 // Check if we've finished parsing. 388 if (text.empty()) 389 break; 390 391 // Otherwise, the end of an inner pipeline always has to be followed by 392 // a comma, and then we can continue. 393 if (!text.consume_front(",")) 394 return errorHandler(text.data(), "expected ',' after parsing pipeline"); 395 } 396 397 // Check for unbalanced parentheses. 398 if (pipelineStack.size() > 1) 399 return errorHandler( 400 text.data(), 401 "encountered unbalanced parentheses while parsing pipeline"); 402 403 assert(pipelineStack.back() == &pipeline && 404 "wrong pipeline at the bottom of the stack"); 405 return success(); 406 } 407 408 /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to 409 /// the corresponding registry entry. 410 LogicalResult TextualPipeline::resolvePipelineElements( 411 MutableArrayRef<PipelineElement> elements, ErrorHandlerT errorHandler) { 412 for (auto &elt : elements) 413 if (failed(resolvePipelineElement(elt, errorHandler))) 414 return failure(); 415 return success(); 416 } 417 418 /// Resolve a single element of the pipeline. 419 LogicalResult 420 TextualPipeline::resolvePipelineElement(PipelineElement &element, 421 ErrorHandlerT errorHandler) { 422 // If the inner pipeline of this element is not empty, this is an operation 423 // pipeline. 424 if (!element.innerPipeline.empty()) 425 return resolvePipelineElements(element.innerPipeline, errorHandler); 426 // Otherwise, this must be a pass or pass pipeline. 427 // Check to see if a pipeline was registered with this name. 428 auto pipelineRegistryIt = passPipelineRegistry->find(element.name); 429 if (pipelineRegistryIt != passPipelineRegistry->end()) { 430 element.registryEntry = &pipelineRegistryIt->second; 431 return success(); 432 } 433 434 // If not, then this must be a specific pass name. 435 for (auto &passIt : *passRegistry) { 436 if (passIt.second.getPassArgument() == element.name) { 437 element.registryEntry = &passIt.second; 438 return success(); 439 } 440 } 441 442 // Emit an error for the unknown pass. 443 auto *rawLoc = element.name.data(); 444 return errorHandler(rawLoc, "'" + element.name + 445 "' does not refer to a " 446 "registered pass or pass pipeline"); 447 } 448 449 /// Add the given pipeline elements to the provided pass manager. 450 LogicalResult TextualPipeline::addToPipeline( 451 ArrayRef<PipelineElement> elements, OpPassManager &pm, 452 function_ref<LogicalResult(const Twine &)> errorHandler) const { 453 for (auto &elt : elements) { 454 if (elt.registryEntry) { 455 if (failed(elt.registryEntry->addToPipeline(pm, elt.options, 456 errorHandler))) { 457 return errorHandler("failed to add `" + elt.name + "` with options `" + 458 elt.options + "`"); 459 } 460 } else if (failed(addToPipeline(elt.innerPipeline, pm.nest(elt.name), 461 errorHandler))) { 462 return errorHandler("failed to add `" + elt.name + "` with options `" + 463 elt.options + "` to inner pipeline"); 464 } 465 } 466 return success(); 467 } 468 469 /// This function parses the textual representation of a pass pipeline, and adds 470 /// the result to 'pm' on success. This function returns failure if the given 471 /// pipeline was invalid. 'errorStream' is an optional parameter that, if 472 /// non-null, will be used to emit errors found during parsing. 473 LogicalResult mlir::parsePassPipeline(StringRef pipeline, OpPassManager &pm, 474 raw_ostream &errorStream) { 475 TextualPipeline pipelineParser; 476 if (failed(pipelineParser.initialize(pipeline, errorStream))) 477 return failure(); 478 auto errorHandler = [&](Twine msg) { 479 errorStream << msg << "\n"; 480 return failure(); 481 }; 482 if (failed(pipelineParser.addToPipeline(pm, errorHandler))) 483 return failure(); 484 return success(); 485 } 486 487 //===----------------------------------------------------------------------===// 488 // PassNameParser 489 //===----------------------------------------------------------------------===// 490 491 namespace { 492 /// This struct represents the possible data entries in a parsed pass pipeline 493 /// list. 494 struct PassArgData { 495 PassArgData() : registryEntry(nullptr) {} 496 PassArgData(const PassRegistryEntry *registryEntry) 497 : registryEntry(registryEntry) {} 498 499 /// This field is used when the parsed option corresponds to a registered pass 500 /// or pass pipeline. 501 const PassRegistryEntry *registryEntry; 502 503 /// This field is set when instance specific pass options have been provided 504 /// on the command line. 505 StringRef options; 506 507 /// This field is used when the parsed option corresponds to an explicit 508 /// pipeline. 509 TextualPipeline pipeline; 510 }; 511 } // end anonymous namespace 512 513 namespace llvm { 514 namespace cl { 515 /// Define a valid OptionValue for the command line pass argument. 516 template <> 517 struct OptionValue<PassArgData> final 518 : OptionValueBase<PassArgData, /*isClass=*/true> { 519 OptionValue(const PassArgData &value) { this->setValue(value); } 520 OptionValue() = default; 521 void anchor() override {} 522 523 bool hasValue() const { return true; } 524 const PassArgData &getValue() const { return value; } 525 void setValue(const PassArgData &value) { this->value = value; } 526 527 PassArgData value; 528 }; 529 } // end namespace cl 530 } // end namespace llvm 531 532 namespace { 533 534 /// The name for the command line option used for parsing the textual pass 535 /// pipeline. 536 static constexpr StringLiteral passPipelineArg = "pass-pipeline"; 537 538 /// Adds command line option for each registered pass or pass pipeline, as well 539 /// as textual pass pipelines. 540 struct PassNameParser : public llvm::cl::parser<PassArgData> { 541 PassNameParser(llvm::cl::Option &opt) : llvm::cl::parser<PassArgData>(opt) {} 542 543 void initialize(); 544 void printOptionInfo(const llvm::cl::Option &opt, 545 size_t globalWidth) const override; 546 size_t getOptionWidth(const llvm::cl::Option &opt) const override; 547 bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg, 548 PassArgData &value); 549 550 /// If true, this parser only parses entries that correspond to a concrete 551 /// pass registry entry, and does not add a `pass-pipeline` argument, does not 552 /// include the options for pass entries, and does not include pass pipelines 553 /// entries. 554 bool passNamesOnly = false; 555 }; 556 } // namespace 557 558 void PassNameParser::initialize() { 559 llvm::cl::parser<PassArgData>::initialize(); 560 561 /// Add an entry for the textual pass pipeline option. 562 if (!passNamesOnly) { 563 addLiteralOption(passPipelineArg, PassArgData(), 564 "A textual description of a pass pipeline to run"); 565 } 566 567 /// Add the pass entries. 568 for (const auto &kv : *passRegistry) { 569 addLiteralOption(kv.second.getPassArgument(), &kv.second, 570 kv.second.getPassDescription()); 571 } 572 /// Add the pass pipeline entries. 573 if (!passNamesOnly) { 574 for (const auto &kv : *passPipelineRegistry) { 575 addLiteralOption(kv.second.getPassArgument(), &kv.second, 576 kv.second.getPassDescription()); 577 } 578 } 579 } 580 581 void PassNameParser::printOptionInfo(const llvm::cl::Option &opt, 582 size_t globalWidth) const { 583 // If this parser is just parsing pass names, print a simplified option 584 // string. 585 if (passNamesOnly) { 586 llvm::outs() << " --" << opt.ArgStr << "=<pass-arg>"; 587 opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 18); 588 return; 589 } 590 591 // Print the information for the top-level option. 592 if (opt.hasArgStr()) { 593 llvm::outs() << " --" << opt.ArgStr; 594 opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 7); 595 } else { 596 llvm::outs() << " " << opt.HelpStr << '\n'; 597 } 598 599 // Print the top-level pipeline argument. 600 printOptionHelp(passPipelineArg, 601 "A textual description of a pass pipeline to run", 602 /*indent=*/4, globalWidth, /*isTopLevel=*/!opt.hasArgStr()); 603 604 // Functor used to print the ordered entries of a registration map. 605 auto printOrderedEntries = [&](StringRef header, auto &map) { 606 llvm::SmallVector<PassRegistryEntry *, 32> orderedEntries; 607 for (auto &kv : map) 608 orderedEntries.push_back(&kv.second); 609 llvm::array_pod_sort( 610 orderedEntries.begin(), orderedEntries.end(), 611 [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) { 612 return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument()); 613 }); 614 615 llvm::outs().indent(4) << header << ":\n"; 616 for (PassRegistryEntry *entry : orderedEntries) 617 entry->printHelpStr(/*indent=*/6, globalWidth); 618 }; 619 620 // Print the available passes. 621 printOrderedEntries("Passes", *passRegistry); 622 623 // Print the available pass pipelines. 624 if (!passPipelineRegistry->empty()) 625 printOrderedEntries("Pass Pipelines", *passPipelineRegistry); 626 } 627 628 size_t PassNameParser::getOptionWidth(const llvm::cl::Option &opt) const { 629 size_t maxWidth = llvm::cl::parser<PassArgData>::getOptionWidth(opt) + 2; 630 631 // Check for any wider pass or pipeline options. 632 for (auto &entry : *passRegistry) 633 maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4); 634 for (auto &entry : *passPipelineRegistry) 635 maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4); 636 return maxWidth; 637 } 638 639 bool PassNameParser::parse(llvm::cl::Option &opt, StringRef argName, 640 StringRef arg, PassArgData &value) { 641 // Handle the pipeline option explicitly. 642 if (argName == passPipelineArg) 643 return failed(value.pipeline.initialize(arg, llvm::errs())); 644 645 // Otherwise, default to the base for handling. 646 if (llvm::cl::parser<PassArgData>::parse(opt, argName, arg, value)) 647 return true; 648 value.options = arg; 649 return false; 650 } 651 652 //===----------------------------------------------------------------------===// 653 // PassPipelineCLParser 654 //===----------------------------------------------------------------------===// 655 656 namespace mlir { 657 namespace detail { 658 struct PassPipelineCLParserImpl { 659 PassPipelineCLParserImpl(StringRef arg, StringRef description, 660 bool passNamesOnly) 661 : passList(arg, llvm::cl::desc(description)) { 662 passList.getParser().passNamesOnly = passNamesOnly; 663 passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional); 664 } 665 666 /// Returns true if the given pass registry entry was registered at the 667 /// top-level of the parser, i.e. not within an explicit textual pipeline. 668 bool contains(const PassRegistryEntry *entry) const { 669 return llvm::any_of(passList, [&](const PassArgData &data) { 670 return data.registryEntry == entry; 671 }); 672 } 673 674 /// The set of passes and pass pipelines to run. 675 llvm::cl::list<PassArgData, bool, PassNameParser> passList; 676 }; 677 } // end namespace detail 678 } // end namespace mlir 679 680 /// Construct a pass pipeline parser with the given command line description. 681 PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description) 682 : impl(std::make_unique<detail::PassPipelineCLParserImpl>( 683 arg, description, /*passNamesOnly=*/false)) {} 684 PassPipelineCLParser::~PassPipelineCLParser() {} 685 686 /// Returns true if this parser contains any valid options to add. 687 bool PassPipelineCLParser::hasAnyOccurrences() const { 688 return impl->passList.getNumOccurrences() != 0; 689 } 690 691 /// Returns true if the given pass registry entry was registered at the 692 /// top-level of the parser, i.e. not within an explicit textual pipeline. 693 bool PassPipelineCLParser::contains(const PassRegistryEntry *entry) const { 694 return impl->contains(entry); 695 } 696 697 /// Adds the passes defined by this parser entry to the given pass manager. 698 LogicalResult PassPipelineCLParser::addToPipeline( 699 OpPassManager &pm, 700 function_ref<LogicalResult(const Twine &)> errorHandler) const { 701 for (auto &passIt : impl->passList) { 702 if (passIt.registryEntry) { 703 if (failed(passIt.registryEntry->addToPipeline(pm, passIt.options, 704 errorHandler))) 705 return failure(); 706 } else { 707 OpPassManager::Nesting nesting = pm.getNesting(); 708 pm.setNesting(OpPassManager::Nesting::Explicit); 709 LogicalResult status = passIt.pipeline.addToPipeline(pm, errorHandler); 710 pm.setNesting(nesting); 711 if (failed(status)) 712 return failure(); 713 } 714 } 715 return success(); 716 } 717 718 //===----------------------------------------------------------------------===// 719 // PassNameCLParser 720 721 /// Construct a pass pipeline parser with the given command line description. 722 PassNameCLParser::PassNameCLParser(StringRef arg, StringRef description) 723 : impl(std::make_unique<detail::PassPipelineCLParserImpl>( 724 arg, description, /*passNamesOnly=*/true)) { 725 impl->passList.setMiscFlag(llvm::cl::CommaSeparated); 726 } 727 PassNameCLParser::~PassNameCLParser() {} 728 729 /// Returns true if this parser contains any valid options to add. 730 bool PassNameCLParser::hasAnyOccurrences() const { 731 return impl->passList.getNumOccurrences() != 0; 732 } 733 734 /// Returns true if the given pass registry entry was registered at the 735 /// top-level of the parser, i.e. not within an explicit textual pipeline. 736 bool PassNameCLParser::contains(const PassRegistryEntry *entry) const { 737 return impl->contains(entry); 738 } 739