1 //===- PassRegistry.cpp - Pass Registration Utilities ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "mlir/Pass/PassRegistry.h" 12 #include "mlir/Pass/Pass.h" 13 #include "mlir/Pass/PassManager.h" 14 #include "llvm/ADT/DenseMap.h" 15 #include "llvm/Support/Format.h" 16 #include "llvm/Support/ManagedStatic.h" 17 #include "llvm/Support/MemoryBuffer.h" 18 #include "llvm/Support/SourceMgr.h" 19 20 using namespace mlir; 21 using namespace detail; 22 23 /// Static mapping of all of the registered passes. 24 static llvm::ManagedStatic<llvm::StringMap<PassInfo>> passRegistry; 25 26 /// A mapping of the above pass registry entries to the corresponding TypeID 27 /// of the pass that they generate. 28 static llvm::ManagedStatic<llvm::StringMap<TypeID>> passRegistryTypeIDs; 29 30 /// Static mapping of all of the registered pass pipelines. 31 static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>> 32 passPipelineRegistry; 33 34 /// Utility to create a default registry function from a pass instance. 35 static PassRegistryFunction 36 buildDefaultRegistryFn(const PassAllocatorFunction &allocator) { 37 return [=](OpPassManager &pm, StringRef options, 38 function_ref<LogicalResult(const Twine &)> errorHandler) { 39 std::unique_ptr<Pass> pass = allocator(); 40 LogicalResult result = pass->initializeOptions(options); 41 42 Optional<StringRef> pmOpName = pm.getOpName(); 43 Optional<StringRef> passOpName = pass->getOpName(); 44 if ((pm.getNesting() == OpPassManager::Nesting::Explicit) && pmOpName && 45 passOpName && *pmOpName != *passOpName) { 46 return errorHandler(llvm::Twine("Can't add pass '") + pass->getName() + 47 "' restricted to '" + *pass->getOpName() + 48 "' on a PassManager intended to run on '" + 49 pm.getOpAnchorName() + "', did you intend to nest?"); 50 } 51 pm.addPass(std::move(pass)); 52 return result; 53 }; 54 } 55 56 /// Utility to print the help string for a specific option. 57 static void printOptionHelp(StringRef arg, StringRef desc, size_t indent, 58 size_t descIndent, bool isTopLevel) { 59 size_t numSpaces = descIndent - indent - 4; 60 llvm::outs().indent(indent) 61 << "--" << llvm::left_justify(arg, numSpaces) << "- " << desc << '\n'; 62 } 63 64 //===----------------------------------------------------------------------===// 65 // PassRegistry 66 //===----------------------------------------------------------------------===// 67 68 /// Print the help information for this pass. This includes the argument, 69 /// description, and any pass options. `descIndent` is the indent that the 70 /// descriptions should be aligned. 71 void PassRegistryEntry::printHelpStr(size_t indent, size_t descIndent) const { 72 printOptionHelp(getPassArgument(), getPassDescription(), indent, descIndent, 73 /*isTopLevel=*/true); 74 // If this entry has options, print the help for those as well. 75 optHandler([=](const PassOptions &options) { 76 options.printHelp(indent, descIndent); 77 }); 78 } 79 80 /// Return the maximum width required when printing the options of this 81 /// entry. 82 size_t PassRegistryEntry::getOptionWidth() const { 83 size_t maxLen = 0; 84 optHandler([&](const PassOptions &options) mutable { 85 maxLen = options.getOptionWidth() + 2; 86 }); 87 return maxLen; 88 } 89 90 //===----------------------------------------------------------------------===// 91 // PassPipelineInfo 92 //===----------------------------------------------------------------------===// 93 94 void mlir::registerPassPipeline( 95 StringRef arg, StringRef description, const PassRegistryFunction &function, 96 std::function<void(function_ref<void(const PassOptions &)>)> optHandler) { 97 PassPipelineInfo pipelineInfo(arg, description, function, 98 std::move(optHandler)); 99 bool inserted = passPipelineRegistry->try_emplace(arg, pipelineInfo).second; 100 assert(inserted && "Pass pipeline registered multiple times"); 101 (void)inserted; 102 } 103 104 //===----------------------------------------------------------------------===// 105 // PassInfo 106 //===----------------------------------------------------------------------===// 107 108 PassInfo::PassInfo(StringRef arg, StringRef description, 109 const PassAllocatorFunction &allocator) 110 : PassRegistryEntry( 111 arg, description, buildDefaultRegistryFn(allocator), 112 // Use a temporary pass to provide an options instance. 113 [=](function_ref<void(const PassOptions &)> optHandler) { 114 optHandler(allocator()->passOptions); 115 }) {} 116 117 void mlir::registerPass(const PassAllocatorFunction &function) { 118 std::unique_ptr<Pass> pass = function(); 119 StringRef arg = pass->getArgument(); 120 if (arg.empty()) 121 llvm::report_fatal_error(llvm::Twine("Trying to register '") + 122 pass->getName() + 123 "' pass that does not override `getArgument()`"); 124 StringRef description = pass->getDescription(); 125 PassInfo passInfo(arg, description, function); 126 passRegistry->try_emplace(arg, passInfo); 127 128 // Verify that the registered pass has the same ID as any registered to this 129 // arg before it. 130 TypeID entryTypeID = pass->getTypeID(); 131 auto it = passRegistryTypeIDs->try_emplace(arg, entryTypeID).first; 132 if (it->second != entryTypeID) 133 llvm::report_fatal_error( 134 "pass allocator creates a different pass than previously " 135 "registered for pass " + 136 arg); 137 } 138 139 /// Returns the pass info for the specified pass argument or null if unknown. 140 const PassInfo *mlir::Pass::lookupPassInfo(StringRef passArg) { 141 auto it = passRegistry->find(passArg); 142 return it == passRegistry->end() ? nullptr : &it->second; 143 } 144 145 //===----------------------------------------------------------------------===// 146 // PassOptions 147 //===----------------------------------------------------------------------===// 148 149 LogicalResult detail::pass_options::parseCommaSeparatedList( 150 llvm::cl::Option &opt, StringRef argName, StringRef optionStr, 151 function_ref<LogicalResult(StringRef)> elementParseFn) { 152 // Functor used for finding a character in a string, and skipping over 153 // various "range" characters. 154 llvm::unique_function<size_t(StringRef, size_t, char)> findChar = 155 [&](StringRef str, size_t index, char c) -> size_t { 156 for (size_t i = index, e = str.size(); i < e; ++i) { 157 if (str[i] == c) 158 return i; 159 // Check for various range characters. 160 if (str[i] == '{') 161 i = findChar(str, i + 1, '}'); 162 else if (str[i] == '(') 163 i = findChar(str, i + 1, ')'); 164 else if (str[i] == '[') 165 i = findChar(str, i + 1, ']'); 166 else if (str[i] == '\"') 167 i = str.find_first_of('\"', i + 1); 168 else if (str[i] == '\'') 169 i = str.find_first_of('\'', i + 1); 170 } 171 return StringRef::npos; 172 }; 173 174 size_t nextElePos = findChar(optionStr, 0, ','); 175 while (nextElePos != StringRef::npos) { 176 // Process the portion before the comma. 177 if (failed(elementParseFn(optionStr.substr(0, nextElePos)))) 178 return failure(); 179 180 optionStr = optionStr.substr(nextElePos + 1); 181 nextElePos = findChar(optionStr, 0, ','); 182 } 183 return elementParseFn(optionStr.substr(0, nextElePos)); 184 } 185 186 /// Out of line virtual function to provide home for the class. 187 void detail::PassOptions::OptionBase::anchor() {} 188 189 /// Copy the option values from 'other'. 190 void detail::PassOptions::copyOptionValuesFrom(const PassOptions &other) { 191 assert(options.size() == other.options.size()); 192 if (options.empty()) 193 return; 194 for (auto optionsIt : llvm::zip(options, other.options)) 195 std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt)); 196 } 197 198 /// Parse in the next argument from the given options string. Returns a tuple 199 /// containing [the key of the option, the value of the option, updated 200 /// `options` string pointing after the parsed option]. 201 static std::tuple<StringRef, StringRef, StringRef> 202 parseNextArg(StringRef options) { 203 // Functor used to extract an argument from 'options' and update it to point 204 // after the arg. 205 auto extractArgAndUpdateOptions = [&](size_t argSize) { 206 StringRef str = options.take_front(argSize).trim(); 207 options = options.drop_front(argSize).ltrim(); 208 return str; 209 }; 210 // Try to process the given punctuation, properly escaping any contained 211 // characters. 212 auto tryProcessPunct = [&](size_t ¤tPos, char punct) { 213 if (options[currentPos] != punct) 214 return false; 215 size_t nextIt = options.find_first_of(punct, currentPos + 1); 216 if (nextIt != StringRef::npos) 217 currentPos = nextIt; 218 return true; 219 }; 220 221 // Parse the argument name of the option. 222 StringRef argName; 223 for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) { 224 // Check for the end of the full option. 225 if (argEndIt == optionsE || options[argEndIt] == ' ') { 226 argName = extractArgAndUpdateOptions(argEndIt); 227 return std::make_tuple(argName, StringRef(), options); 228 } 229 230 // Check for the end of the name and the start of the value. 231 if (options[argEndIt] == '=') { 232 argName = extractArgAndUpdateOptions(argEndIt); 233 options = options.drop_front(); 234 break; 235 } 236 } 237 238 // Parse the value of the option. 239 for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) { 240 // Handle the end of the options string. 241 if (argEndIt == optionsE || options[argEndIt] == ' ') { 242 StringRef value = extractArgAndUpdateOptions(argEndIt); 243 return std::make_tuple(argName, value, options); 244 } 245 246 // Skip over escaped sequences. 247 char c = options[argEndIt]; 248 if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"')) 249 continue; 250 // '{...}' is used to specify options to passes, properly escape it so 251 // that we don't accidentally split any nested options. 252 if (c == '{') { 253 size_t braceCount = 1; 254 for (++argEndIt; argEndIt != optionsE; ++argEndIt) { 255 // Allow nested punctuation. 256 if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"')) 257 continue; 258 if (options[argEndIt] == '{') 259 ++braceCount; 260 else if (options[argEndIt] == '}' && --braceCount == 0) 261 break; 262 } 263 // Account for the increment at the top of the loop. 264 --argEndIt; 265 } 266 } 267 llvm_unreachable("unexpected control flow in pass option parsing"); 268 } 269 270 LogicalResult detail::PassOptions::parseFromString(StringRef options) { 271 // NOTE: `options` is modified in place to always refer to the unprocessed 272 // part of the string. 273 while (!options.empty()) { 274 StringRef key, value; 275 std::tie(key, value, options) = parseNextArg(options); 276 if (key.empty()) 277 continue; 278 279 auto it = OptionsMap.find(key); 280 if (it == OptionsMap.end()) { 281 llvm::errs() << "<Pass-Options-Parser>: no such option " << key << "\n"; 282 return failure(); 283 } 284 if (llvm::cl::ProvidePositionalOption(it->second, value, 0)) 285 return failure(); 286 } 287 288 return success(); 289 } 290 291 /// Print the options held by this struct in a form that can be parsed via 292 /// 'parseFromString'. 293 void detail::PassOptions::print(raw_ostream &os) { 294 // If there are no options, there is nothing left to do. 295 if (OptionsMap.empty()) 296 return; 297 298 // Sort the options to make the ordering deterministic. 299 SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end()); 300 auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) { 301 return (*lhs)->getArgStr().compare((*rhs)->getArgStr()); 302 }; 303 llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs); 304 305 // Interleave the options with ' '. 306 os << '{'; 307 llvm::interleave( 308 orderedOps, os, [&](OptionBase *option) { option->print(os); }, " "); 309 os << '}'; 310 } 311 312 /// Print the help string for the options held by this struct. `descIndent` is 313 /// the indent within the stream that the descriptions should be aligned. 314 void detail::PassOptions::printHelp(size_t indent, size_t descIndent) const { 315 // Sort the options to make the ordering deterministic. 316 SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end()); 317 auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) { 318 return (*lhs)->getArgStr().compare((*rhs)->getArgStr()); 319 }; 320 llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs); 321 for (OptionBase *option : orderedOps) { 322 // TODO: printOptionInfo assumes a specific indent and will 323 // print options with values with incorrect indentation. We should add 324 // support to llvm::cl::Option for passing in a base indent to use when 325 // printing. 326 llvm::outs().indent(indent); 327 option->getOption()->printOptionInfo(descIndent - indent); 328 } 329 } 330 331 /// Return the maximum width required when printing the help string. 332 size_t detail::PassOptions::getOptionWidth() const { 333 size_t max = 0; 334 for (auto *option : options) 335 max = std::max(max, option->getOption()->getOptionWidth()); 336 return max; 337 } 338 339 //===----------------------------------------------------------------------===// 340 // MLIR Options 341 //===----------------------------------------------------------------------===// 342 343 //===----------------------------------------------------------------------===// 344 // OpPassManager: OptionValue 345 346 llvm::cl::OptionValue<OpPassManager>::OptionValue() = default; 347 llvm::cl::OptionValue<OpPassManager>::OptionValue( 348 const mlir::OpPassManager &value) { 349 setValue(value); 350 } 351 llvm::cl::OptionValue<OpPassManager> & 352 llvm::cl::OptionValue<OpPassManager>::operator=( 353 const mlir::OpPassManager &rhs) { 354 setValue(rhs); 355 return *this; 356 } 357 358 llvm::cl::OptionValue<OpPassManager>::~OptionValue<OpPassManager>() = default; 359 360 void llvm::cl::OptionValue<OpPassManager>::setValue( 361 const OpPassManager &newValue) { 362 if (hasValue()) 363 *value = newValue; 364 else 365 value = std::make_unique<mlir::OpPassManager>(newValue); 366 } 367 void llvm::cl::OptionValue<OpPassManager>::setValue(StringRef pipelineStr) { 368 FailureOr<OpPassManager> pipeline = parsePassPipeline(pipelineStr); 369 assert(succeeded(pipeline) && "invalid pass pipeline"); 370 setValue(*pipeline); 371 } 372 373 bool llvm::cl::OptionValue<OpPassManager>::compare( 374 const mlir::OpPassManager &rhs) const { 375 std::string lhsStr, rhsStr; 376 { 377 raw_string_ostream lhsStream(lhsStr); 378 value->printAsTextualPipeline(lhsStream); 379 380 raw_string_ostream rhsStream(rhsStr); 381 rhs.printAsTextualPipeline(rhsStream); 382 } 383 384 // Use the textual format for pipeline comparisons. 385 return lhsStr == rhsStr; 386 } 387 388 void llvm::cl::OptionValue<OpPassManager>::anchor() {} 389 390 //===----------------------------------------------------------------------===// 391 // OpPassManager: Parser 392 393 namespace llvm { 394 namespace cl { 395 template class basic_parser<OpPassManager>; 396 } // namespace cl 397 } // namespace llvm 398 399 bool llvm::cl::parser<OpPassManager>::parse(Option &, StringRef, StringRef arg, 400 ParsedPassManager &value) { 401 FailureOr<OpPassManager> pipeline = parsePassPipeline(arg); 402 if (failed(pipeline)) 403 return true; 404 value.value = std::make_unique<OpPassManager>(std::move(*pipeline)); 405 return false; 406 } 407 408 void llvm::cl::parser<OpPassManager>::print(raw_ostream &os, 409 const OpPassManager &value) { 410 value.printAsTextualPipeline(os); 411 } 412 413 void llvm::cl::parser<OpPassManager>::printOptionDiff( 414 const Option &opt, OpPassManager &pm, const OptVal &defaultValue, 415 size_t globalWidth) const { 416 printOptionName(opt, globalWidth); 417 outs() << "= "; 418 pm.printAsTextualPipeline(outs()); 419 420 if (defaultValue.hasValue()) { 421 outs().indent(2) << " (default: "; 422 defaultValue.getValue().printAsTextualPipeline(outs()); 423 outs() << ")"; 424 } 425 outs() << "\n"; 426 } 427 428 void llvm::cl::parser<OpPassManager>::anchor() {} 429 430 llvm::cl::parser<OpPassManager>::ParsedPassManager::ParsedPassManager() = 431 default; 432 llvm::cl::parser<OpPassManager>::ParsedPassManager::ParsedPassManager( 433 ParsedPassManager &&) = default; 434 llvm::cl::parser<OpPassManager>::ParsedPassManager::~ParsedPassManager() = 435 default; 436 437 //===----------------------------------------------------------------------===// 438 // TextualPassPipeline Parser 439 //===----------------------------------------------------------------------===// 440 441 namespace { 442 /// This class represents a textual description of a pass pipeline. 443 class TextualPipeline { 444 public: 445 /// Try to initialize this pipeline with the given pipeline text. 446 /// `errorStream` is the output stream to emit errors to. 447 LogicalResult initialize(StringRef text, raw_ostream &errorStream); 448 449 /// Add the internal pipeline elements to the provided pass manager. 450 LogicalResult 451 addToPipeline(OpPassManager &pm, 452 function_ref<LogicalResult(const Twine &)> errorHandler) const; 453 454 private: 455 /// A functor used to emit errors found during pipeline handling. The first 456 /// parameter corresponds to the raw location within the pipeline string. This 457 /// should always return failure. 458 using ErrorHandlerT = function_ref<LogicalResult(const char *, Twine)>; 459 460 /// A struct to capture parsed pass pipeline names. 461 /// 462 /// A pipeline is defined as a series of names, each of which may in itself 463 /// recursively contain a nested pipeline. A name is either the name of a pass 464 /// (e.g. "cse") or the name of an operation type (e.g. "buitin.module"). If 465 /// the name is the name of a pass, the InnerPipeline is empty, since passes 466 /// cannot contain inner pipelines. 467 struct PipelineElement { 468 PipelineElement(StringRef name) : name(name) {} 469 470 StringRef name; 471 StringRef options; 472 const PassRegistryEntry *registryEntry = nullptr; 473 std::vector<PipelineElement> innerPipeline; 474 }; 475 476 /// Parse the given pipeline text into the internal pipeline vector. This 477 /// function only parses the structure of the pipeline, and does not resolve 478 /// its elements. 479 LogicalResult parsePipelineText(StringRef text, ErrorHandlerT errorHandler); 480 481 /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to 482 /// the corresponding registry entry. 483 LogicalResult 484 resolvePipelineElements(MutableArrayRef<PipelineElement> elements, 485 ErrorHandlerT errorHandler); 486 487 /// Resolve a single element of the pipeline. 488 LogicalResult resolvePipelineElement(PipelineElement &element, 489 ErrorHandlerT errorHandler); 490 491 /// Add the given pipeline elements to the provided pass manager. 492 LogicalResult 493 addToPipeline(ArrayRef<PipelineElement> elements, OpPassManager &pm, 494 function_ref<LogicalResult(const Twine &)> errorHandler) const; 495 496 std::vector<PipelineElement> pipeline; 497 }; 498 499 } // namespace 500 501 /// Try to initialize this pipeline with the given pipeline text. An option is 502 /// given to enable accurate error reporting. 503 LogicalResult TextualPipeline::initialize(StringRef text, 504 raw_ostream &errorStream) { 505 if (text.empty()) 506 return success(); 507 508 // Build a source manager to use for error reporting. 509 llvm::SourceMgr pipelineMgr; 510 pipelineMgr.AddNewSourceBuffer( 511 llvm::MemoryBuffer::getMemBuffer(text, "MLIR Textual PassPipeline Parser", 512 /*RequiresNullTerminator=*/false), 513 SMLoc()); 514 auto errorHandler = [&](const char *rawLoc, Twine msg) { 515 pipelineMgr.PrintMessage(errorStream, SMLoc::getFromPointer(rawLoc), 516 llvm::SourceMgr::DK_Error, msg); 517 return failure(); 518 }; 519 520 // Parse the provided pipeline string. 521 if (failed(parsePipelineText(text, errorHandler))) 522 return failure(); 523 return resolvePipelineElements(pipeline, errorHandler); 524 } 525 526 /// Add the internal pipeline elements to the provided pass manager. 527 LogicalResult TextualPipeline::addToPipeline( 528 OpPassManager &pm, 529 function_ref<LogicalResult(const Twine &)> errorHandler) const { 530 return addToPipeline(pipeline, pm, errorHandler); 531 } 532 533 /// Parse the given pipeline text into the internal pipeline vector. This 534 /// function only parses the structure of the pipeline, and does not resolve 535 /// its elements. 536 LogicalResult TextualPipeline::parsePipelineText(StringRef text, 537 ErrorHandlerT errorHandler) { 538 SmallVector<std::vector<PipelineElement> *, 4> pipelineStack = {&pipeline}; 539 for (;;) { 540 std::vector<PipelineElement> &pipeline = *pipelineStack.back(); 541 size_t pos = text.find_first_of(",(){"); 542 pipeline.emplace_back(/*name=*/text.substr(0, pos).trim()); 543 544 // If we have a single terminating name, we're done. 545 if (pos == StringRef::npos) 546 break; 547 548 text = text.substr(pos); 549 char sep = text[0]; 550 551 // Handle pulling ... from 'pass{...}' out as PipelineElement.options. 552 if (sep == '{') { 553 text = text.substr(1); 554 555 // Skip over everything until the closing '}' and store as options. 556 size_t close = StringRef::npos; 557 for (unsigned i = 0, e = text.size(), braceCount = 1; i < e; ++i) { 558 if (text[i] == '{') { 559 ++braceCount; 560 continue; 561 } 562 if (text[i] == '}' && --braceCount == 0) { 563 close = i; 564 break; 565 } 566 } 567 568 // Check to see if a closing options brace was found. 569 if (close == StringRef::npos) { 570 return errorHandler( 571 /*rawLoc=*/text.data() - 1, 572 "missing closing '}' while processing pass options"); 573 } 574 pipeline.back().options = text.substr(0, close); 575 text = text.substr(close + 1); 576 577 // Skip checking for '(' because nested pipelines cannot have options. 578 } else if (sep == '(') { 579 text = text.substr(1); 580 581 // Push the inner pipeline onto the stack to continue processing. 582 pipelineStack.push_back(&pipeline.back().innerPipeline); 583 continue; 584 } 585 586 // When handling the close parenthesis, we greedily consume them to avoid 587 // empty strings in the pipeline. 588 while (text.consume_front(")")) { 589 // If we try to pop the outer pipeline we have unbalanced parentheses. 590 if (pipelineStack.size() == 1) 591 return errorHandler(/*rawLoc=*/text.data() - 1, 592 "encountered extra closing ')' creating unbalanced " 593 "parentheses while parsing pipeline"); 594 595 pipelineStack.pop_back(); 596 } 597 598 // Check if we've finished parsing. 599 if (text.empty()) 600 break; 601 602 // Otherwise, the end of an inner pipeline always has to be followed by 603 // a comma, and then we can continue. 604 if (!text.consume_front(",")) 605 return errorHandler(text.data(), "expected ',' after parsing pipeline"); 606 } 607 608 // Check for unbalanced parentheses. 609 if (pipelineStack.size() > 1) 610 return errorHandler( 611 text.data(), 612 "encountered unbalanced parentheses while parsing pipeline"); 613 614 assert(pipelineStack.back() == &pipeline && 615 "wrong pipeline at the bottom of the stack"); 616 return success(); 617 } 618 619 /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to 620 /// the corresponding registry entry. 621 LogicalResult TextualPipeline::resolvePipelineElements( 622 MutableArrayRef<PipelineElement> elements, ErrorHandlerT errorHandler) { 623 for (auto &elt : elements) 624 if (failed(resolvePipelineElement(elt, errorHandler))) 625 return failure(); 626 return success(); 627 } 628 629 /// Resolve a single element of the pipeline. 630 LogicalResult 631 TextualPipeline::resolvePipelineElement(PipelineElement &element, 632 ErrorHandlerT errorHandler) { 633 // If the inner pipeline of this element is not empty, this is an operation 634 // pipeline. 635 if (!element.innerPipeline.empty()) 636 return resolvePipelineElements(element.innerPipeline, errorHandler); 637 // Otherwise, this must be a pass or pass pipeline. 638 // Check to see if a pipeline was registered with this name. 639 auto pipelineRegistryIt = passPipelineRegistry->find(element.name); 640 if (pipelineRegistryIt != passPipelineRegistry->end()) { 641 element.registryEntry = &pipelineRegistryIt->second; 642 return success(); 643 } 644 645 // If not, then this must be a specific pass name. 646 if ((element.registryEntry = Pass::lookupPassInfo(element.name))) 647 return success(); 648 649 // Emit an error for the unknown pass. 650 auto *rawLoc = element.name.data(); 651 return errorHandler(rawLoc, "'" + element.name + 652 "' does not refer to a " 653 "registered pass or pass pipeline"); 654 } 655 656 /// Add the given pipeline elements to the provided pass manager. 657 LogicalResult TextualPipeline::addToPipeline( 658 ArrayRef<PipelineElement> elements, OpPassManager &pm, 659 function_ref<LogicalResult(const Twine &)> errorHandler) const { 660 for (auto &elt : elements) { 661 if (elt.registryEntry) { 662 if (failed(elt.registryEntry->addToPipeline(pm, elt.options, 663 errorHandler))) { 664 return errorHandler("failed to add `" + elt.name + "` with options `" + 665 elt.options + "`"); 666 } 667 } else if (failed(addToPipeline(elt.innerPipeline, pm.nest(elt.name), 668 errorHandler))) { 669 return errorHandler("failed to add `" + elt.name + "` with options `" + 670 elt.options + "` to inner pipeline"); 671 } 672 } 673 return success(); 674 } 675 676 LogicalResult mlir::parsePassPipeline(StringRef pipeline, OpPassManager &pm, 677 raw_ostream &errorStream) { 678 TextualPipeline pipelineParser; 679 if (failed(pipelineParser.initialize(pipeline, errorStream))) 680 return failure(); 681 auto errorHandler = [&](Twine msg) { 682 errorStream << msg << "\n"; 683 return failure(); 684 }; 685 if (failed(pipelineParser.addToPipeline(pm, errorHandler))) 686 return failure(); 687 return success(); 688 } 689 690 FailureOr<OpPassManager> mlir::parsePassPipeline(StringRef pipeline, 691 raw_ostream &errorStream) { 692 // Pipelines are expected to be of the form `<op-name>(<pipeline>)`. 693 size_t pipelineStart = pipeline.find_first_of('('); 694 if (pipelineStart == 0 || pipelineStart == StringRef::npos || 695 !pipeline.consume_back(")")) { 696 errorStream << "expected pass pipeline to be wrapped with the anchor " 697 "operation type, e.g. `builtin.module(...)"; 698 return failure(); 699 } 700 701 StringRef opName = pipeline.take_front(pipelineStart); 702 OpPassManager pm(opName); 703 if (failed(parsePassPipeline(pipeline.drop_front(1 + pipelineStart), pm))) 704 return failure(); 705 return pm; 706 } 707 708 //===----------------------------------------------------------------------===// 709 // PassNameParser 710 //===----------------------------------------------------------------------===// 711 712 namespace { 713 /// This struct represents the possible data entries in a parsed pass pipeline 714 /// list. 715 struct PassArgData { 716 PassArgData() = default; 717 PassArgData(const PassRegistryEntry *registryEntry) 718 : registryEntry(registryEntry) {} 719 720 /// This field is used when the parsed option corresponds to a registered pass 721 /// or pass pipeline. 722 const PassRegistryEntry *registryEntry{nullptr}; 723 724 /// This field is set when instance specific pass options have been provided 725 /// on the command line. 726 StringRef options; 727 728 /// This field is used when the parsed option corresponds to an explicit 729 /// pipeline. 730 TextualPipeline pipeline; 731 }; 732 } // namespace 733 734 namespace llvm { 735 namespace cl { 736 /// Define a valid OptionValue for the command line pass argument. 737 template <> 738 struct OptionValue<PassArgData> final 739 : OptionValueBase<PassArgData, /*isClass=*/true> { 740 OptionValue(const PassArgData &value) { this->setValue(value); } 741 OptionValue() = default; 742 void anchor() override {} 743 744 bool hasValue() const { return true; } 745 const PassArgData &getValue() const { return value; } 746 void setValue(const PassArgData &value) { this->value = value; } 747 748 PassArgData value; 749 }; 750 } // namespace cl 751 } // namespace llvm 752 753 namespace { 754 755 /// The name for the command line option used for parsing the textual pass 756 /// pipeline. 757 static constexpr StringLiteral passPipelineArg = "pass-pipeline"; 758 759 /// Adds command line option for each registered pass or pass pipeline, as well 760 /// as textual pass pipelines. 761 struct PassNameParser : public llvm::cl::parser<PassArgData> { 762 PassNameParser(llvm::cl::Option &opt) : llvm::cl::parser<PassArgData>(opt) {} 763 764 void initialize(); 765 void printOptionInfo(const llvm::cl::Option &opt, 766 size_t globalWidth) const override; 767 size_t getOptionWidth(const llvm::cl::Option &opt) const override; 768 bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg, 769 PassArgData &value); 770 771 /// If true, this parser only parses entries that correspond to a concrete 772 /// pass registry entry, and does not add a `pass-pipeline` argument, does not 773 /// include the options for pass entries, and does not include pass pipelines 774 /// entries. 775 bool passNamesOnly = false; 776 }; 777 } // namespace 778 779 void PassNameParser::initialize() { 780 llvm::cl::parser<PassArgData>::initialize(); 781 782 /// Add an entry for the textual pass pipeline option. 783 if (!passNamesOnly) { 784 addLiteralOption(passPipelineArg, PassArgData(), 785 "A textual description of a pass pipeline to run"); 786 } 787 788 /// Add the pass entries. 789 for (const auto &kv : *passRegistry) { 790 addLiteralOption(kv.second.getPassArgument(), &kv.second, 791 kv.second.getPassDescription()); 792 } 793 /// Add the pass pipeline entries. 794 if (!passNamesOnly) { 795 for (const auto &kv : *passPipelineRegistry) { 796 addLiteralOption(kv.second.getPassArgument(), &kv.second, 797 kv.second.getPassDescription()); 798 } 799 } 800 } 801 802 void PassNameParser::printOptionInfo(const llvm::cl::Option &opt, 803 size_t globalWidth) const { 804 // If this parser is just parsing pass names, print a simplified option 805 // string. 806 if (passNamesOnly) { 807 llvm::outs() << " --" << opt.ArgStr << "=<pass-arg>"; 808 opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 18); 809 return; 810 } 811 812 // Print the information for the top-level option. 813 if (opt.hasArgStr()) { 814 llvm::outs() << " --" << opt.ArgStr; 815 opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 7); 816 } else { 817 llvm::outs() << " " << opt.HelpStr << '\n'; 818 } 819 820 // Print the top-level pipeline argument. 821 printOptionHelp(passPipelineArg, 822 "A textual description of a pass pipeline to run", 823 /*indent=*/4, globalWidth, /*isTopLevel=*/!opt.hasArgStr()); 824 825 // Functor used to print the ordered entries of a registration map. 826 auto printOrderedEntries = [&](StringRef header, auto &map) { 827 llvm::SmallVector<PassRegistryEntry *, 32> orderedEntries; 828 for (auto &kv : map) 829 orderedEntries.push_back(&kv.second); 830 llvm::array_pod_sort( 831 orderedEntries.begin(), orderedEntries.end(), 832 [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) { 833 return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument()); 834 }); 835 836 llvm::outs().indent(4) << header << ":\n"; 837 for (PassRegistryEntry *entry : orderedEntries) 838 entry->printHelpStr(/*indent=*/6, globalWidth); 839 }; 840 841 // Print the available passes. 842 printOrderedEntries("Passes", *passRegistry); 843 844 // Print the available pass pipelines. 845 if (!passPipelineRegistry->empty()) 846 printOrderedEntries("Pass Pipelines", *passPipelineRegistry); 847 } 848 849 size_t PassNameParser::getOptionWidth(const llvm::cl::Option &opt) const { 850 size_t maxWidth = llvm::cl::parser<PassArgData>::getOptionWidth(opt) + 2; 851 852 // Check for any wider pass or pipeline options. 853 for (auto &entry : *passRegistry) 854 maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4); 855 for (auto &entry : *passPipelineRegistry) 856 maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4); 857 return maxWidth; 858 } 859 860 bool PassNameParser::parse(llvm::cl::Option &opt, StringRef argName, 861 StringRef arg, PassArgData &value) { 862 // Handle the pipeline option explicitly. 863 if (argName == passPipelineArg) 864 return failed(value.pipeline.initialize(arg, llvm::errs())); 865 866 // Otherwise, default to the base for handling. 867 if (llvm::cl::parser<PassArgData>::parse(opt, argName, arg, value)) 868 return true; 869 value.options = arg; 870 return false; 871 } 872 873 //===----------------------------------------------------------------------===// 874 // PassPipelineCLParser 875 //===----------------------------------------------------------------------===// 876 877 namespace mlir { 878 namespace detail { 879 struct PassPipelineCLParserImpl { 880 PassPipelineCLParserImpl(StringRef arg, StringRef description, 881 bool passNamesOnly) 882 : passList(arg, llvm::cl::desc(description)) { 883 passList.getParser().passNamesOnly = passNamesOnly; 884 passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional); 885 } 886 887 /// Returns true if the given pass registry entry was registered at the 888 /// top-level of the parser, i.e. not within an explicit textual pipeline. 889 bool contains(const PassRegistryEntry *entry) const { 890 return llvm::any_of(passList, [&](const PassArgData &data) { 891 return data.registryEntry == entry; 892 }); 893 } 894 895 /// The set of passes and pass pipelines to run. 896 llvm::cl::list<PassArgData, bool, PassNameParser> passList; 897 }; 898 } // namespace detail 899 } // namespace mlir 900 901 /// Construct a pass pipeline parser with the given command line description. 902 PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description) 903 : impl(std::make_unique<detail::PassPipelineCLParserImpl>( 904 arg, description, /*passNamesOnly=*/false)) {} 905 PassPipelineCLParser::~PassPipelineCLParser() = default; 906 907 /// Returns true if this parser contains any valid options to add. 908 bool PassPipelineCLParser::hasAnyOccurrences() const { 909 return impl->passList.getNumOccurrences() != 0; 910 } 911 912 /// Returns true if the given pass registry entry was registered at the 913 /// top-level of the parser, i.e. not within an explicit textual pipeline. 914 bool PassPipelineCLParser::contains(const PassRegistryEntry *entry) const { 915 return impl->contains(entry); 916 } 917 918 /// Adds the passes defined by this parser entry to the given pass manager. 919 LogicalResult PassPipelineCLParser::addToPipeline( 920 OpPassManager &pm, 921 function_ref<LogicalResult(const Twine &)> errorHandler) const { 922 for (auto &passIt : impl->passList) { 923 if (passIt.registryEntry) { 924 if (failed(passIt.registryEntry->addToPipeline(pm, passIt.options, 925 errorHandler))) 926 return failure(); 927 } else { 928 OpPassManager::Nesting nesting = pm.getNesting(); 929 pm.setNesting(OpPassManager::Nesting::Explicit); 930 LogicalResult status = passIt.pipeline.addToPipeline(pm, errorHandler); 931 pm.setNesting(nesting); 932 if (failed(status)) 933 return failure(); 934 } 935 } 936 return success(); 937 } 938 939 //===----------------------------------------------------------------------===// 940 // PassNameCLParser 941 942 /// Construct a pass pipeline parser with the given command line description. 943 PassNameCLParser::PassNameCLParser(StringRef arg, StringRef description) 944 : impl(std::make_unique<detail::PassPipelineCLParserImpl>( 945 arg, description, /*passNamesOnly=*/true)) { 946 impl->passList.setMiscFlag(llvm::cl::CommaSeparated); 947 } 948 PassNameCLParser::~PassNameCLParser() = default; 949 950 /// Returns true if this parser contains any valid options to add. 951 bool PassNameCLParser::hasAnyOccurrences() const { 952 return impl->passList.getNumOccurrences() != 0; 953 } 954 955 /// Returns true if the given pass registry entry was registered at the 956 /// top-level of the parser, i.e. not within an explicit textual pipeline. 957 bool PassNameCLParser::contains(const PassRegistryEntry *entry) const { 958 return impl->contains(entry); 959 } 960