1 //===- Inliner.cpp - Pass to inline function calls ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a basic inlining algorithm that operates bottom up over
10 // the Strongly Connect Components(SCCs) of the CallGraph. This enables a more
11 // incremental propagation of inlining decisions from the leafs to the roots of
12 // the callgraph.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "PassDetail.h"
17 #include "mlir/Analysis/CallGraph.h"
18 #include "mlir/IR/Threading.h"
19 #include "mlir/Interfaces/CallInterfaces.h"
20 #include "mlir/Interfaces/SideEffectInterfaces.h"
21 #include "mlir/Pass/PassManager.h"
22 #include "mlir/Transforms/InliningUtils.h"
23 #include "mlir/Transforms/Passes.h"
24 #include "llvm/ADT/SCCIterator.h"
25 #include "llvm/Support/Debug.h"
26 
27 #define DEBUG_TYPE "inlining"
28 
29 using namespace mlir;
30 
31 /// This function implements the default inliner optimization pipeline.
32 static void defaultInlinerOptPipeline(OpPassManager &pm) {
33   pm.addPass(createCanonicalizerPass());
34 }
35 
36 //===----------------------------------------------------------------------===//
37 // Symbol Use Tracking
38 //===----------------------------------------------------------------------===//
39 
40 /// Walk all of the used symbol callgraph nodes referenced with the given op.
41 static void walkReferencedSymbolNodes(
42     Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable,
43     DenseMap<Attribute, CallGraphNode *> &resolvedRefs,
44     function_ref<void(CallGraphNode *, Operation *)> callback) {
45   auto symbolUses = SymbolTable::getSymbolUses(op);
46   assert(symbolUses && "expected uses to be valid");
47 
48   Operation *symbolTableOp = op->getParentOp();
49   for (const SymbolTable::SymbolUse &use : *symbolUses) {
50     auto refIt = resolvedRefs.insert({use.getSymbolRef(), nullptr});
51     CallGraphNode *&node = refIt.first->second;
52 
53     // If this is the first instance of this reference, try to resolve a
54     // callgraph node for it.
55     if (refIt.second) {
56       auto *symbolOp = symbolTable.lookupNearestSymbolFrom(symbolTableOp,
57                                                            use.getSymbolRef());
58       auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp);
59       if (!callableOp)
60         continue;
61       node = cg.lookupNode(callableOp.getCallableRegion());
62     }
63     if (node)
64       callback(node, use.getUser());
65   }
66 }
67 
68 //===----------------------------------------------------------------------===//
69 // CGUseList
70 
71 namespace {
72 /// This struct tracks the uses of callgraph nodes that can be dropped when
73 /// use_empty. It directly tracks and manages a use-list for all of the
74 /// call-graph nodes. This is necessary because many callgraph nodes are
75 /// referenced by SymbolRefAttr, which has no mechanism akin to the SSA `Use`
76 /// class.
77 struct CGUseList {
78   /// This struct tracks the uses of callgraph nodes within a specific
79   /// operation.
80   struct CGUser {
81     /// Any nodes referenced in the top-level attribute list of this user. We
82     /// use a set here because the number of references does not matter.
83     DenseSet<CallGraphNode *> topLevelUses;
84 
85     /// Uses of nodes referenced by nested operations.
86     DenseMap<CallGraphNode *, int> innerUses;
87   };
88 
89   CGUseList(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable);
90 
91   /// Drop uses of nodes referred to by the given call operation that resides
92   /// within 'userNode'.
93   void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg);
94 
95   /// Remove the given node from the use list.
96   void eraseNode(CallGraphNode *node);
97 
98   /// Returns true if the given callgraph node has no uses and can be pruned.
99   bool isDead(CallGraphNode *node) const;
100 
101   /// Returns true if the given callgraph node has a single use and can be
102   /// discarded.
103   bool hasOneUseAndDiscardable(CallGraphNode *node) const;
104 
105   /// Recompute the uses held by the given callgraph node.
106   void recomputeUses(CallGraphNode *node, CallGraph &cg);
107 
108   /// Merge the uses of 'lhs' with the uses of the 'rhs' after inlining a copy
109   /// of 'lhs' into 'rhs'.
110   void mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs);
111 
112 private:
113   /// Decrement the uses of discardable nodes referenced by the given user.
114   void decrementDiscardableUses(CGUser &uses);
115 
116   /// A mapping between a discardable callgraph node (that is a symbol) and the
117   /// number of uses for this node.
118   DenseMap<CallGraphNode *, int> discardableSymNodeUses;
119 
120   /// A mapping between a callgraph node and the symbol callgraph nodes that it
121   /// uses.
122   DenseMap<CallGraphNode *, CGUser> nodeUses;
123 
124   /// A symbol table to use when resolving call lookups.
125   SymbolTableCollection &symbolTable;
126 };
127 } // namespace
128 
129 CGUseList::CGUseList(Operation *op, CallGraph &cg,
130                      SymbolTableCollection &symbolTable)
131     : symbolTable(symbolTable) {
132   /// A set of callgraph nodes that are always known to be live during inlining.
133   DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes;
134 
135   // Walk each of the symbol tables looking for discardable callgraph nodes.
136   auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
137     for (Operation &op : symbolTableOp->getRegion(0).getOps()) {
138       // If this is a callgraph operation, check to see if it is discardable.
139       if (auto callable = dyn_cast<CallableOpInterface>(&op)) {
140         if (auto *node = cg.lookupNode(callable.getCallableRegion())) {
141           SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
142           if (symbol && (allUsesVisible || symbol.isPrivate()) &&
143               symbol.canDiscardOnUseEmpty()) {
144             discardableSymNodeUses.try_emplace(node, 0);
145           }
146           continue;
147         }
148       }
149       // Otherwise, check for any referenced nodes. These will be always-live.
150       walkReferencedSymbolNodes(&op, cg, symbolTable, alwaysLiveNodes,
151                                 [](CallGraphNode *, Operation *) {});
152     }
153   };
154   SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
155                                 walkFn);
156 
157   // Drop the use information for any discardable nodes that are always live.
158   for (auto &it : alwaysLiveNodes)
159     discardableSymNodeUses.erase(it.second);
160 
161   // Compute the uses for each of the callable nodes in the graph.
162   for (CallGraphNode *node : cg)
163     recomputeUses(node, cg);
164 }
165 
166 void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp,
167                              CallGraph &cg) {
168   auto &userRefs = nodeUses[userNode].innerUses;
169   auto walkFn = [&](CallGraphNode *node, Operation *user) {
170     auto parentIt = userRefs.find(node);
171     if (parentIt == userRefs.end())
172       return;
173     --parentIt->second;
174     --discardableSymNodeUses[node];
175   };
176   DenseMap<Attribute, CallGraphNode *> resolvedRefs;
177   walkReferencedSymbolNodes(callOp, cg, symbolTable, resolvedRefs, walkFn);
178 }
179 
180 void CGUseList::eraseNode(CallGraphNode *node) {
181   // Drop all child nodes.
182   for (auto &edge : *node)
183     if (edge.isChild())
184       eraseNode(edge.getTarget());
185 
186   // Drop the uses held by this node and erase it.
187   auto useIt = nodeUses.find(node);
188   assert(useIt != nodeUses.end() && "expected node to be valid");
189   decrementDiscardableUses(useIt->getSecond());
190   nodeUses.erase(useIt);
191   discardableSymNodeUses.erase(node);
192 }
193 
194 bool CGUseList::isDead(CallGraphNode *node) const {
195   // If the parent operation isn't a symbol, simply check normal SSA deadness.
196   Operation *nodeOp = node->getCallableRegion()->getParentOp();
197   if (!isa<SymbolOpInterface>(nodeOp))
198     return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->use_empty();
199 
200   // Otherwise, check the number of symbol uses.
201   auto symbolIt = discardableSymNodeUses.find(node);
202   return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0;
203 }
204 
205 bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const {
206   // If this isn't a symbol node, check for side-effects and SSA use count.
207   Operation *nodeOp = node->getCallableRegion()->getParentOp();
208   if (!isa<SymbolOpInterface>(nodeOp))
209     return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->hasOneUse();
210 
211   // Otherwise, check the number of symbol uses.
212   auto symbolIt = discardableSymNodeUses.find(node);
213   return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1;
214 }
215 
216 void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) {
217   Operation *parentOp = node->getCallableRegion()->getParentOp();
218   CGUser &uses = nodeUses[node];
219   decrementDiscardableUses(uses);
220 
221   // Collect the new discardable uses within this node.
222   uses = CGUser();
223   DenseMap<Attribute, CallGraphNode *> resolvedRefs;
224   auto walkFn = [&](CallGraphNode *refNode, Operation *user) {
225     auto discardSymIt = discardableSymNodeUses.find(refNode);
226     if (discardSymIt == discardableSymNodeUses.end())
227       return;
228 
229     if (user != parentOp)
230       ++uses.innerUses[refNode];
231     else if (!uses.topLevelUses.insert(refNode).second)
232       return;
233     ++discardSymIt->second;
234   };
235   walkReferencedSymbolNodes(parentOp, cg, symbolTable, resolvedRefs, walkFn);
236 }
237 
238 void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) {
239   auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs];
240   for (auto &useIt : lhsUses.innerUses) {
241     rhsUses.innerUses[useIt.first] += useIt.second;
242     discardableSymNodeUses[useIt.first] += useIt.second;
243   }
244 }
245 
246 void CGUseList::decrementDiscardableUses(CGUser &uses) {
247   for (CallGraphNode *node : uses.topLevelUses)
248     --discardableSymNodeUses[node];
249   for (auto &it : uses.innerUses)
250     discardableSymNodeUses[it.first] -= it.second;
251 }
252 
253 //===----------------------------------------------------------------------===//
254 // CallGraph traversal
255 //===----------------------------------------------------------------------===//
256 
257 namespace {
258 /// This class represents a specific callgraph SCC.
259 class CallGraphSCC {
260 public:
261   CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator)
262       : parentIterator(parentIterator) {}
263   /// Return a range over the nodes within this SCC.
264   std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); }
265   std::vector<CallGraphNode *>::iterator end() { return nodes.end(); }
266 
267   /// Reset the nodes of this SCC with those provided.
268   void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; }
269 
270   /// Remove the given node from this SCC.
271   void remove(CallGraphNode *node) {
272     auto it = llvm::find(nodes, node);
273     if (it != nodes.end()) {
274       nodes.erase(it);
275       parentIterator.ReplaceNode(node, nullptr);
276     }
277   }
278 
279 private:
280   std::vector<CallGraphNode *> nodes;
281   llvm::scc_iterator<const CallGraph *> &parentIterator;
282 };
283 } // namespace
284 
285 /// Run a given transformation over the SCCs of the callgraph in a bottom up
286 /// traversal.
287 static LogicalResult runTransformOnCGSCCs(
288     const CallGraph &cg,
289     function_ref<LogicalResult(CallGraphSCC &)> sccTransformer) {
290   llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg);
291   CallGraphSCC currentSCC(cgi);
292   while (!cgi.isAtEnd()) {
293     // Copy the current SCC and increment so that the transformer can modify the
294     // SCC without invalidating our iterator.
295     currentSCC.reset(*cgi);
296     ++cgi;
297     if (failed(sccTransformer(currentSCC)))
298       return failure();
299   }
300   return success();
301 }
302 
303 namespace {
304 /// This struct represents a resolved call to a given callgraph node. Given that
305 /// the call does not actually contain a direct reference to the
306 /// Region(CallGraphNode) that it is dispatching to, we need to resolve them
307 /// explicitly.
308 struct ResolvedCall {
309   ResolvedCall(CallOpInterface call, CallGraphNode *sourceNode,
310                CallGraphNode *targetNode)
311       : call(call), sourceNode(sourceNode), targetNode(targetNode) {}
312   CallOpInterface call;
313   CallGraphNode *sourceNode, *targetNode;
314 };
315 } // namespace
316 
317 /// Collect all of the callable operations within the given range of blocks. If
318 /// `traverseNestedCGNodes` is true, this will also collect call operations
319 /// inside of nested callgraph nodes.
320 static void collectCallOps(iterator_range<Region::iterator> blocks,
321                            CallGraphNode *sourceNode, CallGraph &cg,
322                            SymbolTableCollection &symbolTable,
323                            SmallVectorImpl<ResolvedCall> &calls,
324                            bool traverseNestedCGNodes) {
325   SmallVector<std::pair<Block *, CallGraphNode *>, 8> worklist;
326   auto addToWorklist = [&](CallGraphNode *node,
327                            iterator_range<Region::iterator> blocks) {
328     for (Block &block : blocks)
329       worklist.emplace_back(&block, node);
330   };
331 
332   addToWorklist(sourceNode, blocks);
333   while (!worklist.empty()) {
334     Block *block;
335     std::tie(block, sourceNode) = worklist.pop_back_val();
336 
337     for (Operation &op : *block) {
338       if (auto call = dyn_cast<CallOpInterface>(op)) {
339         // TODO: Support inlining nested call references.
340         CallInterfaceCallable callable = call.getCallableForCallee();
341         if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) {
342           if (!symRef.isa<FlatSymbolRefAttr>())
343             continue;
344         }
345 
346         CallGraphNode *targetNode = cg.resolveCallable(call, symbolTable);
347         if (!targetNode->isExternal())
348           calls.emplace_back(call, sourceNode, targetNode);
349         continue;
350       }
351 
352       // If this is not a call, traverse the nested regions. If
353       // `traverseNestedCGNodes` is false, then don't traverse nested call graph
354       // regions.
355       for (auto &nestedRegion : op.getRegions()) {
356         CallGraphNode *nestedNode = cg.lookupNode(&nestedRegion);
357         if (traverseNestedCGNodes || !nestedNode)
358           addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion);
359       }
360     }
361   }
362 }
363 
364 //===----------------------------------------------------------------------===//
365 // Inliner
366 //===----------------------------------------------------------------------===//
367 namespace {
368 /// This class provides a specialization of the main inlining interface.
369 struct Inliner : public InlinerInterface {
370   Inliner(MLIRContext *context, CallGraph &cg,
371           SymbolTableCollection &symbolTable)
372       : InlinerInterface(context), cg(cg), symbolTable(symbolTable) {}
373 
374   /// Process a set of blocks that have been inlined. This callback is invoked
375   /// *before* inlined terminator operations have been processed.
376   void
377   processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final {
378     // Find the closest callgraph node from the first block.
379     CallGraphNode *node;
380     Region *region = inlinedBlocks.begin()->getParent();
381     while (!(node = cg.lookupNode(region))) {
382       region = region->getParentRegion();
383       assert(region && "expected valid parent node");
384     }
385 
386     collectCallOps(inlinedBlocks, node, cg, symbolTable, calls,
387                    /*traverseNestedCGNodes=*/true);
388   }
389 
390   /// Mark the given callgraph node for deletion.
391   void markForDeletion(CallGraphNode *node) { deadNodes.insert(node); }
392 
393   /// This method properly disposes of callables that became dead during
394   /// inlining. This should not be called while iterating over the SCCs.
395   void eraseDeadCallables() {
396     for (CallGraphNode *node : deadNodes)
397       node->getCallableRegion()->getParentOp()->erase();
398   }
399 
400   /// The set of callables known to be dead.
401   SmallPtrSet<CallGraphNode *, 8> deadNodes;
402 
403   /// The current set of call instructions to consider for inlining.
404   SmallVector<ResolvedCall, 8> calls;
405 
406   /// The callgraph being operated on.
407   CallGraph &cg;
408 
409   /// A symbol table to use when resolving call lookups.
410   SymbolTableCollection &symbolTable;
411 };
412 } // namespace
413 
414 /// Returns true if the given call should be inlined.
415 static bool shouldInline(ResolvedCall &resolvedCall) {
416   // Don't allow inlining terminator calls. We currently don't support this
417   // case.
418   if (resolvedCall.call->hasTrait<OpTrait::IsTerminator>())
419     return false;
420 
421   // Don't allow inlining if the target is an ancestor of the call. This
422   // prevents inlining recursively.
423   if (resolvedCall.targetNode->getCallableRegion()->isAncestor(
424           resolvedCall.call->getParentRegion()))
425     return false;
426 
427   // Otherwise, inline.
428   return true;
429 }
430 
431 /// Attempt to inline calls within the given scc. This function returns
432 /// success if any calls were inlined, failure otherwise.
433 static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
434                                       CallGraphSCC &currentSCC) {
435   CallGraph &cg = inliner.cg;
436   auto &calls = inliner.calls;
437 
438   // A set of dead nodes to remove after inlining.
439   llvm::SmallSetVector<CallGraphNode *, 1> deadNodes;
440 
441   // Collect all of the direct calls within the nodes of the current SCC. We
442   // don't traverse nested callgraph nodes, because they are handled separately
443   // likely within a different SCC.
444   for (CallGraphNode *node : currentSCC) {
445     if (node->isExternal())
446       continue;
447 
448     // Don't collect calls if the node is already dead.
449     if (useList.isDead(node)) {
450       deadNodes.insert(node);
451     } else {
452       collectCallOps(*node->getCallableRegion(), node, cg, inliner.symbolTable,
453                      calls, /*traverseNestedCGNodes=*/false);
454     }
455   }
456 
457   // Try to inline each of the call operations. Don't cache the end iterator
458   // here as more calls may be added during inlining.
459   bool inlinedAnyCalls = false;
460   for (unsigned i = 0; i != calls.size(); ++i) {
461     if (deadNodes.contains(calls[i].sourceNode))
462       continue;
463     ResolvedCall it = calls[i];
464     bool doInline = shouldInline(it);
465     CallOpInterface call = it.call;
466     LLVM_DEBUG({
467       if (doInline)
468         llvm::dbgs() << "* Inlining call: " << call << "\n";
469       else
470         llvm::dbgs() << "* Not inlining call: " << call << "\n";
471     });
472     if (!doInline)
473       continue;
474     Region *targetRegion = it.targetNode->getCallableRegion();
475 
476     // If this is the last call to the target node and the node is discardable,
477     // then inline it in-place and delete the node if successful.
478     bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode);
479 
480     LogicalResult inlineResult = inlineCall(
481         inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()),
482         targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace);
483     if (failed(inlineResult)) {
484       LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n");
485       continue;
486     }
487     inlinedAnyCalls = true;
488 
489     // If the inlining was successful, Merge the new uses into the source node.
490     useList.dropCallUses(it.sourceNode, call.getOperation(), cg);
491     useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode);
492 
493     // then erase the call.
494     call.erase();
495 
496     // If we inlined in place, mark the node for deletion.
497     if (inlineInPlace) {
498       useList.eraseNode(it.targetNode);
499       deadNodes.insert(it.targetNode);
500     }
501   }
502 
503   for (CallGraphNode *node : deadNodes) {
504     currentSCC.remove(node);
505     inliner.markForDeletion(node);
506   }
507   calls.clear();
508   return success(inlinedAnyCalls);
509 }
510 
511 //===----------------------------------------------------------------------===//
512 // InlinerPass
513 //===----------------------------------------------------------------------===//
514 
515 namespace {
516 class InlinerPass : public InlinerBase<InlinerPass> {
517 public:
518   InlinerPass();
519   InlinerPass(const InlinerPass &) = default;
520   InlinerPass(std::function<void(OpPassManager &)> defaultPipeline);
521   InlinerPass(std::function<void(OpPassManager &)> defaultPipeline,
522               llvm::StringMap<OpPassManager> opPipelines);
523   void runOnOperation() override;
524 
525 private:
526   /// Attempt to inline calls within the given scc, and run simplifications,
527   /// until a fixed point is reached. This allows for the inlining of newly
528   /// devirtualized calls. Returns failure if there was a fatal error during
529   /// inlining.
530   LogicalResult inlineSCC(Inliner &inliner, CGUseList &useList,
531                           CallGraphSCC &currentSCC, MLIRContext *context);
532 
533   /// Optimize the nodes within the given SCC with one of the held optimization
534   /// pass pipelines. Returns failure if an error occurred during the
535   /// optimization of the SCC, success otherwise.
536   LogicalResult optimizeSCC(CallGraph &cg, CGUseList &useList,
537                             CallGraphSCC &currentSCC, MLIRContext *context);
538 
539   /// Optimize the nodes within the given SCC in parallel. Returns failure if an
540   /// error occurred during the optimization of the SCC, success otherwise.
541   LogicalResult optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
542                                  MLIRContext *context);
543 
544   /// Optimize the given callable node with one of the pass managers provided
545   /// with `pipelines`, or the default pipeline. Returns failure if an error
546   /// occurred during the optimization of the callable, success otherwise.
547   LogicalResult optimizeCallable(CallGraphNode *node,
548                                  llvm::StringMap<OpPassManager> &pipelines);
549 
550   /// Attempt to initialize the options of this pass from the given string.
551   /// Derived classes may override this method to hook into the point at which
552   /// options are initialized, but should generally always invoke this base
553   /// class variant.
554   LogicalResult initializeOptions(StringRef options) override;
555 
556   /// An optional function that constructs a default optimization pipeline for
557   /// a given operation.
558   std::function<void(OpPassManager &)> defaultPipeline;
559   /// A map of operation names to pass pipelines to use when optimizing
560   /// callable operations of these types. This provides a specialized pipeline
561   /// instead of the default. The vector size is the number of threads used
562   /// during optimization.
563   SmallVector<llvm::StringMap<OpPassManager>, 8> opPipelines;
564 };
565 } // namespace
566 
567 InlinerPass::InlinerPass() : InlinerPass(defaultInlinerOptPipeline) {}
568 InlinerPass::InlinerPass(std::function<void(OpPassManager &)> defaultPipeline)
569     : defaultPipeline(std::move(defaultPipeline)) {
570   opPipelines.push_back({});
571 
572   // Initialize the pass options with the provided arguments.
573   if (defaultPipeline) {
574     OpPassManager fakePM("__mlir_fake_pm_op");
575     defaultPipeline(fakePM);
576     llvm::raw_string_ostream strStream(defaultPipelineStr);
577     fakePM.printAsTextualPipeline(strStream);
578   }
579 }
580 
581 InlinerPass::InlinerPass(std::function<void(OpPassManager &)> defaultPipeline,
582                          llvm::StringMap<OpPassManager> opPipelines)
583     : InlinerPass(std::move(defaultPipeline)) {
584   if (opPipelines.empty())
585     return;
586 
587   // Update the option for the op specific optimization pipelines.
588   for (auto &it : opPipelines)
589     opPipelineList.addValue(it.second);
590   this->opPipelines.emplace_back(std::move(opPipelines));
591 }
592 
593 void InlinerPass::runOnOperation() {
594   CallGraph &cg = getAnalysis<CallGraph>();
595   auto *context = &getContext();
596 
597   // The inliner should only be run on operations that define a symbol table,
598   // as the callgraph will need to resolve references.
599   Operation *op = getOperation();
600   if (!op->hasTrait<OpTrait::SymbolTable>()) {
601     op->emitOpError() << " was scheduled to run under the inliner, but does "
602                          "not define a symbol table";
603     return signalPassFailure();
604   }
605 
606   // Run the inline transform in post-order over the SCCs in the callgraph.
607   SymbolTableCollection symbolTable;
608   Inliner inliner(context, cg, symbolTable);
609   CGUseList useList(getOperation(), cg, symbolTable);
610   LogicalResult result = runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) {
611     return inlineSCC(inliner, useList, scc, context);
612   });
613   if (failed(result))
614     return signalPassFailure();
615 
616   // After inlining, make sure to erase any callables proven to be dead.
617   inliner.eraseDeadCallables();
618 }
619 
620 LogicalResult InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList,
621                                      CallGraphSCC &currentSCC,
622                                      MLIRContext *context) {
623   // Continuously simplify and inline until we either reach a fixed point, or
624   // hit the maximum iteration count. Simplifying early helps to refine the cost
625   // model, and in future iterations may devirtualize new calls.
626   unsigned iterationCount = 0;
627   do {
628     if (failed(optimizeSCC(inliner.cg, useList, currentSCC, context)))
629       return failure();
630     if (failed(inlineCallsInSCC(inliner, useList, currentSCC)))
631       break;
632   } while (++iterationCount < maxInliningIterations);
633   return success();
634 }
635 
636 LogicalResult InlinerPass::optimizeSCC(CallGraph &cg, CGUseList &useList,
637                                        CallGraphSCC &currentSCC,
638                                        MLIRContext *context) {
639   // Collect the sets of nodes to simplify.
640   SmallVector<CallGraphNode *, 4> nodesToVisit;
641   for (auto *node : currentSCC) {
642     if (node->isExternal())
643       continue;
644 
645     // Don't simplify nodes with children. Nodes with children require special
646     // handling as we may remove the node during simplification. In the future,
647     // we should be able to handle this case with proper node deletion tracking.
648     if (node->hasChildren())
649       continue;
650 
651     // We also won't apply simplifications to nodes that can't have passes
652     // scheduled on them.
653     auto *region = node->getCallableRegion();
654     if (!region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
655       continue;
656     nodesToVisit.push_back(node);
657   }
658   if (nodesToVisit.empty())
659     return success();
660 
661   // Optimize each of the nodes within the SCC in parallel.
662   if (failed(optimizeSCCAsync(nodesToVisit, context)))
663     return failure();
664 
665   // Recompute the uses held by each of the nodes.
666   for (CallGraphNode *node : nodesToVisit)
667     useList.recomputeUses(node, cg);
668   return success();
669 }
670 
671 LogicalResult
672 InlinerPass::optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
673                               MLIRContext *ctx) {
674   // We must maintain a fixed pool of pass managers which is at least as large
675   // as the maximum parallelism of the failableParallelForEach below.
676   // Note: The number of pass managers here needs to remain constant
677   // to prevent issues with pass instrumentations that rely on having the same
678   // pass manager for the main thread.
679   size_t numThreads = ctx->getNumThreads();
680   if (opPipelines.size() < numThreads) {
681     // Reserve before resizing so that we can use a reference to the first
682     // element.
683     opPipelines.reserve(numThreads);
684     opPipelines.resize(numThreads, opPipelines.front());
685   }
686 
687   // Ensure an analysis manager has been constructed for each of the nodes.
688   // This prevents thread races when running the nested pipelines.
689   for (CallGraphNode *node : nodesToVisit)
690     getAnalysisManager().nest(node->getCallableRegion()->getParentOp());
691 
692   // An atomic failure variable for the async executors.
693   std::vector<std::atomic<bool>> activePMs(opPipelines.size());
694   std::fill(activePMs.begin(), activePMs.end(), false);
695   return failableParallelForEach(ctx, nodesToVisit, [&](CallGraphNode *node) {
696     // Find a pass manager for this operation.
697     auto it = llvm::find_if(activePMs, [](std::atomic<bool> &isActive) {
698       bool expectedInactive = false;
699       return isActive.compare_exchange_strong(expectedInactive, true);
700     });
701     assert(it != activePMs.end() &&
702            "could not find inactive pass manager for thread");
703     unsigned pmIndex = it - activePMs.begin();
704 
705     // Optimize this callable node.
706     LogicalResult result = optimizeCallable(node, opPipelines[pmIndex]);
707 
708     // Reset the active bit for this pass manager.
709     activePMs[pmIndex].store(false);
710     return result;
711   });
712 }
713 
714 LogicalResult
715 InlinerPass::optimizeCallable(CallGraphNode *node,
716                               llvm::StringMap<OpPassManager> &pipelines) {
717   Operation *callable = node->getCallableRegion()->getParentOp();
718   StringRef opName = callable->getName().getStringRef();
719   auto pipelineIt = pipelines.find(opName);
720   if (pipelineIt == pipelines.end()) {
721     // If a pipeline didn't exist, use the default if possible.
722     if (!defaultPipeline)
723       return success();
724 
725     OpPassManager defaultPM(opName);
726     defaultPipeline(defaultPM);
727     pipelineIt = pipelines.try_emplace(opName, std::move(defaultPM)).first;
728   }
729   return runPipeline(pipelineIt->second, callable);
730 }
731 
732 LogicalResult InlinerPass::initializeOptions(StringRef options) {
733   if (failed(Pass::initializeOptions(options)))
734     return failure();
735 
736   // Initialize the default pipeline builder to use the option string.
737   // TODO: Use a generic pass manager for default pipelines, and remove this.
738   if (!defaultPipelineStr.empty()) {
739     std::string defaultPipelineCopy = defaultPipelineStr;
740     defaultPipeline = [=](OpPassManager &pm) {
741       (void)parsePassPipeline(defaultPipelineCopy, pm);
742     };
743   } else if (defaultPipelineStr.getNumOccurrences()) {
744     defaultPipeline = nullptr;
745   }
746 
747   // Initialize the op specific pass pipelines.
748   llvm::StringMap<OpPassManager> pipelines;
749   for (OpPassManager pipeline : opPipelineList)
750     if (!pipeline.empty())
751       pipelines.try_emplace(pipeline.getOpAnchorName(), pipeline);
752   opPipelines.assign({std::move(pipelines)});
753 
754   return success();
755 }
756 
757 std::unique_ptr<Pass> mlir::createInlinerPass() {
758   return std::make_unique<InlinerPass>();
759 }
760 std::unique_ptr<Pass>
761 mlir::createInlinerPass(llvm::StringMap<OpPassManager> opPipelines) {
762   return std::make_unique<InlinerPass>(defaultInlinerOptPipeline,
763                                        std::move(opPipelines));
764 }
765 std::unique_ptr<Pass> mlir::createInlinerPass(
766     llvm::StringMap<OpPassManager> opPipelines,
767     std::function<void(OpPassManager &)> defaultPipelineBuilder) {
768   return std::make_unique<InlinerPass>(std::move(defaultPipelineBuilder),
769                                        std::move(opPipelines));
770 }
771