xref: /llvm-project-15.0.7/mlir/lib/Pass/Pass.cpp (revision f7d033f4)
1 //===- Pass.cpp - Pass infrastructure implementation ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements common pass infrastructure.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Pass/Pass.h"
14 #include "PassDetail.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/Diagnostics.h"
17 #include "mlir/IR/Dialect.h"
18 #include "mlir/IR/Verifier.h"
19 #include "mlir/Support/FileUtilities.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/ScopeExit.h"
22 #include "llvm/ADT/SetVector.h"
23 #include "llvm/Support/CommandLine.h"
24 #include "llvm/Support/CrashRecoveryContext.h"
25 #include "llvm/Support/Mutex.h"
26 #include "llvm/Support/Parallel.h"
27 #include "llvm/Support/Signals.h"
28 #include "llvm/Support/Threading.h"
29 #include "llvm/Support/ToolOutputFile.h"
30 
31 using namespace mlir;
32 using namespace mlir::detail;
33 
34 //===----------------------------------------------------------------------===//
35 // Pass
36 //===----------------------------------------------------------------------===//
37 
38 /// Out of line virtual method to ensure vtables and metadata are emitted to a
39 /// single .o file.
40 void Pass::anchor() {}
41 
42 /// Attempt to initialize the options of this pass from the given string.
43 LogicalResult Pass::initializeOptions(StringRef options) {
44   return passOptions.parseFromString(options);
45 }
46 
47 /// Copy the option values from 'other', which is another instance of this
48 /// pass.
49 void Pass::copyOptionValuesFrom(const Pass *other) {
50   passOptions.copyOptionValuesFrom(other->passOptions);
51 }
52 
53 /// Prints out the pass in the textual representation of pipelines. If this is
54 /// an adaptor pass, print with the op_name(sub_pass,...) format.
55 void Pass::printAsTextualPipeline(raw_ostream &os) {
56   // Special case for adaptors to use the 'op_name(sub_passes)' format.
57   if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(this)) {
58     llvm::interleaveComma(adaptor->getPassManagers(), os,
59                           [&](OpPassManager &pm) {
60                             os << pm.getOpName() << "(";
61                             pm.printAsTextualPipeline(os);
62                             os << ")";
63                           });
64     return;
65   }
66   // Otherwise, print the pass argument followed by its options. If the pass
67   // doesn't have an argument, print the name of the pass to give some indicator
68   // of what pass was run.
69   StringRef argument = getArgument();
70   if (!argument.empty())
71     os << argument;
72   else
73     os << "unknown<" << getName() << ">";
74   passOptions.print(os);
75 }
76 
77 //===----------------------------------------------------------------------===//
78 // OpPassManagerImpl
79 //===----------------------------------------------------------------------===//
80 
81 namespace mlir {
82 namespace detail {
83 struct OpPassManagerImpl {
84   OpPassManagerImpl(Identifier identifier, OpPassManager::Nesting nesting)
85       : name(identifier.str()), identifier(identifier), nesting(nesting) {}
86   OpPassManagerImpl(StringRef name, OpPassManager::Nesting nesting)
87       : name(name), nesting(nesting) {}
88 
89   /// Merge the passes of this pass manager into the one provided.
90   void mergeInto(OpPassManagerImpl &rhs);
91 
92   /// Nest a new operation pass manager for the given operation kind under this
93   /// pass manager.
94   OpPassManager &nest(Identifier nestedName);
95   OpPassManager &nest(StringRef nestedName);
96 
97   /// Add the given pass to this pass manager. If this pass has a concrete
98   /// operation type, it must be the same type as this pass manager.
99   void addPass(std::unique_ptr<Pass> pass);
100 
101   /// Coalesce adjacent AdaptorPasses into one large adaptor. This runs
102   /// recursively through the pipeline graph.
103   void coalesceAdjacentAdaptorPasses();
104 
105   /// Split all of AdaptorPasses such that each adaptor only contains one leaf
106   /// pass.
107   void splitAdaptorPasses();
108 
109   Identifier getOpName(MLIRContext &context) {
110     if (!identifier)
111       identifier = Identifier::get(name, &context);
112     return *identifier;
113   }
114 
115   /// The name of the operation that passes of this pass manager operate on.
116   std::string name;
117 
118   /// The cached identifier (internalized in the context) for the name of the
119   /// operation that passes of this pass manager operate on.
120   Optional<Identifier> identifier;
121 
122   /// The set of passes to run as part of this pass manager.
123   std::vector<std::unique_ptr<Pass>> passes;
124 
125   /// Control the implicit nesting of passes that mismatch the name set for this
126   /// OpPassManager.
127   OpPassManager::Nesting nesting;
128 };
129 } // end namespace detail
130 } // end namespace mlir
131 
132 void OpPassManagerImpl::mergeInto(OpPassManagerImpl &rhs) {
133   assert(name == rhs.name && "merging unrelated pass managers");
134   for (auto &pass : passes)
135     rhs.passes.push_back(std::move(pass));
136   passes.clear();
137 }
138 
139 OpPassManager &OpPassManagerImpl::nest(Identifier nestedName) {
140   OpPassManager nested(nestedName, nesting);
141   auto *adaptor = new OpToOpPassAdaptor(std::move(nested));
142   addPass(std::unique_ptr<Pass>(adaptor));
143   return adaptor->getPassManagers().front();
144 }
145 
146 OpPassManager &OpPassManagerImpl::nest(StringRef nestedName) {
147   OpPassManager nested(nestedName, nesting);
148   auto *adaptor = new OpToOpPassAdaptor(std::move(nested));
149   addPass(std::unique_ptr<Pass>(adaptor));
150   return adaptor->getPassManagers().front();
151 }
152 
153 void OpPassManagerImpl::addPass(std::unique_ptr<Pass> pass) {
154   // If this pass runs on a different operation than this pass manager, then
155   // implicitly nest a pass manager for this operation if enabled.
156   auto passOpName = pass->getOpName();
157   if (passOpName && passOpName->str() != name) {
158     if (nesting == OpPassManager::Nesting::Implicit)
159       return nest(*passOpName).addPass(std::move(pass));
160     llvm::report_fatal_error(llvm::Twine("Can't add pass '") + pass->getName() +
161                              "' restricted to '" + *passOpName +
162                              "' on a PassManager intended to run on '" + name +
163                              "', did you intend to nest?");
164   }
165 
166   passes.emplace_back(std::move(pass));
167 }
168 
169 void OpPassManagerImpl::coalesceAdjacentAdaptorPasses() {
170   // Bail out early if there are no adaptor passes.
171   if (llvm::none_of(passes, [](std::unique_ptr<Pass> &pass) {
172         return isa<OpToOpPassAdaptor>(pass.get());
173       }))
174     return;
175 
176   // Walk the pass list and merge adjacent adaptors.
177   OpToOpPassAdaptor *lastAdaptor = nullptr;
178   for (auto it = passes.begin(), e = passes.end(); it != e; ++it) {
179     // Check to see if this pass is an adaptor.
180     if (auto *currentAdaptor = dyn_cast<OpToOpPassAdaptor>(it->get())) {
181       // If it is the first adaptor in a possible chain, remember it and
182       // continue.
183       if (!lastAdaptor) {
184         lastAdaptor = currentAdaptor;
185         continue;
186       }
187 
188       // Otherwise, merge into the existing adaptor and delete the current one.
189       currentAdaptor->mergeInto(*lastAdaptor);
190       it->reset();
191     } else if (lastAdaptor) {
192       // If this pass is not an adaptor, then coalesce and forget any existing
193       // adaptor.
194       for (auto &pm : lastAdaptor->getPassManagers())
195         pm.getImpl().coalesceAdjacentAdaptorPasses();
196       lastAdaptor = nullptr;
197     }
198   }
199 
200   // If there was an adaptor at the end of the manager, coalesce it as well.
201   if (lastAdaptor) {
202     for (auto &pm : lastAdaptor->getPassManagers())
203       pm.getImpl().coalesceAdjacentAdaptorPasses();
204   }
205 
206   // Now that the adaptors have been merged, erase the empty slot corresponding
207   // to the merged adaptors that were nulled-out in the loop above.
208   llvm::erase_if(passes, std::logical_not<std::unique_ptr<Pass>>());
209 }
210 
211 void OpPassManagerImpl::splitAdaptorPasses() {
212   std::vector<std::unique_ptr<Pass>> oldPasses;
213   std::swap(passes, oldPasses);
214 
215   for (std::unique_ptr<Pass> &pass : oldPasses) {
216     // If this pass isn't an adaptor, move it directly to the new pass list.
217     auto *currentAdaptor = dyn_cast<OpToOpPassAdaptor>(pass.get());
218     if (!currentAdaptor) {
219       addPass(std::move(pass));
220       continue;
221     }
222 
223     // Otherwise, split the adaptors of each manager within the adaptor.
224     for (OpPassManager &adaptorPM : currentAdaptor->getPassManagers()) {
225       adaptorPM.getImpl().splitAdaptorPasses();
226       for (std::unique_ptr<Pass> &nestedPass : adaptorPM.getImpl().passes)
227         nest(adaptorPM.getOpName()).addPass(std::move(nestedPass));
228     }
229   }
230 }
231 
232 //===----------------------------------------------------------------------===//
233 // OpPassManager
234 //===----------------------------------------------------------------------===//
235 
236 OpPassManager::OpPassManager(Identifier name, Nesting nesting)
237     : impl(new OpPassManagerImpl(name, nesting)) {}
238 OpPassManager::OpPassManager(StringRef name, Nesting nesting)
239     : impl(new OpPassManagerImpl(name, nesting)) {}
240 OpPassManager::OpPassManager(OpPassManager &&rhs) : impl(std::move(rhs.impl)) {}
241 OpPassManager::OpPassManager(const OpPassManager &rhs) { *this = rhs; }
242 OpPassManager &OpPassManager::operator=(const OpPassManager &rhs) {
243   impl.reset(new OpPassManagerImpl(rhs.impl->name, rhs.impl->nesting));
244   for (auto &pass : rhs.impl->passes)
245     impl->passes.emplace_back(pass->clone());
246   return *this;
247 }
248 
249 OpPassManager::~OpPassManager() {}
250 
251 OpPassManager::pass_iterator OpPassManager::begin() {
252   return MutableArrayRef<std::unique_ptr<Pass>>{impl->passes}.begin();
253 }
254 OpPassManager::pass_iterator OpPassManager::end() {
255   return MutableArrayRef<std::unique_ptr<Pass>>{impl->passes}.end();
256 }
257 
258 OpPassManager::const_pass_iterator OpPassManager::begin() const {
259   return ArrayRef<std::unique_ptr<Pass>>{impl->passes}.begin();
260 }
261 OpPassManager::const_pass_iterator OpPassManager::end() const {
262   return ArrayRef<std::unique_ptr<Pass>>{impl->passes}.end();
263 }
264 
265 /// Nest a new operation pass manager for the given operation kind under this
266 /// pass manager.
267 OpPassManager &OpPassManager::nest(Identifier nestedName) {
268   return impl->nest(nestedName);
269 }
270 OpPassManager &OpPassManager::nest(StringRef nestedName) {
271   return impl->nest(nestedName);
272 }
273 
274 /// Add the given pass to this pass manager. If this pass has a concrete
275 /// operation type, it must be the same type as this pass manager.
276 void OpPassManager::addPass(std::unique_ptr<Pass> pass) {
277   impl->addPass(std::move(pass));
278 }
279 
280 /// Returns the number of passes held by this manager.
281 size_t OpPassManager::size() const { return impl->passes.size(); }
282 
283 /// Returns the internal implementation instance.
284 OpPassManagerImpl &OpPassManager::getImpl() { return *impl; }
285 
286 /// Return the operation name that this pass manager operates on.
287 StringRef OpPassManager::getOpName() const { return impl->name; }
288 
289 /// Return the operation name that this pass manager operates on.
290 Identifier OpPassManager::getOpName(MLIRContext &context) const {
291   return impl->getOpName(context);
292 }
293 
294 /// Prints out the given passes as the textual representation of a pipeline.
295 static void printAsTextualPipeline(ArrayRef<std::unique_ptr<Pass>> passes,
296                                    raw_ostream &os) {
297   llvm::interleaveComma(passes, os, [&](const std::unique_ptr<Pass> &pass) {
298     pass->printAsTextualPipeline(os);
299   });
300 }
301 
302 /// Prints out the passes of the pass manager as the textual representation
303 /// of pipelines.
304 void OpPassManager::printAsTextualPipeline(raw_ostream &os) {
305   ::printAsTextualPipeline(impl->passes, os);
306 }
307 
308 void OpPassManager::dump() {
309   llvm::errs() << "Pass Manager with " << impl->passes.size() << " passes: ";
310   ::printAsTextualPipeline(impl->passes, llvm::errs());
311   llvm::errs() << "\n";
312 }
313 
314 static void registerDialectsForPipeline(const OpPassManager &pm,
315                                         DialectRegistry &dialects) {
316   for (const Pass &pass : pm.getPasses())
317     pass.getDependentDialects(dialects);
318 }
319 
320 void OpPassManager::getDependentDialects(DialectRegistry &dialects) const {
321   registerDialectsForPipeline(*this, dialects);
322 }
323 
324 OpPassManager::Nesting OpPassManager::getNesting() { return impl->nesting; }
325 
326 void OpPassManager::setNesting(Nesting nesting) { impl->nesting = nesting; }
327 
328 //===----------------------------------------------------------------------===//
329 // OpToOpPassAdaptor
330 //===----------------------------------------------------------------------===//
331 
332 LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
333                                      AnalysisManager am, bool verifyPasses) {
334   if (!op->getName().getAbstractOperation())
335     return op->emitOpError()
336            << "trying to schedule a pass on an unregistered operation";
337   if (!op->getName().getAbstractOperation()->hasProperty(
338           OperationProperty::IsolatedFromAbove))
339     return op->emitOpError() << "trying to schedule a pass on an operation not "
340                                 "marked as 'IsolatedFromAbove'";
341 
342   // Initialize the pass state with a callback for the pass to dynamically
343   // execute a pipeline on the currently visited operation.
344   auto dynamic_pipeline_callback =
345       [op, &am, verifyPasses](OpPassManager &pipeline,
346                               Operation *root) -> LogicalResult {
347     if (!op->isAncestor(root))
348       return root->emitOpError()
349              << "Trying to schedule a dynamic pipeline on an "
350                 "operation that isn't "
351                 "nested under the current operation the pass is processing";
352 
353     AnalysisManager nestedAm = am.nest(root);
354     return OpToOpPassAdaptor::runPipeline(pipeline.getPasses(), root, nestedAm,
355                                           verifyPasses);
356   };
357   pass->passState.emplace(op, am, dynamic_pipeline_callback);
358   // Instrument before the pass has run.
359   PassInstrumentor *pi = am.getPassInstrumentor();
360   if (pi)
361     pi->runBeforePass(pass, op);
362 
363   // Invoke the virtual runOnOperation method.
364   if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass))
365     adaptor->runOnOperation(verifyPasses);
366   else
367     pass->runOnOperation();
368   bool passFailed = pass->passState->irAndPassFailed.getInt();
369 
370   // Invalidate any non preserved analyses.
371   am.invalidate(pass->passState->preservedAnalyses);
372 
373   // Run the verifier if this pass didn't fail already.
374   if (!passFailed && verifyPasses)
375     passFailed = failed(verify(op));
376 
377   // Instrument after the pass has run.
378   if (pi) {
379     if (passFailed)
380       pi->runAfterPassFailed(pass, op);
381     else
382       pi->runAfterPass(pass, op);
383   }
384 
385   // Return if the pass signaled a failure.
386   return failure(passFailed);
387 }
388 
389 /// Run the given operation and analysis manager on a provided op pass manager.
390 LogicalResult OpToOpPassAdaptor::runPipeline(
391     iterator_range<OpPassManager::pass_iterator> passes, Operation *op,
392     AnalysisManager am, bool verifyPasses) {
393   auto scope_exit = llvm::make_scope_exit([&] {
394     // Clear out any computed operation analyses. These analyses won't be used
395     // any more in this pipeline, and this helps reduce the current working set
396     // of memory. If preserving these analyses becomes important in the future
397     // we can re-evaluate this.
398     am.clear();
399   });
400 
401   // Run the pipeline over the provided operation.
402   for (Pass &pass : passes)
403     if (failed(run(&pass, op, am, verifyPasses)))
404       return failure();
405 
406   return success();
407 }
408 
409 /// Find an operation pass manager that can operate on an operation of the given
410 /// type, or nullptr if one does not exist.
411 static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,
412                                          StringRef name) {
413   auto it = llvm::find_if(
414       mgrs, [&](OpPassManager &mgr) { return mgr.getOpName() == name; });
415   return it == mgrs.end() ? nullptr : &*it;
416 }
417 
418 /// Find an operation pass manager that can operate on an operation of the given
419 /// type, or nullptr if one does not exist.
420 static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,
421                                          Identifier name,
422                                          MLIRContext &context) {
423   auto it = llvm::find_if(
424       mgrs, [&](OpPassManager &mgr) { return mgr.getOpName(context) == name; });
425   return it == mgrs.end() ? nullptr : &*it;
426 }
427 
428 OpToOpPassAdaptor::OpToOpPassAdaptor(OpPassManager &&mgr) {
429   mgrs.emplace_back(std::move(mgr));
430 }
431 
432 void OpToOpPassAdaptor::getDependentDialects(DialectRegistry &dialects) const {
433   for (auto &pm : mgrs)
434     pm.getDependentDialects(dialects);
435 }
436 
437 /// Merge the current pass adaptor into given 'rhs'.
438 void OpToOpPassAdaptor::mergeInto(OpToOpPassAdaptor &rhs) {
439   for (auto &pm : mgrs) {
440     // If an existing pass manager exists, then merge the given pass manager
441     // into it.
442     if (auto *existingPM = findPassManagerFor(rhs.mgrs, pm.getOpName())) {
443       pm.getImpl().mergeInto(existingPM->getImpl());
444     } else {
445       // Otherwise, add the given pass manager to the list.
446       rhs.mgrs.emplace_back(std::move(pm));
447     }
448   }
449   mgrs.clear();
450 
451   // After coalescing, sort the pass managers within rhs by name.
452   llvm::array_pod_sort(rhs.mgrs.begin(), rhs.mgrs.end(),
453                        [](const OpPassManager *lhs, const OpPassManager *rhs) {
454                          return lhs->getOpName().compare(rhs->getOpName());
455                        });
456 }
457 
458 /// Returns the adaptor pass name.
459 std::string OpToOpPassAdaptor::getAdaptorName() {
460   std::string name = "Pipeline Collection : [";
461   llvm::raw_string_ostream os(name);
462   llvm::interleaveComma(getPassManagers(), os, [&](OpPassManager &pm) {
463     os << '\'' << pm.getOpName() << '\'';
464   });
465   os << ']';
466   return os.str();
467 }
468 
469 void OpToOpPassAdaptor::runOnOperation() {
470   llvm_unreachable(
471       "Unexpected call to Pass::runOnOperation() on OpToOpPassAdaptor");
472 }
473 
474 /// Run the held pipeline over all nested operations.
475 void OpToOpPassAdaptor::runOnOperation(bool verifyPasses) {
476   if (getContext().isMultithreadingEnabled())
477     runOnOperationAsyncImpl(verifyPasses);
478   else
479     runOnOperationImpl(verifyPasses);
480 }
481 
482 /// Run this pass adaptor synchronously.
483 void OpToOpPassAdaptor::runOnOperationImpl(bool verifyPasses) {
484   auto am = getAnalysisManager();
485   PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(),
486                                                         this};
487   auto *instrumentor = am.getPassInstrumentor();
488   for (auto &region : getOperation()->getRegions()) {
489     for (auto &block : region) {
490       for (auto &op : block) {
491         auto *mgr = findPassManagerFor(mgrs, op.getName().getIdentifier(),
492                                        *op.getContext());
493         if (!mgr)
494           continue;
495         Identifier opName = mgr->getOpName(*getOperation()->getContext());
496 
497         // Run the held pipeline over the current operation.
498         if (instrumentor)
499           instrumentor->runBeforePipeline(opName, parentInfo);
500         LogicalResult result =
501             runPipeline(mgr->getPasses(), &op, am.nest(&op), verifyPasses);
502         if (instrumentor)
503           instrumentor->runAfterPipeline(opName, parentInfo);
504 
505         if (failed(result))
506           return signalPassFailure();
507       }
508     }
509   }
510 }
511 
512 /// Utility functor that checks if the two ranges of pass managers have a size
513 /// mismatch.
514 static bool hasSizeMismatch(ArrayRef<OpPassManager> lhs,
515                             ArrayRef<OpPassManager> rhs) {
516   return lhs.size() != rhs.size() ||
517          llvm::any_of(llvm::seq<size_t>(0, lhs.size()),
518                       [&](size_t i) { return lhs[i].size() != rhs[i].size(); });
519 }
520 
521 /// Run this pass adaptor synchronously.
522 void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) {
523   AnalysisManager am = getAnalysisManager();
524 
525   // Create the async executors if they haven't been created, or if the main
526   // pipeline has changed.
527   if (asyncExecutors.empty() || hasSizeMismatch(asyncExecutors.front(), mgrs))
528     asyncExecutors.assign(llvm::hardware_concurrency().compute_thread_count(),
529                           mgrs);
530 
531   // Run a prepass over the module to collect the operations to execute over.
532   // This ensures that an analysis manager exists for each operation, as well as
533   // providing a queue of operations to execute over.
534   std::vector<std::pair<Operation *, AnalysisManager>> opAMPairs;
535   for (auto &region : getOperation()->getRegions()) {
536     for (auto &block : region) {
537       for (auto &op : block) {
538         // Add this operation iff the name matches any of the pass managers.
539         if (findPassManagerFor(mgrs, op.getName().getIdentifier(),
540                                getContext()))
541           opAMPairs.emplace_back(&op, am.nest(&op));
542       }
543     }
544   }
545 
546   // A parallel diagnostic handler that provides deterministic diagnostic
547   // ordering.
548   ParallelDiagnosticHandler diagHandler(&getContext());
549 
550   // An index for the current operation/analysis manager pair.
551   std::atomic<unsigned> opIt(0);
552 
553   // Get the current thread for this adaptor.
554   PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(),
555                                                         this};
556   auto *instrumentor = am.getPassInstrumentor();
557 
558   // An atomic failure variable for the async executors.
559   std::atomic<bool> passFailed(false);
560   llvm::parallelForEach(
561       asyncExecutors.begin(),
562       std::next(asyncExecutors.begin(),
563                 std::min(asyncExecutors.size(), opAMPairs.size())),
564       [&](MutableArrayRef<OpPassManager> pms) {
565         for (auto e = opAMPairs.size(); !passFailed && opIt < e;) {
566           // Get the next available operation index.
567           unsigned nextID = opIt++;
568           if (nextID >= e)
569             break;
570 
571           // Set the order id for this thread in the diagnostic handler.
572           diagHandler.setOrderIDForThread(nextID);
573 
574           // Get the pass manager for this operation and execute it.
575           auto &it = opAMPairs[nextID];
576           auto *pm = findPassManagerFor(
577               pms, it.first->getName().getIdentifier(), getContext());
578           assert(pm && "expected valid pass manager for operation");
579 
580           Identifier opName = pm->getOpName(*getOperation()->getContext());
581           if (instrumentor)
582             instrumentor->runBeforePipeline(opName, parentInfo);
583           auto pipelineResult =
584               runPipeline(pm->getPasses(), it.first, it.second, verifyPasses);
585           if (instrumentor)
586             instrumentor->runAfterPipeline(opName, parentInfo);
587 
588           // Drop this thread from being tracked by the diagnostic handler.
589           // After this task has finished, the thread may be used outside of
590           // this pass manager context meaning that we don't want to track
591           // diagnostics from it anymore.
592           diagHandler.eraseOrderIDForThread();
593 
594           // Handle a failed pipeline result.
595           if (failed(pipelineResult)) {
596             passFailed = true;
597             break;
598           }
599         }
600       });
601 
602   // Signal a failure if any of the executors failed.
603   if (passFailed)
604     signalPassFailure();
605 }
606 
607 //===----------------------------------------------------------------------===//
608 // PassCrashReproducer
609 //===----------------------------------------------------------------------===//
610 
611 namespace {
612 /// This class contains all of the context for generating a recovery reproducer.
613 /// Each recovery context is registered globally to allow for generating
614 /// reproducers when a signal is raised, such as a segfault.
615 struct RecoveryReproducerContext {
616   RecoveryReproducerContext(MutableArrayRef<std::unique_ptr<Pass>> passes,
617                             ModuleOp module, StringRef filename,
618                             bool disableThreads, bool verifyPasses);
619   ~RecoveryReproducerContext();
620 
621   /// Generate a reproducer with the current context.
622   LogicalResult generate(std::string &error);
623 
624 private:
625   /// This function is invoked in the event of a crash.
626   static void crashHandler(void *);
627 
628   /// Register a signal handler to run in the event of a crash.
629   static void registerSignalHandler();
630 
631   /// The textual description of the currently executing pipeline.
632   std::string pipeline;
633 
634   /// The MLIR module representing the IR before the crash.
635   OwningModuleRef module;
636 
637   /// The filename to use when generating the reproducer.
638   StringRef filename;
639 
640   /// Various pass manager and context flags.
641   bool disableThreads;
642   bool verifyPasses;
643 
644   /// The current set of active reproducer contexts. This is used in the event
645   /// of a crash. This is not thread_local as the pass manager may produce any
646   /// number of child threads. This uses a set to allow for multiple MLIR pass
647   /// managers to be running at the same time.
648   static llvm::ManagedStatic<llvm::sys::SmartMutex<true>> reproducerMutex;
649   static llvm::ManagedStatic<
650       llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
651       reproducerSet;
652 };
653 } // end anonymous namespace
654 
655 llvm::ManagedStatic<llvm::sys::SmartMutex<true>>
656     RecoveryReproducerContext::reproducerMutex;
657 llvm::ManagedStatic<llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
658     RecoveryReproducerContext::reproducerSet;
659 
660 RecoveryReproducerContext::RecoveryReproducerContext(
661     MutableArrayRef<std::unique_ptr<Pass>> passes, ModuleOp module,
662     StringRef filename, bool disableThreads, bool verifyPasses)
663     : module(module.clone()), filename(filename),
664       disableThreads(disableThreads), verifyPasses(verifyPasses) {
665   // Grab the textual pipeline being executed..
666   {
667     llvm::raw_string_ostream pipelineOS(pipeline);
668     ::printAsTextualPipeline(passes, pipelineOS);
669   }
670 
671   // Make sure that the handler is registered, and update the current context.
672   llvm::sys::SmartScopedLock<true> producerLock(*reproducerMutex);
673   if (reproducerSet->empty())
674     llvm::CrashRecoveryContext::Enable();
675   registerSignalHandler();
676   reproducerSet->insert(this);
677 }
678 
679 RecoveryReproducerContext::~RecoveryReproducerContext() {
680   llvm::sys::SmartScopedLock<true> producerLock(*reproducerMutex);
681   reproducerSet->remove(this);
682   if (reproducerSet->empty())
683     llvm::CrashRecoveryContext::Disable();
684 }
685 
686 LogicalResult RecoveryReproducerContext::generate(std::string &error) {
687   std::unique_ptr<llvm::ToolOutputFile> outputFile =
688       mlir::openOutputFile(filename, &error);
689   if (!outputFile)
690     return failure();
691   auto &outputOS = outputFile->os();
692 
693   // Output the current pass manager configuration.
694   outputOS << "// configuration: -pass-pipeline='" << pipeline << "'";
695   if (disableThreads)
696     outputOS << " -mlir-disable-threading";
697 
698   // TODO: Should this also be configured with a pass manager flag?
699   outputOS << "\n// note: verifyPasses=" << (verifyPasses ? "true" : "false")
700            << "\n";
701 
702   // Output the .mlir module.
703   module->print(outputOS);
704   outputFile->keep();
705   return success();
706 }
707 
708 void RecoveryReproducerContext::crashHandler(void *) {
709   // Walk the current stack of contexts and generate a reproducer for each one.
710   // We can't know for certain which one was the cause, so we need to generate
711   // a reproducer for all of them.
712   std::string ignored;
713   for (RecoveryReproducerContext *context : *reproducerSet)
714     context->generate(ignored);
715 }
716 
717 void RecoveryReproducerContext::registerSignalHandler() {
718   // Ensure that the handler is only registered once.
719   static bool registered =
720       (llvm::sys::AddSignalHandler(crashHandler, nullptr), false);
721   (void)registered;
722 }
723 
724 /// Run the pass manager with crash recover enabled.
725 LogicalResult PassManager::runWithCrashRecovery(ModuleOp module,
726                                                 AnalysisManager am) {
727   // If this isn't a local producer, run all of the passes in recovery mode.
728   if (!localReproducer)
729     return runWithCrashRecovery(impl->passes, module, am);
730 
731   // Split the passes within adaptors to ensure that each pass can be run in
732   // isolation.
733   impl->splitAdaptorPasses();
734 
735   // If this is a local producer, run each of the passes individually.
736   MutableArrayRef<std::unique_ptr<Pass>> passes = impl->passes;
737   for (std::unique_ptr<Pass> &pass : passes)
738     if (failed(runWithCrashRecovery(pass, module, am)))
739       return failure();
740   return success();
741 }
742 
743 /// Run the given passes with crash recover enabled.
744 LogicalResult
745 PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
746                                   ModuleOp module, AnalysisManager am) {
747   RecoveryReproducerContext context(passes, module, *crashReproducerFileName,
748                                     !getContext()->isMultithreadingEnabled(),
749                                     verifyPasses);
750 
751   // Safely invoke the passes within a recovery context.
752   LogicalResult passManagerResult = failure();
753   llvm::CrashRecoveryContext recoveryContext;
754   recoveryContext.RunSafelyOnThread([&] {
755     for (std::unique_ptr<Pass> &pass : passes)
756       if (failed(OpToOpPassAdaptor::run(pass.get(), module, am, verifyPasses)))
757         return;
758     passManagerResult = success();
759   });
760   if (succeeded(passManagerResult))
761     return success();
762 
763   std::string error;
764   if (failed(context.generate(error)))
765     return module.emitError("<MLIR-PassManager-Crash-Reproducer>: ") << error;
766   return module.emitError()
767          << "A failure has been detected while processing the MLIR module, a "
768             "reproducer has been generated in '"
769          << *crashReproducerFileName << "'";
770 }
771 
772 //===----------------------------------------------------------------------===//
773 // PassManager
774 //===----------------------------------------------------------------------===//
775 
776 PassManager::PassManager(MLIRContext *ctx, Nesting nesting)
777     : OpPassManager(Identifier::get(ModuleOp::getOperationName(), ctx),
778                     nesting),
779       context(ctx), passTiming(false), localReproducer(false),
780       verifyPasses(true) {}
781 
782 PassManager::~PassManager() {}
783 
784 void PassManager::enableVerifier(bool enabled) { verifyPasses = enabled; }
785 
786 /// Run the passes within this manager on the provided module.
787 LogicalResult PassManager::run(ModuleOp module) {
788   // Before running, make sure to coalesce any adjacent pass adaptors in the
789   // pipeline.
790   getImpl().coalesceAdjacentAdaptorPasses();
791 
792   // Register all dialects for the current pipeline.
793   DialectRegistry dependentDialects;
794   getDependentDialects(dependentDialects);
795   dependentDialects.loadAll(module.getContext());
796 
797   // Construct an analysis manager for the pipeline.
798   ModuleAnalysisManager am(module, instrumentor.get());
799 
800   // Notify the context that we start running a pipeline for book keeping.
801   module.getContext()->enterMultiThreadedExecution();
802 
803   // If reproducer generation is enabled, run the pass manager with crash
804   // handling enabled.
805   LogicalResult result = crashReproducerFileName
806                              ? runWithCrashRecovery(module, am)
807                              : OpToOpPassAdaptor::runPipeline(
808                                    getPasses(), module, am, verifyPasses);
809 
810   // Notify the context that the run is done.
811   module.getContext()->exitMultiThreadedExecution();
812 
813   // Dump all of the pass statistics if necessary.
814   if (passStatisticsMode)
815     dumpStatistics();
816   return result;
817 }
818 
819 /// Enable support for the pass manager to generate a reproducer on the event
820 /// of a crash or a pass failure. `outputFile` is a .mlir filename used to write
821 /// the generated reproducer. If `genLocalReproducer` is true, the pass manager
822 /// will attempt to generate a local reproducer that contains the smallest
823 /// pipeline.
824 void PassManager::enableCrashReproducerGeneration(StringRef outputFile,
825                                                   bool genLocalReproducer) {
826   crashReproducerFileName = std::string(outputFile);
827   localReproducer = genLocalReproducer;
828 }
829 
830 /// Add the provided instrumentation to the pass manager.
831 void PassManager::addInstrumentation(std::unique_ptr<PassInstrumentation> pi) {
832   if (!instrumentor)
833     instrumentor = std::make_unique<PassInstrumentor>();
834 
835   instrumentor->addInstrumentation(std::move(pi));
836 }
837 
838 //===----------------------------------------------------------------------===//
839 // AnalysisManager
840 //===----------------------------------------------------------------------===//
841 
842 /// Returns a pass instrumentation object for the current operation.
843 PassInstrumentor *AnalysisManager::getPassInstrumentor() const {
844   ParentPointerT curParent = parent;
845   while (auto *parentAM = curParent.dyn_cast<const AnalysisManager *>())
846     curParent = parentAM->parent;
847   return curParent.get<const ModuleAnalysisManager *>()->getPassInstrumentor();
848 }
849 
850 /// Get an analysis manager for the given child operation.
851 AnalysisManager AnalysisManager::nest(Operation *op) {
852   auto it = impl->childAnalyses.find(op);
853   if (it == impl->childAnalyses.end())
854     it = impl->childAnalyses
855              .try_emplace(op, std::make_unique<NestedAnalysisMap>(op))
856              .first;
857   return {this, it->second.get()};
858 }
859 
860 /// Invalidate any non preserved analyses.
861 void detail::NestedAnalysisMap::invalidate(
862     const detail::PreservedAnalyses &pa) {
863   // If all analyses were preserved, then there is nothing to do here.
864   if (pa.isAll())
865     return;
866 
867   // Invalidate the analyses for the current operation directly.
868   analyses.invalidate(pa);
869 
870   // If no analyses were preserved, then just simply clear out the child
871   // analysis results.
872   if (pa.isNone()) {
873     childAnalyses.clear();
874     return;
875   }
876 
877   // Otherwise, invalidate each child analysis map.
878   SmallVector<NestedAnalysisMap *, 8> mapsToInvalidate(1, this);
879   while (!mapsToInvalidate.empty()) {
880     auto *map = mapsToInvalidate.pop_back_val();
881     for (auto &analysisPair : map->childAnalyses) {
882       analysisPair.second->invalidate(pa);
883       if (!analysisPair.second->childAnalyses.empty())
884         mapsToInvalidate.push_back(analysisPair.second.get());
885     }
886   }
887 }
888 
889 //===----------------------------------------------------------------------===//
890 // PassInstrumentation
891 //===----------------------------------------------------------------------===//
892 
893 PassInstrumentation::~PassInstrumentation() {}
894 
895 //===----------------------------------------------------------------------===//
896 // PassInstrumentor
897 //===----------------------------------------------------------------------===//
898 
899 namespace mlir {
900 namespace detail {
901 struct PassInstrumentorImpl {
902   /// Mutex to keep instrumentation access thread-safe.
903   llvm::sys::SmartMutex<true> mutex;
904 
905   /// Set of registered instrumentations.
906   std::vector<std::unique_ptr<PassInstrumentation>> instrumentations;
907 };
908 } // end namespace detail
909 } // end namespace mlir
910 
911 PassInstrumentor::PassInstrumentor() : impl(new PassInstrumentorImpl()) {}
912 PassInstrumentor::~PassInstrumentor() {}
913 
914 /// See PassInstrumentation::runBeforePipeline for details.
915 void PassInstrumentor::runBeforePipeline(
916     Identifier name,
917     const PassInstrumentation::PipelineParentInfo &parentInfo) {
918   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
919   for (auto &instr : impl->instrumentations)
920     instr->runBeforePipeline(name, parentInfo);
921 }
922 
923 /// See PassInstrumentation::runAfterPipeline for details.
924 void PassInstrumentor::runAfterPipeline(
925     Identifier name,
926     const PassInstrumentation::PipelineParentInfo &parentInfo) {
927   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
928   for (auto &instr : llvm::reverse(impl->instrumentations))
929     instr->runAfterPipeline(name, parentInfo);
930 }
931 
932 /// See PassInstrumentation::runBeforePass for details.
933 void PassInstrumentor::runBeforePass(Pass *pass, Operation *op) {
934   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
935   for (auto &instr : impl->instrumentations)
936     instr->runBeforePass(pass, op);
937 }
938 
939 /// See PassInstrumentation::runAfterPass for details.
940 void PassInstrumentor::runAfterPass(Pass *pass, Operation *op) {
941   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
942   for (auto &instr : llvm::reverse(impl->instrumentations))
943     instr->runAfterPass(pass, op);
944 }
945 
946 /// See PassInstrumentation::runAfterPassFailed for details.
947 void PassInstrumentor::runAfterPassFailed(Pass *pass, Operation *op) {
948   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
949   for (auto &instr : llvm::reverse(impl->instrumentations))
950     instr->runAfterPassFailed(pass, op);
951 }
952 
953 /// See PassInstrumentation::runBeforeAnalysis for details.
954 void PassInstrumentor::runBeforeAnalysis(StringRef name, TypeID id,
955                                          Operation *op) {
956   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
957   for (auto &instr : impl->instrumentations)
958     instr->runBeforeAnalysis(name, id, op);
959 }
960 
961 /// See PassInstrumentation::runAfterAnalysis for details.
962 void PassInstrumentor::runAfterAnalysis(StringRef name, TypeID id,
963                                         Operation *op) {
964   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
965   for (auto &instr : llvm::reverse(impl->instrumentations))
966     instr->runAfterAnalysis(name, id, op);
967 }
968 
969 /// Add the given instrumentation to the collection.
970 void PassInstrumentor::addInstrumentation(
971     std::unique_ptr<PassInstrumentation> pi) {
972   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
973   impl->instrumentations.emplace_back(std::move(pi));
974 }
975