1 //===- PassRegistry.cpp - Pass Registration Utilities ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include <utility>
10
11 #include "mlir/Pass/Pass.h"
12 #include "mlir/Pass/PassManager.h"
13 #include "mlir/Pass/PassRegistry.h"
14 #include "llvm/ADT/DenseMap.h"
15 #include "llvm/Support/Format.h"
16 #include "llvm/Support/ManagedStatic.h"
17 #include "llvm/Support/MemoryBuffer.h"
18 #include "llvm/Support/SourceMgr.h"
19
20 using namespace mlir;
21 using namespace detail;
22
23 /// Static mapping of all of the registered passes.
24 static llvm::ManagedStatic<llvm::StringMap<PassInfo>> passRegistry;
25
26 /// A mapping of the above pass registry entries to the corresponding TypeID
27 /// of the pass that they generate.
28 static llvm::ManagedStatic<llvm::StringMap<TypeID>> passRegistryTypeIDs;
29
30 /// Static mapping of all of the registered pass pipelines.
31 static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
32 passPipelineRegistry;
33
34 /// Utility to create a default registry function from a pass instance.
35 static PassRegistryFunction
buildDefaultRegistryFn(const PassAllocatorFunction & allocator)36 buildDefaultRegistryFn(const PassAllocatorFunction &allocator) {
37 return [=](OpPassManager &pm, StringRef options,
38 function_ref<LogicalResult(const Twine &)> errorHandler) {
39 std::unique_ptr<Pass> pass = allocator();
40 LogicalResult result = pass->initializeOptions(options);
41
42 Optional<StringRef> pmOpName = pm.getOpName();
43 Optional<StringRef> passOpName = pass->getOpName();
44 if ((pm.getNesting() == OpPassManager::Nesting::Explicit) && pmOpName &&
45 passOpName && *pmOpName != *passOpName) {
46 return errorHandler(llvm::Twine("Can't add pass '") + pass->getName() +
47 "' restricted to '" + *pass->getOpName() +
48 "' on a PassManager intended to run on '" +
49 pm.getOpAnchorName() + "', did you intend to nest?");
50 }
51 pm.addPass(std::move(pass));
52 return result;
53 };
54 }
55
56 /// Utility to print the help string for a specific option.
printOptionHelp(StringRef arg,StringRef desc,size_t indent,size_t descIndent,bool isTopLevel)57 static void printOptionHelp(StringRef arg, StringRef desc, size_t indent,
58 size_t descIndent, bool isTopLevel) {
59 size_t numSpaces = descIndent - indent - 4;
60 llvm::outs().indent(indent)
61 << "--" << llvm::left_justify(arg, numSpaces) << "- " << desc << '\n';
62 }
63
64 //===----------------------------------------------------------------------===//
65 // PassRegistry
66 //===----------------------------------------------------------------------===//
67
68 /// Print the help information for this pass. This includes the argument,
69 /// description, and any pass options. `descIndent` is the indent that the
70 /// descriptions should be aligned.
printHelpStr(size_t indent,size_t descIndent) const71 void PassRegistryEntry::printHelpStr(size_t indent, size_t descIndent) const {
72 printOptionHelp(getPassArgument(), getPassDescription(), indent, descIndent,
73 /*isTopLevel=*/true);
74 // If this entry has options, print the help for those as well.
75 optHandler([=](const PassOptions &options) {
76 options.printHelp(indent, descIndent);
77 });
78 }
79
80 /// Return the maximum width required when printing the options of this
81 /// entry.
getOptionWidth() const82 size_t PassRegistryEntry::getOptionWidth() const {
83 size_t maxLen = 0;
84 optHandler([&](const PassOptions &options) mutable {
85 maxLen = options.getOptionWidth() + 2;
86 });
87 return maxLen;
88 }
89
90 //===----------------------------------------------------------------------===//
91 // PassPipelineInfo
92 //===----------------------------------------------------------------------===//
93
registerPassPipeline(StringRef arg,StringRef description,const PassRegistryFunction & function,std::function<void (function_ref<void (const PassOptions &)>)> optHandler)94 void mlir::registerPassPipeline(
95 StringRef arg, StringRef description, const PassRegistryFunction &function,
96 std::function<void(function_ref<void(const PassOptions &)>)> optHandler) {
97 PassPipelineInfo pipelineInfo(arg, description, function,
98 std::move(optHandler));
99 bool inserted = passPipelineRegistry->try_emplace(arg, pipelineInfo).second;
100 assert(inserted && "Pass pipeline registered multiple times");
101 (void)inserted;
102 }
103
104 //===----------------------------------------------------------------------===//
105 // PassInfo
106 //===----------------------------------------------------------------------===//
107
PassInfo(StringRef arg,StringRef description,const PassAllocatorFunction & allocator)108 PassInfo::PassInfo(StringRef arg, StringRef description,
109 const PassAllocatorFunction &allocator)
110 : PassRegistryEntry(
111 arg, description, buildDefaultRegistryFn(allocator),
112 // Use a temporary pass to provide an options instance.
113 [=](function_ref<void(const PassOptions &)> optHandler) {
114 optHandler(allocator()->passOptions);
115 }) {}
116
registerPass(const PassAllocatorFunction & function)117 void mlir::registerPass(const PassAllocatorFunction &function) {
118 std::unique_ptr<Pass> pass = function();
119 StringRef arg = pass->getArgument();
120 if (arg.empty())
121 llvm::report_fatal_error(llvm::Twine("Trying to register '") +
122 pass->getName() +
123 "' pass that does not override `getArgument()`");
124 StringRef description = pass->getDescription();
125 PassInfo passInfo(arg, description, function);
126 passRegistry->try_emplace(arg, passInfo);
127
128 // Verify that the registered pass has the same ID as any registered to this
129 // arg before it.
130 TypeID entryTypeID = pass->getTypeID();
131 auto it = passRegistryTypeIDs->try_emplace(arg, entryTypeID).first;
132 if (it->second != entryTypeID)
133 llvm::report_fatal_error(
134 "pass allocator creates a different pass than previously "
135 "registered for pass " +
136 arg);
137 }
138
139 /// Returns the pass info for the specified pass argument or null if unknown.
lookupPassInfo(StringRef passArg)140 const PassInfo *mlir::Pass::lookupPassInfo(StringRef passArg) {
141 auto it = passRegistry->find(passArg);
142 return it == passRegistry->end() ? nullptr : &it->second;
143 }
144
145 //===----------------------------------------------------------------------===//
146 // PassOptions
147 //===----------------------------------------------------------------------===//
148
parseCommaSeparatedList(llvm::cl::Option & opt,StringRef argName,StringRef optionStr,function_ref<LogicalResult (StringRef)> elementParseFn)149 LogicalResult detail::pass_options::parseCommaSeparatedList(
150 llvm::cl::Option &opt, StringRef argName, StringRef optionStr,
151 function_ref<LogicalResult(StringRef)> elementParseFn) {
152 // Functor used for finding a character in a string, and skipping over
153 // various "range" characters.
154 llvm::unique_function<size_t(StringRef, size_t, char)> findChar =
155 [&](StringRef str, size_t index, char c) -> size_t {
156 for (size_t i = index, e = str.size(); i < e; ++i) {
157 if (str[i] == c)
158 return i;
159 // Check for various range characters.
160 if (str[i] == '{')
161 i = findChar(str, i + 1, '}');
162 else if (str[i] == '(')
163 i = findChar(str, i + 1, ')');
164 else if (str[i] == '[')
165 i = findChar(str, i + 1, ']');
166 else if (str[i] == '\"')
167 i = str.find_first_of('\"', i + 1);
168 else if (str[i] == '\'')
169 i = str.find_first_of('\'', i + 1);
170 }
171 return StringRef::npos;
172 };
173
174 size_t nextElePos = findChar(optionStr, 0, ',');
175 while (nextElePos != StringRef::npos) {
176 // Process the portion before the comma.
177 if (failed(elementParseFn(optionStr.substr(0, nextElePos))))
178 return failure();
179
180 optionStr = optionStr.substr(nextElePos + 1);
181 nextElePos = findChar(optionStr, 0, ',');
182 }
183 return elementParseFn(optionStr.substr(0, nextElePos));
184 }
185
186 /// Out of line virtual function to provide home for the class.
anchor()187 void detail::PassOptions::OptionBase::anchor() {}
188
189 /// Copy the option values from 'other'.
copyOptionValuesFrom(const PassOptions & other)190 void detail::PassOptions::copyOptionValuesFrom(const PassOptions &other) {
191 assert(options.size() == other.options.size());
192 if (options.empty())
193 return;
194 for (auto optionsIt : llvm::zip(options, other.options))
195 std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt));
196 }
197
198 /// Parse in the next argument from the given options string. Returns a tuple
199 /// containing [the key of the option, the value of the option, updated
200 /// `options` string pointing after the parsed option].
201 static std::tuple<StringRef, StringRef, StringRef>
parseNextArg(StringRef options)202 parseNextArg(StringRef options) {
203 // Functor used to extract an argument from 'options' and update it to point
204 // after the arg.
205 auto extractArgAndUpdateOptions = [&](size_t argSize) {
206 StringRef str = options.take_front(argSize).trim();
207 options = options.drop_front(argSize).ltrim();
208 return str;
209 };
210 // Try to process the given punctuation, properly escaping any contained
211 // characters.
212 auto tryProcessPunct = [&](size_t ¤tPos, char punct) {
213 if (options[currentPos] != punct)
214 return false;
215 size_t nextIt = options.find_first_of(punct, currentPos + 1);
216 if (nextIt != StringRef::npos)
217 currentPos = nextIt;
218 return true;
219 };
220
221 // Parse the argument name of the option.
222 StringRef argName;
223 for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
224 // Check for the end of the full option.
225 if (argEndIt == optionsE || options[argEndIt] == ' ') {
226 argName = extractArgAndUpdateOptions(argEndIt);
227 return std::make_tuple(argName, StringRef(), options);
228 }
229
230 // Check for the end of the name and the start of the value.
231 if (options[argEndIt] == '=') {
232 argName = extractArgAndUpdateOptions(argEndIt);
233 options = options.drop_front();
234 break;
235 }
236 }
237
238 // Parse the value of the option.
239 for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
240 // Handle the end of the options string.
241 if (argEndIt == optionsE || options[argEndIt] == ' ') {
242 StringRef value = extractArgAndUpdateOptions(argEndIt);
243 return std::make_tuple(argName, value, options);
244 }
245
246 // Skip over escaped sequences.
247 char c = options[argEndIt];
248 if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"'))
249 continue;
250 // '{...}' is used to specify options to passes, properly escape it so
251 // that we don't accidentally split any nested options.
252 if (c == '{') {
253 size_t braceCount = 1;
254 for (++argEndIt; argEndIt != optionsE; ++argEndIt) {
255 // Allow nested punctuation.
256 if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"'))
257 continue;
258 if (options[argEndIt] == '{')
259 ++braceCount;
260 else if (options[argEndIt] == '}' && --braceCount == 0)
261 break;
262 }
263 // Account for the increment at the top of the loop.
264 --argEndIt;
265 }
266 }
267 llvm_unreachable("unexpected control flow in pass option parsing");
268 }
269
parseFromString(StringRef options)270 LogicalResult detail::PassOptions::parseFromString(StringRef options) {
271 // NOTE: `options` is modified in place to always refer to the unprocessed
272 // part of the string.
273 while (!options.empty()) {
274 StringRef key, value;
275 std::tie(key, value, options) = parseNextArg(options);
276 if (key.empty())
277 continue;
278
279 auto it = OptionsMap.find(key);
280 if (it == OptionsMap.end()) {
281 llvm::errs() << "<Pass-Options-Parser>: no such option " << key << "\n";
282 return failure();
283 }
284 if (llvm::cl::ProvidePositionalOption(it->second, value, 0))
285 return failure();
286 }
287
288 return success();
289 }
290
291 /// Print the options held by this struct in a form that can be parsed via
292 /// 'parseFromString'.
print(raw_ostream & os)293 void detail::PassOptions::print(raw_ostream &os) {
294 // If there are no options, there is nothing left to do.
295 if (OptionsMap.empty())
296 return;
297
298 // Sort the options to make the ordering deterministic.
299 SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
300 auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
301 return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
302 };
303 llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
304
305 // Interleave the options with ' '.
306 os << '{';
307 llvm::interleave(
308 orderedOps, os, [&](OptionBase *option) { option->print(os); }, " ");
309 os << '}';
310 }
311
312 /// Print the help string for the options held by this struct. `descIndent` is
313 /// the indent within the stream that the descriptions should be aligned.
printHelp(size_t indent,size_t descIndent) const314 void detail::PassOptions::printHelp(size_t indent, size_t descIndent) const {
315 // Sort the options to make the ordering deterministic.
316 SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
317 auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
318 return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
319 };
320 llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
321 for (OptionBase *option : orderedOps) {
322 // TODO: printOptionInfo assumes a specific indent and will
323 // print options with values with incorrect indentation. We should add
324 // support to llvm::cl::Option for passing in a base indent to use when
325 // printing.
326 llvm::outs().indent(indent);
327 option->getOption()->printOptionInfo(descIndent - indent);
328 }
329 }
330
331 /// Return the maximum width required when printing the help string.
getOptionWidth() const332 size_t detail::PassOptions::getOptionWidth() const {
333 size_t max = 0;
334 for (auto *option : options)
335 max = std::max(max, option->getOption()->getOptionWidth());
336 return max;
337 }
338
339 //===----------------------------------------------------------------------===//
340 // MLIR Options
341 //===----------------------------------------------------------------------===//
342
343 //===----------------------------------------------------------------------===//
344 // OpPassManager: OptionValue
345
346 llvm::cl::OptionValue<OpPassManager>::OptionValue() = default;
OptionValue(const mlir::OpPassManager & value)347 llvm::cl::OptionValue<OpPassManager>::OptionValue(
348 const mlir::OpPassManager &value) {
349 setValue(value);
350 }
351 llvm::cl::OptionValue<OpPassManager> &
operator =(const mlir::OpPassManager & rhs)352 llvm::cl::OptionValue<OpPassManager>::operator=(
353 const mlir::OpPassManager &rhs) {
354 setValue(rhs);
355 return *this;
356 }
357
358 llvm::cl::OptionValue<OpPassManager>::~OptionValue<OpPassManager>() = default;
359
setValue(const OpPassManager & newValue)360 void llvm::cl::OptionValue<OpPassManager>::setValue(
361 const OpPassManager &newValue) {
362 if (hasValue())
363 *value = newValue;
364 else
365 value = std::make_unique<mlir::OpPassManager>(newValue);
366 }
setValue(StringRef pipelineStr)367 void llvm::cl::OptionValue<OpPassManager>::setValue(StringRef pipelineStr) {
368 FailureOr<OpPassManager> pipeline = parsePassPipeline(pipelineStr);
369 assert(succeeded(pipeline) && "invalid pass pipeline");
370 setValue(*pipeline);
371 }
372
compare(const mlir::OpPassManager & rhs) const373 bool llvm::cl::OptionValue<OpPassManager>::compare(
374 const mlir::OpPassManager &rhs) const {
375 std::string lhsStr, rhsStr;
376 {
377 raw_string_ostream lhsStream(lhsStr);
378 value->printAsTextualPipeline(lhsStream);
379
380 raw_string_ostream rhsStream(rhsStr);
381 rhs.printAsTextualPipeline(rhsStream);
382 }
383
384 // Use the textual format for pipeline comparisons.
385 return lhsStr == rhsStr;
386 }
387
anchor()388 void llvm::cl::OptionValue<OpPassManager>::anchor() {}
389
390 //===----------------------------------------------------------------------===//
391 // OpPassManager: Parser
392
393 namespace llvm {
394 namespace cl {
395 template class basic_parser<OpPassManager>;
396 } // namespace cl
397 } // namespace llvm
398
parse(Option &,StringRef,StringRef arg,ParsedPassManager & value)399 bool llvm::cl::parser<OpPassManager>::parse(Option &, StringRef, StringRef arg,
400 ParsedPassManager &value) {
401 FailureOr<OpPassManager> pipeline = parsePassPipeline(arg);
402 if (failed(pipeline))
403 return true;
404 value.value = std::make_unique<OpPassManager>(std::move(*pipeline));
405 return false;
406 }
407
print(raw_ostream & os,const OpPassManager & value)408 void llvm::cl::parser<OpPassManager>::print(raw_ostream &os,
409 const OpPassManager &value) {
410 value.printAsTextualPipeline(os);
411 }
412
printOptionDiff(const Option & opt,OpPassManager & pm,const OptVal & defaultValue,size_t globalWidth) const413 void llvm::cl::parser<OpPassManager>::printOptionDiff(
414 const Option &opt, OpPassManager &pm, const OptVal &defaultValue,
415 size_t globalWidth) const {
416 printOptionName(opt, globalWidth);
417 outs() << "= ";
418 pm.printAsTextualPipeline(outs());
419
420 if (defaultValue.hasValue()) {
421 outs().indent(2) << " (default: ";
422 defaultValue.getValue().printAsTextualPipeline(outs());
423 outs() << ")";
424 }
425 outs() << "\n";
426 }
427
anchor()428 void llvm::cl::parser<OpPassManager>::anchor() {}
429
430 llvm::cl::parser<OpPassManager>::ParsedPassManager::ParsedPassManager() =
431 default;
432 llvm::cl::parser<OpPassManager>::ParsedPassManager::ParsedPassManager(
433 ParsedPassManager &&) = default;
434 llvm::cl::parser<OpPassManager>::ParsedPassManager::~ParsedPassManager() =
435 default;
436
437 //===----------------------------------------------------------------------===//
438 // TextualPassPipeline Parser
439 //===----------------------------------------------------------------------===//
440
441 namespace {
442 /// This class represents a textual description of a pass pipeline.
443 class TextualPipeline {
444 public:
445 /// Try to initialize this pipeline with the given pipeline text.
446 /// `errorStream` is the output stream to emit errors to.
447 LogicalResult initialize(StringRef text, raw_ostream &errorStream);
448
449 /// Add the internal pipeline elements to the provided pass manager.
450 LogicalResult
451 addToPipeline(OpPassManager &pm,
452 function_ref<LogicalResult(const Twine &)> errorHandler) const;
453
454 private:
455 /// A functor used to emit errors found during pipeline handling. The first
456 /// parameter corresponds to the raw location within the pipeline string. This
457 /// should always return failure.
458 using ErrorHandlerT = function_ref<LogicalResult(const char *, Twine)>;
459
460 /// A struct to capture parsed pass pipeline names.
461 ///
462 /// A pipeline is defined as a series of names, each of which may in itself
463 /// recursively contain a nested pipeline. A name is either the name of a pass
464 /// (e.g. "cse") or the name of an operation type (e.g. "buitin.module"). If
465 /// the name is the name of a pass, the InnerPipeline is empty, since passes
466 /// cannot contain inner pipelines.
467 struct PipelineElement {
PipelineElement__anon67970f8f0b11::TextualPipeline::PipelineElement468 PipelineElement(StringRef name) : name(name) {}
469
470 StringRef name;
471 StringRef options;
472 const PassRegistryEntry *registryEntry = nullptr;
473 std::vector<PipelineElement> innerPipeline;
474 };
475
476 /// Parse the given pipeline text into the internal pipeline vector. This
477 /// function only parses the structure of the pipeline, and does not resolve
478 /// its elements.
479 LogicalResult parsePipelineText(StringRef text, ErrorHandlerT errorHandler);
480
481 /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
482 /// the corresponding registry entry.
483 LogicalResult
484 resolvePipelineElements(MutableArrayRef<PipelineElement> elements,
485 ErrorHandlerT errorHandler);
486
487 /// Resolve a single element of the pipeline.
488 LogicalResult resolvePipelineElement(PipelineElement &element,
489 ErrorHandlerT errorHandler);
490
491 /// Add the given pipeline elements to the provided pass manager.
492 LogicalResult
493 addToPipeline(ArrayRef<PipelineElement> elements, OpPassManager &pm,
494 function_ref<LogicalResult(const Twine &)> errorHandler) const;
495
496 std::vector<PipelineElement> pipeline;
497 };
498
499 } // namespace
500
501 /// Try to initialize this pipeline with the given pipeline text. An option is
502 /// given to enable accurate error reporting.
initialize(StringRef text,raw_ostream & errorStream)503 LogicalResult TextualPipeline::initialize(StringRef text,
504 raw_ostream &errorStream) {
505 if (text.empty())
506 return success();
507
508 // Build a source manager to use for error reporting.
509 llvm::SourceMgr pipelineMgr;
510 pipelineMgr.AddNewSourceBuffer(
511 llvm::MemoryBuffer::getMemBuffer(text, "MLIR Textual PassPipeline Parser",
512 /*RequiresNullTerminator=*/false),
513 SMLoc());
514 auto errorHandler = [&](const char *rawLoc, Twine msg) {
515 pipelineMgr.PrintMessage(errorStream, SMLoc::getFromPointer(rawLoc),
516 llvm::SourceMgr::DK_Error, msg);
517 return failure();
518 };
519
520 // Parse the provided pipeline string.
521 if (failed(parsePipelineText(text, errorHandler)))
522 return failure();
523 return resolvePipelineElements(pipeline, errorHandler);
524 }
525
526 /// Add the internal pipeline elements to the provided pass manager.
addToPipeline(OpPassManager & pm,function_ref<LogicalResult (const Twine &)> errorHandler) const527 LogicalResult TextualPipeline::addToPipeline(
528 OpPassManager &pm,
529 function_ref<LogicalResult(const Twine &)> errorHandler) const {
530 return addToPipeline(pipeline, pm, errorHandler);
531 }
532
533 /// Parse the given pipeline text into the internal pipeline vector. This
534 /// function only parses the structure of the pipeline, and does not resolve
535 /// its elements.
parsePipelineText(StringRef text,ErrorHandlerT errorHandler)536 LogicalResult TextualPipeline::parsePipelineText(StringRef text,
537 ErrorHandlerT errorHandler) {
538 SmallVector<std::vector<PipelineElement> *, 4> pipelineStack = {&pipeline};
539 for (;;) {
540 std::vector<PipelineElement> &pipeline = *pipelineStack.back();
541 size_t pos = text.find_first_of(",(){");
542 pipeline.emplace_back(/*name=*/text.substr(0, pos).trim());
543
544 // If we have a single terminating name, we're done.
545 if (pos == StringRef::npos)
546 break;
547
548 text = text.substr(pos);
549 char sep = text[0];
550
551 // Handle pulling ... from 'pass{...}' out as PipelineElement.options.
552 if (sep == '{') {
553 text = text.substr(1);
554
555 // Skip over everything until the closing '}' and store as options.
556 size_t close = StringRef::npos;
557 for (unsigned i = 0, e = text.size(), braceCount = 1; i < e; ++i) {
558 if (text[i] == '{') {
559 ++braceCount;
560 continue;
561 }
562 if (text[i] == '}' && --braceCount == 0) {
563 close = i;
564 break;
565 }
566 }
567
568 // Check to see if a closing options brace was found.
569 if (close == StringRef::npos) {
570 return errorHandler(
571 /*rawLoc=*/text.data() - 1,
572 "missing closing '}' while processing pass options");
573 }
574 pipeline.back().options = text.substr(0, close);
575 text = text.substr(close + 1);
576
577 // Skip checking for '(' because nested pipelines cannot have options.
578 } else if (sep == '(') {
579 text = text.substr(1);
580
581 // Push the inner pipeline onto the stack to continue processing.
582 pipelineStack.push_back(&pipeline.back().innerPipeline);
583 continue;
584 }
585
586 // When handling the close parenthesis, we greedily consume them to avoid
587 // empty strings in the pipeline.
588 while (text.consume_front(")")) {
589 // If we try to pop the outer pipeline we have unbalanced parentheses.
590 if (pipelineStack.size() == 1)
591 return errorHandler(/*rawLoc=*/text.data() - 1,
592 "encountered extra closing ')' creating unbalanced "
593 "parentheses while parsing pipeline");
594
595 pipelineStack.pop_back();
596 }
597
598 // Check if we've finished parsing.
599 if (text.empty())
600 break;
601
602 // Otherwise, the end of an inner pipeline always has to be followed by
603 // a comma, and then we can continue.
604 if (!text.consume_front(","))
605 return errorHandler(text.data(), "expected ',' after parsing pipeline");
606 }
607
608 // Check for unbalanced parentheses.
609 if (pipelineStack.size() > 1)
610 return errorHandler(
611 text.data(),
612 "encountered unbalanced parentheses while parsing pipeline");
613
614 assert(pipelineStack.back() == &pipeline &&
615 "wrong pipeline at the bottom of the stack");
616 return success();
617 }
618
619 /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
620 /// the corresponding registry entry.
resolvePipelineElements(MutableArrayRef<PipelineElement> elements,ErrorHandlerT errorHandler)621 LogicalResult TextualPipeline::resolvePipelineElements(
622 MutableArrayRef<PipelineElement> elements, ErrorHandlerT errorHandler) {
623 for (auto &elt : elements)
624 if (failed(resolvePipelineElement(elt, errorHandler)))
625 return failure();
626 return success();
627 }
628
629 /// Resolve a single element of the pipeline.
630 LogicalResult
resolvePipelineElement(PipelineElement & element,ErrorHandlerT errorHandler)631 TextualPipeline::resolvePipelineElement(PipelineElement &element,
632 ErrorHandlerT errorHandler) {
633 // If the inner pipeline of this element is not empty, this is an operation
634 // pipeline.
635 if (!element.innerPipeline.empty())
636 return resolvePipelineElements(element.innerPipeline, errorHandler);
637 // Otherwise, this must be a pass or pass pipeline.
638 // Check to see if a pipeline was registered with this name.
639 auto pipelineRegistryIt = passPipelineRegistry->find(element.name);
640 if (pipelineRegistryIt != passPipelineRegistry->end()) {
641 element.registryEntry = &pipelineRegistryIt->second;
642 return success();
643 }
644
645 // If not, then this must be a specific pass name.
646 if ((element.registryEntry = Pass::lookupPassInfo(element.name)))
647 return success();
648
649 // Emit an error for the unknown pass.
650 auto *rawLoc = element.name.data();
651 return errorHandler(rawLoc, "'" + element.name +
652 "' does not refer to a "
653 "registered pass or pass pipeline");
654 }
655
656 /// Add the given pipeline elements to the provided pass manager.
addToPipeline(ArrayRef<PipelineElement> elements,OpPassManager & pm,function_ref<LogicalResult (const Twine &)> errorHandler) const657 LogicalResult TextualPipeline::addToPipeline(
658 ArrayRef<PipelineElement> elements, OpPassManager &pm,
659 function_ref<LogicalResult(const Twine &)> errorHandler) const {
660 for (auto &elt : elements) {
661 if (elt.registryEntry) {
662 if (failed(elt.registryEntry->addToPipeline(pm, elt.options,
663 errorHandler))) {
664 return errorHandler("failed to add `" + elt.name + "` with options `" +
665 elt.options + "`");
666 }
667 } else if (failed(addToPipeline(elt.innerPipeline, pm.nest(elt.name),
668 errorHandler))) {
669 return errorHandler("failed to add `" + elt.name + "` with options `" +
670 elt.options + "` to inner pipeline");
671 }
672 }
673 return success();
674 }
675
parsePassPipeline(StringRef pipeline,OpPassManager & pm,raw_ostream & errorStream)676 LogicalResult mlir::parsePassPipeline(StringRef pipeline, OpPassManager &pm,
677 raw_ostream &errorStream) {
678 TextualPipeline pipelineParser;
679 if (failed(pipelineParser.initialize(pipeline, errorStream)))
680 return failure();
681 auto errorHandler = [&](Twine msg) {
682 errorStream << msg << "\n";
683 return failure();
684 };
685 if (failed(pipelineParser.addToPipeline(pm, errorHandler)))
686 return failure();
687 return success();
688 }
689
parsePassPipeline(StringRef pipeline,raw_ostream & errorStream)690 FailureOr<OpPassManager> mlir::parsePassPipeline(StringRef pipeline,
691 raw_ostream &errorStream) {
692 // Pipelines are expected to be of the form `<op-name>(<pipeline>)`.
693 size_t pipelineStart = pipeline.find_first_of('(');
694 if (pipelineStart == 0 || pipelineStart == StringRef::npos ||
695 !pipeline.consume_back(")")) {
696 errorStream << "expected pass pipeline to be wrapped with the anchor "
697 "operation type, e.g. `builtin.module(...)";
698 return failure();
699 }
700
701 StringRef opName = pipeline.take_front(pipelineStart);
702 OpPassManager pm(opName);
703 if (failed(parsePassPipeline(pipeline.drop_front(1 + pipelineStart), pm)))
704 return failure();
705 return pm;
706 }
707
708 //===----------------------------------------------------------------------===//
709 // PassNameParser
710 //===----------------------------------------------------------------------===//
711
712 namespace {
713 /// This struct represents the possible data entries in a parsed pass pipeline
714 /// list.
715 struct PassArgData {
716 PassArgData() = default;
PassArgData__anon67970f8f0e11::PassArgData717 PassArgData(const PassRegistryEntry *registryEntry)
718 : registryEntry(registryEntry) {}
719
720 /// This field is used when the parsed option corresponds to a registered pass
721 /// or pass pipeline.
722 const PassRegistryEntry *registryEntry{nullptr};
723
724 /// This field is set when instance specific pass options have been provided
725 /// on the command line.
726 StringRef options;
727
728 /// This field is used when the parsed option corresponds to an explicit
729 /// pipeline.
730 TextualPipeline pipeline;
731 };
732 } // namespace
733
734 namespace llvm {
735 namespace cl {
736 /// Define a valid OptionValue for the command line pass argument.
737 template <>
738 struct OptionValue<PassArgData> final
739 : OptionValueBase<PassArgData, /*isClass=*/true> {
OptionValuellvm::cl::OptionValue740 OptionValue(const PassArgData &value) { this->setValue(value); }
741 OptionValue() = default;
anchorllvm::cl::OptionValue742 void anchor() override {}
743
hasValuellvm::cl::OptionValue744 bool hasValue() const { return true; }
getValuellvm::cl::OptionValue745 const PassArgData &getValue() const { return value; }
setValuellvm::cl::OptionValue746 void setValue(const PassArgData &value) { this->value = value; }
747
748 PassArgData value;
749 };
750 } // namespace cl
751 } // namespace llvm
752
753 namespace {
754
755 /// The name for the command line option used for parsing the textual pass
756 /// pipeline.
757 static constexpr StringLiteral passPipelineArg = "pass-pipeline";
758
759 /// Adds command line option for each registered pass or pass pipeline, as well
760 /// as textual pass pipelines.
761 struct PassNameParser : public llvm::cl::parser<PassArgData> {
PassNameParser__anon67970f8f0f11::PassNameParser762 PassNameParser(llvm::cl::Option &opt) : llvm::cl::parser<PassArgData>(opt) {}
763
764 void initialize();
765 void printOptionInfo(const llvm::cl::Option &opt,
766 size_t globalWidth) const override;
767 size_t getOptionWidth(const llvm::cl::Option &opt) const override;
768 bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
769 PassArgData &value);
770
771 /// If true, this parser only parses entries that correspond to a concrete
772 /// pass registry entry, and does not add a `pass-pipeline` argument, does not
773 /// include the options for pass entries, and does not include pass pipelines
774 /// entries.
775 bool passNamesOnly = false;
776 };
777 } // namespace
778
initialize()779 void PassNameParser::initialize() {
780 llvm::cl::parser<PassArgData>::initialize();
781
782 /// Add an entry for the textual pass pipeline option.
783 if (!passNamesOnly) {
784 addLiteralOption(passPipelineArg, PassArgData(),
785 "A textual description of a pass pipeline to run");
786 }
787
788 /// Add the pass entries.
789 for (const auto &kv : *passRegistry) {
790 addLiteralOption(kv.second.getPassArgument(), &kv.second,
791 kv.second.getPassDescription());
792 }
793 /// Add the pass pipeline entries.
794 if (!passNamesOnly) {
795 for (const auto &kv : *passPipelineRegistry) {
796 addLiteralOption(kv.second.getPassArgument(), &kv.second,
797 kv.second.getPassDescription());
798 }
799 }
800 }
801
printOptionInfo(const llvm::cl::Option & opt,size_t globalWidth) const802 void PassNameParser::printOptionInfo(const llvm::cl::Option &opt,
803 size_t globalWidth) const {
804 // If this parser is just parsing pass names, print a simplified option
805 // string.
806 if (passNamesOnly) {
807 llvm::outs() << " --" << opt.ArgStr << "=<pass-arg>";
808 opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 18);
809 return;
810 }
811
812 // Print the information for the top-level option.
813 if (opt.hasArgStr()) {
814 llvm::outs() << " --" << opt.ArgStr;
815 opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 7);
816 } else {
817 llvm::outs() << " " << opt.HelpStr << '\n';
818 }
819
820 // Print the top-level pipeline argument.
821 printOptionHelp(passPipelineArg,
822 "A textual description of a pass pipeline to run",
823 /*indent=*/4, globalWidth, /*isTopLevel=*/!opt.hasArgStr());
824
825 // Functor used to print the ordered entries of a registration map.
826 auto printOrderedEntries = [&](StringRef header, auto &map) {
827 llvm::SmallVector<PassRegistryEntry *, 32> orderedEntries;
828 for (auto &kv : map)
829 orderedEntries.push_back(&kv.second);
830 llvm::array_pod_sort(
831 orderedEntries.begin(), orderedEntries.end(),
832 [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) {
833 return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());
834 });
835
836 llvm::outs().indent(4) << header << ":\n";
837 for (PassRegistryEntry *entry : orderedEntries)
838 entry->printHelpStr(/*indent=*/6, globalWidth);
839 };
840
841 // Print the available passes.
842 printOrderedEntries("Passes", *passRegistry);
843
844 // Print the available pass pipelines.
845 if (!passPipelineRegistry->empty())
846 printOrderedEntries("Pass Pipelines", *passPipelineRegistry);
847 }
848
getOptionWidth(const llvm::cl::Option & opt) const849 size_t PassNameParser::getOptionWidth(const llvm::cl::Option &opt) const {
850 size_t maxWidth = llvm::cl::parser<PassArgData>::getOptionWidth(opt) + 2;
851
852 // Check for any wider pass or pipeline options.
853 for (auto &entry : *passRegistry)
854 maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
855 for (auto &entry : *passPipelineRegistry)
856 maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
857 return maxWidth;
858 }
859
parse(llvm::cl::Option & opt,StringRef argName,StringRef arg,PassArgData & value)860 bool PassNameParser::parse(llvm::cl::Option &opt, StringRef argName,
861 StringRef arg, PassArgData &value) {
862 // Handle the pipeline option explicitly.
863 if (argName == passPipelineArg)
864 return failed(value.pipeline.initialize(arg, llvm::errs()));
865
866 // Otherwise, default to the base for handling.
867 if (llvm::cl::parser<PassArgData>::parse(opt, argName, arg, value))
868 return true;
869 value.options = arg;
870 return false;
871 }
872
873 //===----------------------------------------------------------------------===//
874 // PassPipelineCLParser
875 //===----------------------------------------------------------------------===//
876
877 namespace mlir {
878 namespace detail {
879 struct PassPipelineCLParserImpl {
PassPipelineCLParserImplmlir::detail::PassPipelineCLParserImpl880 PassPipelineCLParserImpl(StringRef arg, StringRef description,
881 bool passNamesOnly)
882 : passList(arg, llvm::cl::desc(description)) {
883 passList.getParser().passNamesOnly = passNamesOnly;
884 passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional);
885 }
886
887 /// Returns true if the given pass registry entry was registered at the
888 /// top-level of the parser, i.e. not within an explicit textual pipeline.
containsmlir::detail::PassPipelineCLParserImpl889 bool contains(const PassRegistryEntry *entry) const {
890 return llvm::any_of(passList, [&](const PassArgData &data) {
891 return data.registryEntry == entry;
892 });
893 }
894
895 /// The set of passes and pass pipelines to run.
896 llvm::cl::list<PassArgData, bool, PassNameParser> passList;
897 };
898 } // namespace detail
899 } // namespace mlir
900
901 /// Construct a pass pipeline parser with the given command line description.
PassPipelineCLParser(StringRef arg,StringRef description)902 PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description)
903 : impl(std::make_unique<detail::PassPipelineCLParserImpl>(
904 arg, description, /*passNamesOnly=*/false)) {}
905 PassPipelineCLParser::~PassPipelineCLParser() = default;
906
907 /// Returns true if this parser contains any valid options to add.
hasAnyOccurrences() const908 bool PassPipelineCLParser::hasAnyOccurrences() const {
909 return impl->passList.getNumOccurrences() != 0;
910 }
911
912 /// Returns true if the given pass registry entry was registered at the
913 /// top-level of the parser, i.e. not within an explicit textual pipeline.
contains(const PassRegistryEntry * entry) const914 bool PassPipelineCLParser::contains(const PassRegistryEntry *entry) const {
915 return impl->contains(entry);
916 }
917
918 /// Adds the passes defined by this parser entry to the given pass manager.
addToPipeline(OpPassManager & pm,function_ref<LogicalResult (const Twine &)> errorHandler) const919 LogicalResult PassPipelineCLParser::addToPipeline(
920 OpPassManager &pm,
921 function_ref<LogicalResult(const Twine &)> errorHandler) const {
922 for (auto &passIt : impl->passList) {
923 if (passIt.registryEntry) {
924 if (failed(passIt.registryEntry->addToPipeline(pm, passIt.options,
925 errorHandler)))
926 return failure();
927 } else {
928 OpPassManager::Nesting nesting = pm.getNesting();
929 pm.setNesting(OpPassManager::Nesting::Explicit);
930 LogicalResult status = passIt.pipeline.addToPipeline(pm, errorHandler);
931 pm.setNesting(nesting);
932 if (failed(status))
933 return failure();
934 }
935 }
936 return success();
937 }
938
939 //===----------------------------------------------------------------------===//
940 // PassNameCLParser
941
942 /// Construct a pass pipeline parser with the given command line description.
PassNameCLParser(StringRef arg,StringRef description)943 PassNameCLParser::PassNameCLParser(StringRef arg, StringRef description)
944 : impl(std::make_unique<detail::PassPipelineCLParserImpl>(
945 arg, description, /*passNamesOnly=*/true)) {
946 impl->passList.setMiscFlag(llvm::cl::CommaSeparated);
947 }
948 PassNameCLParser::~PassNameCLParser() = default;
949
950 /// Returns true if this parser contains any valid options to add.
hasAnyOccurrences() const951 bool PassNameCLParser::hasAnyOccurrences() const {
952 return impl->passList.getNumOccurrences() != 0;
953 }
954
955 /// Returns true if the given pass registry entry was registered at the
956 /// top-level of the parser, i.e. not within an explicit textual pipeline.
contains(const PassRegistryEntry * entry) const957 bool PassNameCLParser::contains(const PassRegistryEntry *entry) const {
958 return impl->contains(entry);
959 }
960