xref: /llvm-project-15.0.7/mlir/lib/Pass/Pass.cpp (revision c83cd8fe)
1 //===- Pass.cpp - Pass infrastructure implementation ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements common pass infrastructure.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Pass/Pass.h"
14 #include "PassDetail.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/Verifier.h"
18 #include "mlir/Support/FileUtilities.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/ScopeExit.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "llvm/Support/CommandLine.h"
23 #include "llvm/Support/CrashRecoveryContext.h"
24 #include "llvm/Support/Mutex.h"
25 #include "llvm/Support/Parallel.h"
26 #include "llvm/Support/Signals.h"
27 #include "llvm/Support/Threading.h"
28 #include "llvm/Support/ToolOutputFile.h"
29 
30 using namespace mlir;
31 using namespace mlir::detail;
32 
33 //===----------------------------------------------------------------------===//
34 // Pass
35 //===----------------------------------------------------------------------===//
36 
37 /// Out of line virtual method to ensure vtables and metadata are emitted to a
38 /// single .o file.
39 void Pass::anchor() {}
40 
41 /// Attempt to initialize the options of this pass from the given string.
42 LogicalResult Pass::initializeOptions(StringRef options) {
43   return passOptions.parseFromString(options);
44 }
45 
46 /// Copy the option values from 'other', which is another instance of this
47 /// pass.
48 void Pass::copyOptionValuesFrom(const Pass *other) {
49   passOptions.copyOptionValuesFrom(other->passOptions);
50 }
51 
52 /// Prints out the pass in the textual representation of pipelines. If this is
53 /// an adaptor pass, print with the op_name(sub_pass,...) format.
54 void Pass::printAsTextualPipeline(raw_ostream &os) {
55   // Special case for adaptors to use the 'op_name(sub_passes)' format.
56   if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(this)) {
57     llvm::interleaveComma(adaptor->getPassManagers(), os,
58                           [&](OpPassManager &pm) {
59                             os << pm.getOpName() << "(";
60                             pm.printAsTextualPipeline(os);
61                             os << ")";
62                           });
63     return;
64   }
65   // Otherwise, print the pass argument followed by its options. If the pass
66   // doesn't have an argument, print the name of the pass to give some indicator
67   // of what pass was run.
68   StringRef argument = getArgument();
69   if (!argument.empty())
70     os << argument;
71   else
72     os << "unknown<" << getName() << ">";
73   passOptions.print(os);
74 }
75 
76 //===----------------------------------------------------------------------===//
77 // OpPassManagerImpl
78 //===----------------------------------------------------------------------===//
79 
80 namespace mlir {
81 namespace detail {
82 struct OpPassManagerImpl {
83   OpPassManagerImpl(Identifier identifier, OpPassManager::Nesting nesting)
84       : name(identifier.str()), identifier(identifier),
85         initializationGeneration(0), nesting(nesting) {}
86   OpPassManagerImpl(StringRef name, OpPassManager::Nesting nesting)
87       : name(name), initializationGeneration(0), nesting(nesting) {}
88 
89   /// Merge the passes of this pass manager into the one provided.
90   void mergeInto(OpPassManagerImpl &rhs);
91 
92   /// Nest a new operation pass manager for the given operation kind under this
93   /// pass manager.
94   OpPassManager &nest(Identifier nestedName);
95   OpPassManager &nest(StringRef nestedName);
96 
97   /// Add the given pass to this pass manager. If this pass has a concrete
98   /// operation type, it must be the same type as this pass manager.
99   void addPass(std::unique_ptr<Pass> pass);
100 
101   /// Coalesce adjacent AdaptorPasses into one large adaptor. This runs
102   /// recursively through the pipeline graph.
103   void coalesceAdjacentAdaptorPasses();
104 
105   /// Split all of AdaptorPasses such that each adaptor only contains one leaf
106   /// pass.
107   void splitAdaptorPasses();
108 
109   /// Return the operation name of this pass manager as an identifier.
110   Identifier getOpName(MLIRContext &context) {
111     if (!identifier)
112       identifier = Identifier::get(name, &context);
113     return *identifier;
114   }
115 
116   /// The name of the operation that passes of this pass manager operate on.
117   std::string name;
118 
119   /// The cached identifier (internalized in the context) for the name of the
120   /// operation that passes of this pass manager operate on.
121   Optional<Identifier> identifier;
122 
123   /// The set of passes to run as part of this pass manager.
124   std::vector<std::unique_ptr<Pass>> passes;
125 
126   /// The current initialization generation of this pass manager. This is used
127   /// to indicate when a pass manager should be reinitialized.
128   unsigned initializationGeneration;
129 
130   /// Control the implicit nesting of passes that mismatch the name set for this
131   /// OpPassManager.
132   OpPassManager::Nesting nesting;
133 };
134 } // end namespace detail
135 } // end namespace mlir
136 
137 void OpPassManagerImpl::mergeInto(OpPassManagerImpl &rhs) {
138   assert(name == rhs.name && "merging unrelated pass managers");
139   for (auto &pass : passes)
140     rhs.passes.push_back(std::move(pass));
141   passes.clear();
142 }
143 
144 OpPassManager &OpPassManagerImpl::nest(Identifier nestedName) {
145   OpPassManager nested(nestedName, nesting);
146   auto *adaptor = new OpToOpPassAdaptor(std::move(nested));
147   addPass(std::unique_ptr<Pass>(adaptor));
148   return adaptor->getPassManagers().front();
149 }
150 
151 OpPassManager &OpPassManagerImpl::nest(StringRef nestedName) {
152   OpPassManager nested(nestedName, nesting);
153   auto *adaptor = new OpToOpPassAdaptor(std::move(nested));
154   addPass(std::unique_ptr<Pass>(adaptor));
155   return adaptor->getPassManagers().front();
156 }
157 
158 void OpPassManagerImpl::addPass(std::unique_ptr<Pass> pass) {
159   // If this pass runs on a different operation than this pass manager, then
160   // implicitly nest a pass manager for this operation if enabled.
161   auto passOpName = pass->getOpName();
162   if (passOpName && passOpName->str() != name) {
163     if (nesting == OpPassManager::Nesting::Implicit)
164       return nest(*passOpName).addPass(std::move(pass));
165     llvm::report_fatal_error(llvm::Twine("Can't add pass '") + pass->getName() +
166                              "' restricted to '" + *passOpName +
167                              "' on a PassManager intended to run on '" + name +
168                              "', did you intend to nest?");
169   }
170 
171   passes.emplace_back(std::move(pass));
172 }
173 
174 void OpPassManagerImpl::coalesceAdjacentAdaptorPasses() {
175   // Bail out early if there are no adaptor passes.
176   if (llvm::none_of(passes, [](std::unique_ptr<Pass> &pass) {
177         return isa<OpToOpPassAdaptor>(pass.get());
178       }))
179     return;
180 
181   // Walk the pass list and merge adjacent adaptors.
182   OpToOpPassAdaptor *lastAdaptor = nullptr;
183   for (auto it = passes.begin(), e = passes.end(); it != e; ++it) {
184     // Check to see if this pass is an adaptor.
185     if (auto *currentAdaptor = dyn_cast<OpToOpPassAdaptor>(it->get())) {
186       // If it is the first adaptor in a possible chain, remember it and
187       // continue.
188       if (!lastAdaptor) {
189         lastAdaptor = currentAdaptor;
190         continue;
191       }
192 
193       // Otherwise, merge into the existing adaptor and delete the current one.
194       currentAdaptor->mergeInto(*lastAdaptor);
195       it->reset();
196     } else if (lastAdaptor) {
197       // If this pass is not an adaptor, then coalesce and forget any existing
198       // adaptor.
199       for (auto &pm : lastAdaptor->getPassManagers())
200         pm.getImpl().coalesceAdjacentAdaptorPasses();
201       lastAdaptor = nullptr;
202     }
203   }
204 
205   // If there was an adaptor at the end of the manager, coalesce it as well.
206   if (lastAdaptor) {
207     for (auto &pm : lastAdaptor->getPassManagers())
208       pm.getImpl().coalesceAdjacentAdaptorPasses();
209   }
210 
211   // Now that the adaptors have been merged, erase the empty slot corresponding
212   // to the merged adaptors that were nulled-out in the loop above.
213   llvm::erase_if(passes, std::logical_not<std::unique_ptr<Pass>>());
214 }
215 
216 void OpPassManagerImpl::splitAdaptorPasses() {
217   std::vector<std::unique_ptr<Pass>> oldPasses;
218   std::swap(passes, oldPasses);
219 
220   for (std::unique_ptr<Pass> &pass : oldPasses) {
221     // If this pass isn't an adaptor, move it directly to the new pass list.
222     auto *currentAdaptor = dyn_cast<OpToOpPassAdaptor>(pass.get());
223     if (!currentAdaptor) {
224       addPass(std::move(pass));
225       continue;
226     }
227 
228     // Otherwise, split the adaptors of each manager within the adaptor.
229     for (OpPassManager &adaptorPM : currentAdaptor->getPassManagers()) {
230       adaptorPM.getImpl().splitAdaptorPasses();
231       for (std::unique_ptr<Pass> &nestedPass : adaptorPM.getImpl().passes)
232         nest(adaptorPM.getOpName()).addPass(std::move(nestedPass));
233     }
234   }
235 }
236 
237 //===----------------------------------------------------------------------===//
238 // OpPassManager
239 //===----------------------------------------------------------------------===//
240 
241 OpPassManager::OpPassManager(Identifier name, Nesting nesting)
242     : impl(new OpPassManagerImpl(name, nesting)) {}
243 OpPassManager::OpPassManager(StringRef name, Nesting nesting)
244     : impl(new OpPassManagerImpl(name, nesting)) {}
245 OpPassManager::OpPassManager(OpPassManager &&rhs) : impl(std::move(rhs.impl)) {}
246 OpPassManager::OpPassManager(const OpPassManager &rhs) { *this = rhs; }
247 OpPassManager &OpPassManager::operator=(const OpPassManager &rhs) {
248   impl.reset(new OpPassManagerImpl(rhs.impl->name, rhs.impl->nesting));
249   impl->initializationGeneration = rhs.impl->initializationGeneration;
250   for (auto &pass : rhs.impl->passes)
251     impl->passes.emplace_back(pass->clone());
252   return *this;
253 }
254 
255 OpPassManager::~OpPassManager() {}
256 
257 OpPassManager::pass_iterator OpPassManager::begin() {
258   return MutableArrayRef<std::unique_ptr<Pass>>{impl->passes}.begin();
259 }
260 OpPassManager::pass_iterator OpPassManager::end() {
261   return MutableArrayRef<std::unique_ptr<Pass>>{impl->passes}.end();
262 }
263 
264 OpPassManager::const_pass_iterator OpPassManager::begin() const {
265   return ArrayRef<std::unique_ptr<Pass>>{impl->passes}.begin();
266 }
267 OpPassManager::const_pass_iterator OpPassManager::end() const {
268   return ArrayRef<std::unique_ptr<Pass>>{impl->passes}.end();
269 }
270 
271 /// Nest a new operation pass manager for the given operation kind under this
272 /// pass manager.
273 OpPassManager &OpPassManager::nest(Identifier nestedName) {
274   return impl->nest(nestedName);
275 }
276 OpPassManager &OpPassManager::nest(StringRef nestedName) {
277   return impl->nest(nestedName);
278 }
279 
280 /// Add the given pass to this pass manager. If this pass has a concrete
281 /// operation type, it must be the same type as this pass manager.
282 void OpPassManager::addPass(std::unique_ptr<Pass> pass) {
283   impl->addPass(std::move(pass));
284 }
285 
286 /// Returns the number of passes held by this manager.
287 size_t OpPassManager::size() const { return impl->passes.size(); }
288 
289 /// Returns the internal implementation instance.
290 OpPassManagerImpl &OpPassManager::getImpl() { return *impl; }
291 
292 /// Return the operation name that this pass manager operates on.
293 StringRef OpPassManager::getOpName() const { return impl->name; }
294 
295 /// Return the operation name that this pass manager operates on.
296 Identifier OpPassManager::getOpName(MLIRContext &context) const {
297   return impl->getOpName(context);
298 }
299 
300 /// Prints out the given passes as the textual representation of a pipeline.
301 static void printAsTextualPipeline(ArrayRef<std::unique_ptr<Pass>> passes,
302                                    raw_ostream &os) {
303   llvm::interleaveComma(passes, os, [&](const std::unique_ptr<Pass> &pass) {
304     pass->printAsTextualPipeline(os);
305   });
306 }
307 
308 /// Prints out the passes of the pass manager as the textual representation
309 /// of pipelines.
310 void OpPassManager::printAsTextualPipeline(raw_ostream &os) {
311   ::printAsTextualPipeline(impl->passes, os);
312 }
313 
314 void OpPassManager::dump() {
315   llvm::errs() << "Pass Manager with " << impl->passes.size() << " passes: ";
316   ::printAsTextualPipeline(impl->passes, llvm::errs());
317   llvm::errs() << "\n";
318 }
319 
320 static void registerDialectsForPipeline(const OpPassManager &pm,
321                                         DialectRegistry &dialects) {
322   for (const Pass &pass : pm.getPasses())
323     pass.getDependentDialects(dialects);
324 }
325 
326 void OpPassManager::getDependentDialects(DialectRegistry &dialects) const {
327   registerDialectsForPipeline(*this, dialects);
328 }
329 
330 void OpPassManager::setNesting(Nesting nesting) { impl->nesting = nesting; }
331 
332 OpPassManager::Nesting OpPassManager::getNesting() { return impl->nesting; }
333 
334 LogicalResult OpPassManager::initialize(MLIRContext *context,
335                                         unsigned newInitGeneration) {
336   if (impl->initializationGeneration == newInitGeneration)
337     return success();
338   impl->initializationGeneration = newInitGeneration;
339   for (Pass &pass : getPasses()) {
340     // If this pass isn't an adaptor, directly initialize it.
341     auto *adaptor = dyn_cast<OpToOpPassAdaptor>(&pass);
342     if (!adaptor) {
343       if (failed(pass.initialize(context)))
344         return failure();
345       continue;
346     }
347 
348     // Otherwise, initialize each of the adaptors pass managers.
349     for (OpPassManager &adaptorPM : adaptor->getPassManagers())
350       if (failed(adaptorPM.initialize(context, newInitGeneration)))
351         return failure();
352   }
353   return success();
354 }
355 
356 //===----------------------------------------------------------------------===//
357 // OpToOpPassAdaptor
358 //===----------------------------------------------------------------------===//
359 
360 LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
361                                      AnalysisManager am, bool verifyPasses,
362                                      unsigned parentInitGeneration) {
363   if (!op->isRegistered())
364     return op->emitOpError()
365            << "trying to schedule a pass on an unregistered operation";
366   if (!op->hasTrait<OpTrait::IsIsolatedFromAbove>())
367     return op->emitOpError() << "trying to schedule a pass on an operation not "
368                                 "marked as 'IsolatedFromAbove'";
369 
370   // Initialize the pass state with a callback for the pass to dynamically
371   // execute a pipeline on the currently visited operation.
372   PassInstrumentor *pi = am.getPassInstrumentor();
373   PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(),
374                                                         pass};
375   auto dynamic_pipeline_callback = [&](OpPassManager &pipeline,
376                                        Operation *root) -> LogicalResult {
377     if (!op->isAncestor(root))
378       return root->emitOpError()
379              << "Trying to schedule a dynamic pipeline on an "
380                 "operation that isn't "
381                 "nested under the current operation the pass is processing";
382     assert(pipeline.getOpName() == root->getName().getStringRef());
383 
384     // Before running, make sure to coalesce any adjacent pass adaptors in the
385     // pipeline.
386     pipeline.getImpl().coalesceAdjacentAdaptorPasses();
387 
388     // Initialize the user provided pipeline and execute the pipeline.
389     if (failed(pipeline.initialize(root->getContext(), parentInitGeneration)))
390       return failure();
391     AnalysisManager nestedAm = root == op ? am : am.nest(root);
392     return OpToOpPassAdaptor::runPipeline(pipeline.getPasses(), root, nestedAm,
393                                           verifyPasses, parentInitGeneration,
394                                           pi, &parentInfo);
395   };
396   pass->passState.emplace(op, am, dynamic_pipeline_callback);
397 
398   // Instrument before the pass has run.
399   if (pi)
400     pi->runBeforePass(pass, op);
401 
402   // Invoke the virtual runOnOperation method.
403   if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass))
404     adaptor->runOnOperation(verifyPasses);
405   else
406     pass->runOnOperation();
407   bool passFailed = pass->passState->irAndPassFailed.getInt();
408 
409   // Invalidate any non preserved analyses.
410   am.invalidate(pass->passState->preservedAnalyses);
411 
412   // Run the verifier if this pass didn't fail already.
413   if (!passFailed && verifyPasses)
414     passFailed = failed(verify(op));
415 
416   // Instrument after the pass has run.
417   if (pi) {
418     if (passFailed)
419       pi->runAfterPassFailed(pass, op);
420     else
421       pi->runAfterPass(pass, op);
422   }
423 
424   // Return if the pass signaled a failure.
425   return failure(passFailed);
426 }
427 
428 /// Run the given operation and analysis manager on a provided op pass manager.
429 LogicalResult OpToOpPassAdaptor::runPipeline(
430     iterator_range<OpPassManager::pass_iterator> passes, Operation *op,
431     AnalysisManager am, bool verifyPasses, unsigned parentInitGeneration,
432     PassInstrumentor *instrumentor,
433     const PassInstrumentation::PipelineParentInfo *parentInfo) {
434   assert((!instrumentor || parentInfo) &&
435          "expected parent info if instrumentor is provided");
436   auto scope_exit = llvm::make_scope_exit([&] {
437     // Clear out any computed operation analyses. These analyses won't be used
438     // any more in this pipeline, and this helps reduce the current working set
439     // of memory. If preserving these analyses becomes important in the future
440     // we can re-evaluate this.
441     am.clear();
442   });
443 
444   // Run the pipeline over the provided operation.
445   if (instrumentor)
446     instrumentor->runBeforePipeline(op->getName().getIdentifier(), *parentInfo);
447   for (Pass &pass : passes)
448     if (failed(run(&pass, op, am, verifyPasses, parentInitGeneration)))
449       return failure();
450   if (instrumentor)
451     instrumentor->runAfterPipeline(op->getName().getIdentifier(), *parentInfo);
452   return success();
453 }
454 
455 /// Find an operation pass manager that can operate on an operation of the given
456 /// type, or nullptr if one does not exist.
457 static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,
458                                          StringRef name) {
459   auto it = llvm::find_if(
460       mgrs, [&](OpPassManager &mgr) { return mgr.getOpName() == name; });
461   return it == mgrs.end() ? nullptr : &*it;
462 }
463 
464 /// Find an operation pass manager that can operate on an operation of the given
465 /// type, or nullptr if one does not exist.
466 static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,
467                                          Identifier name,
468                                          MLIRContext &context) {
469   auto it = llvm::find_if(
470       mgrs, [&](OpPassManager &mgr) { return mgr.getOpName(context) == name; });
471   return it == mgrs.end() ? nullptr : &*it;
472 }
473 
474 OpToOpPassAdaptor::OpToOpPassAdaptor(OpPassManager &&mgr) {
475   mgrs.emplace_back(std::move(mgr));
476 }
477 
478 void OpToOpPassAdaptor::getDependentDialects(DialectRegistry &dialects) const {
479   for (auto &pm : mgrs)
480     pm.getDependentDialects(dialects);
481 }
482 
483 /// Merge the current pass adaptor into given 'rhs'.
484 void OpToOpPassAdaptor::mergeInto(OpToOpPassAdaptor &rhs) {
485   for (auto &pm : mgrs) {
486     // If an existing pass manager exists, then merge the given pass manager
487     // into it.
488     if (auto *existingPM = findPassManagerFor(rhs.mgrs, pm.getOpName())) {
489       pm.getImpl().mergeInto(existingPM->getImpl());
490     } else {
491       // Otherwise, add the given pass manager to the list.
492       rhs.mgrs.emplace_back(std::move(pm));
493     }
494   }
495   mgrs.clear();
496 
497   // After coalescing, sort the pass managers within rhs by name.
498   llvm::array_pod_sort(rhs.mgrs.begin(), rhs.mgrs.end(),
499                        [](const OpPassManager *lhs, const OpPassManager *rhs) {
500                          return lhs->getOpName().compare(rhs->getOpName());
501                        });
502 }
503 
504 /// Returns the adaptor pass name.
505 std::string OpToOpPassAdaptor::getAdaptorName() {
506   std::string name = "Pipeline Collection : [";
507   llvm::raw_string_ostream os(name);
508   llvm::interleaveComma(getPassManagers(), os, [&](OpPassManager &pm) {
509     os << '\'' << pm.getOpName() << '\'';
510   });
511   os << ']';
512   return os.str();
513 }
514 
515 void OpToOpPassAdaptor::runOnOperation() {
516   llvm_unreachable(
517       "Unexpected call to Pass::runOnOperation() on OpToOpPassAdaptor");
518 }
519 
520 /// Run the held pipeline over all nested operations.
521 void OpToOpPassAdaptor::runOnOperation(bool verifyPasses) {
522   if (getContext().isMultithreadingEnabled())
523     runOnOperationAsyncImpl(verifyPasses);
524   else
525     runOnOperationImpl(verifyPasses);
526 }
527 
528 /// Run this pass adaptor synchronously.
529 void OpToOpPassAdaptor::runOnOperationImpl(bool verifyPasses) {
530   auto am = getAnalysisManager();
531   PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(),
532                                                         this};
533   auto *instrumentor = am.getPassInstrumentor();
534   for (auto &region : getOperation()->getRegions()) {
535     for (auto &block : region) {
536       for (auto &op : block) {
537         auto *mgr = findPassManagerFor(mgrs, op.getName().getIdentifier(),
538                                        *op.getContext());
539         if (!mgr)
540           continue;
541 
542         // Run the held pipeline over the current operation.
543         unsigned initGeneration = mgr->impl->initializationGeneration;
544         if (failed(runPipeline(mgr->getPasses(), &op, am.nest(&op),
545                                verifyPasses, initGeneration, instrumentor,
546                                &parentInfo)))
547           return signalPassFailure();
548       }
549     }
550   }
551 }
552 
553 /// Utility functor that checks if the two ranges of pass managers have a size
554 /// mismatch.
555 static bool hasSizeMismatch(ArrayRef<OpPassManager> lhs,
556                             ArrayRef<OpPassManager> rhs) {
557   return lhs.size() != rhs.size() ||
558          llvm::any_of(llvm::seq<size_t>(0, lhs.size()),
559                       [&](size_t i) { return lhs[i].size() != rhs[i].size(); });
560 }
561 
562 /// Run this pass adaptor synchronously.
563 void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) {
564   AnalysisManager am = getAnalysisManager();
565 
566   // Create the async executors if they haven't been created, or if the main
567   // pipeline has changed.
568   if (asyncExecutors.empty() || hasSizeMismatch(asyncExecutors.front(), mgrs))
569     asyncExecutors.assign(llvm::hardware_concurrency().compute_thread_count(),
570                           mgrs);
571 
572   // Run a prepass over the operation to collect the nested operations to
573   // execute over. This ensures that an analysis manager exists for each
574   // operation, as well as providing a queue of operations to execute over.
575   std::vector<std::pair<Operation *, AnalysisManager>> opAMPairs;
576   for (auto &region : getOperation()->getRegions()) {
577     for (auto &block : region) {
578       for (auto &op : block) {
579         // Add this operation iff the name matches any of the pass managers.
580         if (findPassManagerFor(mgrs, op.getName().getIdentifier(),
581                                getContext()))
582           opAMPairs.emplace_back(&op, am.nest(&op));
583       }
584     }
585   }
586 
587   // A parallel diagnostic handler that provides deterministic diagnostic
588   // ordering.
589   ParallelDiagnosticHandler diagHandler(&getContext());
590 
591   // An index for the current operation/analysis manager pair.
592   std::atomic<unsigned> opIt(0);
593 
594   // Get the current thread for this adaptor.
595   PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(),
596                                                         this};
597   auto *instrumentor = am.getPassInstrumentor();
598 
599   // An atomic failure variable for the async executors.
600   std::atomic<bool> passFailed(false);
601   llvm::parallelForEach(
602       asyncExecutors.begin(),
603       std::next(asyncExecutors.begin(),
604                 std::min(asyncExecutors.size(), opAMPairs.size())),
605       [&](MutableArrayRef<OpPassManager> pms) {
606         for (auto e = opAMPairs.size(); !passFailed && opIt < e;) {
607           // Get the next available operation index.
608           unsigned nextID = opIt++;
609           if (nextID >= e)
610             break;
611 
612           // Set the order id for this thread in the diagnostic handler.
613           diagHandler.setOrderIDForThread(nextID);
614 
615           // Get the pass manager for this operation and execute it.
616           auto &it = opAMPairs[nextID];
617           auto *pm = findPassManagerFor(
618               pms, it.first->getName().getIdentifier(), getContext());
619           assert(pm && "expected valid pass manager for operation");
620 
621           unsigned initGeneration = pm->impl->initializationGeneration;
622           LogicalResult pipelineResult =
623               runPipeline(pm->getPasses(), it.first, it.second, verifyPasses,
624                           initGeneration, instrumentor, &parentInfo);
625 
626           // Drop this thread from being tracked by the diagnostic handler.
627           // After this task has finished, the thread may be used outside of
628           // this pass manager context meaning that we don't want to track
629           // diagnostics from it anymore.
630           diagHandler.eraseOrderIDForThread();
631 
632           // Handle a failed pipeline result.
633           if (failed(pipelineResult)) {
634             passFailed = true;
635             break;
636           }
637         }
638       });
639 
640   // Signal a failure if any of the executors failed.
641   if (passFailed)
642     signalPassFailure();
643 }
644 
645 //===----------------------------------------------------------------------===//
646 // PassCrashReproducer
647 //===----------------------------------------------------------------------===//
648 
649 namespace {
650 /// This class contains all of the context for generating a recovery reproducer.
651 /// Each recovery context is registered globally to allow for generating
652 /// reproducers when a signal is raised, such as a segfault.
653 struct RecoveryReproducerContext {
654   RecoveryReproducerContext(MutableArrayRef<std::unique_ptr<Pass>> passes,
655                             Operation *op,
656                             PassManager::ReproducerStreamFactory &crashStream,
657                             bool disableThreads, bool verifyPasses);
658   ~RecoveryReproducerContext();
659 
660   /// Generate a reproducer with the current context.
661   LogicalResult generate(std::string &error);
662 
663 private:
664   /// This function is invoked in the event of a crash.
665   static void crashHandler(void *);
666 
667   /// Register a signal handler to run in the event of a crash.
668   static void registerSignalHandler();
669 
670   /// The textual description of the currently executing pipeline.
671   std::string pipeline;
672 
673   /// The MLIR operation representing the IR before the crash.
674   Operation *preCrashOperation;
675 
676   /// The factory for the reproducer output stream to use when generating the
677   /// reproducer.
678   PassManager::ReproducerStreamFactory &crashStreamFactory;
679 
680   /// Various pass manager and context flags.
681   bool disableThreads;
682   bool verifyPasses;
683 
684   /// The current set of active reproducer contexts. This is used in the event
685   /// of a crash. This is not thread_local as the pass manager may produce any
686   /// number of child threads. This uses a set to allow for multiple MLIR pass
687   /// managers to be running at the same time.
688   static llvm::ManagedStatic<llvm::sys::SmartMutex<true>> reproducerMutex;
689   static llvm::ManagedStatic<
690       llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
691       reproducerSet;
692 };
693 
694 /// Instance of ReproducerStream backed by file.
695 struct FileReproducerStream : public PassManager::ReproducerStream {
696   FileReproducerStream(std::unique_ptr<llvm::ToolOutputFile> outputFile)
697       : outputFile(std::move(outputFile)) {}
698   ~FileReproducerStream() override;
699 
700   /// Description of the reproducer stream.
701   StringRef description() override;
702 
703   /// Stream on which to output reprooducer.
704   raw_ostream &os() override;
705 
706 private:
707   /// ToolOutputFile corresponding to opened `filename`.
708   std::unique_ptr<llvm::ToolOutputFile> outputFile = nullptr;
709 };
710 
711 } // end anonymous namespace
712 
713 llvm::ManagedStatic<llvm::sys::SmartMutex<true>>
714     RecoveryReproducerContext::reproducerMutex;
715 llvm::ManagedStatic<llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
716     RecoveryReproducerContext::reproducerSet;
717 
718 RecoveryReproducerContext::RecoveryReproducerContext(
719     MutableArrayRef<std::unique_ptr<Pass>> passes, Operation *op,
720     PassManager::ReproducerStreamFactory &crashStreamFactory,
721     bool disableThreads, bool verifyPasses)
722     : preCrashOperation(op->clone()), crashStreamFactory(crashStreamFactory),
723       disableThreads(disableThreads), verifyPasses(verifyPasses) {
724   // Grab the textual pipeline being executed..
725   {
726     llvm::raw_string_ostream pipelineOS(pipeline);
727     ::printAsTextualPipeline(passes, pipelineOS);
728   }
729 
730   // Make sure that the handler is registered, and update the current context.
731   llvm::sys::SmartScopedLock<true> producerLock(*reproducerMutex);
732   if (reproducerSet->empty())
733     llvm::CrashRecoveryContext::Enable();
734   registerSignalHandler();
735   reproducerSet->insert(this);
736 }
737 
738 RecoveryReproducerContext::~RecoveryReproducerContext() {
739   // Erase the cloned preCrash IR that we cached.
740   preCrashOperation->erase();
741 
742   llvm::sys::SmartScopedLock<true> producerLock(*reproducerMutex);
743   reproducerSet->remove(this);
744   if (reproducerSet->empty())
745     llvm::CrashRecoveryContext::Disable();
746 }
747 
748 /// Description of the reproducer stream.
749 StringRef FileReproducerStream::description() {
750   return outputFile->getFilename();
751 }
752 
753 /// Stream on which to output reproducer.
754 raw_ostream &FileReproducerStream::os() { return outputFile->os(); }
755 
756 FileReproducerStream::~FileReproducerStream() { outputFile->keep(); }
757 
758 LogicalResult RecoveryReproducerContext::generate(std::string &error) {
759   std::unique_ptr<PassManager::ReproducerStream> crashStream =
760       crashStreamFactory(error);
761   if (!crashStream)
762     return failure();
763 
764   // Output the current pass manager configuration.
765   auto &os = crashStream->os();
766   os << "// configuration: -pass-pipeline='" << pipeline << "'";
767   if (disableThreads)
768     os << " -mlir-disable-threading";
769   if (verifyPasses)
770     os << " -verify-each";
771   os << '\n';
772 
773   // Output the .mlir module.
774   preCrashOperation->print(os);
775 
776   bool shouldPrintOnOp =
777       preCrashOperation->getContext()->shouldPrintOpOnDiagnostic();
778   preCrashOperation->getContext()->printOpOnDiagnostic(false);
779   preCrashOperation->emitError()
780       << "A failure has been detected while processing the MLIR module, a "
781          "reproducer has been generated in '"
782       << crashStream->description() << "'";
783   preCrashOperation->getContext()->printOpOnDiagnostic(shouldPrintOnOp);
784   return success();
785 }
786 
787 void RecoveryReproducerContext::crashHandler(void *) {
788   // Walk the current stack of contexts and generate a reproducer for each one.
789   // We can't know for certain which one was the cause, so we need to generate
790   // a reproducer for all of them.
791   std::string ignored;
792   for (RecoveryReproducerContext *context : *reproducerSet)
793     (void)context->generate(ignored);
794 }
795 
796 void RecoveryReproducerContext::registerSignalHandler() {
797   // Ensure that the handler is only registered once.
798   static bool registered =
799       (llvm::sys::AddSignalHandler(crashHandler, nullptr), false);
800   (void)registered;
801 }
802 
803 /// Run the pass manager with crash recover enabled.
804 LogicalResult PassManager::runWithCrashRecovery(Operation *op,
805                                                 AnalysisManager am) {
806   // If this isn't a local producer, run all of the passes in recovery mode.
807   if (!localReproducer)
808     return runWithCrashRecovery(impl->passes, op, am);
809 
810   // Split the passes within adaptors to ensure that each pass can be run in
811   // isolation.
812   impl->splitAdaptorPasses();
813 
814   // If this is a local producer, run each of the passes individually.
815   MutableArrayRef<std::unique_ptr<Pass>> passes = impl->passes;
816   for (std::unique_ptr<Pass> &pass : passes)
817     if (failed(runWithCrashRecovery(pass, op, am)))
818       return failure();
819   return success();
820 }
821 
822 /// Run the given passes with crash recover enabled.
823 LogicalResult
824 PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
825                                   Operation *op, AnalysisManager am) {
826   RecoveryReproducerContext context(passes, op, crashReproducerStreamFactory,
827                                     !getContext()->isMultithreadingEnabled(),
828                                     verifyPasses);
829 
830   // Safely invoke the passes within a recovery context.
831   LogicalResult passManagerResult = failure();
832   llvm::CrashRecoveryContext recoveryContext;
833   recoveryContext.RunSafelyOnThread([&] {
834     for (std::unique_ptr<Pass> &pass : passes)
835       if (failed(OpToOpPassAdaptor::run(pass.get(), op, am, verifyPasses,
836                                         impl->initializationGeneration)))
837         return;
838     passManagerResult = success();
839   });
840   if (succeeded(passManagerResult))
841     return success();
842 
843   std::string error;
844   if (failed(context.generate(error)))
845     return op->emitError("<MLIR-PassManager-Crash-Reproducer>: ") << error;
846   return failure();
847 }
848 
849 //===----------------------------------------------------------------------===//
850 // PassManager
851 //===----------------------------------------------------------------------===//
852 
853 PassManager::PassManager(MLIRContext *ctx, Nesting nesting,
854                          StringRef operationName)
855     : OpPassManager(Identifier::get(operationName, ctx), nesting), context(ctx),
856       initializationKey(DenseMapInfo<llvm::hash_code>::getTombstoneKey()),
857       passTiming(false), localReproducer(false), verifyPasses(true) {}
858 
859 PassManager::~PassManager() {}
860 
861 void PassManager::enableVerifier(bool enabled) { verifyPasses = enabled; }
862 
863 /// Run the passes within this manager on the provided operation.
864 LogicalResult PassManager::run(Operation *op) {
865   MLIRContext *context = getContext();
866   assert(op->getName().getIdentifier() == getOpName(*context) &&
867          "operation has a different name than the PassManager or is from a "
868          "different context");
869 
870   // Before running, make sure to coalesce any adjacent pass adaptors in the
871   // pipeline.
872   getImpl().coalesceAdjacentAdaptorPasses();
873 
874   // Register all dialects for the current pipeline.
875   DialectRegistry dependentDialects;
876   getDependentDialects(dependentDialects);
877   context->appendDialectRegistry(dependentDialects);
878   for (StringRef name : dependentDialects.getDialectNames())
879     context->getOrLoadDialect(name);
880 
881   // Initialize all of the passes within the pass manager with a new generation.
882   llvm::hash_code newInitKey = context->getRegistryHash();
883   if (newInitKey != initializationKey) {
884     if (failed(initialize(context, impl->initializationGeneration + 1)))
885       return failure();
886     initializationKey = newInitKey;
887   }
888 
889   // Construct a top level analysis manager for the pipeline.
890   ModuleAnalysisManager am(op, instrumentor.get());
891 
892   // Notify the context that we start running a pipeline for book keeping.
893   context->enterMultiThreadedExecution();
894 
895   // If reproducer generation is enabled, run the pass manager with crash
896   // handling enabled.
897   LogicalResult result =
898       crashReproducerStreamFactory
899           ? runWithCrashRecovery(op, am)
900           : OpToOpPassAdaptor::runPipeline(getPasses(), op, am, verifyPasses,
901                                            impl->initializationGeneration);
902 
903   // Notify the context that the run is done.
904   context->exitMultiThreadedExecution();
905 
906   // Dump all of the pass statistics if necessary.
907   if (passStatisticsMode)
908     dumpStatistics();
909   return result;
910 }
911 
912 /// Enable support for the pass manager to generate a reproducer on the event
913 /// of a crash or a pass failure. `outputFile` is a .mlir filename used to write
914 /// the generated reproducer. If `genLocalReproducer` is true, the pass manager
915 /// will attempt to generate a local reproducer that contains the smallest
916 /// pipeline.
917 void PassManager::enableCrashReproducerGeneration(StringRef outputFile,
918                                                   bool genLocalReproducer) {
919   // Capture the filename by value in case outputFile is out of scope when
920   // invoked.
921   std::string filename = outputFile.str();
922   enableCrashReproducerGeneration(
923       [filename](std::string &error) -> std::unique_ptr<ReproducerStream> {
924         std::unique_ptr<llvm::ToolOutputFile> outputFile =
925             mlir::openOutputFile(filename, &error);
926         if (!outputFile) {
927           error = "Failed to create reproducer stream: " + error;
928           return nullptr;
929         }
930         return std::make_unique<FileReproducerStream>(std::move(outputFile));
931       },
932       genLocalReproducer);
933 }
934 
935 /// Enable support for the pass manager to generate a reproducer on the event
936 /// of a crash or a pass failure. `factory` is used to construct the streams
937 /// to write the generated reproducer to. If `genLocalReproducer` is true, the
938 /// pass manager will attempt to generate a local reproducer that contains the
939 /// smallest pipeline.
940 void PassManager::enableCrashReproducerGeneration(
941     ReproducerStreamFactory factory, bool genLocalReproducer) {
942   crashReproducerStreamFactory = factory;
943   localReproducer = genLocalReproducer;
944 }
945 
946 /// Add the provided instrumentation to the pass manager.
947 void PassManager::addInstrumentation(std::unique_ptr<PassInstrumentation> pi) {
948   if (!instrumentor)
949     instrumentor = std::make_unique<PassInstrumentor>();
950 
951   instrumentor->addInstrumentation(std::move(pi));
952 }
953 
954 //===----------------------------------------------------------------------===//
955 // AnalysisManager
956 //===----------------------------------------------------------------------===//
957 
958 /// Get an analysis manager for the given operation, which must be a proper
959 /// descendant of the current operation represented by this analysis manager.
960 AnalysisManager AnalysisManager::nest(Operation *op) {
961   Operation *currentOp = impl->getOperation();
962   assert(currentOp->isProperAncestor(op) &&
963          "expected valid descendant operation");
964 
965   // Check for the base case where the provided operation is immediately nested.
966   if (currentOp == op->getParentOp())
967     return nestImmediate(op);
968 
969   // Otherwise, we need to collect all ancestors up to the current operation.
970   SmallVector<Operation *, 4> opAncestors;
971   do {
972     opAncestors.push_back(op);
973     op = op->getParentOp();
974   } while (op != currentOp);
975 
976   AnalysisManager result = *this;
977   for (Operation *op : llvm::reverse(opAncestors))
978     result = result.nestImmediate(op);
979   return result;
980 }
981 
982 /// Get an analysis manager for the given immediately nested child operation.
983 AnalysisManager AnalysisManager::nestImmediate(Operation *op) {
984   assert(impl->getOperation() == op->getParentOp() &&
985          "expected immediate child operation");
986 
987   auto it = impl->childAnalyses.find(op);
988   if (it == impl->childAnalyses.end())
989     it = impl->childAnalyses
990              .try_emplace(op, std::make_unique<NestedAnalysisMap>(op, impl))
991              .first;
992   return {it->second.get()};
993 }
994 
995 /// Invalidate any non preserved analyses.
996 void detail::NestedAnalysisMap::invalidate(
997     const detail::PreservedAnalyses &pa) {
998   // If all analyses were preserved, then there is nothing to do here.
999   if (pa.isAll())
1000     return;
1001 
1002   // Invalidate the analyses for the current operation directly.
1003   analyses.invalidate(pa);
1004 
1005   // If no analyses were preserved, then just simply clear out the child
1006   // analysis results.
1007   if (pa.isNone()) {
1008     childAnalyses.clear();
1009     return;
1010   }
1011 
1012   // Otherwise, invalidate each child analysis map.
1013   SmallVector<NestedAnalysisMap *, 8> mapsToInvalidate(1, this);
1014   while (!mapsToInvalidate.empty()) {
1015     auto *map = mapsToInvalidate.pop_back_val();
1016     for (auto &analysisPair : map->childAnalyses) {
1017       analysisPair.second->invalidate(pa);
1018       if (!analysisPair.second->childAnalyses.empty())
1019         mapsToInvalidate.push_back(analysisPair.second.get());
1020     }
1021   }
1022 }
1023 
1024 //===----------------------------------------------------------------------===//
1025 // PassInstrumentation
1026 //===----------------------------------------------------------------------===//
1027 
1028 PassInstrumentation::~PassInstrumentation() {}
1029 
1030 //===----------------------------------------------------------------------===//
1031 // PassInstrumentor
1032 //===----------------------------------------------------------------------===//
1033 
1034 namespace mlir {
1035 namespace detail {
1036 struct PassInstrumentorImpl {
1037   /// Mutex to keep instrumentation access thread-safe.
1038   llvm::sys::SmartMutex<true> mutex;
1039 
1040   /// Set of registered instrumentations.
1041   std::vector<std::unique_ptr<PassInstrumentation>> instrumentations;
1042 };
1043 } // end namespace detail
1044 } // end namespace mlir
1045 
1046 PassInstrumentor::PassInstrumentor() : impl(new PassInstrumentorImpl()) {}
1047 PassInstrumentor::~PassInstrumentor() {}
1048 
1049 /// See PassInstrumentation::runBeforePipeline for details.
1050 void PassInstrumentor::runBeforePipeline(
1051     Identifier name,
1052     const PassInstrumentation::PipelineParentInfo &parentInfo) {
1053   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1054   for (auto &instr : impl->instrumentations)
1055     instr->runBeforePipeline(name, parentInfo);
1056 }
1057 
1058 /// See PassInstrumentation::runAfterPipeline for details.
1059 void PassInstrumentor::runAfterPipeline(
1060     Identifier name,
1061     const PassInstrumentation::PipelineParentInfo &parentInfo) {
1062   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1063   for (auto &instr : llvm::reverse(impl->instrumentations))
1064     instr->runAfterPipeline(name, parentInfo);
1065 }
1066 
1067 /// See PassInstrumentation::runBeforePass for details.
1068 void PassInstrumentor::runBeforePass(Pass *pass, Operation *op) {
1069   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1070   for (auto &instr : impl->instrumentations)
1071     instr->runBeforePass(pass, op);
1072 }
1073 
1074 /// See PassInstrumentation::runAfterPass for details.
1075 void PassInstrumentor::runAfterPass(Pass *pass, Operation *op) {
1076   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1077   for (auto &instr : llvm::reverse(impl->instrumentations))
1078     instr->runAfterPass(pass, op);
1079 }
1080 
1081 /// See PassInstrumentation::runAfterPassFailed for details.
1082 void PassInstrumentor::runAfterPassFailed(Pass *pass, Operation *op) {
1083   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1084   for (auto &instr : llvm::reverse(impl->instrumentations))
1085     instr->runAfterPassFailed(pass, op);
1086 }
1087 
1088 /// See PassInstrumentation::runBeforeAnalysis for details.
1089 void PassInstrumentor::runBeforeAnalysis(StringRef name, TypeID id,
1090                                          Operation *op) {
1091   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1092   for (auto &instr : impl->instrumentations)
1093     instr->runBeforeAnalysis(name, id, op);
1094 }
1095 
1096 /// See PassInstrumentation::runAfterAnalysis for details.
1097 void PassInstrumentor::runAfterAnalysis(StringRef name, TypeID id,
1098                                         Operation *op) {
1099   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1100   for (auto &instr : llvm::reverse(impl->instrumentations))
1101     instr->runAfterAnalysis(name, id, op);
1102 }
1103 
1104 /// Add the given instrumentation to the collection.
1105 void PassInstrumentor::addInstrumentation(
1106     std::unique_ptr<PassInstrumentation> pi) {
1107   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1108   impl->instrumentations.emplace_back(std::move(pi));
1109 }
1110