1 //===- PassCrashRecovery.cpp - Pass Crash Recovery Implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/IR/Diagnostics.h"
11 #include "mlir/IR/Dialect.h"
12 #include "mlir/IR/SymbolTable.h"
13 #include "mlir/IR/Verifier.h"
14 #include "mlir/Pass/Pass.h"
15 #include "mlir/Support/FileUtilities.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/ScopeExit.h"
18 #include "llvm/ADT/SetVector.h"
19 #include "llvm/Support/CommandLine.h"
20 #include "llvm/Support/CrashRecoveryContext.h"
21 #include "llvm/Support/Mutex.h"
22 #include "llvm/Support/Signals.h"
23 #include "llvm/Support/Threading.h"
24 #include "llvm/Support/ToolOutputFile.h"
25 
26 using namespace mlir;
27 using namespace mlir::detail;
28 
29 //===----------------------------------------------------------------------===//
30 // RecoveryReproducerContext
31 //===----------------------------------------------------------------------===//
32 
33 namespace mlir {
34 namespace detail {
35 /// This class contains all of the context for generating a recovery reproducer.
36 /// Each recovery context is registered globally to allow for generating
37 /// reproducers when a signal is raised, such as a segfault.
38 struct RecoveryReproducerContext {
39   RecoveryReproducerContext(std::string passPipelineStr, Operation *op,
40                             PassManager::ReproducerStreamFactory &streamFactory,
41                             bool verifyPasses);
42   ~RecoveryReproducerContext();
43 
44   /// Generate a reproducer with the current context.
45   void generate(std::string &description);
46 
47   /// Disable this reproducer context. This prevents the context from generating
48   /// a reproducer in the result of a crash.
49   void disable();
50 
51   /// Enable a previously disabled reproducer context.
52   void enable();
53 
54 private:
55   /// This function is invoked in the event of a crash.
56   static void crashHandler(void *);
57 
58   /// Register a signal handler to run in the event of a crash.
59   static void registerSignalHandler();
60 
61   /// The textual description of the currently executing pipeline.
62   std::string pipeline;
63 
64   /// The MLIR operation representing the IR before the crash.
65   Operation *preCrashOperation;
66 
67   /// The factory for the reproducer output stream to use when generating the
68   /// reproducer.
69   PassManager::ReproducerStreamFactory &streamFactory;
70 
71   /// Various pass manager and context flags.
72   bool disableThreads;
73   bool verifyPasses;
74 
75   /// The current set of active reproducer contexts. This is used in the event
76   /// of a crash. This is not thread_local as the pass manager may produce any
77   /// number of child threads. This uses a set to allow for multiple MLIR pass
78   /// managers to be running at the same time.
79   static llvm::ManagedStatic<llvm::sys::SmartMutex<true>> reproducerMutex;
80   static llvm::ManagedStatic<
81       llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
82       reproducerSet;
83 };
84 } // namespace detail
85 } // namespace mlir
86 
87 llvm::ManagedStatic<llvm::sys::SmartMutex<true>>
88     RecoveryReproducerContext::reproducerMutex;
89 llvm::ManagedStatic<llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
90     RecoveryReproducerContext::reproducerSet;
91 
92 RecoveryReproducerContext::RecoveryReproducerContext(
93     std::string passPipelineStr, Operation *op,
94     PassManager::ReproducerStreamFactory &streamFactory, bool verifyPasses)
95     : pipeline(std::move(passPipelineStr)), preCrashOperation(op->clone()),
96       streamFactory(streamFactory),
97       disableThreads(!op->getContext()->isMultithreadingEnabled()),
98       verifyPasses(verifyPasses) {
99   enable();
100 }
101 
102 RecoveryReproducerContext::~RecoveryReproducerContext() {
103   // Erase the cloned preCrash IR that we cached.
104   preCrashOperation->erase();
105   disable();
106 }
107 
108 void RecoveryReproducerContext::generate(std::string &description) {
109   llvm::raw_string_ostream descOS(description);
110 
111   // Try to create a new output stream for this crash reproducer.
112   std::string error;
113   std::unique_ptr<PassManager::ReproducerStream> stream = streamFactory(error);
114   if (!stream) {
115     descOS << "failed to create output stream: " << error;
116     return;
117   }
118   descOS << "reproducer generated at `" << stream->description() << "`";
119 
120   // Output the current pass manager configuration to the crash stream.
121   auto &os = stream->os();
122   os << "// configuration: -pass-pipeline='" << pipeline << "'";
123   if (disableThreads)
124     os << " -mlir-disable-threading";
125   if (verifyPasses)
126     os << " -verify-each";
127   os << '\n';
128 
129   // Output the .mlir module.
130   preCrashOperation->print(os);
131 }
132 
133 void RecoveryReproducerContext::disable() {
134   llvm::sys::SmartScopedLock<true> lock(*reproducerMutex);
135   reproducerSet->remove(this);
136   if (reproducerSet->empty())
137     llvm::CrashRecoveryContext::Disable();
138 }
139 
140 void RecoveryReproducerContext::enable() {
141   llvm::sys::SmartScopedLock<true> lock(*reproducerMutex);
142   if (reproducerSet->empty())
143     llvm::CrashRecoveryContext::Enable();
144   registerSignalHandler();
145   reproducerSet->insert(this);
146 }
147 
148 void RecoveryReproducerContext::crashHandler(void *) {
149   // Walk the current stack of contexts and generate a reproducer for each one.
150   // We can't know for certain which one was the cause, so we need to generate
151   // a reproducer for all of them.
152   for (RecoveryReproducerContext *context : *reproducerSet) {
153     std::string description;
154     context->generate(description);
155 
156     // Emit an error using information only available within the context.
157     emitError(context->preCrashOperation->getLoc())
158         << "A failure has been detected while processing the MLIR module:"
159         << description;
160   }
161 }
162 
163 void RecoveryReproducerContext::registerSignalHandler() {
164   // Ensure that the handler is only registered once.
165   static bool registered =
166       (llvm::sys::AddSignalHandler(crashHandler, nullptr), false);
167   (void)registered;
168 }
169 
170 //===----------------------------------------------------------------------===//
171 // PassCrashReproducerGenerator
172 //===----------------------------------------------------------------------===//
173 
174 struct PassCrashReproducerGenerator::Impl {
175   Impl(PassManager::ReproducerStreamFactory &streamFactory,
176        bool localReproducer)
177       : streamFactory(streamFactory), localReproducer(localReproducer) {}
178 
179   /// The factory to use when generating a crash reproducer.
180   PassManager::ReproducerStreamFactory streamFactory;
181 
182   /// Flag indicating if reproducer generation should be localized to the
183   /// failing pass.
184   bool localReproducer = false;
185 
186   /// A record of all of the currently active reproducer contexts.
187   SmallVector<std::unique_ptr<RecoveryReproducerContext>> activeContexts;
188 
189   /// The set of all currently running passes. Note: This is not populated when
190   /// `localReproducer` is true, as each pass will get its own recovery context.
191   SetVector<std::pair<Pass *, Operation *>> runningPasses;
192 
193   /// Various pass manager flags that get emitted when generating a reproducer.
194   bool pmFlagVerifyPasses = false;
195 };
196 
197 PassCrashReproducerGenerator::PassCrashReproducerGenerator(
198     PassManager::ReproducerStreamFactory &streamFactory, bool localReproducer)
199     : impl(std::make_unique<Impl>(streamFactory, localReproducer)) {}
200 PassCrashReproducerGenerator::~PassCrashReproducerGenerator() = default;
201 
202 void PassCrashReproducerGenerator::initialize(
203     iterator_range<PassManager::pass_iterator> passes, Operation *op,
204     bool pmFlagVerifyPasses) {
205   assert((!impl->localReproducer ||
206           !op->getContext()->isMultithreadingEnabled()) &&
207          "expected multi-threading to be disabled when generating a local "
208          "reproducer");
209 
210   llvm::CrashRecoveryContext::Enable();
211   impl->pmFlagVerifyPasses = pmFlagVerifyPasses;
212 
213   // If we aren't generating a local reproducer, prepare a reproducer for the
214   // given top-level operation.
215   if (!impl->localReproducer)
216     prepareReproducerFor(passes, op);
217 }
218 
219 static void
220 formatPassOpReproducerMessage(Diagnostic &os,
221                               std::pair<Pass *, Operation *> passOpPair) {
222   os << "`" << passOpPair.first->getName() << "` on "
223      << "'" << passOpPair.second->getName() << "' operation";
224   if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(passOpPair.second))
225     os << ": @" << symbol.getName();
226 }
227 
228 void PassCrashReproducerGenerator::finalize(Operation *rootOp,
229                                             LogicalResult executionResult) {
230   // Don't generate a reproducer if we have no active contexts.
231   if (impl->activeContexts.empty())
232     return;
233 
234   // If the pass manager execution succeeded, we don't generate any reproducers.
235   if (succeeded(executionResult))
236     return impl->activeContexts.clear();
237 
238   InFlightDiagnostic diag = emitError(rootOp->getLoc())
239                             << "Failures have been detected while "
240                                "processing an MLIR pass pipeline";
241 
242   // If we are generating a global reproducer, we include all of the running
243   // passes in the error message for the only active context.
244   if (!impl->localReproducer) {
245     assert(impl->activeContexts.size() == 1 && "expected one active context");
246 
247     // Generate the reproducer.
248     std::string description;
249     impl->activeContexts.front()->generate(description);
250 
251     // Emit an error to the user.
252     Diagnostic &note = diag.attachNote() << "Pipeline failed while executing [";
253     llvm::interleaveComma(impl->runningPasses, note,
254                           [&](const std::pair<Pass *, Operation *> &value) {
255                             formatPassOpReproducerMessage(note, value);
256                           });
257     note << "]: " << description;
258     return;
259   }
260 
261   // If we were generating a local reproducer, we generate a reproducer for the
262   // most recently executing pass using the matching entry from  `runningPasses`
263   // to generate a localized diagnostic message.
264   assert(impl->activeContexts.size() == impl->runningPasses.size() &&
265          "expected running passes to match active contexts");
266 
267   // Generate the reproducer.
268   RecoveryReproducerContext &reproducerContext = *impl->activeContexts.back();
269   std::string description;
270   reproducerContext.generate(description);
271 
272   // Emit an error to the user.
273   Diagnostic &note = diag.attachNote() << "Pipeline failed while executing ";
274   formatPassOpReproducerMessage(note, impl->runningPasses.back());
275   note << ": " << description;
276 
277   impl->activeContexts.clear();
278 }
279 
280 void PassCrashReproducerGenerator::prepareReproducerFor(Pass *pass,
281                                                         Operation *op) {
282   // If not tracking local reproducers, we simply remember that this pass is
283   // running.
284   impl->runningPasses.insert(std::make_pair(pass, op));
285   if (!impl->localReproducer)
286     return;
287 
288   // Disable the current pass recovery context, if there is one. This may happen
289   // in the case of dynamic pass pipelines.
290   if (!impl->activeContexts.empty())
291     impl->activeContexts.back()->disable();
292 
293   // Collect all of the parent scopes of this operation.
294   SmallVector<OperationName> scopes;
295   while (Operation *parentOp = op->getParentOp()) {
296     scopes.push_back(op->getName());
297     op = parentOp;
298   }
299 
300   // Emit a pass pipeline string for the current pass running on the current
301   // operation type.
302   std::string passStr;
303   llvm::raw_string_ostream passOS(passStr);
304   for (OperationName scope : llvm::reverse(scopes))
305     passOS << scope << "(";
306   pass->printAsTextualPipeline(passOS);
307   for (unsigned i = 0, e = scopes.size(); i < e; ++i)
308     passOS << ")";
309 
310   impl->activeContexts.push_back(std::make_unique<RecoveryReproducerContext>(
311       passOS.str(), op, impl->streamFactory, impl->pmFlagVerifyPasses));
312 }
313 void PassCrashReproducerGenerator::prepareReproducerFor(
314     iterator_range<PassManager::pass_iterator> passes, Operation *op) {
315   std::string passStr;
316   llvm::raw_string_ostream passOS(passStr);
317   llvm::interleaveComma(
318       passes, passOS, [&](Pass &pass) { pass.printAsTextualPipeline(passOS); });
319 
320   impl->activeContexts.push_back(std::make_unique<RecoveryReproducerContext>(
321       passOS.str(), op, impl->streamFactory, impl->pmFlagVerifyPasses));
322 }
323 
324 void PassCrashReproducerGenerator::removeLastReproducerFor(Pass *pass,
325                                                            Operation *op) {
326   // We only pop the active context if we are tracking local reproducers.
327   impl->runningPasses.remove(std::make_pair(pass, op));
328   if (impl->localReproducer) {
329     impl->activeContexts.pop_back();
330 
331     // Re-enable the previous pass recovery context, if there was one. This may
332     // happen in the case of dynamic pass pipelines.
333     if (!impl->activeContexts.empty())
334       impl->activeContexts.back()->enable();
335   }
336 }
337 
338 //===----------------------------------------------------------------------===//
339 // CrashReproducerInstrumentation
340 //===----------------------------------------------------------------------===//
341 
342 namespace {
343 struct CrashReproducerInstrumentation : public PassInstrumentation {
344   CrashReproducerInstrumentation(PassCrashReproducerGenerator &generator)
345       : generator(generator) {}
346   ~CrashReproducerInstrumentation() override = default;
347 
348   void runBeforePass(Pass *pass, Operation *op) override {
349     if (!isa<OpToOpPassAdaptor>(pass))
350       generator.prepareReproducerFor(pass, op);
351   }
352 
353   void runAfterPass(Pass *pass, Operation *op) override {
354     if (!isa<OpToOpPassAdaptor>(pass))
355       generator.removeLastReproducerFor(pass, op);
356   }
357 
358   void runAfterPassFailed(Pass *pass, Operation *op) override {
359     generator.finalize(op, /*executionResult=*/failure());
360   }
361 
362 private:
363   /// The generator used to create crash reproducers.
364   PassCrashReproducerGenerator &generator;
365 };
366 } // namespace
367 
368 //===----------------------------------------------------------------------===//
369 // FileReproducerStream
370 //===----------------------------------------------------------------------===//
371 
372 namespace {
373 /// This class represents a default instance of PassManager::ReproducerStream
374 /// that is backed by a file.
375 struct FileReproducerStream : public PassManager::ReproducerStream {
376   FileReproducerStream(std::unique_ptr<llvm::ToolOutputFile> outputFile)
377       : outputFile(std::move(outputFile)) {}
378   ~FileReproducerStream() override { outputFile->keep(); }
379 
380   /// Returns a description of the reproducer stream.
381   StringRef description() override { return outputFile->getFilename(); }
382 
383   /// Returns the stream on which to output the reproducer.
384   raw_ostream &os() override { return outputFile->os(); }
385 
386 private:
387   /// ToolOutputFile corresponding to opened `filename`.
388   std::unique_ptr<llvm::ToolOutputFile> outputFile = nullptr;
389 };
390 } // namespace
391 
392 //===----------------------------------------------------------------------===//
393 // PassManager
394 //===----------------------------------------------------------------------===//
395 
396 LogicalResult PassManager::runWithCrashRecovery(Operation *op,
397                                                 AnalysisManager am) {
398   crashReproGenerator->initialize(getPasses(), op, verifyPasses);
399 
400   // Safely invoke the passes within a recovery context.
401   LogicalResult passManagerResult = failure();
402   llvm::CrashRecoveryContext recoveryContext;
403   recoveryContext.RunSafelyOnThread(
404       [&] { passManagerResult = runPasses(op, am); });
405   crashReproGenerator->finalize(op, passManagerResult);
406   return passManagerResult;
407 }
408 
409 void PassManager::enableCrashReproducerGeneration(StringRef outputFile,
410                                                   bool genLocalReproducer) {
411   // Capture the filename by value in case outputFile is out of scope when
412   // invoked.
413   std::string filename = outputFile.str();
414   enableCrashReproducerGeneration(
415       [filename](std::string &error) -> std::unique_ptr<ReproducerStream> {
416         std::unique_ptr<llvm::ToolOutputFile> outputFile =
417             mlir::openOutputFile(filename, &error);
418         if (!outputFile) {
419           error = "Failed to create reproducer stream: " + error;
420           return nullptr;
421         }
422         return std::make_unique<FileReproducerStream>(std::move(outputFile));
423       },
424       genLocalReproducer);
425 }
426 
427 void PassManager::enableCrashReproducerGeneration(
428     ReproducerStreamFactory factory, bool genLocalReproducer) {
429   assert(!crashReproGenerator &&
430          "crash reproducer has already been initialized");
431   if (genLocalReproducer && getContext()->isMultithreadingEnabled())
432     llvm::report_fatal_error(
433         "Local crash reproduction can't be setup on a "
434         "pass-manager without disabling multi-threading first.");
435 
436   crashReproGenerator = std::make_unique<PassCrashReproducerGenerator>(
437       factory, genLocalReproducer);
438   addInstrumentation(
439       std::make_unique<CrashReproducerInstrumentation>(*crashReproGenerator));
440 }
441