xref: /llvm-project-15.0.7/mlir/lib/Pass/Pass.cpp (revision 87d627b6)
1 //===- Pass.cpp - Pass infrastructure implementation ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements common pass infrastructure.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Pass/Pass.h"
14 #include "PassDetail.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/IR/Threading.h"
19 #include "mlir/IR/Verifier.h"
20 #include "mlir/Support/FileUtilities.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/ScopeExit.h"
23 #include "llvm/ADT/SetVector.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/CrashRecoveryContext.h"
26 #include "llvm/Support/Mutex.h"
27 #include "llvm/Support/Signals.h"
28 #include "llvm/Support/Threading.h"
29 #include "llvm/Support/ToolOutputFile.h"
30 
31 using namespace mlir;
32 using namespace mlir::detail;
33 
34 //===----------------------------------------------------------------------===//
35 // Pass
36 //===----------------------------------------------------------------------===//
37 
38 /// Out of line virtual method to ensure vtables and metadata are emitted to a
39 /// single .o file.
anchor()40 void Pass::anchor() {}
41 
42 /// Attempt to initialize the options of this pass from the given string.
initializeOptions(StringRef options)43 LogicalResult Pass::initializeOptions(StringRef options) {
44   return passOptions.parseFromString(options);
45 }
46 
47 /// Copy the option values from 'other', which is another instance of this
48 /// pass.
copyOptionValuesFrom(const Pass * other)49 void Pass::copyOptionValuesFrom(const Pass *other) {
50   passOptions.copyOptionValuesFrom(other->passOptions);
51 }
52 
53 /// Prints out the pass in the textual representation of pipelines. If this is
54 /// an adaptor pass, print with the op_name(sub_pass,...) format.
printAsTextualPipeline(raw_ostream & os)55 void Pass::printAsTextualPipeline(raw_ostream &os) {
56   // Special case for adaptors to use the 'op_name(sub_passes)' format.
57   if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(this)) {
58     llvm::interleave(
59         adaptor->getPassManagers(),
60         [&](OpPassManager &pm) {
61           os << pm.getOpAnchorName() << "(";
62           pm.printAsTextualPipeline(os);
63           os << ")";
64         },
65         [&] { os << ","; });
66     return;
67   }
68   // Otherwise, print the pass argument followed by its options. If the pass
69   // doesn't have an argument, print the name of the pass to give some indicator
70   // of what pass was run.
71   StringRef argument = getArgument();
72   if (!argument.empty())
73     os << argument;
74   else
75     os << "unknown<" << getName() << ">";
76   passOptions.print(os);
77 }
78 
79 //===----------------------------------------------------------------------===//
80 // OpPassManagerImpl
81 //===----------------------------------------------------------------------===//
82 
83 namespace mlir {
84 namespace detail {
85 struct OpPassManagerImpl {
OpPassManagerImplmlir::detail::OpPassManagerImpl86   OpPassManagerImpl(OperationName opName, OpPassManager::Nesting nesting)
87       : name(opName.getStringRef().str()), opName(opName),
88         initializationGeneration(0), nesting(nesting) {}
OpPassManagerImplmlir::detail::OpPassManagerImpl89   OpPassManagerImpl(StringRef name, OpPassManager::Nesting nesting)
90       : name(name == OpPassManager::getAnyOpAnchorName() ? "" : name.str()),
91         initializationGeneration(0), nesting(nesting) {}
OpPassManagerImplmlir::detail::OpPassManagerImpl92   OpPassManagerImpl(OpPassManager::Nesting nesting)
93       : initializationGeneration(0), nesting(nesting) {}
OpPassManagerImplmlir::detail::OpPassManagerImpl94   OpPassManagerImpl(const OpPassManagerImpl &rhs)
95       : name(rhs.name), opName(rhs.opName),
96         initializationGeneration(rhs.initializationGeneration),
97         nesting(rhs.nesting) {
98     for (const std::unique_ptr<Pass> &pass : rhs.passes) {
99       std::unique_ptr<Pass> newPass = pass->clone();
100       newPass->threadingSibling = pass.get();
101       passes.push_back(std::move(newPass));
102     }
103   }
104 
105   /// Merge the passes of this pass manager into the one provided.
106   void mergeInto(OpPassManagerImpl &rhs);
107 
108   /// Nest a new operation pass manager for the given operation kind under this
109   /// pass manager.
nestmlir::detail::OpPassManagerImpl110   OpPassManager &nest(OperationName nestedName) {
111     return nest(OpPassManager(nestedName, nesting));
112   }
nestmlir::detail::OpPassManagerImpl113   OpPassManager &nest(StringRef nestedName) {
114     return nest(OpPassManager(nestedName, nesting));
115   }
nestAnymlir::detail::OpPassManagerImpl116   OpPassManager &nestAny() { return nest(OpPassManager(nesting)); }
117 
118   /// Nest the given pass manager under this pass manager.
119   OpPassManager &nest(OpPassManager &&nested);
120 
121   /// Add the given pass to this pass manager. If this pass has a concrete
122   /// operation type, it must be the same type as this pass manager.
123   void addPass(std::unique_ptr<Pass> pass);
124 
125   /// Clear the list of passes in this pass manager, other options are
126   /// preserved.
127   void clear();
128 
129   /// Finalize the pass list in preparation for execution. This includes
130   /// coalescing adjacent pass managers when possible, verifying scheduled
131   /// passes, etc.
132   LogicalResult finalizePassList(MLIRContext *ctx);
133 
134   /// Return the operation name of this pass manager.
getOpNamemlir::detail::OpPassManagerImpl135   Optional<OperationName> getOpName(MLIRContext &context) {
136     if (!name.empty() && !opName)
137       opName = OperationName(name, &context);
138     return opName;
139   }
getOpNamemlir::detail::OpPassManagerImpl140   Optional<StringRef> getOpName() const {
141     return name.empty() ? Optional<StringRef>() : Optional<StringRef>(name);
142   }
143 
144   /// Return the name used to anchor this pass manager. This is either the name
145   /// of an operation, or the result of `getAnyOpAnchorName()` in the case of an
146   /// op-agnostic pass manager.
getOpAnchorNamemlir::detail::OpPassManagerImpl147   StringRef getOpAnchorName() const {
148     return getOpName().value_or(OpPassManager::getAnyOpAnchorName());
149   }
150 
151   /// Indicate if the current pass manager can be scheduled on the given
152   /// operation type.
153   bool canScheduleOn(MLIRContext &context, OperationName opName);
154 
155   /// The name of the operation that passes of this pass manager operate on.
156   std::string name;
157 
158   /// The cached OperationName (internalized in the context) for the name of the
159   /// operation that passes of this pass manager operate on.
160   Optional<OperationName> opName;
161 
162   /// The set of passes to run as part of this pass manager.
163   std::vector<std::unique_ptr<Pass>> passes;
164 
165   /// The current initialization generation of this pass manager. This is used
166   /// to indicate when a pass manager should be reinitialized.
167   unsigned initializationGeneration;
168 
169   /// Control the implicit nesting of passes that mismatch the name set for this
170   /// OpPassManager.
171   OpPassManager::Nesting nesting;
172 };
173 } // namespace detail
174 } // namespace mlir
175 
mergeInto(OpPassManagerImpl & rhs)176 void OpPassManagerImpl::mergeInto(OpPassManagerImpl &rhs) {
177   assert(name == rhs.name && "merging unrelated pass managers");
178   for (auto &pass : passes)
179     rhs.passes.push_back(std::move(pass));
180   passes.clear();
181 }
182 
nest(OpPassManager && nested)183 OpPassManager &OpPassManagerImpl::nest(OpPassManager &&nested) {
184   auto *adaptor = new OpToOpPassAdaptor(std::move(nested));
185   addPass(std::unique_ptr<Pass>(adaptor));
186   return adaptor->getPassManagers().front();
187 }
188 
addPass(std::unique_ptr<Pass> pass)189 void OpPassManagerImpl::addPass(std::unique_ptr<Pass> pass) {
190   // If this pass runs on a different operation than this pass manager, then
191   // implicitly nest a pass manager for this operation if enabled.
192   Optional<StringRef> pmOpName = getOpName();
193   Optional<StringRef> passOpName = pass->getOpName();
194   if (pmOpName && passOpName && *pmOpName != *passOpName) {
195     if (nesting == OpPassManager::Nesting::Implicit)
196       return nest(*passOpName).addPass(std::move(pass));
197     llvm::report_fatal_error(llvm::Twine("Can't add pass '") + pass->getName() +
198                              "' restricted to '" + *passOpName +
199                              "' on a PassManager intended to run on '" +
200                              getOpAnchorName() + "', did you intend to nest?");
201   }
202 
203   passes.emplace_back(std::move(pass));
204 }
205 
clear()206 void OpPassManagerImpl::clear() { passes.clear(); }
207 
finalizePassList(MLIRContext * ctx)208 LogicalResult OpPassManagerImpl::finalizePassList(MLIRContext *ctx) {
209   auto finalizeAdaptor = [ctx](OpToOpPassAdaptor *adaptor) {
210     for (auto &pm : adaptor->getPassManagers())
211       if (failed(pm.getImpl().finalizePassList(ctx)))
212         return failure();
213     return success();
214   };
215 
216   // Walk the pass list and merge adjacent adaptors.
217   OpToOpPassAdaptor *lastAdaptor = nullptr;
218   for (auto &pass : passes) {
219     // Check to see if this pass is an adaptor.
220     if (auto *currentAdaptor = dyn_cast<OpToOpPassAdaptor>(pass.get())) {
221       // If it is the first adaptor in a possible chain, remember it and
222       // continue.
223       if (!lastAdaptor) {
224         lastAdaptor = currentAdaptor;
225         continue;
226       }
227 
228       // Otherwise, try to merge into the existing adaptor and delete the
229       // current one. If merging fails, just remember this as the last adaptor.
230       if (succeeded(currentAdaptor->tryMergeInto(ctx, *lastAdaptor)))
231         pass.reset();
232       else
233         lastAdaptor = currentAdaptor;
234     } else if (lastAdaptor) {
235       // If this pass isn't an adaptor, finalize it and forget the last adaptor.
236       if (failed(finalizeAdaptor(lastAdaptor)))
237         return failure();
238       lastAdaptor = nullptr;
239     }
240   }
241 
242   // If there was an adaptor at the end of the manager, finalize it as well.
243   if (lastAdaptor && failed(finalizeAdaptor(lastAdaptor)))
244     return failure();
245 
246   // Now that the adaptors have been merged, erase any empty slots corresponding
247   // to the merged adaptors that were nulled-out in the loop above.
248   llvm::erase_if(passes, std::logical_not<std::unique_ptr<Pass>>());
249 
250   // If this is a op-agnostic pass manager, there is nothing left to do.
251   Optional<OperationName> rawOpName = getOpName(*ctx);
252   if (!rawOpName)
253     return success();
254 
255   // Otherwise, verify that all of the passes are valid for the current
256   // operation anchor.
257   Optional<RegisteredOperationName> opName = rawOpName->getRegisteredInfo();
258   for (std::unique_ptr<Pass> &pass : passes) {
259     if (opName && !pass->canScheduleOn(*opName)) {
260       return emitError(UnknownLoc::get(ctx))
261              << "unable to schedule pass '" << pass->getName()
262              << "' on a PassManager intended to run on '" << getOpAnchorName()
263              << "'!";
264     }
265   }
266   return success();
267 }
268 
canScheduleOn(MLIRContext & context,OperationName opName)269 bool OpPassManagerImpl::canScheduleOn(MLIRContext &context,
270                                       OperationName opName) {
271   // If this pass manager is op-specific, we simply check if the provided
272   // operation name is the same as this one.
273   Optional<OperationName> pmOpName = getOpName(context);
274   if (pmOpName)
275     return pmOpName == opName;
276 
277   // Otherwise, this is an op-agnostic pass manager. Check that the operation
278   // can be scheduled on all passes within the manager.
279   Optional<RegisteredOperationName> registeredInfo = opName.getRegisteredInfo();
280   if (!registeredInfo ||
281       !registeredInfo->hasTrait<OpTrait::IsIsolatedFromAbove>())
282     return false;
283   return llvm::all_of(passes, [&](const std::unique_ptr<Pass> &pass) {
284     return pass->canScheduleOn(*registeredInfo);
285   });
286 }
287 
288 //===----------------------------------------------------------------------===//
289 // OpPassManager
290 //===----------------------------------------------------------------------===//
291 
OpPassManager(Nesting nesting)292 OpPassManager::OpPassManager(Nesting nesting)
293     : impl(new OpPassManagerImpl(nesting)) {}
OpPassManager(StringRef name,Nesting nesting)294 OpPassManager::OpPassManager(StringRef name, Nesting nesting)
295     : impl(new OpPassManagerImpl(name, nesting)) {}
OpPassManager(OperationName name,Nesting nesting)296 OpPassManager::OpPassManager(OperationName name, Nesting nesting)
297     : impl(new OpPassManagerImpl(name, nesting)) {}
OpPassManager(OpPassManager && rhs)298 OpPassManager::OpPassManager(OpPassManager &&rhs) : impl(std::move(rhs.impl)) {}
OpPassManager(const OpPassManager & rhs)299 OpPassManager::OpPassManager(const OpPassManager &rhs) { *this = rhs; }
operator =(const OpPassManager & rhs)300 OpPassManager &OpPassManager::operator=(const OpPassManager &rhs) {
301   impl = std::make_unique<OpPassManagerImpl>(*rhs.impl);
302   return *this;
303 }
304 
305 OpPassManager::~OpPassManager() = default;
306 
begin()307 OpPassManager::pass_iterator OpPassManager::begin() {
308   return MutableArrayRef<std::unique_ptr<Pass>>{impl->passes}.begin();
309 }
end()310 OpPassManager::pass_iterator OpPassManager::end() {
311   return MutableArrayRef<std::unique_ptr<Pass>>{impl->passes}.end();
312 }
313 
begin() const314 OpPassManager::const_pass_iterator OpPassManager::begin() const {
315   return ArrayRef<std::unique_ptr<Pass>>{impl->passes}.begin();
316 }
end() const317 OpPassManager::const_pass_iterator OpPassManager::end() const {
318   return ArrayRef<std::unique_ptr<Pass>>{impl->passes}.end();
319 }
320 
321 /// Nest a new operation pass manager for the given operation kind under this
322 /// pass manager.
nest(OperationName nestedName)323 OpPassManager &OpPassManager::nest(OperationName nestedName) {
324   return impl->nest(nestedName);
325 }
nest(StringRef nestedName)326 OpPassManager &OpPassManager::nest(StringRef nestedName) {
327   return impl->nest(nestedName);
328 }
nestAny()329 OpPassManager &OpPassManager::nestAny() { return impl->nestAny(); }
330 
331 /// Add the given pass to this pass manager. If this pass has a concrete
332 /// operation type, it must be the same type as this pass manager.
addPass(std::unique_ptr<Pass> pass)333 void OpPassManager::addPass(std::unique_ptr<Pass> pass) {
334   impl->addPass(std::move(pass));
335 }
336 
clear()337 void OpPassManager::clear() { impl->clear(); }
338 
339 /// Returns the number of passes held by this manager.
size() const340 size_t OpPassManager::size() const { return impl->passes.size(); }
341 
342 /// Returns the internal implementation instance.
getImpl()343 OpPassManagerImpl &OpPassManager::getImpl() { return *impl; }
344 
345 /// Return the operation name that this pass manager operates on.
getOpName() const346 Optional<StringRef> OpPassManager::getOpName() const {
347   return impl->getOpName();
348 }
349 
350 /// Return the operation name that this pass manager operates on.
getOpName(MLIRContext & context) const351 Optional<OperationName> OpPassManager::getOpName(MLIRContext &context) const {
352   return impl->getOpName(context);
353 }
354 
getOpAnchorName() const355 StringRef OpPassManager::getOpAnchorName() const {
356   return impl->getOpAnchorName();
357 }
358 
359 /// Prints out the given passes as the textual representation of a pipeline.
printAsTextualPipeline(ArrayRef<std::unique_ptr<Pass>> passes,raw_ostream & os)360 static void printAsTextualPipeline(ArrayRef<std::unique_ptr<Pass>> passes,
361                                    raw_ostream &os) {
362   llvm::interleave(
363       passes,
364       [&](const std::unique_ptr<Pass> &pass) {
365         pass->printAsTextualPipeline(os);
366       },
367       [&] { os << ","; });
368 }
369 
370 /// Prints out the passes of the pass manager as the textual representation
371 /// of pipelines.
printAsTextualPipeline(raw_ostream & os) const372 void OpPassManager::printAsTextualPipeline(raw_ostream &os) const {
373   ::printAsTextualPipeline(impl->passes, os);
374 }
375 
dump()376 void OpPassManager::dump() {
377   llvm::errs() << "Pass Manager with " << impl->passes.size() << " passes: ";
378   ::printAsTextualPipeline(impl->passes, llvm::errs());
379   llvm::errs() << "\n";
380 }
381 
registerDialectsForPipeline(const OpPassManager & pm,DialectRegistry & dialects)382 static void registerDialectsForPipeline(const OpPassManager &pm,
383                                         DialectRegistry &dialects) {
384   for (const Pass &pass : pm.getPasses())
385     pass.getDependentDialects(dialects);
386 }
387 
getDependentDialects(DialectRegistry & dialects) const388 void OpPassManager::getDependentDialects(DialectRegistry &dialects) const {
389   registerDialectsForPipeline(*this, dialects);
390 }
391 
setNesting(Nesting nesting)392 void OpPassManager::setNesting(Nesting nesting) { impl->nesting = nesting; }
393 
getNesting()394 OpPassManager::Nesting OpPassManager::getNesting() { return impl->nesting; }
395 
initialize(MLIRContext * context,unsigned newInitGeneration)396 LogicalResult OpPassManager::initialize(MLIRContext *context,
397                                         unsigned newInitGeneration) {
398   if (impl->initializationGeneration == newInitGeneration)
399     return success();
400   impl->initializationGeneration = newInitGeneration;
401   for (Pass &pass : getPasses()) {
402     // If this pass isn't an adaptor, directly initialize it.
403     auto *adaptor = dyn_cast<OpToOpPassAdaptor>(&pass);
404     if (!adaptor) {
405       if (failed(pass.initialize(context)))
406         return failure();
407       continue;
408     }
409 
410     // Otherwise, initialize each of the adaptors pass managers.
411     for (OpPassManager &adaptorPM : adaptor->getPassManagers())
412       if (failed(adaptorPM.initialize(context, newInitGeneration)))
413         return failure();
414   }
415   return success();
416 }
417 
418 //===----------------------------------------------------------------------===//
419 // OpToOpPassAdaptor
420 //===----------------------------------------------------------------------===//
421 
run(Pass * pass,Operation * op,AnalysisManager am,bool verifyPasses,unsigned parentInitGeneration)422 LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
423                                      AnalysisManager am, bool verifyPasses,
424                                      unsigned parentInitGeneration) {
425   Optional<RegisteredOperationName> opInfo = op->getRegisteredInfo();
426   if (!opInfo)
427     return op->emitOpError()
428            << "trying to schedule a pass on an unregistered operation";
429   if (!opInfo->hasTrait<OpTrait::IsIsolatedFromAbove>())
430     return op->emitOpError() << "trying to schedule a pass on an operation not "
431                                 "marked as 'IsolatedFromAbove'";
432 
433   // Initialize the pass state with a callback for the pass to dynamically
434   // execute a pipeline on the currently visited operation.
435   PassInstrumentor *pi = am.getPassInstrumentor();
436   PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(),
437                                                         pass};
438   auto dynamicPipelineCallback = [&](OpPassManager &pipeline,
439                                      Operation *root) -> LogicalResult {
440     if (!op->isAncestor(root))
441       return root->emitOpError()
442              << "Trying to schedule a dynamic pipeline on an "
443                 "operation that isn't "
444                 "nested under the current operation the pass is processing";
445     assert(
446         pipeline.getImpl().canScheduleOn(*op->getContext(), root->getName()));
447 
448     // Before running, finalize the passes held by the pipeline.
449     if (failed(pipeline.getImpl().finalizePassList(root->getContext())))
450       return failure();
451 
452     // Initialize the user provided pipeline and execute the pipeline.
453     if (failed(pipeline.initialize(root->getContext(), parentInitGeneration)))
454       return failure();
455     AnalysisManager nestedAm = root == op ? am : am.nest(root);
456     return OpToOpPassAdaptor::runPipeline(pipeline, root, nestedAm,
457                                           verifyPasses, parentInitGeneration,
458                                           pi, &parentInfo);
459   };
460   pass->passState.emplace(op, am, dynamicPipelineCallback);
461 
462   // Instrument before the pass has run.
463   if (pi)
464     pi->runBeforePass(pass, op);
465 
466   // Invoke the virtual runOnOperation method.
467   if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass))
468     adaptor->runOnOperation(verifyPasses);
469   else
470     pass->runOnOperation();
471   bool passFailed = pass->passState->irAndPassFailed.getInt();
472 
473   // Invalidate any non preserved analyses.
474   am.invalidate(pass->passState->preservedAnalyses);
475 
476   // When verifyPasses is specified, we run the verifier (unless the pass
477   // failed).
478   if (!passFailed && verifyPasses) {
479     bool runVerifierNow = true;
480 
481     // If the pass is an adaptor pass, we don't run the verifier recursively
482     // because the nested operations should have already been verified after
483     // nested passes had run.
484     bool runVerifierRecursively = !isa<OpToOpPassAdaptor>(pass);
485 
486     // Reduce compile time by avoiding running the verifier if the pass didn't
487     // change the IR since the last time the verifier was run:
488     //
489     //  1) If the pass said that it preserved all analyses then it can't have
490     //     permuted the IR.
491     //
492     // We run these checks in EXPENSIVE_CHECKS mode out of caution.
493 #ifndef EXPENSIVE_CHECKS
494     runVerifierNow = !pass->passState->preservedAnalyses.isAll();
495 #endif
496     if (runVerifierNow)
497       passFailed = failed(verify(op, runVerifierRecursively));
498   }
499 
500   // Instrument after the pass has run.
501   if (pi) {
502     if (passFailed)
503       pi->runAfterPassFailed(pass, op);
504     else
505       pi->runAfterPass(pass, op);
506   }
507 
508   // Return if the pass signaled a failure.
509   return failure(passFailed);
510 }
511 
512 /// Run the given operation and analysis manager on a provided op pass manager.
runPipeline(OpPassManager & pm,Operation * op,AnalysisManager am,bool verifyPasses,unsigned parentInitGeneration,PassInstrumentor * instrumentor,const PassInstrumentation::PipelineParentInfo * parentInfo)513 LogicalResult OpToOpPassAdaptor::runPipeline(
514     OpPassManager &pm, Operation *op, AnalysisManager am, bool verifyPasses,
515     unsigned parentInitGeneration, PassInstrumentor *instrumentor,
516     const PassInstrumentation::PipelineParentInfo *parentInfo) {
517   assert((!instrumentor || parentInfo) &&
518          "expected parent info if instrumentor is provided");
519   auto scopeExit = llvm::make_scope_exit([&] {
520     // Clear out any computed operation analyses. These analyses won't be used
521     // any more in this pipeline, and this helps reduce the current working set
522     // of memory. If preserving these analyses becomes important in the future
523     // we can re-evaluate this.
524     am.clear();
525   });
526 
527   // Run the pipeline over the provided operation.
528   if (instrumentor) {
529     instrumentor->runBeforePipeline(pm.getOpName(*op->getContext()),
530                                     *parentInfo);
531   }
532 
533   for (Pass &pass : pm.getPasses())
534     if (failed(run(&pass, op, am, verifyPasses, parentInitGeneration)))
535       return failure();
536 
537   if (instrumentor) {
538     instrumentor->runAfterPipeline(pm.getOpName(*op->getContext()),
539                                    *parentInfo);
540   }
541   return success();
542 }
543 
544 /// Find an operation pass manager with the given anchor name, or nullptr if one
545 /// does not exist.
546 static OpPassManager *
findPassManagerWithAnchor(MutableArrayRef<OpPassManager> mgrs,StringRef name)547 findPassManagerWithAnchor(MutableArrayRef<OpPassManager> mgrs, StringRef name) {
548   auto *it = llvm::find_if(
549       mgrs, [&](OpPassManager &mgr) { return mgr.getOpAnchorName() == name; });
550   return it == mgrs.end() ? nullptr : &*it;
551 }
552 
553 /// Find an operation pass manager that can operate on an operation of the given
554 /// type, or nullptr if one does not exist.
findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,OperationName name,MLIRContext & context)555 static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,
556                                          OperationName name,
557                                          MLIRContext &context) {
558   auto *it = llvm::find_if(mgrs, [&](OpPassManager &mgr) {
559     return mgr.getImpl().canScheduleOn(context, name);
560   });
561   return it == mgrs.end() ? nullptr : &*it;
562 }
563 
OpToOpPassAdaptor(OpPassManager && mgr)564 OpToOpPassAdaptor::OpToOpPassAdaptor(OpPassManager &&mgr) {
565   mgrs.emplace_back(std::move(mgr));
566 }
567 
getDependentDialects(DialectRegistry & dialects) const568 void OpToOpPassAdaptor::getDependentDialects(DialectRegistry &dialects) const {
569   for (auto &pm : mgrs)
570     pm.getDependentDialects(dialects);
571 }
572 
tryMergeInto(MLIRContext * ctx,OpToOpPassAdaptor & rhs)573 LogicalResult OpToOpPassAdaptor::tryMergeInto(MLIRContext *ctx,
574                                               OpToOpPassAdaptor &rhs) {
575   // Functor used to check if a pass manager is generic, i.e. op-agnostic.
576   auto isGenericPM = [&](OpPassManager &pm) { return !pm.getOpName(); };
577 
578   // Functor used to detect if the given generic pass manager will have a
579   // potential schedule conflict with the given `otherPMs`.
580   auto hasScheduleConflictWith = [&](OpPassManager &genericPM,
581                                      MutableArrayRef<OpPassManager> otherPMs) {
582     return llvm::any_of(otherPMs, [&](OpPassManager &pm) {
583       // If this is a non-generic pass manager, a conflict will arise if a
584       // non-generic pass manager's operation name can be scheduled on the
585       // generic passmanager.
586       if (Optional<OperationName> pmOpName = pm.getOpName(*ctx))
587         return genericPM.getImpl().canScheduleOn(*ctx, *pmOpName);
588       // Otherwise, this is a generic pass manager. We current can't determine
589       // when generic pass managers can be merged, so conservatively assume they
590       // conflict.
591       return true;
592     });
593   };
594 
595   // Check that if either adaptor has a generic pass manager, that pm is
596   // compatible within any non-generic pass managers.
597   //
598   // Check the current adaptor.
599   auto *lhsGenericPMIt = llvm::find_if(mgrs, isGenericPM);
600   if (lhsGenericPMIt != mgrs.end() &&
601       hasScheduleConflictWith(*lhsGenericPMIt, rhs.mgrs))
602     return failure();
603   // Check the rhs adaptor.
604   auto *rhsGenericPMIt = llvm::find_if(rhs.mgrs, isGenericPM);
605   if (rhsGenericPMIt != rhs.mgrs.end() &&
606       hasScheduleConflictWith(*rhsGenericPMIt, mgrs))
607     return failure();
608 
609   for (auto &pm : mgrs) {
610     // If an existing pass manager exists, then merge the given pass manager
611     // into it.
612     if (auto *existingPM =
613             findPassManagerWithAnchor(rhs.mgrs, pm.getOpAnchorName())) {
614       pm.getImpl().mergeInto(existingPM->getImpl());
615     } else {
616       // Otherwise, add the given pass manager to the list.
617       rhs.mgrs.emplace_back(std::move(pm));
618     }
619   }
620   mgrs.clear();
621 
622   // After coalescing, sort the pass managers within rhs by name.
623   auto compareFn = [](const OpPassManager *lhs, const OpPassManager *rhs) {
624     // Order op-specific pass managers first and op-agnostic pass managers last.
625     if (Optional<StringRef> lhsName = lhs->getOpName()) {
626       if (Optional<StringRef> rhsName = rhs->getOpName())
627         return lhsName->compare(*rhsName);
628       return -1; // lhs(op-specific) < rhs(op-agnostic)
629     }
630     return 1; // lhs(op-agnostic) > rhs(op-specific)
631   };
632   llvm::array_pod_sort(rhs.mgrs.begin(), rhs.mgrs.end(), compareFn);
633   return success();
634 }
635 
636 /// Returns the adaptor pass name.
getAdaptorName()637 std::string OpToOpPassAdaptor::getAdaptorName() {
638   std::string name = "Pipeline Collection : [";
639   llvm::raw_string_ostream os(name);
640   llvm::interleaveComma(getPassManagers(), os, [&](OpPassManager &pm) {
641     os << '\'' << pm.getOpAnchorName() << '\'';
642   });
643   os << ']';
644   return os.str();
645 }
646 
runOnOperation()647 void OpToOpPassAdaptor::runOnOperation() {
648   llvm_unreachable(
649       "Unexpected call to Pass::runOnOperation() on OpToOpPassAdaptor");
650 }
651 
652 /// Run the held pipeline over all nested operations.
runOnOperation(bool verifyPasses)653 void OpToOpPassAdaptor::runOnOperation(bool verifyPasses) {
654   if (getContext().isMultithreadingEnabled())
655     runOnOperationAsyncImpl(verifyPasses);
656   else
657     runOnOperationImpl(verifyPasses);
658 }
659 
660 /// Run this pass adaptor synchronously.
runOnOperationImpl(bool verifyPasses)661 void OpToOpPassAdaptor::runOnOperationImpl(bool verifyPasses) {
662   auto am = getAnalysisManager();
663   PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(),
664                                                         this};
665   auto *instrumentor = am.getPassInstrumentor();
666   for (auto &region : getOperation()->getRegions()) {
667     for (auto &block : region) {
668       for (auto &op : block) {
669         auto *mgr = findPassManagerFor(mgrs, op.getName(), *op.getContext());
670         if (!mgr)
671           continue;
672 
673         // Run the held pipeline over the current operation.
674         unsigned initGeneration = mgr->impl->initializationGeneration;
675         if (failed(runPipeline(*mgr, &op, am.nest(&op), verifyPasses,
676                                initGeneration, instrumentor, &parentInfo)))
677           return signalPassFailure();
678       }
679     }
680   }
681 }
682 
683 /// Utility functor that checks if the two ranges of pass managers have a size
684 /// mismatch.
hasSizeMismatch(ArrayRef<OpPassManager> lhs,ArrayRef<OpPassManager> rhs)685 static bool hasSizeMismatch(ArrayRef<OpPassManager> lhs,
686                             ArrayRef<OpPassManager> rhs) {
687   return lhs.size() != rhs.size() ||
688          llvm::any_of(llvm::seq<size_t>(0, lhs.size()),
689                       [&](size_t i) { return lhs[i].size() != rhs[i].size(); });
690 }
691 
692 /// Run this pass adaptor synchronously.
runOnOperationAsyncImpl(bool verifyPasses)693 void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) {
694   AnalysisManager am = getAnalysisManager();
695   MLIRContext *context = &getContext();
696 
697   // Create the async executors if they haven't been created, or if the main
698   // pipeline has changed.
699   if (asyncExecutors.empty() || hasSizeMismatch(asyncExecutors.front(), mgrs))
700     asyncExecutors.assign(context->getThreadPool().getThreadCount(), mgrs);
701 
702   // This struct represents the information for a single operation to be
703   // scheduled on a pass manager.
704   struct OpPMInfo {
705     OpPMInfo(unsigned passManagerIdx, Operation *op, AnalysisManager am)
706         : passManagerIdx(passManagerIdx), op(op), am(am) {}
707 
708     /// The index of the pass manager to schedule the operation on.
709     unsigned passManagerIdx;
710     /// The operation to schedule.
711     Operation *op;
712     /// The analysis manager for the operation.
713     AnalysisManager am;
714   };
715 
716   // Run a prepass over the operation to collect the nested operations to
717   // execute over. This ensures that an analysis manager exists for each
718   // operation, as well as providing a queue of operations to execute over.
719   std::vector<OpPMInfo> opInfos;
720   DenseMap<OperationName, Optional<unsigned>> knownOpPMIdx;
721   for (auto &region : getOperation()->getRegions()) {
722     for (Operation &op : region.getOps()) {
723       // Get the pass manager index for this operation type.
724       auto pmIdxIt = knownOpPMIdx.try_emplace(op.getName(), llvm::None);
725       if (pmIdxIt.second) {
726         if (auto *mgr = findPassManagerFor(mgrs, op.getName(), *context))
727           pmIdxIt.first->second = std::distance(mgrs.begin(), mgr);
728       }
729 
730       // If this operation can be scheduled, add it to the list.
731       if (pmIdxIt.first->second)
732         opInfos.emplace_back(*pmIdxIt.first->second, &op, am.nest(&op));
733     }
734   }
735 
736   // Get the current thread for this adaptor.
737   PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(),
738                                                         this};
739   auto *instrumentor = am.getPassInstrumentor();
740 
741   // An atomic failure variable for the async executors.
742   std::vector<std::atomic<bool>> activePMs(asyncExecutors.size());
743   std::fill(activePMs.begin(), activePMs.end(), false);
744   auto processFn = [&](OpPMInfo &opInfo) {
745     // Find an executor for this operation.
746     auto it = llvm::find_if(activePMs, [](std::atomic<bool> &isActive) {
747       bool expectedInactive = false;
748       return isActive.compare_exchange_strong(expectedInactive, true);
749     });
750     unsigned pmIndex = it - activePMs.begin();
751 
752     // Get the pass manager for this operation and execute it.
753     OpPassManager &pm = asyncExecutors[pmIndex][opInfo.passManagerIdx];
754     LogicalResult pipelineResult = runPipeline(
755         pm, opInfo.op, opInfo.am, verifyPasses,
756         pm.impl->initializationGeneration, instrumentor, &parentInfo);
757 
758     // Reset the active bit for this pass manager.
759     activePMs[pmIndex].store(false);
760     return pipelineResult;
761   };
762 
763   // Signal a failure if any of the executors failed.
764   if (failed(failableParallelForEach(context, opInfos, processFn)))
765     signalPassFailure();
766 }
767 
768 //===----------------------------------------------------------------------===//
769 // PassManager
770 //===----------------------------------------------------------------------===//
771 
PassManager(MLIRContext * ctx,Nesting nesting,StringRef operationName)772 PassManager::PassManager(MLIRContext *ctx, Nesting nesting,
773                          StringRef operationName)
774     : OpPassManager(OperationName(operationName, ctx), nesting), context(ctx),
775       initializationKey(DenseMapInfo<llvm::hash_code>::getTombstoneKey()),
776       passTiming(false), verifyPasses(true) {}
777 
778 PassManager::~PassManager() = default;
779 
enableVerifier(bool enabled)780 void PassManager::enableVerifier(bool enabled) { verifyPasses = enabled; }
781 
782 /// Run the passes within this manager on the provided operation.
run(Operation * op)783 LogicalResult PassManager::run(Operation *op) {
784   MLIRContext *context = getContext();
785   assert(op->getName() == getOpName(*context) &&
786          "operation has a different name than the PassManager or is from a "
787          "different context");
788 
789   // Register all dialects for the current pipeline.
790   DialectRegistry dependentDialects;
791   getDependentDialects(dependentDialects);
792   context->appendDialectRegistry(dependentDialects);
793   for (StringRef name : dependentDialects.getDialectNames())
794     context->getOrLoadDialect(name);
795 
796   // Before running, make sure to finalize the pipeline pass list.
797   if (failed(getImpl().finalizePassList(context)))
798     return failure();
799 
800   // Initialize all of the passes within the pass manager with a new generation.
801   llvm::hash_code newInitKey = context->getRegistryHash();
802   if (newInitKey != initializationKey) {
803     if (failed(initialize(context, impl->initializationGeneration + 1)))
804       return failure();
805     initializationKey = newInitKey;
806   }
807 
808   // Construct a top level analysis manager for the pipeline.
809   ModuleAnalysisManager am(op, instrumentor.get());
810 
811   // Notify the context that we start running a pipeline for book keeping.
812   context->enterMultiThreadedExecution();
813 
814   // If reproducer generation is enabled, run the pass manager with crash
815   // handling enabled.
816   LogicalResult result =
817       crashReproGenerator ? runWithCrashRecovery(op, am) : runPasses(op, am);
818 
819   // Notify the context that the run is done.
820   context->exitMultiThreadedExecution();
821 
822   // Dump all of the pass statistics if necessary.
823   if (passStatisticsMode)
824     dumpStatistics();
825   return result;
826 }
827 
828 /// Add the provided instrumentation to the pass manager.
addInstrumentation(std::unique_ptr<PassInstrumentation> pi)829 void PassManager::addInstrumentation(std::unique_ptr<PassInstrumentation> pi) {
830   if (!instrumentor)
831     instrumentor = std::make_unique<PassInstrumentor>();
832 
833   instrumentor->addInstrumentation(std::move(pi));
834 }
835 
runPasses(Operation * op,AnalysisManager am)836 LogicalResult PassManager::runPasses(Operation *op, AnalysisManager am) {
837   return OpToOpPassAdaptor::runPipeline(*this, op, am, verifyPasses,
838                                         impl->initializationGeneration);
839 }
840 
841 //===----------------------------------------------------------------------===//
842 // AnalysisManager
843 //===----------------------------------------------------------------------===//
844 
845 /// Get an analysis manager for the given operation, which must be a proper
846 /// descendant of the current operation represented by this analysis manager.
nest(Operation * op)847 AnalysisManager AnalysisManager::nest(Operation *op) {
848   Operation *currentOp = impl->getOperation();
849   assert(currentOp->isProperAncestor(op) &&
850          "expected valid descendant operation");
851 
852   // Check for the base case where the provided operation is immediately nested.
853   if (currentOp == op->getParentOp())
854     return nestImmediate(op);
855 
856   // Otherwise, we need to collect all ancestors up to the current operation.
857   SmallVector<Operation *, 4> opAncestors;
858   do {
859     opAncestors.push_back(op);
860     op = op->getParentOp();
861   } while (op != currentOp);
862 
863   AnalysisManager result = *this;
864   for (Operation *op : llvm::reverse(opAncestors))
865     result = result.nestImmediate(op);
866   return result;
867 }
868 
869 /// Get an analysis manager for the given immediately nested child operation.
nestImmediate(Operation * op)870 AnalysisManager AnalysisManager::nestImmediate(Operation *op) {
871   assert(impl->getOperation() == op->getParentOp() &&
872          "expected immediate child operation");
873 
874   auto it = impl->childAnalyses.find(op);
875   if (it == impl->childAnalyses.end())
876     it = impl->childAnalyses
877              .try_emplace(op, std::make_unique<NestedAnalysisMap>(op, impl))
878              .first;
879   return {it->second.get()};
880 }
881 
882 /// Invalidate any non preserved analyses.
invalidate(const detail::PreservedAnalyses & pa)883 void detail::NestedAnalysisMap::invalidate(
884     const detail::PreservedAnalyses &pa) {
885   // If all analyses were preserved, then there is nothing to do here.
886   if (pa.isAll())
887     return;
888 
889   // Invalidate the analyses for the current operation directly.
890   analyses.invalidate(pa);
891 
892   // If no analyses were preserved, then just simply clear out the child
893   // analysis results.
894   if (pa.isNone()) {
895     childAnalyses.clear();
896     return;
897   }
898 
899   // Otherwise, invalidate each child analysis map.
900   SmallVector<NestedAnalysisMap *, 8> mapsToInvalidate(1, this);
901   while (!mapsToInvalidate.empty()) {
902     auto *map = mapsToInvalidate.pop_back_val();
903     for (auto &analysisPair : map->childAnalyses) {
904       analysisPair.second->invalidate(pa);
905       if (!analysisPair.second->childAnalyses.empty())
906         mapsToInvalidate.push_back(analysisPair.second.get());
907     }
908   }
909 }
910 
911 //===----------------------------------------------------------------------===//
912 // PassInstrumentation
913 //===----------------------------------------------------------------------===//
914 
915 PassInstrumentation::~PassInstrumentation() = default;
916 
runBeforePipeline(Optional<OperationName> name,const PipelineParentInfo & parentInfo)917 void PassInstrumentation::runBeforePipeline(
918     Optional<OperationName> name, const PipelineParentInfo &parentInfo) {}
919 
runAfterPipeline(Optional<OperationName> name,const PipelineParentInfo & parentInfo)920 void PassInstrumentation::runAfterPipeline(
921     Optional<OperationName> name, const PipelineParentInfo &parentInfo) {}
922 
923 //===----------------------------------------------------------------------===//
924 // PassInstrumentor
925 //===----------------------------------------------------------------------===//
926 
927 namespace mlir {
928 namespace detail {
929 struct PassInstrumentorImpl {
930   /// Mutex to keep instrumentation access thread-safe.
931   llvm::sys::SmartMutex<true> mutex;
932 
933   /// Set of registered instrumentations.
934   std::vector<std::unique_ptr<PassInstrumentation>> instrumentations;
935 };
936 } // namespace detail
937 } // namespace mlir
938 
PassInstrumentor()939 PassInstrumentor::PassInstrumentor() : impl(new PassInstrumentorImpl()) {}
940 PassInstrumentor::~PassInstrumentor() = default;
941 
942 /// See PassInstrumentation::runBeforePipeline for details.
runBeforePipeline(Optional<OperationName> name,const PassInstrumentation::PipelineParentInfo & parentInfo)943 void PassInstrumentor::runBeforePipeline(
944     Optional<OperationName> name,
945     const PassInstrumentation::PipelineParentInfo &parentInfo) {
946   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
947   for (auto &instr : impl->instrumentations)
948     instr->runBeforePipeline(name, parentInfo);
949 }
950 
951 /// See PassInstrumentation::runAfterPipeline for details.
runAfterPipeline(Optional<OperationName> name,const PassInstrumentation::PipelineParentInfo & parentInfo)952 void PassInstrumentor::runAfterPipeline(
953     Optional<OperationName> name,
954     const PassInstrumentation::PipelineParentInfo &parentInfo) {
955   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
956   for (auto &instr : llvm::reverse(impl->instrumentations))
957     instr->runAfterPipeline(name, parentInfo);
958 }
959 
960 /// See PassInstrumentation::runBeforePass for details.
runBeforePass(Pass * pass,Operation * op)961 void PassInstrumentor::runBeforePass(Pass *pass, Operation *op) {
962   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
963   for (auto &instr : impl->instrumentations)
964     instr->runBeforePass(pass, op);
965 }
966 
967 /// See PassInstrumentation::runAfterPass for details.
runAfterPass(Pass * pass,Operation * op)968 void PassInstrumentor::runAfterPass(Pass *pass, Operation *op) {
969   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
970   for (auto &instr : llvm::reverse(impl->instrumentations))
971     instr->runAfterPass(pass, op);
972 }
973 
974 /// See PassInstrumentation::runAfterPassFailed for details.
runAfterPassFailed(Pass * pass,Operation * op)975 void PassInstrumentor::runAfterPassFailed(Pass *pass, Operation *op) {
976   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
977   for (auto &instr : llvm::reverse(impl->instrumentations))
978     instr->runAfterPassFailed(pass, op);
979 }
980 
981 /// See PassInstrumentation::runBeforeAnalysis for details.
runBeforeAnalysis(StringRef name,TypeID id,Operation * op)982 void PassInstrumentor::runBeforeAnalysis(StringRef name, TypeID id,
983                                          Operation *op) {
984   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
985   for (auto &instr : impl->instrumentations)
986     instr->runBeforeAnalysis(name, id, op);
987 }
988 
989 /// See PassInstrumentation::runAfterAnalysis for details.
runAfterAnalysis(StringRef name,TypeID id,Operation * op)990 void PassInstrumentor::runAfterAnalysis(StringRef name, TypeID id,
991                                         Operation *op) {
992   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
993   for (auto &instr : llvm::reverse(impl->instrumentations))
994     instr->runAfterAnalysis(name, id, op);
995 }
996 
997 /// Add the given instrumentation to the collection.
addInstrumentation(std::unique_ptr<PassInstrumentation> pi)998 void PassInstrumentor::addInstrumentation(
999     std::unique_ptr<PassInstrumentation> pi) {
1000   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1001   impl->instrumentations.emplace_back(std::move(pi));
1002 }
1003