1 //===- PassRegistry.cpp - Pass Registration Utilities ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/Pass/PassRegistry.h"
12 #include "mlir/Pass/Pass.h"
13 #include "mlir/Pass/PassManager.h"
14 #include "llvm/ADT/DenseMap.h"
15 #include "llvm/Support/Format.h"
16 #include "llvm/Support/ManagedStatic.h"
17 #include "llvm/Support/MemoryBuffer.h"
18 #include "llvm/Support/SourceMgr.h"
19 
20 using namespace mlir;
21 using namespace detail;
22 
23 /// Static mapping of all of the registered passes.
24 static llvm::ManagedStatic<llvm::StringMap<PassInfo>> passRegistry;
25 
26 /// A mapping of the above pass registry entries to the corresponding TypeID
27 /// of the pass that they generate.
28 static llvm::ManagedStatic<llvm::StringMap<TypeID>> passRegistryTypeIDs;
29 
30 /// Static mapping of all of the registered pass pipelines.
31 static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
32     passPipelineRegistry;
33 
34 /// Utility to create a default registry function from a pass instance.
35 static PassRegistryFunction
36 buildDefaultRegistryFn(const PassAllocatorFunction &allocator) {
37   return [=](OpPassManager &pm, StringRef options,
38              function_ref<LogicalResult(const Twine &)> errorHandler) {
39     std::unique_ptr<Pass> pass = allocator();
40     LogicalResult result = pass->initializeOptions(options);
41     if ((pm.getNesting() == OpPassManager::Nesting::Explicit) &&
42         pass->getOpName() && *pass->getOpName() != pm.getOpName())
43       return errorHandler(llvm::Twine("Can't add pass '") + pass->getName() +
44                           "' restricted to '" + *pass->getOpName() +
45                           "' on a PassManager intended to run on '" +
46                           pm.getOpName() + "', did you intend to nest?");
47     pm.addPass(std::move(pass));
48     return result;
49   };
50 }
51 
52 /// Utility to print the help string for a specific option.
53 static void printOptionHelp(StringRef arg, StringRef desc, size_t indent,
54                             size_t descIndent, bool isTopLevel) {
55   size_t numSpaces = descIndent - indent - 4;
56   llvm::outs().indent(indent)
57       << "--" << llvm::left_justify(arg, numSpaces) << "-   " << desc << '\n';
58 }
59 
60 //===----------------------------------------------------------------------===//
61 // PassRegistry
62 //===----------------------------------------------------------------------===//
63 
64 /// Print the help information for this pass. This includes the argument,
65 /// description, and any pass options. `descIndent` is the indent that the
66 /// descriptions should be aligned.
67 void PassRegistryEntry::printHelpStr(size_t indent, size_t descIndent) const {
68   printOptionHelp(getPassArgument(), getPassDescription(), indent, descIndent,
69                   /*isTopLevel=*/true);
70   // If this entry has options, print the help for those as well.
71   optHandler([=](const PassOptions &options) {
72     options.printHelp(indent, descIndent);
73   });
74 }
75 
76 /// Return the maximum width required when printing the options of this
77 /// entry.
78 size_t PassRegistryEntry::getOptionWidth() const {
79   size_t maxLen = 0;
80   optHandler([&](const PassOptions &options) mutable {
81     maxLen = options.getOptionWidth() + 2;
82   });
83   return maxLen;
84 }
85 
86 //===----------------------------------------------------------------------===//
87 // PassPipelineInfo
88 //===----------------------------------------------------------------------===//
89 
90 void mlir::registerPassPipeline(
91     StringRef arg, StringRef description, const PassRegistryFunction &function,
92     std::function<void(function_ref<void(const PassOptions &)>)> optHandler) {
93   PassPipelineInfo pipelineInfo(arg, description, function,
94                                 std::move(optHandler));
95   bool inserted = passPipelineRegistry->try_emplace(arg, pipelineInfo).second;
96   assert(inserted && "Pass pipeline registered multiple times");
97   (void)inserted;
98 }
99 
100 //===----------------------------------------------------------------------===//
101 // PassInfo
102 //===----------------------------------------------------------------------===//
103 
104 PassInfo::PassInfo(StringRef arg, StringRef description,
105                    const PassAllocatorFunction &allocator)
106     : PassRegistryEntry(
107           arg, description, buildDefaultRegistryFn(allocator),
108           // Use a temporary pass to provide an options instance.
109           [=](function_ref<void(const PassOptions &)> optHandler) {
110             optHandler(allocator()->passOptions);
111           }) {}
112 
113 void mlir::registerPass(const PassAllocatorFunction &function) {
114   std::unique_ptr<Pass> pass = function();
115   StringRef arg = pass->getArgument();
116   if (arg.empty())
117     llvm::report_fatal_error(llvm::Twine("Trying to register '") +
118                              pass->getName() +
119                              "' pass that does not override `getArgument()`");
120   StringRef description = pass->getDescription();
121   PassInfo passInfo(arg, description, function);
122   passRegistry->try_emplace(arg, passInfo);
123 
124   // Verify that the registered pass has the same ID as any registered to this
125   // arg before it.
126   TypeID entryTypeID = pass->getTypeID();
127   auto it = passRegistryTypeIDs->try_emplace(arg, entryTypeID).first;
128   if (it->second != entryTypeID)
129     llvm::report_fatal_error(
130         "pass allocator creates a different pass than previously "
131         "registered for pass " +
132         arg);
133 }
134 
135 /// Returns the pass info for the specified pass argument or null if unknown.
136 const PassInfo *mlir::Pass::lookupPassInfo(StringRef passArg) {
137   auto it = passRegistry->find(passArg);
138   return it == passRegistry->end() ? nullptr : &it->second;
139 }
140 
141 //===----------------------------------------------------------------------===//
142 // PassOptions
143 //===----------------------------------------------------------------------===//
144 
145 /// Out of line virtual function to provide home for the class.
146 void detail::PassOptions::OptionBase::anchor() {}
147 
148 /// Copy the option values from 'other'.
149 void detail::PassOptions::copyOptionValuesFrom(const PassOptions &other) {
150   assert(options.size() == other.options.size());
151   if (options.empty())
152     return;
153   for (auto optionsIt : llvm::zip(options, other.options))
154     std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt));
155 }
156 
157 LogicalResult detail::PassOptions::parseFromString(StringRef options) {
158   // TODO: Handle escaping strings.
159   // NOTE: `options` is modified in place to always refer to the unprocessed
160   // part of the string.
161   while (!options.empty()) {
162     size_t spacePos = options.find(' ');
163     StringRef arg = options;
164     if (spacePos != StringRef::npos) {
165       arg = options.substr(0, spacePos);
166       options = options.substr(spacePos + 1);
167     } else {
168       options = StringRef();
169     }
170     if (arg.empty())
171       continue;
172 
173     // At this point, arg refers to everything that is non-space in options
174     // upto the next space, and options refers to the rest of the string after
175     // that point.
176 
177     // Split the individual option on '=' to form key and value. If there is no
178     // '=', then value is `StringRef()`.
179     size_t equalPos = arg.find('=');
180     StringRef key = arg;
181     StringRef value;
182     if (equalPos != StringRef::npos) {
183       key = arg.substr(0, equalPos);
184       value = arg.substr(equalPos + 1);
185     }
186     auto it = OptionsMap.find(key);
187     if (it == OptionsMap.end()) {
188       llvm::errs() << "<Pass-Options-Parser>: no such option " << key << "\n";
189       return failure();
190     }
191     if (llvm::cl::ProvidePositionalOption(it->second, value, 0))
192       return failure();
193   }
194 
195   return success();
196 }
197 
198 /// Print the options held by this struct in a form that can be parsed via
199 /// 'parseFromString'.
200 void detail::PassOptions::print(raw_ostream &os) {
201   // If there are no options, there is nothing left to do.
202   if (OptionsMap.empty())
203     return;
204 
205   // Sort the options to make the ordering deterministic.
206   SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
207   auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
208     return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
209   };
210   llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
211 
212   // Interleave the options with ' '.
213   os << '{';
214   llvm::interleave(
215       orderedOps, os, [&](OptionBase *option) { option->print(os); }, " ");
216   os << '}';
217 }
218 
219 /// Print the help string for the options held by this struct. `descIndent` is
220 /// the indent within the stream that the descriptions should be aligned.
221 void detail::PassOptions::printHelp(size_t indent, size_t descIndent) const {
222   // Sort the options to make the ordering deterministic.
223   SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
224   auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
225     return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
226   };
227   llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
228   for (OptionBase *option : orderedOps) {
229     // TODO: printOptionInfo assumes a specific indent and will
230     // print options with values with incorrect indentation. We should add
231     // support to llvm::cl::Option for passing in a base indent to use when
232     // printing.
233     llvm::outs().indent(indent);
234     option->getOption()->printOptionInfo(descIndent - indent);
235   }
236 }
237 
238 /// Return the maximum width required when printing the help string.
239 size_t detail::PassOptions::getOptionWidth() const {
240   size_t max = 0;
241   for (auto *option : options)
242     max = std::max(max, option->getOption()->getOptionWidth());
243   return max;
244 }
245 
246 //===----------------------------------------------------------------------===//
247 // TextualPassPipeline Parser
248 //===----------------------------------------------------------------------===//
249 
250 namespace {
251 /// This class represents a textual description of a pass pipeline.
252 class TextualPipeline {
253 public:
254   /// Try to initialize this pipeline with the given pipeline text.
255   /// `errorStream` is the output stream to emit errors to.
256   LogicalResult initialize(StringRef text, raw_ostream &errorStream);
257 
258   /// Add the internal pipeline elements to the provided pass manager.
259   LogicalResult
260   addToPipeline(OpPassManager &pm,
261                 function_ref<LogicalResult(const Twine &)> errorHandler) const;
262 
263 private:
264   /// A functor used to emit errors found during pipeline handling. The first
265   /// parameter corresponds to the raw location within the pipeline string. This
266   /// should always return failure.
267   using ErrorHandlerT = function_ref<LogicalResult(const char *, Twine)>;
268 
269   /// A struct to capture parsed pass pipeline names.
270   ///
271   /// A pipeline is defined as a series of names, each of which may in itself
272   /// recursively contain a nested pipeline. A name is either the name of a pass
273   /// (e.g. "cse") or the name of an operation type (e.g. "builtin.func"). If
274   /// the name is the name of a pass, the InnerPipeline is empty, since passes
275   /// cannot contain inner pipelines.
276   struct PipelineElement {
277     PipelineElement(StringRef name) : name(name), registryEntry(nullptr) {}
278 
279     StringRef name;
280     StringRef options;
281     const PassRegistryEntry *registryEntry;
282     std::vector<PipelineElement> innerPipeline;
283   };
284 
285   /// Parse the given pipeline text into the internal pipeline vector. This
286   /// function only parses the structure of the pipeline, and does not resolve
287   /// its elements.
288   LogicalResult parsePipelineText(StringRef text, ErrorHandlerT errorHandler);
289 
290   /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
291   /// the corresponding registry entry.
292   LogicalResult
293   resolvePipelineElements(MutableArrayRef<PipelineElement> elements,
294                           ErrorHandlerT errorHandler);
295 
296   /// Resolve a single element of the pipeline.
297   LogicalResult resolvePipelineElement(PipelineElement &element,
298                                        ErrorHandlerT errorHandler);
299 
300   /// Add the given pipeline elements to the provided pass manager.
301   LogicalResult
302   addToPipeline(ArrayRef<PipelineElement> elements, OpPassManager &pm,
303                 function_ref<LogicalResult(const Twine &)> errorHandler) const;
304 
305   std::vector<PipelineElement> pipeline;
306 };
307 
308 } // namespace
309 
310 /// Try to initialize this pipeline with the given pipeline text. An option is
311 /// given to enable accurate error reporting.
312 LogicalResult TextualPipeline::initialize(StringRef text,
313                                           raw_ostream &errorStream) {
314   if (text.empty())
315     return success();
316 
317   // Build a source manager to use for error reporting.
318   llvm::SourceMgr pipelineMgr;
319   pipelineMgr.AddNewSourceBuffer(
320       llvm::MemoryBuffer::getMemBuffer(text, "MLIR Textual PassPipeline Parser",
321                                        /*RequiresNullTerminator=*/false),
322       llvm::SMLoc());
323   auto errorHandler = [&](const char *rawLoc, Twine msg) {
324     pipelineMgr.PrintMessage(errorStream, llvm::SMLoc::getFromPointer(rawLoc),
325                              llvm::SourceMgr::DK_Error, msg);
326     return failure();
327   };
328 
329   // Parse the provided pipeline string.
330   if (failed(parsePipelineText(text, errorHandler)))
331     return failure();
332   return resolvePipelineElements(pipeline, errorHandler);
333 }
334 
335 /// Add the internal pipeline elements to the provided pass manager.
336 LogicalResult TextualPipeline::addToPipeline(
337     OpPassManager &pm,
338     function_ref<LogicalResult(const Twine &)> errorHandler) const {
339   return addToPipeline(pipeline, pm, errorHandler);
340 }
341 
342 /// Parse the given pipeline text into the internal pipeline vector. This
343 /// function only parses the structure of the pipeline, and does not resolve
344 /// its elements.
345 LogicalResult TextualPipeline::parsePipelineText(StringRef text,
346                                                  ErrorHandlerT errorHandler) {
347   SmallVector<std::vector<PipelineElement> *, 4> pipelineStack = {&pipeline};
348   for (;;) {
349     std::vector<PipelineElement> &pipeline = *pipelineStack.back();
350     size_t pos = text.find_first_of(",(){");
351     pipeline.emplace_back(/*name=*/text.substr(0, pos).trim());
352 
353     // If we have a single terminating name, we're done.
354     if (pos == StringRef::npos)
355       break;
356 
357     text = text.substr(pos);
358     char sep = text[0];
359 
360     // Handle pulling ... from 'pass{...}' out as PipelineElement.options.
361     if (sep == '{') {
362       text = text.substr(1);
363 
364       // Skip over everything until the closing '}' and store as options.
365       size_t close = StringRef::npos;
366       for (unsigned i = 0, e = text.size(), braceCount = 1; i < e; ++i) {
367         if (text[i] == '{') {
368           ++braceCount;
369           continue;
370         }
371         if (text[i] == '}' && --braceCount == 0) {
372           close = i;
373           break;
374         }
375       }
376 
377       // Check to see if a closing options brace was found.
378       if (close == StringRef::npos) {
379         return errorHandler(
380             /*rawLoc=*/text.data() - 1,
381             "missing closing '}' while processing pass options");
382       }
383       pipeline.back().options = text.substr(0, close);
384       text = text.substr(close + 1);
385 
386       // Skip checking for '(' because nested pipelines cannot have options.
387     } else if (sep == '(') {
388       text = text.substr(1);
389 
390       // Push the inner pipeline onto the stack to continue processing.
391       pipelineStack.push_back(&pipeline.back().innerPipeline);
392       continue;
393     }
394 
395     // When handling the close parenthesis, we greedily consume them to avoid
396     // empty strings in the pipeline.
397     while (text.consume_front(")")) {
398       // If we try to pop the outer pipeline we have unbalanced parentheses.
399       if (pipelineStack.size() == 1)
400         return errorHandler(/*rawLoc=*/text.data() - 1,
401                             "encountered extra closing ')' creating unbalanced "
402                             "parentheses while parsing pipeline");
403 
404       pipelineStack.pop_back();
405     }
406 
407     // Check if we've finished parsing.
408     if (text.empty())
409       break;
410 
411     // Otherwise, the end of an inner pipeline always has to be followed by
412     // a comma, and then we can continue.
413     if (!text.consume_front(","))
414       return errorHandler(text.data(), "expected ',' after parsing pipeline");
415   }
416 
417   // Check for unbalanced parentheses.
418   if (pipelineStack.size() > 1)
419     return errorHandler(
420         text.data(),
421         "encountered unbalanced parentheses while parsing pipeline");
422 
423   assert(pipelineStack.back() == &pipeline &&
424          "wrong pipeline at the bottom of the stack");
425   return success();
426 }
427 
428 /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
429 /// the corresponding registry entry.
430 LogicalResult TextualPipeline::resolvePipelineElements(
431     MutableArrayRef<PipelineElement> elements, ErrorHandlerT errorHandler) {
432   for (auto &elt : elements)
433     if (failed(resolvePipelineElement(elt, errorHandler)))
434       return failure();
435   return success();
436 }
437 
438 /// Resolve a single element of the pipeline.
439 LogicalResult
440 TextualPipeline::resolvePipelineElement(PipelineElement &element,
441                                         ErrorHandlerT errorHandler) {
442   // If the inner pipeline of this element is not empty, this is an operation
443   // pipeline.
444   if (!element.innerPipeline.empty())
445     return resolvePipelineElements(element.innerPipeline, errorHandler);
446   // Otherwise, this must be a pass or pass pipeline.
447   // Check to see if a pipeline was registered with this name.
448   auto pipelineRegistryIt = passPipelineRegistry->find(element.name);
449   if (pipelineRegistryIt != passPipelineRegistry->end()) {
450     element.registryEntry = &pipelineRegistryIt->second;
451     return success();
452   }
453 
454   // If not, then this must be a specific pass name.
455   if ((element.registryEntry = Pass::lookupPassInfo(element.name)))
456     return success();
457 
458   // Emit an error for the unknown pass.
459   auto *rawLoc = element.name.data();
460   return errorHandler(rawLoc, "'" + element.name +
461                                   "' does not refer to a "
462                                   "registered pass or pass pipeline");
463 }
464 
465 /// Add the given pipeline elements to the provided pass manager.
466 LogicalResult TextualPipeline::addToPipeline(
467     ArrayRef<PipelineElement> elements, OpPassManager &pm,
468     function_ref<LogicalResult(const Twine &)> errorHandler) const {
469   for (auto &elt : elements) {
470     if (elt.registryEntry) {
471       if (failed(elt.registryEntry->addToPipeline(pm, elt.options,
472                                                   errorHandler))) {
473         return errorHandler("failed to add `" + elt.name + "` with options `" +
474                             elt.options + "`");
475       }
476     } else if (failed(addToPipeline(elt.innerPipeline, pm.nest(elt.name),
477                                     errorHandler))) {
478       return errorHandler("failed to add `" + elt.name + "` with options `" +
479                           elt.options + "` to inner pipeline");
480     }
481   }
482   return success();
483 }
484 
485 LogicalResult mlir::parsePassPipeline(StringRef pipeline, OpPassManager &pm,
486                                       raw_ostream &errorStream) {
487   TextualPipeline pipelineParser;
488   if (failed(pipelineParser.initialize(pipeline, errorStream)))
489     return failure();
490   auto errorHandler = [&](Twine msg) {
491     errorStream << msg << "\n";
492     return failure();
493   };
494   if (failed(pipelineParser.addToPipeline(pm, errorHandler)))
495     return failure();
496   return success();
497 }
498 
499 FailureOr<OpPassManager> mlir::parsePassPipeline(StringRef pipeline,
500                                                  raw_ostream &errorStream) {
501   // Pipelines are expected to be of the form `<op-name>(<pipeline>)`.
502   size_t pipelineStart = pipeline.find_first_of('(');
503   if (pipelineStart == 0 || pipelineStart == StringRef::npos ||
504       !pipeline.consume_back(")")) {
505     errorStream << "expected pass pipeline to be wrapped with the anchor "
506                    "operation type, e.g. `builtin.module(...)";
507     return failure();
508   }
509 
510   StringRef opName = pipeline.take_front(pipelineStart);
511   OpPassManager pm(opName);
512   if (failed(parsePassPipeline(pipeline.drop_front(1 + pipelineStart), pm)))
513     return failure();
514   return pm;
515 }
516 
517 //===----------------------------------------------------------------------===//
518 // PassNameParser
519 //===----------------------------------------------------------------------===//
520 
521 namespace {
522 /// This struct represents the possible data entries in a parsed pass pipeline
523 /// list.
524 struct PassArgData {
525   PassArgData() = default;
526   PassArgData(const PassRegistryEntry *registryEntry)
527       : registryEntry(registryEntry) {}
528 
529   /// This field is used when the parsed option corresponds to a registered pass
530   /// or pass pipeline.
531   const PassRegistryEntry *registryEntry{nullptr};
532 
533   /// This field is set when instance specific pass options have been provided
534   /// on the command line.
535   StringRef options;
536 
537   /// This field is used when the parsed option corresponds to an explicit
538   /// pipeline.
539   TextualPipeline pipeline;
540 };
541 } // namespace
542 
543 namespace llvm {
544 namespace cl {
545 /// Define a valid OptionValue for the command line pass argument.
546 template <>
547 struct OptionValue<PassArgData> final
548     : OptionValueBase<PassArgData, /*isClass=*/true> {
549   OptionValue(const PassArgData &value) { this->setValue(value); }
550   OptionValue() = default;
551   void anchor() override {}
552 
553   bool hasValue() const { return true; }
554   const PassArgData &getValue() const { return value; }
555   void setValue(const PassArgData &value) { this->value = value; }
556 
557   PassArgData value;
558 };
559 } // namespace cl
560 } // namespace llvm
561 
562 namespace {
563 
564 /// The name for the command line option used for parsing the textual pass
565 /// pipeline.
566 static constexpr StringLiteral passPipelineArg = "pass-pipeline";
567 
568 /// Adds command line option for each registered pass or pass pipeline, as well
569 /// as textual pass pipelines.
570 struct PassNameParser : public llvm::cl::parser<PassArgData> {
571   PassNameParser(llvm::cl::Option &opt) : llvm::cl::parser<PassArgData>(opt) {}
572 
573   void initialize();
574   void printOptionInfo(const llvm::cl::Option &opt,
575                        size_t globalWidth) const override;
576   size_t getOptionWidth(const llvm::cl::Option &opt) const override;
577   bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
578              PassArgData &value);
579 
580   /// If true, this parser only parses entries that correspond to a concrete
581   /// pass registry entry, and does not add a `pass-pipeline` argument, does not
582   /// include the options for pass entries, and does not include pass pipelines
583   /// entries.
584   bool passNamesOnly = false;
585 };
586 } // namespace
587 
588 void PassNameParser::initialize() {
589   llvm::cl::parser<PassArgData>::initialize();
590 
591   /// Add an entry for the textual pass pipeline option.
592   if (!passNamesOnly) {
593     addLiteralOption(passPipelineArg, PassArgData(),
594                      "A textual description of a pass pipeline to run");
595   }
596 
597   /// Add the pass entries.
598   for (const auto &kv : *passRegistry) {
599     addLiteralOption(kv.second.getPassArgument(), &kv.second,
600                      kv.second.getPassDescription());
601   }
602   /// Add the pass pipeline entries.
603   if (!passNamesOnly) {
604     for (const auto &kv : *passPipelineRegistry) {
605       addLiteralOption(kv.second.getPassArgument(), &kv.second,
606                        kv.second.getPassDescription());
607     }
608   }
609 }
610 
611 void PassNameParser::printOptionInfo(const llvm::cl::Option &opt,
612                                      size_t globalWidth) const {
613   // If this parser is just parsing pass names, print a simplified option
614   // string.
615   if (passNamesOnly) {
616     llvm::outs() << "  --" << opt.ArgStr << "=<pass-arg>";
617     opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 18);
618     return;
619   }
620 
621   // Print the information for the top-level option.
622   if (opt.hasArgStr()) {
623     llvm::outs() << "  --" << opt.ArgStr;
624     opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 7);
625   } else {
626     llvm::outs() << "  " << opt.HelpStr << '\n';
627   }
628 
629   // Print the top-level pipeline argument.
630   printOptionHelp(passPipelineArg,
631                   "A textual description of a pass pipeline to run",
632                   /*indent=*/4, globalWidth, /*isTopLevel=*/!opt.hasArgStr());
633 
634   // Functor used to print the ordered entries of a registration map.
635   auto printOrderedEntries = [&](StringRef header, auto &map) {
636     llvm::SmallVector<PassRegistryEntry *, 32> orderedEntries;
637     for (auto &kv : map)
638       orderedEntries.push_back(&kv.second);
639     llvm::array_pod_sort(
640         orderedEntries.begin(), orderedEntries.end(),
641         [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) {
642           return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());
643         });
644 
645     llvm::outs().indent(4) << header << ":\n";
646     for (PassRegistryEntry *entry : orderedEntries)
647       entry->printHelpStr(/*indent=*/6, globalWidth);
648   };
649 
650   // Print the available passes.
651   printOrderedEntries("Passes", *passRegistry);
652 
653   // Print the available pass pipelines.
654   if (!passPipelineRegistry->empty())
655     printOrderedEntries("Pass Pipelines", *passPipelineRegistry);
656 }
657 
658 size_t PassNameParser::getOptionWidth(const llvm::cl::Option &opt) const {
659   size_t maxWidth = llvm::cl::parser<PassArgData>::getOptionWidth(opt) + 2;
660 
661   // Check for any wider pass or pipeline options.
662   for (auto &entry : *passRegistry)
663     maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
664   for (auto &entry : *passPipelineRegistry)
665     maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
666   return maxWidth;
667 }
668 
669 bool PassNameParser::parse(llvm::cl::Option &opt, StringRef argName,
670                            StringRef arg, PassArgData &value) {
671   // Handle the pipeline option explicitly.
672   if (argName == passPipelineArg)
673     return failed(value.pipeline.initialize(arg, llvm::errs()));
674 
675   // Otherwise, default to the base for handling.
676   if (llvm::cl::parser<PassArgData>::parse(opt, argName, arg, value))
677     return true;
678   value.options = arg;
679   return false;
680 }
681 
682 //===----------------------------------------------------------------------===//
683 // PassPipelineCLParser
684 //===----------------------------------------------------------------------===//
685 
686 namespace mlir {
687 namespace detail {
688 struct PassPipelineCLParserImpl {
689   PassPipelineCLParserImpl(StringRef arg, StringRef description,
690                            bool passNamesOnly)
691       : passList(arg, llvm::cl::desc(description)) {
692     passList.getParser().passNamesOnly = passNamesOnly;
693     passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional);
694   }
695 
696   /// Returns true if the given pass registry entry was registered at the
697   /// top-level of the parser, i.e. not within an explicit textual pipeline.
698   bool contains(const PassRegistryEntry *entry) const {
699     return llvm::any_of(passList, [&](const PassArgData &data) {
700       return data.registryEntry == entry;
701     });
702   }
703 
704   /// The set of passes and pass pipelines to run.
705   llvm::cl::list<PassArgData, bool, PassNameParser> passList;
706 };
707 } // namespace detail
708 } // namespace mlir
709 
710 /// Construct a pass pipeline parser with the given command line description.
711 PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description)
712     : impl(std::make_unique<detail::PassPipelineCLParserImpl>(
713           arg, description, /*passNamesOnly=*/false)) {}
714 PassPipelineCLParser::~PassPipelineCLParser() = default;
715 
716 /// Returns true if this parser contains any valid options to add.
717 bool PassPipelineCLParser::hasAnyOccurrences() const {
718   return impl->passList.getNumOccurrences() != 0;
719 }
720 
721 /// Returns true if the given pass registry entry was registered at the
722 /// top-level of the parser, i.e. not within an explicit textual pipeline.
723 bool PassPipelineCLParser::contains(const PassRegistryEntry *entry) const {
724   return impl->contains(entry);
725 }
726 
727 /// Adds the passes defined by this parser entry to the given pass manager.
728 LogicalResult PassPipelineCLParser::addToPipeline(
729     OpPassManager &pm,
730     function_ref<LogicalResult(const Twine &)> errorHandler) const {
731   for (auto &passIt : impl->passList) {
732     if (passIt.registryEntry) {
733       if (failed(passIt.registryEntry->addToPipeline(pm, passIt.options,
734                                                      errorHandler)))
735         return failure();
736     } else {
737       OpPassManager::Nesting nesting = pm.getNesting();
738       pm.setNesting(OpPassManager::Nesting::Explicit);
739       LogicalResult status = passIt.pipeline.addToPipeline(pm, errorHandler);
740       pm.setNesting(nesting);
741       if (failed(status))
742         return failure();
743     }
744   }
745   return success();
746 }
747 
748 //===----------------------------------------------------------------------===//
749 // PassNameCLParser
750 
751 /// Construct a pass pipeline parser with the given command line description.
752 PassNameCLParser::PassNameCLParser(StringRef arg, StringRef description)
753     : impl(std::make_unique<detail::PassPipelineCLParserImpl>(
754           arg, description, /*passNamesOnly=*/true)) {
755   impl->passList.setMiscFlag(llvm::cl::CommaSeparated);
756 }
757 PassNameCLParser::~PassNameCLParser() = default;
758 
759 /// Returns true if this parser contains any valid options to add.
760 bool PassNameCLParser::hasAnyOccurrences() const {
761   return impl->passList.getNumOccurrences() != 0;
762 }
763 
764 /// Returns true if the given pass registry entry was registered at the
765 /// top-level of the parser, i.e. not within an explicit textual pipeline.
766 bool PassNameCLParser::contains(const PassRegistryEntry *entry) const {
767   return impl->contains(entry);
768 }
769