1 //===- PassCrashRecovery.cpp - Pass Crash Recovery Implementation ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/IR/Diagnostics.h" 11 #include "mlir/IR/Dialect.h" 12 #include "mlir/IR/Verifier.h" 13 #include "mlir/Pass/Pass.h" 14 #include "mlir/Support/FileUtilities.h" 15 #include "llvm/ADT/STLExtras.h" 16 #include "llvm/ADT/ScopeExit.h" 17 #include "llvm/ADT/SetVector.h" 18 #include "llvm/Support/CommandLine.h" 19 #include "llvm/Support/CrashRecoveryContext.h" 20 #include "llvm/Support/Mutex.h" 21 #include "llvm/Support/Parallel.h" 22 #include "llvm/Support/Signals.h" 23 #include "llvm/Support/Threading.h" 24 #include "llvm/Support/ToolOutputFile.h" 25 26 using namespace mlir; 27 using namespace mlir::detail; 28 29 //===----------------------------------------------------------------------===// 30 // RecoveryReproducerContext 31 //===----------------------------------------------------------------------===// 32 33 namespace mlir { 34 namespace detail { 35 /// This class contains all of the context for generating a recovery reproducer. 36 /// Each recovery context is registered globally to allow for generating 37 /// reproducers when a signal is raised, such as a segfault. 38 struct RecoveryReproducerContext { 39 RecoveryReproducerContext(std::string passPipelineStr, Operation *op, 40 PassManager::ReproducerStreamFactory &streamFactory, 41 bool verifyPasses); 42 ~RecoveryReproducerContext(); 43 44 /// Generate a reproducer with the current context. 45 void generate(std::string &description); 46 47 /// Disable this reproducer context. This prevents the context from generating 48 /// a reproducer in the result of a crash. 49 void disable(); 50 51 /// Enable a previously disabled reproducer context. 52 void enable(); 53 54 private: 55 /// This function is invoked in the event of a crash. 56 static void crashHandler(void *); 57 58 /// Register a signal handler to run in the event of a crash. 59 static void registerSignalHandler(); 60 61 /// The textual description of the currently executing pipeline. 62 std::string pipeline; 63 64 /// The MLIR operation representing the IR before the crash. 65 Operation *preCrashOperation; 66 67 /// The factory for the reproducer output stream to use when generating the 68 /// reproducer. 69 PassManager::ReproducerStreamFactory &streamFactory; 70 71 /// Various pass manager and context flags. 72 bool disableThreads; 73 bool verifyPasses; 74 75 /// The current set of active reproducer contexts. This is used in the event 76 /// of a crash. This is not thread_local as the pass manager may produce any 77 /// number of child threads. This uses a set to allow for multiple MLIR pass 78 /// managers to be running at the same time. 79 static llvm::ManagedStatic<llvm::sys::SmartMutex<true>> reproducerMutex; 80 static llvm::ManagedStatic< 81 llvm::SmallSetVector<RecoveryReproducerContext *, 1>> 82 reproducerSet; 83 }; 84 } // namespace detail 85 } // namespace mlir 86 87 llvm::ManagedStatic<llvm::sys::SmartMutex<true>> 88 RecoveryReproducerContext::reproducerMutex; 89 llvm::ManagedStatic<llvm::SmallSetVector<RecoveryReproducerContext *, 1>> 90 RecoveryReproducerContext::reproducerSet; 91 92 RecoveryReproducerContext::RecoveryReproducerContext( 93 std::string passPipelineStr, Operation *op, 94 PassManager::ReproducerStreamFactory &streamFactory, bool verifyPasses) 95 : pipeline(std::move(passPipelineStr)), preCrashOperation(op->clone()), 96 streamFactory(streamFactory), 97 disableThreads(!op->getContext()->isMultithreadingEnabled()), 98 verifyPasses(verifyPasses) { 99 enable(); 100 } 101 102 RecoveryReproducerContext::~RecoveryReproducerContext() { 103 // Erase the cloned preCrash IR that we cached. 104 preCrashOperation->erase(); 105 disable(); 106 } 107 108 void RecoveryReproducerContext::generate(std::string &description) { 109 llvm::raw_string_ostream descOS(description); 110 111 // Try to create a new output stream for this crash reproducer. 112 std::string error; 113 std::unique_ptr<PassManager::ReproducerStream> stream = streamFactory(error); 114 if (!stream) { 115 descOS << "failed to create output stream: " << error; 116 return; 117 } 118 descOS << "reproducer generated at `" << stream->description() << "`"; 119 120 // Output the current pass manager configuration to the crash stream. 121 auto &os = stream->os(); 122 os << "// configuration: -pass-pipeline='" << pipeline << "'"; 123 if (disableThreads) 124 os << " -mlir-disable-threading"; 125 if (verifyPasses) 126 os << " -verify-each"; 127 os << '\n'; 128 129 // Output the .mlir module. 130 preCrashOperation->print(os); 131 } 132 133 void RecoveryReproducerContext::disable() { 134 llvm::sys::SmartScopedLock<true> lock(*reproducerMutex); 135 reproducerSet->remove(this); 136 if (reproducerSet->empty()) 137 llvm::CrashRecoveryContext::Disable(); 138 } 139 140 void RecoveryReproducerContext::enable() { 141 llvm::sys::SmartScopedLock<true> lock(*reproducerMutex); 142 if (reproducerSet->empty()) 143 llvm::CrashRecoveryContext::Enable(); 144 registerSignalHandler(); 145 reproducerSet->insert(this); 146 } 147 148 void RecoveryReproducerContext::crashHandler(void *) { 149 // Walk the current stack of contexts and generate a reproducer for each one. 150 // We can't know for certain which one was the cause, so we need to generate 151 // a reproducer for all of them. 152 for (RecoveryReproducerContext *context : *reproducerSet) { 153 std::string description; 154 context->generate(description); 155 156 // Emit an error using information only available within the context. 157 context->preCrashOperation->getContext()->printOpOnDiagnostic(false); 158 context->preCrashOperation->emitError() 159 << "A failure has been detected while processing the MLIR module:" 160 << description; 161 } 162 } 163 164 void RecoveryReproducerContext::registerSignalHandler() { 165 // Ensure that the handler is only registered once. 166 static bool registered = 167 (llvm::sys::AddSignalHandler(crashHandler, nullptr), false); 168 (void)registered; 169 } 170 171 //===----------------------------------------------------------------------===// 172 // PassCrashReproducerGenerator 173 //===----------------------------------------------------------------------===// 174 175 struct PassCrashReproducerGenerator::Impl { 176 Impl(PassManager::ReproducerStreamFactory &streamFactory, 177 bool localReproducer) 178 : streamFactory(streamFactory), localReproducer(localReproducer) {} 179 180 /// The factory to use when generating a crash reproducer. 181 PassManager::ReproducerStreamFactory streamFactory; 182 183 /// Flag indicating if reproducer generation should be localized to the 184 /// failing pass. 185 bool localReproducer; 186 187 /// A record of all of the currently active reproducer contexts. 188 SmallVector<std::unique_ptr<RecoveryReproducerContext>> activeContexts; 189 190 /// The set of all currently running passes. Note: This is not populated when 191 /// `localReproducer` is true, as each pass will get its own recovery context. 192 SetVector<std::pair<Pass *, Operation *>> runningPasses; 193 194 /// Various pass manager flags that get emitted when generating a reproducer. 195 bool pmFlagVerifyPasses; 196 }; 197 198 PassCrashReproducerGenerator::PassCrashReproducerGenerator( 199 PassManager::ReproducerStreamFactory &streamFactory, bool localReproducer) 200 : impl(std::make_unique<Impl>(streamFactory, localReproducer)) {} 201 PassCrashReproducerGenerator::~PassCrashReproducerGenerator() {} 202 203 void PassCrashReproducerGenerator::initialize( 204 iterator_range<PassManager::pass_iterator> passes, Operation *op, 205 bool pmFlagVerifyPasses) { 206 assert((!impl->localReproducer || 207 !op->getContext()->isMultithreadingEnabled()) && 208 "expected multi-threading to be disabled when generating a local " 209 "reproducer"); 210 211 llvm::CrashRecoveryContext::Enable(); 212 impl->pmFlagVerifyPasses = pmFlagVerifyPasses; 213 214 // If we aren't generating a local reproducer, prepare a reproducer for the 215 // given top-level operation. 216 if (!impl->localReproducer) 217 prepareReproducerFor(passes, op); 218 } 219 220 static void 221 formatPassOpReproducerMessage(Diagnostic &os, 222 std::pair<Pass *, Operation *> passOpPair) { 223 os << "`" << passOpPair.first->getName() << "` on " 224 << "'" << passOpPair.second->getName() << "' operation"; 225 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(passOpPair.second)) 226 os << ": @" << symbol.getName(); 227 } 228 229 void PassCrashReproducerGenerator::finalize(Operation *rootOp, 230 LogicalResult executionResult) { 231 // If the pass manager execution succeeded, we don't generate any reproducers. 232 if (succeeded(executionResult)) 233 return impl->activeContexts.clear(); 234 235 MLIRContext *context = rootOp->getContext(); 236 bool shouldPrintOnOp = context->shouldPrintOpOnDiagnostic(); 237 context->printOpOnDiagnostic(false); 238 InFlightDiagnostic diag = rootOp->emitError() 239 << "Failures have been detected while " 240 "processing an MLIR pass pipeline"; 241 context->printOpOnDiagnostic(shouldPrintOnOp); 242 243 // If we are generating a global reproducer, we include all of the running 244 // passes in the error message for the only active context. 245 if (!impl->localReproducer) { 246 assert(impl->activeContexts.size() == 1 && "expected one active context"); 247 248 // Generate the reproducer. 249 std::string description; 250 impl->activeContexts.front()->generate(description); 251 252 // Emit an error to the user. 253 Diagnostic ¬e = diag.attachNote() << "Pipeline failed while executing ["; 254 llvm::interleaveComma(impl->runningPasses, note, 255 [&](const std::pair<Pass *, Operation *> &value) { 256 formatPassOpReproducerMessage(note, value); 257 }); 258 note << "]: " << description; 259 return; 260 } 261 262 // If we were generating a local reproducer, we generate a reproducer for the 263 // most recently executing pass using the matching entry from `runningPasses` 264 // to generate a localized diagnostic message. 265 assert(impl->activeContexts.size() == impl->runningPasses.size() && 266 "expected running passes to match active contexts"); 267 268 // Generate the reproducer. 269 RecoveryReproducerContext &reproducerContext = *impl->activeContexts.back(); 270 std::string description; 271 reproducerContext.generate(description); 272 273 // Emit an error to the user. 274 Diagnostic ¬e = diag.attachNote() << "Pipeline failed while executing "; 275 formatPassOpReproducerMessage(note, impl->runningPasses.back()); 276 note << ": " << description; 277 278 impl->activeContexts.clear(); 279 } 280 281 void PassCrashReproducerGenerator::prepareReproducerFor(Pass *pass, 282 Operation *op) { 283 // If not tracking local reproducers, we simply remember that this pass is 284 // running. 285 impl->runningPasses.insert(std::make_pair(pass, op)); 286 if (!impl->localReproducer) 287 return; 288 289 // Disable the current pass recovery context, if there is one. This may happen 290 // in the case of dynamic pass pipelines. 291 if (!impl->activeContexts.empty()) 292 impl->activeContexts.back()->disable(); 293 294 // Collect all of the parent scopes of this operation. 295 SmallVector<OperationName> scopes; 296 while (Operation *parentOp = op->getParentOp()) { 297 scopes.push_back(op->getName()); 298 op = parentOp; 299 } 300 301 // Emit a pass pipeline string for the current pass running on the current 302 // operation type. 303 std::string passStr; 304 llvm::raw_string_ostream passOS(passStr); 305 for (OperationName scope : llvm::reverse(scopes)) 306 passOS << scope << "("; 307 pass->printAsTextualPipeline(passOS); 308 for (unsigned i = 0, e = scopes.size(); i < e; ++i) 309 passOS << ")"; 310 311 impl->activeContexts.push_back(std::make_unique<RecoveryReproducerContext>( 312 passOS.str(), op, impl->streamFactory, impl->pmFlagVerifyPasses)); 313 } 314 void PassCrashReproducerGenerator::prepareReproducerFor( 315 iterator_range<PassManager::pass_iterator> passes, Operation *op) { 316 std::string passStr; 317 llvm::raw_string_ostream passOS(passStr); 318 llvm::interleaveComma( 319 passes, passOS, [&](Pass &pass) { pass.printAsTextualPipeline(passOS); }); 320 321 impl->activeContexts.push_back(std::make_unique<RecoveryReproducerContext>( 322 passOS.str(), op, impl->streamFactory, impl->pmFlagVerifyPasses)); 323 } 324 325 void PassCrashReproducerGenerator::removeLastReproducerFor(Pass *pass, 326 Operation *op) { 327 // We only pop the active context if we are tracking local reproducers. 328 impl->runningPasses.remove(std::make_pair(pass, op)); 329 if (impl->localReproducer) { 330 impl->activeContexts.pop_back(); 331 332 // Re-enable the previous pass recovery context, if there was one. This may 333 // happen in the case of dynamic pass pipelines. 334 if (!impl->activeContexts.empty()) 335 impl->activeContexts.back()->enable(); 336 } 337 } 338 339 //===----------------------------------------------------------------------===// 340 // CrashReproducerInstrumentation 341 //===----------------------------------------------------------------------===// 342 343 namespace { 344 struct CrashReproducerInstrumentation : public PassInstrumentation { 345 CrashReproducerInstrumentation(PassCrashReproducerGenerator &generator) 346 : generator(generator) {} 347 ~CrashReproducerInstrumentation() override = default; 348 349 /// A callback to run before a pass is executed. 350 void runBeforePass(Pass *pass, Operation *op) override { 351 if (!isa<OpToOpPassAdaptor>(pass)) 352 generator.prepareReproducerFor(pass, op); 353 } 354 355 /// A callback to run after a pass is successfully executed. This function 356 /// takes a pointer to the pass to be executed, as well as the current 357 /// operation being operated on. 358 void runAfterPass(Pass *pass, Operation *op) override { 359 if (!isa<OpToOpPassAdaptor>(pass)) 360 generator.removeLastReproducerFor(pass, op); 361 } 362 363 private: 364 /// The generator used to create crash reproducers. 365 PassCrashReproducerGenerator &generator; 366 }; 367 } // end anonymous namespace 368 369 //===----------------------------------------------------------------------===// 370 // FileReproducerStream 371 //===----------------------------------------------------------------------===// 372 373 namespace { 374 /// This class represents a default instance of PassManager::ReproducerStream 375 /// that is backed by a file. 376 struct FileReproducerStream : public PassManager::ReproducerStream { 377 FileReproducerStream(std::unique_ptr<llvm::ToolOutputFile> outputFile) 378 : outputFile(std::move(outputFile)) {} 379 ~FileReproducerStream() override { outputFile->keep(); } 380 381 /// Returns a description of the reproducer stream. 382 StringRef description() override { return outputFile->getFilename(); } 383 384 /// Returns the stream on which to output the reproducer. 385 raw_ostream &os() override { return outputFile->os(); } 386 387 private: 388 /// ToolOutputFile corresponding to opened `filename`. 389 std::unique_ptr<llvm::ToolOutputFile> outputFile = nullptr; 390 }; 391 } // end anonymous namespace 392 393 //===----------------------------------------------------------------------===// 394 // PassManager 395 //===----------------------------------------------------------------------===// 396 397 LogicalResult PassManager::runWithCrashRecovery(Operation *op, 398 AnalysisManager am) { 399 crashReproGenerator->initialize(getPasses(), op, verifyPasses); 400 401 // Safely invoke the passes within a recovery context. 402 LogicalResult passManagerResult = failure(); 403 llvm::CrashRecoveryContext recoveryContext; 404 recoveryContext.RunSafelyOnThread( 405 [&] { passManagerResult = runPasses(op, am); }); 406 crashReproGenerator->finalize(op, passManagerResult); 407 return passManagerResult; 408 } 409 410 void PassManager::enableCrashReproducerGeneration(StringRef outputFile, 411 bool genLocalReproducer) { 412 // Capture the filename by value in case outputFile is out of scope when 413 // invoked. 414 std::string filename = outputFile.str(); 415 enableCrashReproducerGeneration( 416 [filename](std::string &error) -> std::unique_ptr<ReproducerStream> { 417 std::unique_ptr<llvm::ToolOutputFile> outputFile = 418 mlir::openOutputFile(filename, &error); 419 if (!outputFile) { 420 error = "Failed to create reproducer stream: " + error; 421 return nullptr; 422 } 423 return std::make_unique<FileReproducerStream>(std::move(outputFile)); 424 }, 425 genLocalReproducer); 426 } 427 428 void PassManager::enableCrashReproducerGeneration( 429 ReproducerStreamFactory factory, bool genLocalReproducer) { 430 assert(!crashReproGenerator && 431 "crash reproducer has already been initialized"); 432 if (genLocalReproducer && getContext()->isMultithreadingEnabled()) 433 llvm::report_fatal_error( 434 "Local crash reproduction can't be setup on a " 435 "pass-manager without disabling multi-threading first."); 436 437 crashReproGenerator = std::make_unique<PassCrashReproducerGenerator>( 438 factory, genLocalReproducer); 439 addInstrumentation( 440 std::make_unique<CrashReproducerInstrumentation>(*crashReproGenerator)); 441 } 442