1 //===- PassRegistry.cpp - Pass Registration Utilities ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/Pass/Pass.h"
12 #include "mlir/Pass/PassManager.h"
13 #include "mlir/Pass/PassRegistry.h"
14 #include "llvm/ADT/DenseMap.h"
15 #include "llvm/Support/Format.h"
16 #include "llvm/Support/ManagedStatic.h"
17 #include "llvm/Support/MemoryBuffer.h"
18 #include "llvm/Support/SourceMgr.h"
19 
20 using namespace mlir;
21 using namespace detail;
22 
23 /// Static mapping of all of the registered passes.
24 static llvm::ManagedStatic<llvm::StringMap<PassInfo>> passRegistry;
25 
26 /// A mapping of the above pass registry entries to the corresponding TypeID
27 /// of the pass that they generate.
28 static llvm::ManagedStatic<llvm::StringMap<TypeID>> passRegistryTypeIDs;
29 
30 /// Static mapping of all of the registered pass pipelines.
31 static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
32     passPipelineRegistry;
33 
34 /// Utility to create a default registry function from a pass instance.
35 static PassRegistryFunction
buildDefaultRegistryFn(const PassAllocatorFunction & allocator)36 buildDefaultRegistryFn(const PassAllocatorFunction &allocator) {
37   return [=](OpPassManager &pm, StringRef options,
38              function_ref<LogicalResult(const Twine &)> errorHandler) {
39     std::unique_ptr<Pass> pass = allocator();
40     LogicalResult result = pass->initializeOptions(options);
41 
42     Optional<StringRef> pmOpName = pm.getOpName();
43     Optional<StringRef> passOpName = pass->getOpName();
44     if ((pm.getNesting() == OpPassManager::Nesting::Explicit) && pmOpName &&
45         passOpName && *pmOpName != *passOpName) {
46       return errorHandler(llvm::Twine("Can't add pass '") + pass->getName() +
47                           "' restricted to '" + *pass->getOpName() +
48                           "' on a PassManager intended to run on '" +
49                           pm.getOpAnchorName() + "', did you intend to nest?");
50     }
51     pm.addPass(std::move(pass));
52     return result;
53   };
54 }
55 
56 /// Utility to print the help string for a specific option.
printOptionHelp(StringRef arg,StringRef desc,size_t indent,size_t descIndent,bool isTopLevel)57 static void printOptionHelp(StringRef arg, StringRef desc, size_t indent,
58                             size_t descIndent, bool isTopLevel) {
59   size_t numSpaces = descIndent - indent - 4;
60   llvm::outs().indent(indent)
61       << "--" << llvm::left_justify(arg, numSpaces) << "-   " << desc << '\n';
62 }
63 
64 //===----------------------------------------------------------------------===//
65 // PassRegistry
66 //===----------------------------------------------------------------------===//
67 
68 /// Print the help information for this pass. This includes the argument,
69 /// description, and any pass options. `descIndent` is the indent that the
70 /// descriptions should be aligned.
printHelpStr(size_t indent,size_t descIndent) const71 void PassRegistryEntry::printHelpStr(size_t indent, size_t descIndent) const {
72   printOptionHelp(getPassArgument(), getPassDescription(), indent, descIndent,
73                   /*isTopLevel=*/true);
74   // If this entry has options, print the help for those as well.
75   optHandler([=](const PassOptions &options) {
76     options.printHelp(indent, descIndent);
77   });
78 }
79 
80 /// Return the maximum width required when printing the options of this
81 /// entry.
getOptionWidth() const82 size_t PassRegistryEntry::getOptionWidth() const {
83   size_t maxLen = 0;
84   optHandler([&](const PassOptions &options) mutable {
85     maxLen = options.getOptionWidth() + 2;
86   });
87   return maxLen;
88 }
89 
90 //===----------------------------------------------------------------------===//
91 // PassPipelineInfo
92 //===----------------------------------------------------------------------===//
93 
registerPassPipeline(StringRef arg,StringRef description,const PassRegistryFunction & function,std::function<void (function_ref<void (const PassOptions &)>)> optHandler)94 void mlir::registerPassPipeline(
95     StringRef arg, StringRef description, const PassRegistryFunction &function,
96     std::function<void(function_ref<void(const PassOptions &)>)> optHandler) {
97   PassPipelineInfo pipelineInfo(arg, description, function,
98                                 std::move(optHandler));
99   bool inserted = passPipelineRegistry->try_emplace(arg, pipelineInfo).second;
100   assert(inserted && "Pass pipeline registered multiple times");
101   (void)inserted;
102 }
103 
104 //===----------------------------------------------------------------------===//
105 // PassInfo
106 //===----------------------------------------------------------------------===//
107 
PassInfo(StringRef arg,StringRef description,const PassAllocatorFunction & allocator)108 PassInfo::PassInfo(StringRef arg, StringRef description,
109                    const PassAllocatorFunction &allocator)
110     : PassRegistryEntry(
111           arg, description, buildDefaultRegistryFn(allocator),
112           // Use a temporary pass to provide an options instance.
113           [=](function_ref<void(const PassOptions &)> optHandler) {
114             optHandler(allocator()->passOptions);
115           }) {}
116 
registerPass(const PassAllocatorFunction & function)117 void mlir::registerPass(const PassAllocatorFunction &function) {
118   std::unique_ptr<Pass> pass = function();
119   StringRef arg = pass->getArgument();
120   if (arg.empty())
121     llvm::report_fatal_error(llvm::Twine("Trying to register '") +
122                              pass->getName() +
123                              "' pass that does not override `getArgument()`");
124   StringRef description = pass->getDescription();
125   PassInfo passInfo(arg, description, function);
126   passRegistry->try_emplace(arg, passInfo);
127 
128   // Verify that the registered pass has the same ID as any registered to this
129   // arg before it.
130   TypeID entryTypeID = pass->getTypeID();
131   auto it = passRegistryTypeIDs->try_emplace(arg, entryTypeID).first;
132   if (it->second != entryTypeID)
133     llvm::report_fatal_error(
134         "pass allocator creates a different pass than previously "
135         "registered for pass " +
136         arg);
137 }
138 
139 /// Returns the pass info for the specified pass argument or null if unknown.
lookupPassInfo(StringRef passArg)140 const PassInfo *mlir::Pass::lookupPassInfo(StringRef passArg) {
141   auto it = passRegistry->find(passArg);
142   return it == passRegistry->end() ? nullptr : &it->second;
143 }
144 
145 //===----------------------------------------------------------------------===//
146 // PassOptions
147 //===----------------------------------------------------------------------===//
148 
parseCommaSeparatedList(llvm::cl::Option & opt,StringRef argName,StringRef optionStr,function_ref<LogicalResult (StringRef)> elementParseFn)149 LogicalResult detail::pass_options::parseCommaSeparatedList(
150     llvm::cl::Option &opt, StringRef argName, StringRef optionStr,
151     function_ref<LogicalResult(StringRef)> elementParseFn) {
152   // Functor used for finding a character in a string, and skipping over
153   // various "range" characters.
154   llvm::unique_function<size_t(StringRef, size_t, char)> findChar =
155       [&](StringRef str, size_t index, char c) -> size_t {
156     for (size_t i = index, e = str.size(); i < e; ++i) {
157       if (str[i] == c)
158         return i;
159       // Check for various range characters.
160       if (str[i] == '{')
161         i = findChar(str, i + 1, '}');
162       else if (str[i] == '(')
163         i = findChar(str, i + 1, ')');
164       else if (str[i] == '[')
165         i = findChar(str, i + 1, ']');
166       else if (str[i] == '\"')
167         i = str.find_first_of('\"', i + 1);
168       else if (str[i] == '\'')
169         i = str.find_first_of('\'', i + 1);
170     }
171     return StringRef::npos;
172   };
173 
174   size_t nextElePos = findChar(optionStr, 0, ',');
175   while (nextElePos != StringRef::npos) {
176     // Process the portion before the comma.
177     if (failed(elementParseFn(optionStr.substr(0, nextElePos))))
178       return failure();
179 
180     optionStr = optionStr.substr(nextElePos + 1);
181     nextElePos = findChar(optionStr, 0, ',');
182   }
183   return elementParseFn(optionStr.substr(0, nextElePos));
184 }
185 
186 /// Out of line virtual function to provide home for the class.
anchor()187 void detail::PassOptions::OptionBase::anchor() {}
188 
189 /// Copy the option values from 'other'.
copyOptionValuesFrom(const PassOptions & other)190 void detail::PassOptions::copyOptionValuesFrom(const PassOptions &other) {
191   assert(options.size() == other.options.size());
192   if (options.empty())
193     return;
194   for (auto optionsIt : llvm::zip(options, other.options))
195     std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt));
196 }
197 
198 /// Parse in the next argument from the given options string. Returns a tuple
199 /// containing [the key of the option, the value of the option, updated
200 /// `options` string pointing after the parsed option].
201 static std::tuple<StringRef, StringRef, StringRef>
parseNextArg(StringRef options)202 parseNextArg(StringRef options) {
203   // Functor used to extract an argument from 'options' and update it to point
204   // after the arg.
205   auto extractArgAndUpdateOptions = [&](size_t argSize) {
206     StringRef str = options.take_front(argSize).trim();
207     options = options.drop_front(argSize).ltrim();
208     return str;
209   };
210   // Try to process the given punctuation, properly escaping any contained
211   // characters.
212   auto tryProcessPunct = [&](size_t &currentPos, char punct) {
213     if (options[currentPos] != punct)
214       return false;
215     size_t nextIt = options.find_first_of(punct, currentPos + 1);
216     if (nextIt != StringRef::npos)
217       currentPos = nextIt;
218     return true;
219   };
220 
221   // Parse the argument name of the option.
222   StringRef argName;
223   for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
224     // Check for the end of the full option.
225     if (argEndIt == optionsE || options[argEndIt] == ' ') {
226       argName = extractArgAndUpdateOptions(argEndIt);
227       return std::make_tuple(argName, StringRef(), options);
228     }
229 
230     // Check for the end of the name and the start of the value.
231     if (options[argEndIt] == '=') {
232       argName = extractArgAndUpdateOptions(argEndIt);
233       options = options.drop_front();
234       break;
235     }
236   }
237 
238   // Parse the value of the option.
239   for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
240     // Handle the end of the options string.
241     if (argEndIt == optionsE || options[argEndIt] == ' ') {
242       StringRef value = extractArgAndUpdateOptions(argEndIt);
243       return std::make_tuple(argName, value, options);
244     }
245 
246     // Skip over escaped sequences.
247     char c = options[argEndIt];
248     if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"'))
249       continue;
250     // '{...}' is used to specify options to passes, properly escape it so
251     // that we don't accidentally split any nested options.
252     if (c == '{') {
253       size_t braceCount = 1;
254       for (++argEndIt; argEndIt != optionsE; ++argEndIt) {
255         // Allow nested punctuation.
256         if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"'))
257           continue;
258         if (options[argEndIt] == '{')
259           ++braceCount;
260         else if (options[argEndIt] == '}' && --braceCount == 0)
261           break;
262       }
263       // Account for the increment at the top of the loop.
264       --argEndIt;
265     }
266   }
267   llvm_unreachable("unexpected control flow in pass option parsing");
268 }
269 
parseFromString(StringRef options)270 LogicalResult detail::PassOptions::parseFromString(StringRef options) {
271   // NOTE: `options` is modified in place to always refer to the unprocessed
272   // part of the string.
273   while (!options.empty()) {
274     StringRef key, value;
275     std::tie(key, value, options) = parseNextArg(options);
276     if (key.empty())
277       continue;
278 
279     auto it = OptionsMap.find(key);
280     if (it == OptionsMap.end()) {
281       llvm::errs() << "<Pass-Options-Parser>: no such option " << key << "\n";
282       return failure();
283     }
284     if (llvm::cl::ProvidePositionalOption(it->second, value, 0))
285       return failure();
286   }
287 
288   return success();
289 }
290 
291 /// Print the options held by this struct in a form that can be parsed via
292 /// 'parseFromString'.
print(raw_ostream & os)293 void detail::PassOptions::print(raw_ostream &os) {
294   // If there are no options, there is nothing left to do.
295   if (OptionsMap.empty())
296     return;
297 
298   // Sort the options to make the ordering deterministic.
299   SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
300   auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
301     return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
302   };
303   llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
304 
305   // Interleave the options with ' '.
306   os << '{';
307   llvm::interleave(
308       orderedOps, os, [&](OptionBase *option) { option->print(os); }, " ");
309   os << '}';
310 }
311 
312 /// Print the help string for the options held by this struct. `descIndent` is
313 /// the indent within the stream that the descriptions should be aligned.
printHelp(size_t indent,size_t descIndent) const314 void detail::PassOptions::printHelp(size_t indent, size_t descIndent) const {
315   // Sort the options to make the ordering deterministic.
316   SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
317   auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
318     return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
319   };
320   llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
321   for (OptionBase *option : orderedOps) {
322     // TODO: printOptionInfo assumes a specific indent and will
323     // print options with values with incorrect indentation. We should add
324     // support to llvm::cl::Option for passing in a base indent to use when
325     // printing.
326     llvm::outs().indent(indent);
327     option->getOption()->printOptionInfo(descIndent - indent);
328   }
329 }
330 
331 /// Return the maximum width required when printing the help string.
getOptionWidth() const332 size_t detail::PassOptions::getOptionWidth() const {
333   size_t max = 0;
334   for (auto *option : options)
335     max = std::max(max, option->getOption()->getOptionWidth());
336   return max;
337 }
338 
339 //===----------------------------------------------------------------------===//
340 // MLIR Options
341 //===----------------------------------------------------------------------===//
342 
343 //===----------------------------------------------------------------------===//
344 // OpPassManager: OptionValue
345 
346 llvm::cl::OptionValue<OpPassManager>::OptionValue() = default;
OptionValue(const mlir::OpPassManager & value)347 llvm::cl::OptionValue<OpPassManager>::OptionValue(
348     const mlir::OpPassManager &value) {
349   setValue(value);
350 }
351 llvm::cl::OptionValue<OpPassManager> &
operator =(const mlir::OpPassManager & rhs)352 llvm::cl::OptionValue<OpPassManager>::operator=(
353     const mlir::OpPassManager &rhs) {
354   setValue(rhs);
355   return *this;
356 }
357 
358 llvm::cl::OptionValue<OpPassManager>::~OptionValue<OpPassManager>() = default;
359 
setValue(const OpPassManager & newValue)360 void llvm::cl::OptionValue<OpPassManager>::setValue(
361     const OpPassManager &newValue) {
362   if (hasValue())
363     *value = newValue;
364   else
365     value = std::make_unique<mlir::OpPassManager>(newValue);
366 }
setValue(StringRef pipelineStr)367 void llvm::cl::OptionValue<OpPassManager>::setValue(StringRef pipelineStr) {
368   FailureOr<OpPassManager> pipeline = parsePassPipeline(pipelineStr);
369   assert(succeeded(pipeline) && "invalid pass pipeline");
370   setValue(*pipeline);
371 }
372 
compare(const mlir::OpPassManager & rhs) const373 bool llvm::cl::OptionValue<OpPassManager>::compare(
374     const mlir::OpPassManager &rhs) const {
375   std::string lhsStr, rhsStr;
376   {
377     raw_string_ostream lhsStream(lhsStr);
378     value->printAsTextualPipeline(lhsStream);
379 
380     raw_string_ostream rhsStream(rhsStr);
381     rhs.printAsTextualPipeline(rhsStream);
382   }
383 
384   // Use the textual format for pipeline comparisons.
385   return lhsStr == rhsStr;
386 }
387 
anchor()388 void llvm::cl::OptionValue<OpPassManager>::anchor() {}
389 
390 //===----------------------------------------------------------------------===//
391 // OpPassManager: Parser
392 
393 namespace llvm {
394 namespace cl {
395 template class basic_parser<OpPassManager>;
396 } // namespace cl
397 } // namespace llvm
398 
parse(Option &,StringRef,StringRef arg,ParsedPassManager & value)399 bool llvm::cl::parser<OpPassManager>::parse(Option &, StringRef, StringRef arg,
400                                             ParsedPassManager &value) {
401   FailureOr<OpPassManager> pipeline = parsePassPipeline(arg);
402   if (failed(pipeline))
403     return true;
404   value.value = std::make_unique<OpPassManager>(std::move(*pipeline));
405   return false;
406 }
407 
print(raw_ostream & os,const OpPassManager & value)408 void llvm::cl::parser<OpPassManager>::print(raw_ostream &os,
409                                             const OpPassManager &value) {
410   value.printAsTextualPipeline(os);
411 }
412 
printOptionDiff(const Option & opt,OpPassManager & pm,const OptVal & defaultValue,size_t globalWidth) const413 void llvm::cl::parser<OpPassManager>::printOptionDiff(
414     const Option &opt, OpPassManager &pm, const OptVal &defaultValue,
415     size_t globalWidth) const {
416   printOptionName(opt, globalWidth);
417   outs() << "= ";
418   pm.printAsTextualPipeline(outs());
419 
420   if (defaultValue.hasValue()) {
421     outs().indent(2) << " (default: ";
422     defaultValue.getValue().printAsTextualPipeline(outs());
423     outs() << ")";
424   }
425   outs() << "\n";
426 }
427 
anchor()428 void llvm::cl::parser<OpPassManager>::anchor() {}
429 
430 llvm::cl::parser<OpPassManager>::ParsedPassManager::ParsedPassManager() =
431     default;
432 llvm::cl::parser<OpPassManager>::ParsedPassManager::ParsedPassManager(
433     ParsedPassManager &&) = default;
434 llvm::cl::parser<OpPassManager>::ParsedPassManager::~ParsedPassManager() =
435     default;
436 
437 //===----------------------------------------------------------------------===//
438 // TextualPassPipeline Parser
439 //===----------------------------------------------------------------------===//
440 
441 namespace {
442 /// This class represents a textual description of a pass pipeline.
443 class TextualPipeline {
444 public:
445   /// Try to initialize this pipeline with the given pipeline text.
446   /// `errorStream` is the output stream to emit errors to.
447   LogicalResult initialize(StringRef text, raw_ostream &errorStream);
448 
449   /// Add the internal pipeline elements to the provided pass manager.
450   LogicalResult
451   addToPipeline(OpPassManager &pm,
452                 function_ref<LogicalResult(const Twine &)> errorHandler) const;
453 
454 private:
455   /// A functor used to emit errors found during pipeline handling. The first
456   /// parameter corresponds to the raw location within the pipeline string. This
457   /// should always return failure.
458   using ErrorHandlerT = function_ref<LogicalResult(const char *, Twine)>;
459 
460   /// A struct to capture parsed pass pipeline names.
461   ///
462   /// A pipeline is defined as a series of names, each of which may in itself
463   /// recursively contain a nested pipeline. A name is either the name of a pass
464   /// (e.g. "cse") or the name of an operation type (e.g. "buitin.module"). If
465   /// the name is the name of a pass, the InnerPipeline is empty, since passes
466   /// cannot contain inner pipelines.
467   struct PipelineElement {
PipelineElement__anon67970f8f0b11::TextualPipeline::PipelineElement468     PipelineElement(StringRef name) : name(name) {}
469 
470     StringRef name;
471     StringRef options;
472     const PassRegistryEntry *registryEntry = nullptr;
473     std::vector<PipelineElement> innerPipeline;
474   };
475 
476   /// Parse the given pipeline text into the internal pipeline vector. This
477   /// function only parses the structure of the pipeline, and does not resolve
478   /// its elements.
479   LogicalResult parsePipelineText(StringRef text, ErrorHandlerT errorHandler);
480 
481   /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
482   /// the corresponding registry entry.
483   LogicalResult
484   resolvePipelineElements(MutableArrayRef<PipelineElement> elements,
485                           ErrorHandlerT errorHandler);
486 
487   /// Resolve a single element of the pipeline.
488   LogicalResult resolvePipelineElement(PipelineElement &element,
489                                        ErrorHandlerT errorHandler);
490 
491   /// Add the given pipeline elements to the provided pass manager.
492   LogicalResult
493   addToPipeline(ArrayRef<PipelineElement> elements, OpPassManager &pm,
494                 function_ref<LogicalResult(const Twine &)> errorHandler) const;
495 
496   std::vector<PipelineElement> pipeline;
497 };
498 
499 } // namespace
500 
501 /// Try to initialize this pipeline with the given pipeline text. An option is
502 /// given to enable accurate error reporting.
initialize(StringRef text,raw_ostream & errorStream)503 LogicalResult TextualPipeline::initialize(StringRef text,
504                                           raw_ostream &errorStream) {
505   if (text.empty())
506     return success();
507 
508   // Build a source manager to use for error reporting.
509   llvm::SourceMgr pipelineMgr;
510   pipelineMgr.AddNewSourceBuffer(
511       llvm::MemoryBuffer::getMemBuffer(text, "MLIR Textual PassPipeline Parser",
512                                        /*RequiresNullTerminator=*/false),
513       SMLoc());
514   auto errorHandler = [&](const char *rawLoc, Twine msg) {
515     pipelineMgr.PrintMessage(errorStream, SMLoc::getFromPointer(rawLoc),
516                              llvm::SourceMgr::DK_Error, msg);
517     return failure();
518   };
519 
520   // Parse the provided pipeline string.
521   if (failed(parsePipelineText(text, errorHandler)))
522     return failure();
523   return resolvePipelineElements(pipeline, errorHandler);
524 }
525 
526 /// Add the internal pipeline elements to the provided pass manager.
addToPipeline(OpPassManager & pm,function_ref<LogicalResult (const Twine &)> errorHandler) const527 LogicalResult TextualPipeline::addToPipeline(
528     OpPassManager &pm,
529     function_ref<LogicalResult(const Twine &)> errorHandler) const {
530   return addToPipeline(pipeline, pm, errorHandler);
531 }
532 
533 /// Parse the given pipeline text into the internal pipeline vector. This
534 /// function only parses the structure of the pipeline, and does not resolve
535 /// its elements.
parsePipelineText(StringRef text,ErrorHandlerT errorHandler)536 LogicalResult TextualPipeline::parsePipelineText(StringRef text,
537                                                  ErrorHandlerT errorHandler) {
538   SmallVector<std::vector<PipelineElement> *, 4> pipelineStack = {&pipeline};
539   for (;;) {
540     std::vector<PipelineElement> &pipeline = *pipelineStack.back();
541     size_t pos = text.find_first_of(",(){");
542     pipeline.emplace_back(/*name=*/text.substr(0, pos).trim());
543 
544     // If we have a single terminating name, we're done.
545     if (pos == StringRef::npos)
546       break;
547 
548     text = text.substr(pos);
549     char sep = text[0];
550 
551     // Handle pulling ... from 'pass{...}' out as PipelineElement.options.
552     if (sep == '{') {
553       text = text.substr(1);
554 
555       // Skip over everything until the closing '}' and store as options.
556       size_t close = StringRef::npos;
557       for (unsigned i = 0, e = text.size(), braceCount = 1; i < e; ++i) {
558         if (text[i] == '{') {
559           ++braceCount;
560           continue;
561         }
562         if (text[i] == '}' && --braceCount == 0) {
563           close = i;
564           break;
565         }
566       }
567 
568       // Check to see if a closing options brace was found.
569       if (close == StringRef::npos) {
570         return errorHandler(
571             /*rawLoc=*/text.data() - 1,
572             "missing closing '}' while processing pass options");
573       }
574       pipeline.back().options = text.substr(0, close);
575       text = text.substr(close + 1);
576 
577       // Skip checking for '(' because nested pipelines cannot have options.
578     } else if (sep == '(') {
579       text = text.substr(1);
580 
581       // Push the inner pipeline onto the stack to continue processing.
582       pipelineStack.push_back(&pipeline.back().innerPipeline);
583       continue;
584     }
585 
586     // When handling the close parenthesis, we greedily consume them to avoid
587     // empty strings in the pipeline.
588     while (text.consume_front(")")) {
589       // If we try to pop the outer pipeline we have unbalanced parentheses.
590       if (pipelineStack.size() == 1)
591         return errorHandler(/*rawLoc=*/text.data() - 1,
592                             "encountered extra closing ')' creating unbalanced "
593                             "parentheses while parsing pipeline");
594 
595       pipelineStack.pop_back();
596     }
597 
598     // Check if we've finished parsing.
599     if (text.empty())
600       break;
601 
602     // Otherwise, the end of an inner pipeline always has to be followed by
603     // a comma, and then we can continue.
604     if (!text.consume_front(","))
605       return errorHandler(text.data(), "expected ',' after parsing pipeline");
606   }
607 
608   // Check for unbalanced parentheses.
609   if (pipelineStack.size() > 1)
610     return errorHandler(
611         text.data(),
612         "encountered unbalanced parentheses while parsing pipeline");
613 
614   assert(pipelineStack.back() == &pipeline &&
615          "wrong pipeline at the bottom of the stack");
616   return success();
617 }
618 
619 /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
620 /// the corresponding registry entry.
resolvePipelineElements(MutableArrayRef<PipelineElement> elements,ErrorHandlerT errorHandler)621 LogicalResult TextualPipeline::resolvePipelineElements(
622     MutableArrayRef<PipelineElement> elements, ErrorHandlerT errorHandler) {
623   for (auto &elt : elements)
624     if (failed(resolvePipelineElement(elt, errorHandler)))
625       return failure();
626   return success();
627 }
628 
629 /// Resolve a single element of the pipeline.
630 LogicalResult
resolvePipelineElement(PipelineElement & element,ErrorHandlerT errorHandler)631 TextualPipeline::resolvePipelineElement(PipelineElement &element,
632                                         ErrorHandlerT errorHandler) {
633   // If the inner pipeline of this element is not empty, this is an operation
634   // pipeline.
635   if (!element.innerPipeline.empty())
636     return resolvePipelineElements(element.innerPipeline, errorHandler);
637   // Otherwise, this must be a pass or pass pipeline.
638   // Check to see if a pipeline was registered with this name.
639   auto pipelineRegistryIt = passPipelineRegistry->find(element.name);
640   if (pipelineRegistryIt != passPipelineRegistry->end()) {
641     element.registryEntry = &pipelineRegistryIt->second;
642     return success();
643   }
644 
645   // If not, then this must be a specific pass name.
646   if ((element.registryEntry = Pass::lookupPassInfo(element.name)))
647     return success();
648 
649   // Emit an error for the unknown pass.
650   auto *rawLoc = element.name.data();
651   return errorHandler(rawLoc, "'" + element.name +
652                                   "' does not refer to a "
653                                   "registered pass or pass pipeline");
654 }
655 
656 /// Add the given pipeline elements to the provided pass manager.
addToPipeline(ArrayRef<PipelineElement> elements,OpPassManager & pm,function_ref<LogicalResult (const Twine &)> errorHandler) const657 LogicalResult TextualPipeline::addToPipeline(
658     ArrayRef<PipelineElement> elements, OpPassManager &pm,
659     function_ref<LogicalResult(const Twine &)> errorHandler) const {
660   for (auto &elt : elements) {
661     if (elt.registryEntry) {
662       if (failed(elt.registryEntry->addToPipeline(pm, elt.options,
663                                                   errorHandler))) {
664         return errorHandler("failed to add `" + elt.name + "` with options `" +
665                             elt.options + "`");
666       }
667     } else if (failed(addToPipeline(elt.innerPipeline, pm.nest(elt.name),
668                                     errorHandler))) {
669       return errorHandler("failed to add `" + elt.name + "` with options `" +
670                           elt.options + "` to inner pipeline");
671     }
672   }
673   return success();
674 }
675 
parsePassPipeline(StringRef pipeline,OpPassManager & pm,raw_ostream & errorStream)676 LogicalResult mlir::parsePassPipeline(StringRef pipeline, OpPassManager &pm,
677                                       raw_ostream &errorStream) {
678   TextualPipeline pipelineParser;
679   if (failed(pipelineParser.initialize(pipeline, errorStream)))
680     return failure();
681   auto errorHandler = [&](Twine msg) {
682     errorStream << msg << "\n";
683     return failure();
684   };
685   if (failed(pipelineParser.addToPipeline(pm, errorHandler)))
686     return failure();
687   return success();
688 }
689 
parsePassPipeline(StringRef pipeline,raw_ostream & errorStream)690 FailureOr<OpPassManager> mlir::parsePassPipeline(StringRef pipeline,
691                                                  raw_ostream &errorStream) {
692   // Pipelines are expected to be of the form `<op-name>(<pipeline>)`.
693   size_t pipelineStart = pipeline.find_first_of('(');
694   if (pipelineStart == 0 || pipelineStart == StringRef::npos ||
695       !pipeline.consume_back(")")) {
696     errorStream << "expected pass pipeline to be wrapped with the anchor "
697                    "operation type, e.g. `builtin.module(...)";
698     return failure();
699   }
700 
701   StringRef opName = pipeline.take_front(pipelineStart);
702   OpPassManager pm(opName);
703   if (failed(parsePassPipeline(pipeline.drop_front(1 + pipelineStart), pm)))
704     return failure();
705   return pm;
706 }
707 
708 //===----------------------------------------------------------------------===//
709 // PassNameParser
710 //===----------------------------------------------------------------------===//
711 
712 namespace {
713 /// This struct represents the possible data entries in a parsed pass pipeline
714 /// list.
715 struct PassArgData {
716   PassArgData() = default;
PassArgData__anon67970f8f0e11::PassArgData717   PassArgData(const PassRegistryEntry *registryEntry)
718       : registryEntry(registryEntry) {}
719 
720   /// This field is used when the parsed option corresponds to a registered pass
721   /// or pass pipeline.
722   const PassRegistryEntry *registryEntry{nullptr};
723 
724   /// This field is set when instance specific pass options have been provided
725   /// on the command line.
726   StringRef options;
727 
728   /// This field is used when the parsed option corresponds to an explicit
729   /// pipeline.
730   TextualPipeline pipeline;
731 };
732 } // namespace
733 
734 namespace llvm {
735 namespace cl {
736 /// Define a valid OptionValue for the command line pass argument.
737 template <>
738 struct OptionValue<PassArgData> final
739     : OptionValueBase<PassArgData, /*isClass=*/true> {
OptionValuellvm::cl::OptionValue740   OptionValue(const PassArgData &value) { this->setValue(value); }
741   OptionValue() = default;
anchorllvm::cl::OptionValue742   void anchor() override {}
743 
hasValuellvm::cl::OptionValue744   bool hasValue() const { return true; }
getValuellvm::cl::OptionValue745   const PassArgData &getValue() const { return value; }
setValuellvm::cl::OptionValue746   void setValue(const PassArgData &value) { this->value = value; }
747 
748   PassArgData value;
749 };
750 } // namespace cl
751 } // namespace llvm
752 
753 namespace {
754 
755 /// The name for the command line option used for parsing the textual pass
756 /// pipeline.
757 static constexpr StringLiteral passPipelineArg = "pass-pipeline";
758 
759 /// Adds command line option for each registered pass or pass pipeline, as well
760 /// as textual pass pipelines.
761 struct PassNameParser : public llvm::cl::parser<PassArgData> {
PassNameParser__anon67970f8f0f11::PassNameParser762   PassNameParser(llvm::cl::Option &opt) : llvm::cl::parser<PassArgData>(opt) {}
763 
764   void initialize();
765   void printOptionInfo(const llvm::cl::Option &opt,
766                        size_t globalWidth) const override;
767   size_t getOptionWidth(const llvm::cl::Option &opt) const override;
768   bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
769              PassArgData &value);
770 
771   /// If true, this parser only parses entries that correspond to a concrete
772   /// pass registry entry, and does not add a `pass-pipeline` argument, does not
773   /// include the options for pass entries, and does not include pass pipelines
774   /// entries.
775   bool passNamesOnly = false;
776 };
777 } // namespace
778 
initialize()779 void PassNameParser::initialize() {
780   llvm::cl::parser<PassArgData>::initialize();
781 
782   /// Add an entry for the textual pass pipeline option.
783   if (!passNamesOnly) {
784     addLiteralOption(passPipelineArg, PassArgData(),
785                      "A textual description of a pass pipeline to run");
786   }
787 
788   /// Add the pass entries.
789   for (const auto &kv : *passRegistry) {
790     addLiteralOption(kv.second.getPassArgument(), &kv.second,
791                      kv.second.getPassDescription());
792   }
793   /// Add the pass pipeline entries.
794   if (!passNamesOnly) {
795     for (const auto &kv : *passPipelineRegistry) {
796       addLiteralOption(kv.second.getPassArgument(), &kv.second,
797                        kv.second.getPassDescription());
798     }
799   }
800 }
801 
printOptionInfo(const llvm::cl::Option & opt,size_t globalWidth) const802 void PassNameParser::printOptionInfo(const llvm::cl::Option &opt,
803                                      size_t globalWidth) const {
804   // If this parser is just parsing pass names, print a simplified option
805   // string.
806   if (passNamesOnly) {
807     llvm::outs() << "  --" << opt.ArgStr << "=<pass-arg>";
808     opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 18);
809     return;
810   }
811 
812   // Print the information for the top-level option.
813   if (opt.hasArgStr()) {
814     llvm::outs() << "  --" << opt.ArgStr;
815     opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 7);
816   } else {
817     llvm::outs() << "  " << opt.HelpStr << '\n';
818   }
819 
820   // Print the top-level pipeline argument.
821   printOptionHelp(passPipelineArg,
822                   "A textual description of a pass pipeline to run",
823                   /*indent=*/4, globalWidth, /*isTopLevel=*/!opt.hasArgStr());
824 
825   // Functor used to print the ordered entries of a registration map.
826   auto printOrderedEntries = [&](StringRef header, auto &map) {
827     llvm::SmallVector<PassRegistryEntry *, 32> orderedEntries;
828     for (auto &kv : map)
829       orderedEntries.push_back(&kv.second);
830     llvm::array_pod_sort(
831         orderedEntries.begin(), orderedEntries.end(),
832         [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) {
833           return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());
834         });
835 
836     llvm::outs().indent(4) << header << ":\n";
837     for (PassRegistryEntry *entry : orderedEntries)
838       entry->printHelpStr(/*indent=*/6, globalWidth);
839   };
840 
841   // Print the available passes.
842   printOrderedEntries("Passes", *passRegistry);
843 
844   // Print the available pass pipelines.
845   if (!passPipelineRegistry->empty())
846     printOrderedEntries("Pass Pipelines", *passPipelineRegistry);
847 }
848 
getOptionWidth(const llvm::cl::Option & opt) const849 size_t PassNameParser::getOptionWidth(const llvm::cl::Option &opt) const {
850   size_t maxWidth = llvm::cl::parser<PassArgData>::getOptionWidth(opt) + 2;
851 
852   // Check for any wider pass or pipeline options.
853   for (auto &entry : *passRegistry)
854     maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
855   for (auto &entry : *passPipelineRegistry)
856     maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
857   return maxWidth;
858 }
859 
parse(llvm::cl::Option & opt,StringRef argName,StringRef arg,PassArgData & value)860 bool PassNameParser::parse(llvm::cl::Option &opt, StringRef argName,
861                            StringRef arg, PassArgData &value) {
862   // Handle the pipeline option explicitly.
863   if (argName == passPipelineArg)
864     return failed(value.pipeline.initialize(arg, llvm::errs()));
865 
866   // Otherwise, default to the base for handling.
867   if (llvm::cl::parser<PassArgData>::parse(opt, argName, arg, value))
868     return true;
869   value.options = arg;
870   return false;
871 }
872 
873 //===----------------------------------------------------------------------===//
874 // PassPipelineCLParser
875 //===----------------------------------------------------------------------===//
876 
877 namespace mlir {
878 namespace detail {
879 struct PassPipelineCLParserImpl {
PassPipelineCLParserImplmlir::detail::PassPipelineCLParserImpl880   PassPipelineCLParserImpl(StringRef arg, StringRef description,
881                            bool passNamesOnly)
882       : passList(arg, llvm::cl::desc(description)) {
883     passList.getParser().passNamesOnly = passNamesOnly;
884     passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional);
885   }
886 
887   /// Returns true if the given pass registry entry was registered at the
888   /// top-level of the parser, i.e. not within an explicit textual pipeline.
containsmlir::detail::PassPipelineCLParserImpl889   bool contains(const PassRegistryEntry *entry) const {
890     return llvm::any_of(passList, [&](const PassArgData &data) {
891       return data.registryEntry == entry;
892     });
893   }
894 
895   /// The set of passes and pass pipelines to run.
896   llvm::cl::list<PassArgData, bool, PassNameParser> passList;
897 };
898 } // namespace detail
899 } // namespace mlir
900 
901 /// Construct a pass pipeline parser with the given command line description.
PassPipelineCLParser(StringRef arg,StringRef description)902 PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description)
903     : impl(std::make_unique<detail::PassPipelineCLParserImpl>(
904           arg, description, /*passNamesOnly=*/false)) {}
905 PassPipelineCLParser::~PassPipelineCLParser() = default;
906 
907 /// Returns true if this parser contains any valid options to add.
hasAnyOccurrences() const908 bool PassPipelineCLParser::hasAnyOccurrences() const {
909   return impl->passList.getNumOccurrences() != 0;
910 }
911 
912 /// Returns true if the given pass registry entry was registered at the
913 /// top-level of the parser, i.e. not within an explicit textual pipeline.
contains(const PassRegistryEntry * entry) const914 bool PassPipelineCLParser::contains(const PassRegistryEntry *entry) const {
915   return impl->contains(entry);
916 }
917 
918 /// Adds the passes defined by this parser entry to the given pass manager.
addToPipeline(OpPassManager & pm,function_ref<LogicalResult (const Twine &)> errorHandler) const919 LogicalResult PassPipelineCLParser::addToPipeline(
920     OpPassManager &pm,
921     function_ref<LogicalResult(const Twine &)> errorHandler) const {
922   for (auto &passIt : impl->passList) {
923     if (passIt.registryEntry) {
924       if (failed(passIt.registryEntry->addToPipeline(pm, passIt.options,
925                                                      errorHandler)))
926         return failure();
927     } else {
928       OpPassManager::Nesting nesting = pm.getNesting();
929       pm.setNesting(OpPassManager::Nesting::Explicit);
930       LogicalResult status = passIt.pipeline.addToPipeline(pm, errorHandler);
931       pm.setNesting(nesting);
932       if (failed(status))
933         return failure();
934     }
935   }
936   return success();
937 }
938 
939 //===----------------------------------------------------------------------===//
940 // PassNameCLParser
941 
942 /// Construct a pass pipeline parser with the given command line description.
PassNameCLParser(StringRef arg,StringRef description)943 PassNameCLParser::PassNameCLParser(StringRef arg, StringRef description)
944     : impl(std::make_unique<detail::PassPipelineCLParserImpl>(
945           arg, description, /*passNamesOnly=*/true)) {
946   impl->passList.setMiscFlag(llvm::cl::CommaSeparated);
947 }
948 PassNameCLParser::~PassNameCLParser() = default;
949 
950 /// Returns true if this parser contains any valid options to add.
hasAnyOccurrences() const951 bool PassNameCLParser::hasAnyOccurrences() const {
952   return impl->passList.getNumOccurrences() != 0;
953 }
954 
955 /// Returns true if the given pass registry entry was registered at the
956 /// top-level of the parser, i.e. not within an explicit textual pipeline.
contains(const PassRegistryEntry * entry) const957 bool PassNameCLParser::contains(const PassRegistryEntry *entry) const {
958   return impl->contains(entry);
959 }
960