1 //===- PassCrashRecovery.cpp - Pass Crash Recovery Implementation ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/IR/Diagnostics.h" 11 #include "mlir/IR/Dialect.h" 12 #include "mlir/IR/SymbolTable.h" 13 #include "mlir/IR/Verifier.h" 14 #include "mlir/Parser/Parser.h" 15 #include "mlir/Pass/Pass.h" 16 #include "mlir/Support/FileUtilities.h" 17 #include "llvm/ADT/STLExtras.h" 18 #include "llvm/ADT/ScopeExit.h" 19 #include "llvm/ADT/SetVector.h" 20 #include "llvm/Support/CommandLine.h" 21 #include "llvm/Support/CrashRecoveryContext.h" 22 #include "llvm/Support/Mutex.h" 23 #include "llvm/Support/Signals.h" 24 #include "llvm/Support/Threading.h" 25 #include "llvm/Support/ToolOutputFile.h" 26 27 using namespace mlir; 28 using namespace mlir::detail; 29 30 //===----------------------------------------------------------------------===// 31 // RecoveryReproducerContext 32 //===----------------------------------------------------------------------===// 33 34 namespace mlir { 35 namespace detail { 36 /// This class contains all of the context for generating a recovery reproducer. 37 /// Each recovery context is registered globally to allow for generating 38 /// reproducers when a signal is raised, such as a segfault. 39 struct RecoveryReproducerContext { 40 RecoveryReproducerContext(std::string passPipelineStr, Operation *op, 41 PassManager::ReproducerStreamFactory &streamFactory, 42 bool verifyPasses); 43 ~RecoveryReproducerContext(); 44 45 /// Generate a reproducer with the current context. 46 void generate(std::string &description); 47 48 /// Disable this reproducer context. This prevents the context from generating 49 /// a reproducer in the result of a crash. 50 void disable(); 51 52 /// Enable a previously disabled reproducer context. 53 void enable(); 54 55 private: 56 /// This function is invoked in the event of a crash. 57 static void crashHandler(void *); 58 59 /// Register a signal handler to run in the event of a crash. 60 static void registerSignalHandler(); 61 62 /// The textual description of the currently executing pipeline. 63 std::string pipeline; 64 65 /// The MLIR operation representing the IR before the crash. 66 Operation *preCrashOperation; 67 68 /// The factory for the reproducer output stream to use when generating the 69 /// reproducer. 70 PassManager::ReproducerStreamFactory &streamFactory; 71 72 /// Various pass manager and context flags. 73 bool disableThreads; 74 bool verifyPasses; 75 76 /// The current set of active reproducer contexts. This is used in the event 77 /// of a crash. This is not thread_local as the pass manager may produce any 78 /// number of child threads. This uses a set to allow for multiple MLIR pass 79 /// managers to be running at the same time. 80 static llvm::ManagedStatic<llvm::sys::SmartMutex<true>> reproducerMutex; 81 static llvm::ManagedStatic< 82 llvm::SmallSetVector<RecoveryReproducerContext *, 1>> 83 reproducerSet; 84 }; 85 } // namespace detail 86 } // namespace mlir 87 88 llvm::ManagedStatic<llvm::sys::SmartMutex<true>> 89 RecoveryReproducerContext::reproducerMutex; 90 llvm::ManagedStatic<llvm::SmallSetVector<RecoveryReproducerContext *, 1>> 91 RecoveryReproducerContext::reproducerSet; 92 93 RecoveryReproducerContext::RecoveryReproducerContext( 94 std::string passPipelineStr, Operation *op, 95 PassManager::ReproducerStreamFactory &streamFactory, bool verifyPasses) 96 : pipeline(std::move(passPipelineStr)), preCrashOperation(op->clone()), 97 streamFactory(streamFactory), 98 disableThreads(!op->getContext()->isMultithreadingEnabled()), 99 verifyPasses(verifyPasses) { 100 enable(); 101 } 102 103 RecoveryReproducerContext::~RecoveryReproducerContext() { 104 // Erase the cloned preCrash IR that we cached. 105 preCrashOperation->erase(); 106 disable(); 107 } 108 109 void RecoveryReproducerContext::generate(std::string &description) { 110 llvm::raw_string_ostream descOS(description); 111 112 // Try to create a new output stream for this crash reproducer. 113 std::string error; 114 std::unique_ptr<PassManager::ReproducerStream> stream = streamFactory(error); 115 if (!stream) { 116 descOS << "failed to create output stream: " << error; 117 return; 118 } 119 descOS << "reproducer generated at `" << stream->description() << "`"; 120 121 AsmState state(preCrashOperation); 122 state.attachResourcePrinter( 123 "mlir_reproducer", [&](Operation *op, AsmResourceBuilder &builder) { 124 builder.buildString("pipeline", pipeline); 125 builder.buildBool("disable_threading", disableThreads); 126 builder.buildBool("verify_each", verifyPasses); 127 }); 128 129 // Output the .mlir module. 130 preCrashOperation->print(stream->os(), state); 131 } 132 133 void RecoveryReproducerContext::disable() { 134 llvm::sys::SmartScopedLock<true> lock(*reproducerMutex); 135 reproducerSet->remove(this); 136 if (reproducerSet->empty()) 137 llvm::CrashRecoveryContext::Disable(); 138 } 139 140 void RecoveryReproducerContext::enable() { 141 llvm::sys::SmartScopedLock<true> lock(*reproducerMutex); 142 if (reproducerSet->empty()) 143 llvm::CrashRecoveryContext::Enable(); 144 registerSignalHandler(); 145 reproducerSet->insert(this); 146 } 147 148 void RecoveryReproducerContext::crashHandler(void *) { 149 // Walk the current stack of contexts and generate a reproducer for each one. 150 // We can't know for certain which one was the cause, so we need to generate 151 // a reproducer for all of them. 152 for (RecoveryReproducerContext *context : *reproducerSet) { 153 std::string description; 154 context->generate(description); 155 156 // Emit an error using information only available within the context. 157 emitError(context->preCrashOperation->getLoc()) 158 << "A failure has been detected while processing the MLIR module:" 159 << description; 160 } 161 } 162 163 void RecoveryReproducerContext::registerSignalHandler() { 164 // Ensure that the handler is only registered once. 165 static bool registered = 166 (llvm::sys::AddSignalHandler(crashHandler, nullptr), false); 167 (void)registered; 168 } 169 170 //===----------------------------------------------------------------------===// 171 // PassCrashReproducerGenerator 172 //===----------------------------------------------------------------------===// 173 174 struct PassCrashReproducerGenerator::Impl { 175 Impl(PassManager::ReproducerStreamFactory &streamFactory, 176 bool localReproducer) 177 : streamFactory(streamFactory), localReproducer(localReproducer) {} 178 179 /// The factory to use when generating a crash reproducer. 180 PassManager::ReproducerStreamFactory streamFactory; 181 182 /// Flag indicating if reproducer generation should be localized to the 183 /// failing pass. 184 bool localReproducer = false; 185 186 /// A record of all of the currently active reproducer contexts. 187 SmallVector<std::unique_ptr<RecoveryReproducerContext>> activeContexts; 188 189 /// The set of all currently running passes. Note: This is not populated when 190 /// `localReproducer` is true, as each pass will get its own recovery context. 191 SetVector<std::pair<Pass *, Operation *>> runningPasses; 192 193 /// Various pass manager flags that get emitted when generating a reproducer. 194 bool pmFlagVerifyPasses = false; 195 }; 196 197 PassCrashReproducerGenerator::PassCrashReproducerGenerator( 198 PassManager::ReproducerStreamFactory &streamFactory, bool localReproducer) 199 : impl(std::make_unique<Impl>(streamFactory, localReproducer)) {} 200 PassCrashReproducerGenerator::~PassCrashReproducerGenerator() = default; 201 202 void PassCrashReproducerGenerator::initialize( 203 iterator_range<PassManager::pass_iterator> passes, Operation *op, 204 bool pmFlagVerifyPasses) { 205 assert((!impl->localReproducer || 206 !op->getContext()->isMultithreadingEnabled()) && 207 "expected multi-threading to be disabled when generating a local " 208 "reproducer"); 209 210 llvm::CrashRecoveryContext::Enable(); 211 impl->pmFlagVerifyPasses = pmFlagVerifyPasses; 212 213 // If we aren't generating a local reproducer, prepare a reproducer for the 214 // given top-level operation. 215 if (!impl->localReproducer) 216 prepareReproducerFor(passes, op); 217 } 218 219 static void 220 formatPassOpReproducerMessage(Diagnostic &os, 221 std::pair<Pass *, Operation *> passOpPair) { 222 os << "`" << passOpPair.first->getName() << "` on " 223 << "'" << passOpPair.second->getName() << "' operation"; 224 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(passOpPair.second)) 225 os << ": @" << symbol.getName(); 226 } 227 228 void PassCrashReproducerGenerator::finalize(Operation *rootOp, 229 LogicalResult executionResult) { 230 // Don't generate a reproducer if we have no active contexts. 231 if (impl->activeContexts.empty()) 232 return; 233 234 // If the pass manager execution succeeded, we don't generate any reproducers. 235 if (succeeded(executionResult)) 236 return impl->activeContexts.clear(); 237 238 InFlightDiagnostic diag = emitError(rootOp->getLoc()) 239 << "Failures have been detected while " 240 "processing an MLIR pass pipeline"; 241 242 // If we are generating a global reproducer, we include all of the running 243 // passes in the error message for the only active context. 244 if (!impl->localReproducer) { 245 assert(impl->activeContexts.size() == 1 && "expected one active context"); 246 247 // Generate the reproducer. 248 std::string description; 249 impl->activeContexts.front()->generate(description); 250 251 // Emit an error to the user. 252 Diagnostic ¬e = diag.attachNote() << "Pipeline failed while executing ["; 253 llvm::interleaveComma(impl->runningPasses, note, 254 [&](const std::pair<Pass *, Operation *> &value) { 255 formatPassOpReproducerMessage(note, value); 256 }); 257 note << "]: " << description; 258 return; 259 } 260 261 // If we were generating a local reproducer, we generate a reproducer for the 262 // most recently executing pass using the matching entry from `runningPasses` 263 // to generate a localized diagnostic message. 264 assert(impl->activeContexts.size() == impl->runningPasses.size() && 265 "expected running passes to match active contexts"); 266 267 // Generate the reproducer. 268 RecoveryReproducerContext &reproducerContext = *impl->activeContexts.back(); 269 std::string description; 270 reproducerContext.generate(description); 271 272 // Emit an error to the user. 273 Diagnostic ¬e = diag.attachNote() << "Pipeline failed while executing "; 274 formatPassOpReproducerMessage(note, impl->runningPasses.back()); 275 note << ": " << description; 276 277 impl->activeContexts.clear(); 278 } 279 280 void PassCrashReproducerGenerator::prepareReproducerFor(Pass *pass, 281 Operation *op) { 282 // If not tracking local reproducers, we simply remember that this pass is 283 // running. 284 impl->runningPasses.insert(std::make_pair(pass, op)); 285 if (!impl->localReproducer) 286 return; 287 288 // Disable the current pass recovery context, if there is one. This may happen 289 // in the case of dynamic pass pipelines. 290 if (!impl->activeContexts.empty()) 291 impl->activeContexts.back()->disable(); 292 293 // Collect all of the parent scopes of this operation. 294 SmallVector<OperationName> scopes; 295 while (Operation *parentOp = op->getParentOp()) { 296 scopes.push_back(op->getName()); 297 op = parentOp; 298 } 299 300 // Emit a pass pipeline string for the current pass running on the current 301 // operation type. 302 std::string passStr; 303 llvm::raw_string_ostream passOS(passStr); 304 for (OperationName scope : llvm::reverse(scopes)) 305 passOS << scope << "("; 306 pass->printAsTextualPipeline(passOS); 307 for (unsigned i = 0, e = scopes.size(); i < e; ++i) 308 passOS << ")"; 309 310 impl->activeContexts.push_back(std::make_unique<RecoveryReproducerContext>( 311 passOS.str(), op, impl->streamFactory, impl->pmFlagVerifyPasses)); 312 } 313 void PassCrashReproducerGenerator::prepareReproducerFor( 314 iterator_range<PassManager::pass_iterator> passes, Operation *op) { 315 std::string passStr; 316 llvm::raw_string_ostream passOS(passStr); 317 llvm::interleaveComma( 318 passes, passOS, [&](Pass &pass) { pass.printAsTextualPipeline(passOS); }); 319 320 impl->activeContexts.push_back(std::make_unique<RecoveryReproducerContext>( 321 passOS.str(), op, impl->streamFactory, impl->pmFlagVerifyPasses)); 322 } 323 324 void PassCrashReproducerGenerator::removeLastReproducerFor(Pass *pass, 325 Operation *op) { 326 // We only pop the active context if we are tracking local reproducers. 327 impl->runningPasses.remove(std::make_pair(pass, op)); 328 if (impl->localReproducer) { 329 impl->activeContexts.pop_back(); 330 331 // Re-enable the previous pass recovery context, if there was one. This may 332 // happen in the case of dynamic pass pipelines. 333 if (!impl->activeContexts.empty()) 334 impl->activeContexts.back()->enable(); 335 } 336 } 337 338 //===----------------------------------------------------------------------===// 339 // CrashReproducerInstrumentation 340 //===----------------------------------------------------------------------===// 341 342 namespace { 343 struct CrashReproducerInstrumentation : public PassInstrumentation { 344 CrashReproducerInstrumentation(PassCrashReproducerGenerator &generator) 345 : generator(generator) {} 346 ~CrashReproducerInstrumentation() override = default; 347 348 void runBeforePass(Pass *pass, Operation *op) override { 349 if (!isa<OpToOpPassAdaptor>(pass)) 350 generator.prepareReproducerFor(pass, op); 351 } 352 353 void runAfterPass(Pass *pass, Operation *op) override { 354 if (!isa<OpToOpPassAdaptor>(pass)) 355 generator.removeLastReproducerFor(pass, op); 356 } 357 358 void runAfterPassFailed(Pass *pass, Operation *op) override { 359 generator.finalize(op, /*executionResult=*/failure()); 360 } 361 362 private: 363 /// The generator used to create crash reproducers. 364 PassCrashReproducerGenerator &generator; 365 }; 366 } // namespace 367 368 //===----------------------------------------------------------------------===// 369 // FileReproducerStream 370 //===----------------------------------------------------------------------===// 371 372 namespace { 373 /// This class represents a default instance of PassManager::ReproducerStream 374 /// that is backed by a file. 375 struct FileReproducerStream : public PassManager::ReproducerStream { 376 FileReproducerStream(std::unique_ptr<llvm::ToolOutputFile> outputFile) 377 : outputFile(std::move(outputFile)) {} 378 ~FileReproducerStream() override { outputFile->keep(); } 379 380 /// Returns a description of the reproducer stream. 381 StringRef description() override { return outputFile->getFilename(); } 382 383 /// Returns the stream on which to output the reproducer. 384 raw_ostream &os() override { return outputFile->os(); } 385 386 private: 387 /// ToolOutputFile corresponding to opened `filename`. 388 std::unique_ptr<llvm::ToolOutputFile> outputFile = nullptr; 389 }; 390 } // namespace 391 392 //===----------------------------------------------------------------------===// 393 // PassManager 394 //===----------------------------------------------------------------------===// 395 396 LogicalResult PassManager::runWithCrashRecovery(Operation *op, 397 AnalysisManager am) { 398 crashReproGenerator->initialize(getPasses(), op, verifyPasses); 399 400 // Safely invoke the passes within a recovery context. 401 LogicalResult passManagerResult = failure(); 402 llvm::CrashRecoveryContext recoveryContext; 403 recoveryContext.RunSafelyOnThread( 404 [&] { passManagerResult = runPasses(op, am); }); 405 crashReproGenerator->finalize(op, passManagerResult); 406 return passManagerResult; 407 } 408 409 void PassManager::enableCrashReproducerGeneration(StringRef outputFile, 410 bool genLocalReproducer) { 411 // Capture the filename by value in case outputFile is out of scope when 412 // invoked. 413 std::string filename = outputFile.str(); 414 enableCrashReproducerGeneration( 415 [filename](std::string &error) -> std::unique_ptr<ReproducerStream> { 416 std::unique_ptr<llvm::ToolOutputFile> outputFile = 417 mlir::openOutputFile(filename, &error); 418 if (!outputFile) { 419 error = "Failed to create reproducer stream: " + error; 420 return nullptr; 421 } 422 return std::make_unique<FileReproducerStream>(std::move(outputFile)); 423 }, 424 genLocalReproducer); 425 } 426 427 void PassManager::enableCrashReproducerGeneration( 428 ReproducerStreamFactory factory, bool genLocalReproducer) { 429 assert(!crashReproGenerator && 430 "crash reproducer has already been initialized"); 431 if (genLocalReproducer && getContext()->isMultithreadingEnabled()) 432 llvm::report_fatal_error( 433 "Local crash reproduction can't be setup on a " 434 "pass-manager without disabling multi-threading first."); 435 436 crashReproGenerator = std::make_unique<PassCrashReproducerGenerator>( 437 factory, genLocalReproducer); 438 addInstrumentation( 439 std::make_unique<CrashReproducerInstrumentation>(*crashReproGenerator)); 440 } 441 442 //===----------------------------------------------------------------------===// 443 // Asm Resource 444 //===----------------------------------------------------------------------===// 445 446 void mlir::attachPassReproducerAsmResource(ParserConfig &config, 447 PassManager &pm, 448 bool &enableThreading) { 449 auto parseFn = [&](AsmParsedResourceEntry &entry) -> LogicalResult { 450 if (entry.getKey() == "pipeline") { 451 FailureOr<std::string> pipeline = entry.parseAsString(); 452 if (failed(pipeline)) 453 return failure(); 454 return parsePassPipeline(*pipeline, pm); 455 } 456 if (entry.getKey() == "disable_threading") { 457 FailureOr<bool> value = entry.parseAsBool(); 458 459 // FIXME: We should just update the context directly, but some places 460 // force disable threading during parsing. 461 if (succeeded(value)) 462 enableThreading = !(*value); 463 return value; 464 } 465 if (entry.getKey() == "verify_each") { 466 FailureOr<bool> value = entry.parseAsBool(); 467 if (succeeded(value)) 468 pm.enableVerifier(*value); 469 return value; 470 } 471 return entry.emitError() << "unknown 'mlir_reproducer' resource key '" 472 << entry.getKey() << "'"; 473 }; 474 config.attachResourceParser("mlir_reproducer", parseFn); 475 } 476