1 //===- PassRegistry.cpp - Pass Registration Utilities ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "mlir/Pass/PassRegistry.h" 12 #include "mlir/Pass/Pass.h" 13 #include "mlir/Pass/PassManager.h" 14 #include "llvm/ADT/DenseMap.h" 15 #include "llvm/Support/Format.h" 16 #include "llvm/Support/ManagedStatic.h" 17 #include "llvm/Support/MemoryBuffer.h" 18 #include "llvm/Support/SourceMgr.h" 19 20 using namespace mlir; 21 using namespace detail; 22 23 /// Static mapping of all of the registered passes. 24 static llvm::ManagedStatic<llvm::StringMap<PassInfo>> passRegistry; 25 26 /// A mapping of the above pass registry entries to the corresponding TypeID 27 /// of the pass that they generate. 28 static llvm::ManagedStatic<llvm::StringMap<TypeID>> passRegistryTypeIDs; 29 30 /// Static mapping of all of the registered pass pipelines. 31 static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>> 32 passPipelineRegistry; 33 34 /// Utility to create a default registry function from a pass instance. 35 static PassRegistryFunction 36 buildDefaultRegistryFn(const PassAllocatorFunction &allocator) { 37 return [=](OpPassManager &pm, StringRef options, 38 function_ref<LogicalResult(const Twine &)> errorHandler) { 39 std::unique_ptr<Pass> pass = allocator(); 40 LogicalResult result = pass->initializeOptions(options); 41 if ((pm.getNesting() == OpPassManager::Nesting::Explicit) && 42 pass->getOpName() && *pass->getOpName() != pm.getOpName()) 43 return errorHandler(llvm::Twine("Can't add pass '") + pass->getName() + 44 "' restricted to '" + *pass->getOpName() + 45 "' on a PassManager intended to run on '" + 46 pm.getOpName() + "', did you intend to nest?"); 47 pm.addPass(std::move(pass)); 48 return result; 49 }; 50 } 51 52 /// Utility to print the help string for a specific option. 53 static void printOptionHelp(StringRef arg, StringRef desc, size_t indent, 54 size_t descIndent, bool isTopLevel) { 55 size_t numSpaces = descIndent - indent - 4; 56 llvm::outs().indent(indent) 57 << "--" << llvm::left_justify(arg, numSpaces) << "- " << desc << '\n'; 58 } 59 60 //===----------------------------------------------------------------------===// 61 // PassRegistry 62 //===----------------------------------------------------------------------===// 63 64 /// Print the help information for this pass. This includes the argument, 65 /// description, and any pass options. `descIndent` is the indent that the 66 /// descriptions should be aligned. 67 void PassRegistryEntry::printHelpStr(size_t indent, size_t descIndent) const { 68 printOptionHelp(getPassArgument(), getPassDescription(), indent, descIndent, 69 /*isTopLevel=*/true); 70 // If this entry has options, print the help for those as well. 71 optHandler([=](const PassOptions &options) { 72 options.printHelp(indent, descIndent); 73 }); 74 } 75 76 /// Return the maximum width required when printing the options of this 77 /// entry. 78 size_t PassRegistryEntry::getOptionWidth() const { 79 size_t maxLen = 0; 80 optHandler([&](const PassOptions &options) mutable { 81 maxLen = options.getOptionWidth() + 2; 82 }); 83 return maxLen; 84 } 85 86 //===----------------------------------------------------------------------===// 87 // PassPipelineInfo 88 //===----------------------------------------------------------------------===// 89 90 void mlir::registerPassPipeline( 91 StringRef arg, StringRef description, const PassRegistryFunction &function, 92 std::function<void(function_ref<void(const PassOptions &)>)> optHandler) { 93 PassPipelineInfo pipelineInfo(arg, description, function, 94 std::move(optHandler)); 95 bool inserted = passPipelineRegistry->try_emplace(arg, pipelineInfo).second; 96 assert(inserted && "Pass pipeline registered multiple times"); 97 (void)inserted; 98 } 99 100 //===----------------------------------------------------------------------===// 101 // PassInfo 102 //===----------------------------------------------------------------------===// 103 104 PassInfo::PassInfo(StringRef arg, StringRef description, 105 const PassAllocatorFunction &allocator) 106 : PassRegistryEntry( 107 arg, description, buildDefaultRegistryFn(allocator), 108 // Use a temporary pass to provide an options instance. 109 [=](function_ref<void(const PassOptions &)> optHandler) { 110 optHandler(allocator()->passOptions); 111 }) {} 112 113 void mlir::registerPass(const PassAllocatorFunction &function) { 114 std::unique_ptr<Pass> pass = function(); 115 StringRef arg = pass->getArgument(); 116 if (arg.empty()) 117 llvm::report_fatal_error(llvm::Twine("Trying to register '") + 118 pass->getName() + 119 "' pass that does not override `getArgument()`"); 120 StringRef description = pass->getDescription(); 121 PassInfo passInfo(arg, description, function); 122 passRegistry->try_emplace(arg, passInfo); 123 124 // Verify that the registered pass has the same ID as any registered to this 125 // arg before it. 126 TypeID entryTypeID = pass->getTypeID(); 127 auto it = passRegistryTypeIDs->try_emplace(arg, entryTypeID).first; 128 if (it->second != entryTypeID) 129 llvm::report_fatal_error( 130 "pass allocator creates a different pass than previously " 131 "registered for pass " + 132 arg); 133 } 134 135 /// Returns the pass info for the specified pass argument or null if unknown. 136 const PassInfo *mlir::Pass::lookupPassInfo(StringRef passArg) { 137 auto it = passRegistry->find(passArg); 138 return it == passRegistry->end() ? nullptr : &it->second; 139 } 140 141 //===----------------------------------------------------------------------===// 142 // PassOptions 143 //===----------------------------------------------------------------------===// 144 145 /// Out of line virtual function to provide home for the class. 146 void detail::PassOptions::OptionBase::anchor() {} 147 148 /// Copy the option values from 'other'. 149 void detail::PassOptions::copyOptionValuesFrom(const PassOptions &other) { 150 assert(options.size() == other.options.size()); 151 if (options.empty()) 152 return; 153 for (auto optionsIt : llvm::zip(options, other.options)) 154 std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt)); 155 } 156 157 LogicalResult detail::PassOptions::parseFromString(StringRef options) { 158 // TODO: Handle escaping strings. 159 // NOTE: `options` is modified in place to always refer to the unprocessed 160 // part of the string. 161 while (!options.empty()) { 162 size_t spacePos = options.find(' '); 163 StringRef arg = options; 164 if (spacePos != StringRef::npos) { 165 arg = options.substr(0, spacePos); 166 options = options.substr(spacePos + 1); 167 } else { 168 options = StringRef(); 169 } 170 if (arg.empty()) 171 continue; 172 173 // At this point, arg refers to everything that is non-space in options 174 // upto the next space, and options refers to the rest of the string after 175 // that point. 176 177 // Split the individual option on '=' to form key and value. If there is no 178 // '=', then value is `StringRef()`. 179 size_t equalPos = arg.find('='); 180 StringRef key = arg; 181 StringRef value; 182 if (equalPos != StringRef::npos) { 183 key = arg.substr(0, equalPos); 184 value = arg.substr(equalPos + 1); 185 } 186 auto it = OptionsMap.find(key); 187 if (it == OptionsMap.end()) { 188 llvm::errs() << "<Pass-Options-Parser>: no such option " << key << "\n"; 189 return failure(); 190 } 191 if (llvm::cl::ProvidePositionalOption(it->second, value, 0)) 192 return failure(); 193 } 194 195 return success(); 196 } 197 198 /// Print the options held by this struct in a form that can be parsed via 199 /// 'parseFromString'. 200 void detail::PassOptions::print(raw_ostream &os) { 201 // If there are no options, there is nothing left to do. 202 if (OptionsMap.empty()) 203 return; 204 205 // Sort the options to make the ordering deterministic. 206 SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end()); 207 auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) { 208 return (*lhs)->getArgStr().compare((*rhs)->getArgStr()); 209 }; 210 llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs); 211 212 // Interleave the options with ' '. 213 os << '{'; 214 llvm::interleave( 215 orderedOps, os, [&](OptionBase *option) { option->print(os); }, " "); 216 os << '}'; 217 } 218 219 /// Print the help string for the options held by this struct. `descIndent` is 220 /// the indent within the stream that the descriptions should be aligned. 221 void detail::PassOptions::printHelp(size_t indent, size_t descIndent) const { 222 // Sort the options to make the ordering deterministic. 223 SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end()); 224 auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) { 225 return (*lhs)->getArgStr().compare((*rhs)->getArgStr()); 226 }; 227 llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs); 228 for (OptionBase *option : orderedOps) { 229 // TODO: printOptionInfo assumes a specific indent and will 230 // print options with values with incorrect indentation. We should add 231 // support to llvm::cl::Option for passing in a base indent to use when 232 // printing. 233 llvm::outs().indent(indent); 234 option->getOption()->printOptionInfo(descIndent - indent); 235 } 236 } 237 238 /// Return the maximum width required when printing the help string. 239 size_t detail::PassOptions::getOptionWidth() const { 240 size_t max = 0; 241 for (auto *option : options) 242 max = std::max(max, option->getOption()->getOptionWidth()); 243 return max; 244 } 245 246 //===----------------------------------------------------------------------===// 247 // TextualPassPipeline Parser 248 //===----------------------------------------------------------------------===// 249 250 namespace { 251 /// This class represents a textual description of a pass pipeline. 252 class TextualPipeline { 253 public: 254 /// Try to initialize this pipeline with the given pipeline text. 255 /// `errorStream` is the output stream to emit errors to. 256 LogicalResult initialize(StringRef text, raw_ostream &errorStream); 257 258 /// Add the internal pipeline elements to the provided pass manager. 259 LogicalResult 260 addToPipeline(OpPassManager &pm, 261 function_ref<LogicalResult(const Twine &)> errorHandler) const; 262 263 private: 264 /// A functor used to emit errors found during pipeline handling. The first 265 /// parameter corresponds to the raw location within the pipeline string. This 266 /// should always return failure. 267 using ErrorHandlerT = function_ref<LogicalResult(const char *, Twine)>; 268 269 /// A struct to capture parsed pass pipeline names. 270 /// 271 /// A pipeline is defined as a series of names, each of which may in itself 272 /// recursively contain a nested pipeline. A name is either the name of a pass 273 /// (e.g. "cse") or the name of an operation type (e.g. "builtin.func"). If 274 /// the name is the name of a pass, the InnerPipeline is empty, since passes 275 /// cannot contain inner pipelines. 276 struct PipelineElement { 277 PipelineElement(StringRef name) : name(name), registryEntry(nullptr) {} 278 279 StringRef name; 280 StringRef options; 281 const PassRegistryEntry *registryEntry; 282 std::vector<PipelineElement> innerPipeline; 283 }; 284 285 /// Parse the given pipeline text into the internal pipeline vector. This 286 /// function only parses the structure of the pipeline, and does not resolve 287 /// its elements. 288 LogicalResult parsePipelineText(StringRef text, ErrorHandlerT errorHandler); 289 290 /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to 291 /// the corresponding registry entry. 292 LogicalResult 293 resolvePipelineElements(MutableArrayRef<PipelineElement> elements, 294 ErrorHandlerT errorHandler); 295 296 /// Resolve a single element of the pipeline. 297 LogicalResult resolvePipelineElement(PipelineElement &element, 298 ErrorHandlerT errorHandler); 299 300 /// Add the given pipeline elements to the provided pass manager. 301 LogicalResult 302 addToPipeline(ArrayRef<PipelineElement> elements, OpPassManager &pm, 303 function_ref<LogicalResult(const Twine &)> errorHandler) const; 304 305 std::vector<PipelineElement> pipeline; 306 }; 307 308 } // namespace 309 310 /// Try to initialize this pipeline with the given pipeline text. An option is 311 /// given to enable accurate error reporting. 312 LogicalResult TextualPipeline::initialize(StringRef text, 313 raw_ostream &errorStream) { 314 if (text.empty()) 315 return success(); 316 317 // Build a source manager to use for error reporting. 318 llvm::SourceMgr pipelineMgr; 319 pipelineMgr.AddNewSourceBuffer( 320 llvm::MemoryBuffer::getMemBuffer(text, "MLIR Textual PassPipeline Parser", 321 /*RequiresNullTerminator=*/false), 322 llvm::SMLoc()); 323 auto errorHandler = [&](const char *rawLoc, Twine msg) { 324 pipelineMgr.PrintMessage(errorStream, llvm::SMLoc::getFromPointer(rawLoc), 325 llvm::SourceMgr::DK_Error, msg); 326 return failure(); 327 }; 328 329 // Parse the provided pipeline string. 330 if (failed(parsePipelineText(text, errorHandler))) 331 return failure(); 332 return resolvePipelineElements(pipeline, errorHandler); 333 } 334 335 /// Add the internal pipeline elements to the provided pass manager. 336 LogicalResult TextualPipeline::addToPipeline( 337 OpPassManager &pm, 338 function_ref<LogicalResult(const Twine &)> errorHandler) const { 339 return addToPipeline(pipeline, pm, errorHandler); 340 } 341 342 /// Parse the given pipeline text into the internal pipeline vector. This 343 /// function only parses the structure of the pipeline, and does not resolve 344 /// its elements. 345 LogicalResult TextualPipeline::parsePipelineText(StringRef text, 346 ErrorHandlerT errorHandler) { 347 SmallVector<std::vector<PipelineElement> *, 4> pipelineStack = {&pipeline}; 348 for (;;) { 349 std::vector<PipelineElement> &pipeline = *pipelineStack.back(); 350 size_t pos = text.find_first_of(",(){"); 351 pipeline.emplace_back(/*name=*/text.substr(0, pos).trim()); 352 353 // If we have a single terminating name, we're done. 354 if (pos == StringRef::npos) 355 break; 356 357 text = text.substr(pos); 358 char sep = text[0]; 359 360 // Handle pulling ... from 'pass{...}' out as PipelineElement.options. 361 if (sep == '{') { 362 text = text.substr(1); 363 364 // Skip over everything until the closing '}' and store as options. 365 size_t close = StringRef::npos; 366 for (unsigned i = 0, e = text.size(), braceCount = 1; i < e; ++i) { 367 if (text[i] == '{') { 368 ++braceCount; 369 continue; 370 } 371 if (text[i] == '}' && --braceCount == 0) { 372 close = i; 373 break; 374 } 375 } 376 377 // Check to see if a closing options brace was found. 378 if (close == StringRef::npos) { 379 return errorHandler( 380 /*rawLoc=*/text.data() - 1, 381 "missing closing '}' while processing pass options"); 382 } 383 pipeline.back().options = text.substr(0, close); 384 text = text.substr(close + 1); 385 386 // Skip checking for '(' because nested pipelines cannot have options. 387 } else if (sep == '(') { 388 text = text.substr(1); 389 390 // Push the inner pipeline onto the stack to continue processing. 391 pipelineStack.push_back(&pipeline.back().innerPipeline); 392 continue; 393 } 394 395 // When handling the close parenthesis, we greedily consume them to avoid 396 // empty strings in the pipeline. 397 while (text.consume_front(")")) { 398 // If we try to pop the outer pipeline we have unbalanced parentheses. 399 if (pipelineStack.size() == 1) 400 return errorHandler(/*rawLoc=*/text.data() - 1, 401 "encountered extra closing ')' creating unbalanced " 402 "parentheses while parsing pipeline"); 403 404 pipelineStack.pop_back(); 405 } 406 407 // Check if we've finished parsing. 408 if (text.empty()) 409 break; 410 411 // Otherwise, the end of an inner pipeline always has to be followed by 412 // a comma, and then we can continue. 413 if (!text.consume_front(",")) 414 return errorHandler(text.data(), "expected ',' after parsing pipeline"); 415 } 416 417 // Check for unbalanced parentheses. 418 if (pipelineStack.size() > 1) 419 return errorHandler( 420 text.data(), 421 "encountered unbalanced parentheses while parsing pipeline"); 422 423 assert(pipelineStack.back() == &pipeline && 424 "wrong pipeline at the bottom of the stack"); 425 return success(); 426 } 427 428 /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to 429 /// the corresponding registry entry. 430 LogicalResult TextualPipeline::resolvePipelineElements( 431 MutableArrayRef<PipelineElement> elements, ErrorHandlerT errorHandler) { 432 for (auto &elt : elements) 433 if (failed(resolvePipelineElement(elt, errorHandler))) 434 return failure(); 435 return success(); 436 } 437 438 /// Resolve a single element of the pipeline. 439 LogicalResult 440 TextualPipeline::resolvePipelineElement(PipelineElement &element, 441 ErrorHandlerT errorHandler) { 442 // If the inner pipeline of this element is not empty, this is an operation 443 // pipeline. 444 if (!element.innerPipeline.empty()) 445 return resolvePipelineElements(element.innerPipeline, errorHandler); 446 // Otherwise, this must be a pass or pass pipeline. 447 // Check to see if a pipeline was registered with this name. 448 auto pipelineRegistryIt = passPipelineRegistry->find(element.name); 449 if (pipelineRegistryIt != passPipelineRegistry->end()) { 450 element.registryEntry = &pipelineRegistryIt->second; 451 return success(); 452 } 453 454 // If not, then this must be a specific pass name. 455 if ((element.registryEntry = Pass::lookupPassInfo(element.name))) 456 return success(); 457 458 // Emit an error for the unknown pass. 459 auto *rawLoc = element.name.data(); 460 return errorHandler(rawLoc, "'" + element.name + 461 "' does not refer to a " 462 "registered pass or pass pipeline"); 463 } 464 465 /// Add the given pipeline elements to the provided pass manager. 466 LogicalResult TextualPipeline::addToPipeline( 467 ArrayRef<PipelineElement> elements, OpPassManager &pm, 468 function_ref<LogicalResult(const Twine &)> errorHandler) const { 469 for (auto &elt : elements) { 470 if (elt.registryEntry) { 471 if (failed(elt.registryEntry->addToPipeline(pm, elt.options, 472 errorHandler))) { 473 return errorHandler("failed to add `" + elt.name + "` with options `" + 474 elt.options + "`"); 475 } 476 } else if (failed(addToPipeline(elt.innerPipeline, pm.nest(elt.name), 477 errorHandler))) { 478 return errorHandler("failed to add `" + elt.name + "` with options `" + 479 elt.options + "` to inner pipeline"); 480 } 481 } 482 return success(); 483 } 484 485 LogicalResult mlir::parsePassPipeline(StringRef pipeline, OpPassManager &pm, 486 raw_ostream &errorStream) { 487 TextualPipeline pipelineParser; 488 if (failed(pipelineParser.initialize(pipeline, errorStream))) 489 return failure(); 490 auto errorHandler = [&](Twine msg) { 491 errorStream << msg << "\n"; 492 return failure(); 493 }; 494 if (failed(pipelineParser.addToPipeline(pm, errorHandler))) 495 return failure(); 496 return success(); 497 } 498 499 FailureOr<OpPassManager> mlir::parsePassPipeline(StringRef pipeline, 500 raw_ostream &errorStream) { 501 // Pipelines are expected to be of the form `<op-name>(<pipeline>)`. 502 size_t pipelineStart = pipeline.find_first_of('('); 503 if (pipelineStart == 0 || pipelineStart == StringRef::npos || 504 !pipeline.consume_back(")")) { 505 errorStream << "expected pass pipeline to be wrapped with the anchor " 506 "operation type, e.g. `builtin.module(...)"; 507 return failure(); 508 } 509 510 StringRef opName = pipeline.take_front(pipelineStart); 511 OpPassManager pm(opName); 512 if (failed(parsePassPipeline(pipeline.drop_front(1 + pipelineStart), pm))) 513 return failure(); 514 return pm; 515 } 516 517 //===----------------------------------------------------------------------===// 518 // PassNameParser 519 //===----------------------------------------------------------------------===// 520 521 namespace { 522 /// This struct represents the possible data entries in a parsed pass pipeline 523 /// list. 524 struct PassArgData { 525 PassArgData() = default; 526 PassArgData(const PassRegistryEntry *registryEntry) 527 : registryEntry(registryEntry) {} 528 529 /// This field is used when the parsed option corresponds to a registered pass 530 /// or pass pipeline. 531 const PassRegistryEntry *registryEntry{nullptr}; 532 533 /// This field is set when instance specific pass options have been provided 534 /// on the command line. 535 StringRef options; 536 537 /// This field is used when the parsed option corresponds to an explicit 538 /// pipeline. 539 TextualPipeline pipeline; 540 }; 541 } // namespace 542 543 namespace llvm { 544 namespace cl { 545 /// Define a valid OptionValue for the command line pass argument. 546 template <> 547 struct OptionValue<PassArgData> final 548 : OptionValueBase<PassArgData, /*isClass=*/true> { 549 OptionValue(const PassArgData &value) { this->setValue(value); } 550 OptionValue() = default; 551 void anchor() override {} 552 553 bool hasValue() const { return true; } 554 const PassArgData &getValue() const { return value; } 555 void setValue(const PassArgData &value) { this->value = value; } 556 557 PassArgData value; 558 }; 559 } // namespace cl 560 } // namespace llvm 561 562 namespace { 563 564 /// The name for the command line option used for parsing the textual pass 565 /// pipeline. 566 static constexpr StringLiteral passPipelineArg = "pass-pipeline"; 567 568 /// Adds command line option for each registered pass or pass pipeline, as well 569 /// as textual pass pipelines. 570 struct PassNameParser : public llvm::cl::parser<PassArgData> { 571 PassNameParser(llvm::cl::Option &opt) : llvm::cl::parser<PassArgData>(opt) {} 572 573 void initialize(); 574 void printOptionInfo(const llvm::cl::Option &opt, 575 size_t globalWidth) const override; 576 size_t getOptionWidth(const llvm::cl::Option &opt) const override; 577 bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg, 578 PassArgData &value); 579 580 /// If true, this parser only parses entries that correspond to a concrete 581 /// pass registry entry, and does not add a `pass-pipeline` argument, does not 582 /// include the options for pass entries, and does not include pass pipelines 583 /// entries. 584 bool passNamesOnly = false; 585 }; 586 } // namespace 587 588 void PassNameParser::initialize() { 589 llvm::cl::parser<PassArgData>::initialize(); 590 591 /// Add an entry for the textual pass pipeline option. 592 if (!passNamesOnly) { 593 addLiteralOption(passPipelineArg, PassArgData(), 594 "A textual description of a pass pipeline to run"); 595 } 596 597 /// Add the pass entries. 598 for (const auto &kv : *passRegistry) { 599 addLiteralOption(kv.second.getPassArgument(), &kv.second, 600 kv.second.getPassDescription()); 601 } 602 /// Add the pass pipeline entries. 603 if (!passNamesOnly) { 604 for (const auto &kv : *passPipelineRegistry) { 605 addLiteralOption(kv.second.getPassArgument(), &kv.second, 606 kv.second.getPassDescription()); 607 } 608 } 609 } 610 611 void PassNameParser::printOptionInfo(const llvm::cl::Option &opt, 612 size_t globalWidth) const { 613 // If this parser is just parsing pass names, print a simplified option 614 // string. 615 if (passNamesOnly) { 616 llvm::outs() << " --" << opt.ArgStr << "=<pass-arg>"; 617 opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 18); 618 return; 619 } 620 621 // Print the information for the top-level option. 622 if (opt.hasArgStr()) { 623 llvm::outs() << " --" << opt.ArgStr; 624 opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 7); 625 } else { 626 llvm::outs() << " " << opt.HelpStr << '\n'; 627 } 628 629 // Print the top-level pipeline argument. 630 printOptionHelp(passPipelineArg, 631 "A textual description of a pass pipeline to run", 632 /*indent=*/4, globalWidth, /*isTopLevel=*/!opt.hasArgStr()); 633 634 // Functor used to print the ordered entries of a registration map. 635 auto printOrderedEntries = [&](StringRef header, auto &map) { 636 llvm::SmallVector<PassRegistryEntry *, 32> orderedEntries; 637 for (auto &kv : map) 638 orderedEntries.push_back(&kv.second); 639 llvm::array_pod_sort( 640 orderedEntries.begin(), orderedEntries.end(), 641 [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) { 642 return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument()); 643 }); 644 645 llvm::outs().indent(4) << header << ":\n"; 646 for (PassRegistryEntry *entry : orderedEntries) 647 entry->printHelpStr(/*indent=*/6, globalWidth); 648 }; 649 650 // Print the available passes. 651 printOrderedEntries("Passes", *passRegistry); 652 653 // Print the available pass pipelines. 654 if (!passPipelineRegistry->empty()) 655 printOrderedEntries("Pass Pipelines", *passPipelineRegistry); 656 } 657 658 size_t PassNameParser::getOptionWidth(const llvm::cl::Option &opt) const { 659 size_t maxWidth = llvm::cl::parser<PassArgData>::getOptionWidth(opt) + 2; 660 661 // Check for any wider pass or pipeline options. 662 for (auto &entry : *passRegistry) 663 maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4); 664 for (auto &entry : *passPipelineRegistry) 665 maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4); 666 return maxWidth; 667 } 668 669 bool PassNameParser::parse(llvm::cl::Option &opt, StringRef argName, 670 StringRef arg, PassArgData &value) { 671 // Handle the pipeline option explicitly. 672 if (argName == passPipelineArg) 673 return failed(value.pipeline.initialize(arg, llvm::errs())); 674 675 // Otherwise, default to the base for handling. 676 if (llvm::cl::parser<PassArgData>::parse(opt, argName, arg, value)) 677 return true; 678 value.options = arg; 679 return false; 680 } 681 682 //===----------------------------------------------------------------------===// 683 // PassPipelineCLParser 684 //===----------------------------------------------------------------------===// 685 686 namespace mlir { 687 namespace detail { 688 struct PassPipelineCLParserImpl { 689 PassPipelineCLParserImpl(StringRef arg, StringRef description, 690 bool passNamesOnly) 691 : passList(arg, llvm::cl::desc(description)) { 692 passList.getParser().passNamesOnly = passNamesOnly; 693 passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional); 694 } 695 696 /// Returns true if the given pass registry entry was registered at the 697 /// top-level of the parser, i.e. not within an explicit textual pipeline. 698 bool contains(const PassRegistryEntry *entry) const { 699 return llvm::any_of(passList, [&](const PassArgData &data) { 700 return data.registryEntry == entry; 701 }); 702 } 703 704 /// The set of passes and pass pipelines to run. 705 llvm::cl::list<PassArgData, bool, PassNameParser> passList; 706 }; 707 } // namespace detail 708 } // namespace mlir 709 710 /// Construct a pass pipeline parser with the given command line description. 711 PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description) 712 : impl(std::make_unique<detail::PassPipelineCLParserImpl>( 713 arg, description, /*passNamesOnly=*/false)) {} 714 PassPipelineCLParser::~PassPipelineCLParser() = default; 715 716 /// Returns true if this parser contains any valid options to add. 717 bool PassPipelineCLParser::hasAnyOccurrences() const { 718 return impl->passList.getNumOccurrences() != 0; 719 } 720 721 /// Returns true if the given pass registry entry was registered at the 722 /// top-level of the parser, i.e. not within an explicit textual pipeline. 723 bool PassPipelineCLParser::contains(const PassRegistryEntry *entry) const { 724 return impl->contains(entry); 725 } 726 727 /// Adds the passes defined by this parser entry to the given pass manager. 728 LogicalResult PassPipelineCLParser::addToPipeline( 729 OpPassManager &pm, 730 function_ref<LogicalResult(const Twine &)> errorHandler) const { 731 for (auto &passIt : impl->passList) { 732 if (passIt.registryEntry) { 733 if (failed(passIt.registryEntry->addToPipeline(pm, passIt.options, 734 errorHandler))) 735 return failure(); 736 } else { 737 OpPassManager::Nesting nesting = pm.getNesting(); 738 pm.setNesting(OpPassManager::Nesting::Explicit); 739 LogicalResult status = passIt.pipeline.addToPipeline(pm, errorHandler); 740 pm.setNesting(nesting); 741 if (failed(status)) 742 return failure(); 743 } 744 } 745 return success(); 746 } 747 748 //===----------------------------------------------------------------------===// 749 // PassNameCLParser 750 751 /// Construct a pass pipeline parser with the given command line description. 752 PassNameCLParser::PassNameCLParser(StringRef arg, StringRef description) 753 : impl(std::make_unique<detail::PassPipelineCLParserImpl>( 754 arg, description, /*passNamesOnly=*/true)) { 755 impl->passList.setMiscFlag(llvm::cl::CommaSeparated); 756 } 757 PassNameCLParser::~PassNameCLParser() = default; 758 759 /// Returns true if this parser contains any valid options to add. 760 bool PassNameCLParser::hasAnyOccurrences() const { 761 return impl->passList.getNumOccurrences() != 0; 762 } 763 764 /// Returns true if the given pass registry entry was registered at the 765 /// top-level of the parser, i.e. not within an explicit textual pipeline. 766 bool PassNameCLParser::contains(const PassRegistryEntry *entry) const { 767 return impl->contains(entry); 768 } 769