1 //===- PassCrashRecovery.cpp - Pass Crash Recovery Implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/IR/Diagnostics.h"
11 #include "mlir/IR/Dialect.h"
12 #include "mlir/IR/Verifier.h"
13 #include "mlir/Pass/Pass.h"
14 #include "mlir/Support/FileUtilities.h"
15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/ADT/ScopeExit.h"
17 #include "llvm/ADT/SetVector.h"
18 #include "llvm/Support/CommandLine.h"
19 #include "llvm/Support/CrashRecoveryContext.h"
20 #include "llvm/Support/Mutex.h"
21 #include "llvm/Support/Signals.h"
22 #include "llvm/Support/Threading.h"
23 #include "llvm/Support/ToolOutputFile.h"
24 
25 using namespace mlir;
26 using namespace mlir::detail;
27 
28 //===----------------------------------------------------------------------===//
29 // RecoveryReproducerContext
30 //===----------------------------------------------------------------------===//
31 
32 namespace mlir {
33 namespace detail {
34 /// This class contains all of the context for generating a recovery reproducer.
35 /// Each recovery context is registered globally to allow for generating
36 /// reproducers when a signal is raised, such as a segfault.
37 struct RecoveryReproducerContext {
38   RecoveryReproducerContext(std::string passPipelineStr, Operation *op,
39                             PassManager::ReproducerStreamFactory &streamFactory,
40                             bool verifyPasses);
41   ~RecoveryReproducerContext();
42 
43   /// Generate a reproducer with the current context.
44   void generate(std::string &description);
45 
46   /// Disable this reproducer context. This prevents the context from generating
47   /// a reproducer in the result of a crash.
48   void disable();
49 
50   /// Enable a previously disabled reproducer context.
51   void enable();
52 
53 private:
54   /// This function is invoked in the event of a crash.
55   static void crashHandler(void *);
56 
57   /// Register a signal handler to run in the event of a crash.
58   static void registerSignalHandler();
59 
60   /// The textual description of the currently executing pipeline.
61   std::string pipeline;
62 
63   /// The MLIR operation representing the IR before the crash.
64   Operation *preCrashOperation;
65 
66   /// The factory for the reproducer output stream to use when generating the
67   /// reproducer.
68   PassManager::ReproducerStreamFactory &streamFactory;
69 
70   /// Various pass manager and context flags.
71   bool disableThreads;
72   bool verifyPasses;
73 
74   /// The current set of active reproducer contexts. This is used in the event
75   /// of a crash. This is not thread_local as the pass manager may produce any
76   /// number of child threads. This uses a set to allow for multiple MLIR pass
77   /// managers to be running at the same time.
78   static llvm::ManagedStatic<llvm::sys::SmartMutex<true>> reproducerMutex;
79   static llvm::ManagedStatic<
80       llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
81       reproducerSet;
82 };
83 } // namespace detail
84 } // namespace mlir
85 
86 llvm::ManagedStatic<llvm::sys::SmartMutex<true>>
87     RecoveryReproducerContext::reproducerMutex;
88 llvm::ManagedStatic<llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
89     RecoveryReproducerContext::reproducerSet;
90 
91 RecoveryReproducerContext::RecoveryReproducerContext(
92     std::string passPipelineStr, Operation *op,
93     PassManager::ReproducerStreamFactory &streamFactory, bool verifyPasses)
94     : pipeline(std::move(passPipelineStr)), preCrashOperation(op->clone()),
95       streamFactory(streamFactory),
96       disableThreads(!op->getContext()->isMultithreadingEnabled()),
97       verifyPasses(verifyPasses) {
98   enable();
99 }
100 
101 RecoveryReproducerContext::~RecoveryReproducerContext() {
102   // Erase the cloned preCrash IR that we cached.
103   preCrashOperation->erase();
104   disable();
105 }
106 
107 void RecoveryReproducerContext::generate(std::string &description) {
108   llvm::raw_string_ostream descOS(description);
109 
110   // Try to create a new output stream for this crash reproducer.
111   std::string error;
112   std::unique_ptr<PassManager::ReproducerStream> stream = streamFactory(error);
113   if (!stream) {
114     descOS << "failed to create output stream: " << error;
115     return;
116   }
117   descOS << "reproducer generated at `" << stream->description() << "`";
118 
119   // Output the current pass manager configuration to the crash stream.
120   auto &os = stream->os();
121   os << "// configuration: -pass-pipeline='" << pipeline << "'";
122   if (disableThreads)
123     os << " -mlir-disable-threading";
124   if (verifyPasses)
125     os << " -verify-each";
126   os << '\n';
127 
128   // Output the .mlir module.
129   preCrashOperation->print(os);
130 }
131 
132 void RecoveryReproducerContext::disable() {
133   llvm::sys::SmartScopedLock<true> lock(*reproducerMutex);
134   reproducerSet->remove(this);
135   if (reproducerSet->empty())
136     llvm::CrashRecoveryContext::Disable();
137 }
138 
139 void RecoveryReproducerContext::enable() {
140   llvm::sys::SmartScopedLock<true> lock(*reproducerMutex);
141   if (reproducerSet->empty())
142     llvm::CrashRecoveryContext::Enable();
143   registerSignalHandler();
144   reproducerSet->insert(this);
145 }
146 
147 void RecoveryReproducerContext::crashHandler(void *) {
148   // Walk the current stack of contexts and generate a reproducer for each one.
149   // We can't know for certain which one was the cause, so we need to generate
150   // a reproducer for all of them.
151   for (RecoveryReproducerContext *context : *reproducerSet) {
152     std::string description;
153     context->generate(description);
154 
155     // Emit an error using information only available within the context.
156     emitError(context->preCrashOperation->getLoc())
157         << "A failure has been detected while processing the MLIR module:"
158         << description;
159   }
160 }
161 
162 void RecoveryReproducerContext::registerSignalHandler() {
163   // Ensure that the handler is only registered once.
164   static bool registered =
165       (llvm::sys::AddSignalHandler(crashHandler, nullptr), false);
166   (void)registered;
167 }
168 
169 //===----------------------------------------------------------------------===//
170 // PassCrashReproducerGenerator
171 //===----------------------------------------------------------------------===//
172 
173 struct PassCrashReproducerGenerator::Impl {
174   Impl(PassManager::ReproducerStreamFactory &streamFactory,
175        bool localReproducer)
176       : streamFactory(streamFactory), localReproducer(localReproducer) {}
177 
178   /// The factory to use when generating a crash reproducer.
179   PassManager::ReproducerStreamFactory streamFactory;
180 
181   /// Flag indicating if reproducer generation should be localized to the
182   /// failing pass.
183   bool localReproducer;
184 
185   /// A record of all of the currently active reproducer contexts.
186   SmallVector<std::unique_ptr<RecoveryReproducerContext>> activeContexts;
187 
188   /// The set of all currently running passes. Note: This is not populated when
189   /// `localReproducer` is true, as each pass will get its own recovery context.
190   SetVector<std::pair<Pass *, Operation *>> runningPasses;
191 
192   /// Various pass manager flags that get emitted when generating a reproducer.
193   bool pmFlagVerifyPasses;
194 };
195 
196 PassCrashReproducerGenerator::PassCrashReproducerGenerator(
197     PassManager::ReproducerStreamFactory &streamFactory, bool localReproducer)
198     : impl(std::make_unique<Impl>(streamFactory, localReproducer)) {}
199 PassCrashReproducerGenerator::~PassCrashReproducerGenerator() {}
200 
201 void PassCrashReproducerGenerator::initialize(
202     iterator_range<PassManager::pass_iterator> passes, Operation *op,
203     bool pmFlagVerifyPasses) {
204   assert((!impl->localReproducer ||
205           !op->getContext()->isMultithreadingEnabled()) &&
206          "expected multi-threading to be disabled when generating a local "
207          "reproducer");
208 
209   llvm::CrashRecoveryContext::Enable();
210   impl->pmFlagVerifyPasses = pmFlagVerifyPasses;
211 
212   // If we aren't generating a local reproducer, prepare a reproducer for the
213   // given top-level operation.
214   if (!impl->localReproducer)
215     prepareReproducerFor(passes, op);
216 }
217 
218 static void
219 formatPassOpReproducerMessage(Diagnostic &os,
220                               std::pair<Pass *, Operation *> passOpPair) {
221   os << "`" << passOpPair.first->getName() << "` on "
222      << "'" << passOpPair.second->getName() << "' operation";
223   if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(passOpPair.second))
224     os << ": @" << symbol.getName();
225 }
226 
227 void PassCrashReproducerGenerator::finalize(Operation *rootOp,
228                                             LogicalResult executionResult) {
229   // Don't generate a reproducer if we have no active contexts.
230   if (impl->activeContexts.empty())
231     return;
232 
233   // If the pass manager execution succeeded, we don't generate any reproducers.
234   if (succeeded(executionResult))
235     return impl->activeContexts.clear();
236 
237   InFlightDiagnostic diag = emitError(rootOp->getLoc())
238                             << "Failures have been detected while "
239                                "processing an MLIR pass pipeline";
240 
241   // If we are generating a global reproducer, we include all of the running
242   // passes in the error message for the only active context.
243   if (!impl->localReproducer) {
244     assert(impl->activeContexts.size() == 1 && "expected one active context");
245 
246     // Generate the reproducer.
247     std::string description;
248     impl->activeContexts.front()->generate(description);
249 
250     // Emit an error to the user.
251     Diagnostic &note = diag.attachNote() << "Pipeline failed while executing [";
252     llvm::interleaveComma(impl->runningPasses, note,
253                           [&](const std::pair<Pass *, Operation *> &value) {
254                             formatPassOpReproducerMessage(note, value);
255                           });
256     note << "]: " << description;
257     return;
258   }
259 
260   // If we were generating a local reproducer, we generate a reproducer for the
261   // most recently executing pass using the matching entry from  `runningPasses`
262   // to generate a localized diagnostic message.
263   assert(impl->activeContexts.size() == impl->runningPasses.size() &&
264          "expected running passes to match active contexts");
265 
266   // Generate the reproducer.
267   RecoveryReproducerContext &reproducerContext = *impl->activeContexts.back();
268   std::string description;
269   reproducerContext.generate(description);
270 
271   // Emit an error to the user.
272   Diagnostic &note = diag.attachNote() << "Pipeline failed while executing ";
273   formatPassOpReproducerMessage(note, impl->runningPasses.back());
274   note << ": " << description;
275 
276   impl->activeContexts.clear();
277 }
278 
279 void PassCrashReproducerGenerator::prepareReproducerFor(Pass *pass,
280                                                         Operation *op) {
281   // If not tracking local reproducers, we simply remember that this pass is
282   // running.
283   impl->runningPasses.insert(std::make_pair(pass, op));
284   if (!impl->localReproducer)
285     return;
286 
287   // Disable the current pass recovery context, if there is one. This may happen
288   // in the case of dynamic pass pipelines.
289   if (!impl->activeContexts.empty())
290     impl->activeContexts.back()->disable();
291 
292   // Collect all of the parent scopes of this operation.
293   SmallVector<OperationName> scopes;
294   while (Operation *parentOp = op->getParentOp()) {
295     scopes.push_back(op->getName());
296     op = parentOp;
297   }
298 
299   // Emit a pass pipeline string for the current pass running on the current
300   // operation type.
301   std::string passStr;
302   llvm::raw_string_ostream passOS(passStr);
303   for (OperationName scope : llvm::reverse(scopes))
304     passOS << scope << "(";
305   pass->printAsTextualPipeline(passOS);
306   for (unsigned i = 0, e = scopes.size(); i < e; ++i)
307     passOS << ")";
308 
309   impl->activeContexts.push_back(std::make_unique<RecoveryReproducerContext>(
310       passOS.str(), op, impl->streamFactory, impl->pmFlagVerifyPasses));
311 }
312 void PassCrashReproducerGenerator::prepareReproducerFor(
313     iterator_range<PassManager::pass_iterator> passes, Operation *op) {
314   std::string passStr;
315   llvm::raw_string_ostream passOS(passStr);
316   llvm::interleaveComma(
317       passes, passOS, [&](Pass &pass) { pass.printAsTextualPipeline(passOS); });
318 
319   impl->activeContexts.push_back(std::make_unique<RecoveryReproducerContext>(
320       passOS.str(), op, impl->streamFactory, impl->pmFlagVerifyPasses));
321 }
322 
323 void PassCrashReproducerGenerator::removeLastReproducerFor(Pass *pass,
324                                                            Operation *op) {
325   // We only pop the active context if we are tracking local reproducers.
326   impl->runningPasses.remove(std::make_pair(pass, op));
327   if (impl->localReproducer) {
328     impl->activeContexts.pop_back();
329 
330     // Re-enable the previous pass recovery context, if there was one. This may
331     // happen in the case of dynamic pass pipelines.
332     if (!impl->activeContexts.empty())
333       impl->activeContexts.back()->enable();
334   }
335 }
336 
337 //===----------------------------------------------------------------------===//
338 // CrashReproducerInstrumentation
339 //===----------------------------------------------------------------------===//
340 
341 namespace {
342 struct CrashReproducerInstrumentation : public PassInstrumentation {
343   CrashReproducerInstrumentation(PassCrashReproducerGenerator &generator)
344       : generator(generator) {}
345   ~CrashReproducerInstrumentation() override = default;
346 
347   void runBeforePass(Pass *pass, Operation *op) override {
348     if (!isa<OpToOpPassAdaptor>(pass))
349       generator.prepareReproducerFor(pass, op);
350   }
351 
352   void runAfterPass(Pass *pass, Operation *op) override {
353     if (!isa<OpToOpPassAdaptor>(pass))
354       generator.removeLastReproducerFor(pass, op);
355   }
356 
357   void runAfterPassFailed(Pass *pass, Operation *op) override {
358     generator.finalize(op, /*executionResult=*/failure());
359   }
360 
361 private:
362   /// The generator used to create crash reproducers.
363   PassCrashReproducerGenerator &generator;
364 };
365 } // end anonymous namespace
366 
367 //===----------------------------------------------------------------------===//
368 // FileReproducerStream
369 //===----------------------------------------------------------------------===//
370 
371 namespace {
372 /// This class represents a default instance of PassManager::ReproducerStream
373 /// that is backed by a file.
374 struct FileReproducerStream : public PassManager::ReproducerStream {
375   FileReproducerStream(std::unique_ptr<llvm::ToolOutputFile> outputFile)
376       : outputFile(std::move(outputFile)) {}
377   ~FileReproducerStream() override { outputFile->keep(); }
378 
379   /// Returns a description of the reproducer stream.
380   StringRef description() override { return outputFile->getFilename(); }
381 
382   /// Returns the stream on which to output the reproducer.
383   raw_ostream &os() override { return outputFile->os(); }
384 
385 private:
386   /// ToolOutputFile corresponding to opened `filename`.
387   std::unique_ptr<llvm::ToolOutputFile> outputFile = nullptr;
388 };
389 } // end anonymous namespace
390 
391 //===----------------------------------------------------------------------===//
392 // PassManager
393 //===----------------------------------------------------------------------===//
394 
395 LogicalResult PassManager::runWithCrashRecovery(Operation *op,
396                                                 AnalysisManager am) {
397   crashReproGenerator->initialize(getPasses(), op, verifyPasses);
398 
399   // Safely invoke the passes within a recovery context.
400   LogicalResult passManagerResult = failure();
401   llvm::CrashRecoveryContext recoveryContext;
402   recoveryContext.RunSafelyOnThread(
403       [&] { passManagerResult = runPasses(op, am); });
404   crashReproGenerator->finalize(op, passManagerResult);
405   return passManagerResult;
406 }
407 
408 void PassManager::enableCrashReproducerGeneration(StringRef outputFile,
409                                                   bool genLocalReproducer) {
410   // Capture the filename by value in case outputFile is out of scope when
411   // invoked.
412   std::string filename = outputFile.str();
413   enableCrashReproducerGeneration(
414       [filename](std::string &error) -> std::unique_ptr<ReproducerStream> {
415         std::unique_ptr<llvm::ToolOutputFile> outputFile =
416             mlir::openOutputFile(filename, &error);
417         if (!outputFile) {
418           error = "Failed to create reproducer stream: " + error;
419           return nullptr;
420         }
421         return std::make_unique<FileReproducerStream>(std::move(outputFile));
422       },
423       genLocalReproducer);
424 }
425 
426 void PassManager::enableCrashReproducerGeneration(
427     ReproducerStreamFactory factory, bool genLocalReproducer) {
428   assert(!crashReproGenerator &&
429          "crash reproducer has already been initialized");
430   if (genLocalReproducer && getContext()->isMultithreadingEnabled())
431     llvm::report_fatal_error(
432         "Local crash reproduction can't be setup on a "
433         "pass-manager without disabling multi-threading first.");
434 
435   crashReproGenerator = std::make_unique<PassCrashReproducerGenerator>(
436       factory, genLocalReproducer);
437   addInstrumentation(
438       std::make_unique<CrashReproducerInstrumentation>(*crashReproGenerator));
439 }
440