1 //===- Inliner.cpp - Pass to inline function calls ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a basic inlining algorithm that operates bottom up over
10 // the Strongly Connect Components(SCCs) of the CallGraph. This enables a more
11 // incremental propagation of inlining decisions from the leafs to the roots of
12 // the callgraph.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "PassDetail.h"
17 #include "mlir/Analysis/CallGraph.h"
18 #include "mlir/IR/Threading.h"
19 #include "mlir/Interfaces/SideEffectInterfaces.h"
20 #include "mlir/Pass/PassManager.h"
21 #include "mlir/Transforms/InliningUtils.h"
22 #include "mlir/Transforms/Passes.h"
23 #include "llvm/ADT/SCCIterator.h"
24 #include "llvm/Support/Debug.h"
25 
26 #define DEBUG_TYPE "inlining"
27 
28 using namespace mlir;
29 
30 /// This function implements the default inliner optimization pipeline.
31 static void defaultInlinerOptPipeline(OpPassManager &pm) {
32   pm.addPass(createCanonicalizerPass());
33 }
34 
35 //===----------------------------------------------------------------------===//
36 // Symbol Use Tracking
37 //===----------------------------------------------------------------------===//
38 
39 /// Walk all of the used symbol callgraph nodes referenced with the given op.
40 static void walkReferencedSymbolNodes(
41     Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable,
42     DenseMap<Attribute, CallGraphNode *> &resolvedRefs,
43     function_ref<void(CallGraphNode *, Operation *)> callback) {
44   auto symbolUses = SymbolTable::getSymbolUses(op);
45   assert(symbolUses && "expected uses to be valid");
46 
47   Operation *symbolTableOp = op->getParentOp();
48   for (const SymbolTable::SymbolUse &use : *symbolUses) {
49     auto refIt = resolvedRefs.insert({use.getSymbolRef(), nullptr});
50     CallGraphNode *&node = refIt.first->second;
51 
52     // If this is the first instance of this reference, try to resolve a
53     // callgraph node for it.
54     if (refIt.second) {
55       auto *symbolOp = symbolTable.lookupNearestSymbolFrom(symbolTableOp,
56                                                            use.getSymbolRef());
57       auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp);
58       if (!callableOp)
59         continue;
60       node = cg.lookupNode(callableOp.getCallableRegion());
61     }
62     if (node)
63       callback(node, use.getUser());
64   }
65 }
66 
67 //===----------------------------------------------------------------------===//
68 // CGUseList
69 
70 namespace {
71 /// This struct tracks the uses of callgraph nodes that can be dropped when
72 /// use_empty. It directly tracks and manages a use-list for all of the
73 /// call-graph nodes. This is necessary because many callgraph nodes are
74 /// referenced by SymbolRefAttr, which has no mechanism akin to the SSA `Use`
75 /// class.
76 struct CGUseList {
77   /// This struct tracks the uses of callgraph nodes within a specific
78   /// operation.
79   struct CGUser {
80     /// Any nodes referenced in the top-level attribute list of this user. We
81     /// use a set here because the number of references does not matter.
82     DenseSet<CallGraphNode *> topLevelUses;
83 
84     /// Uses of nodes referenced by nested operations.
85     DenseMap<CallGraphNode *, int> innerUses;
86   };
87 
88   CGUseList(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable);
89 
90   /// Drop uses of nodes referred to by the given call operation that resides
91   /// within 'userNode'.
92   void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg);
93 
94   /// Remove the given node from the use list.
95   void eraseNode(CallGraphNode *node);
96 
97   /// Returns true if the given callgraph node has no uses and can be pruned.
98   bool isDead(CallGraphNode *node) const;
99 
100   /// Returns true if the given callgraph node has a single use and can be
101   /// discarded.
102   bool hasOneUseAndDiscardable(CallGraphNode *node) const;
103 
104   /// Recompute the uses held by the given callgraph node.
105   void recomputeUses(CallGraphNode *node, CallGraph &cg);
106 
107   /// Merge the uses of 'lhs' with the uses of the 'rhs' after inlining a copy
108   /// of 'lhs' into 'rhs'.
109   void mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs);
110 
111 private:
112   /// Decrement the uses of discardable nodes referenced by the given user.
113   void decrementDiscardableUses(CGUser &uses);
114 
115   /// A mapping between a discardable callgraph node (that is a symbol) and the
116   /// number of uses for this node.
117   DenseMap<CallGraphNode *, int> discardableSymNodeUses;
118 
119   /// A mapping between a callgraph node and the symbol callgraph nodes that it
120   /// uses.
121   DenseMap<CallGraphNode *, CGUser> nodeUses;
122 
123   /// A symbol table to use when resolving call lookups.
124   SymbolTableCollection &symbolTable;
125 };
126 } // end anonymous namespace
127 
128 CGUseList::CGUseList(Operation *op, CallGraph &cg,
129                      SymbolTableCollection &symbolTable)
130     : symbolTable(symbolTable) {
131   /// A set of callgraph nodes that are always known to be live during inlining.
132   DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes;
133 
134   // Walk each of the symbol tables looking for discardable callgraph nodes.
135   auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
136     for (Operation &op : symbolTableOp->getRegion(0).getOps()) {
137       // If this is a callgraph operation, check to see if it is discardable.
138       if (auto callable = dyn_cast<CallableOpInterface>(&op)) {
139         if (auto *node = cg.lookupNode(callable.getCallableRegion())) {
140           SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
141           if (symbol && (allUsesVisible || symbol.isPrivate()) &&
142               symbol.canDiscardOnUseEmpty()) {
143             discardableSymNodeUses.try_emplace(node, 0);
144           }
145           continue;
146         }
147       }
148       // Otherwise, check for any referenced nodes. These will be always-live.
149       walkReferencedSymbolNodes(&op, cg, symbolTable, alwaysLiveNodes,
150                                 [](CallGraphNode *, Operation *) {});
151     }
152   };
153   SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
154                                 walkFn);
155 
156   // Drop the use information for any discardable nodes that are always live.
157   for (auto &it : alwaysLiveNodes)
158     discardableSymNodeUses.erase(it.second);
159 
160   // Compute the uses for each of the callable nodes in the graph.
161   for (CallGraphNode *node : cg)
162     recomputeUses(node, cg);
163 }
164 
165 void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp,
166                              CallGraph &cg) {
167   auto &userRefs = nodeUses[userNode].innerUses;
168   auto walkFn = [&](CallGraphNode *node, Operation *user) {
169     auto parentIt = userRefs.find(node);
170     if (parentIt == userRefs.end())
171       return;
172     --parentIt->second;
173     --discardableSymNodeUses[node];
174   };
175   DenseMap<Attribute, CallGraphNode *> resolvedRefs;
176   walkReferencedSymbolNodes(callOp, cg, symbolTable, resolvedRefs, walkFn);
177 }
178 
179 void CGUseList::eraseNode(CallGraphNode *node) {
180   // Drop all child nodes.
181   for (auto &edge : *node)
182     if (edge.isChild())
183       eraseNode(edge.getTarget());
184 
185   // Drop the uses held by this node and erase it.
186   auto useIt = nodeUses.find(node);
187   assert(useIt != nodeUses.end() && "expected node to be valid");
188   decrementDiscardableUses(useIt->getSecond());
189   nodeUses.erase(useIt);
190   discardableSymNodeUses.erase(node);
191 }
192 
193 bool CGUseList::isDead(CallGraphNode *node) const {
194   // If the parent operation isn't a symbol, simply check normal SSA deadness.
195   Operation *nodeOp = node->getCallableRegion()->getParentOp();
196   if (!isa<SymbolOpInterface>(nodeOp))
197     return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->use_empty();
198 
199   // Otherwise, check the number of symbol uses.
200   auto symbolIt = discardableSymNodeUses.find(node);
201   return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0;
202 }
203 
204 bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const {
205   // If this isn't a symbol node, check for side-effects and SSA use count.
206   Operation *nodeOp = node->getCallableRegion()->getParentOp();
207   if (!isa<SymbolOpInterface>(nodeOp))
208     return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->hasOneUse();
209 
210   // Otherwise, check the number of symbol uses.
211   auto symbolIt = discardableSymNodeUses.find(node);
212   return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1;
213 }
214 
215 void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) {
216   Operation *parentOp = node->getCallableRegion()->getParentOp();
217   CGUser &uses = nodeUses[node];
218   decrementDiscardableUses(uses);
219 
220   // Collect the new discardable uses within this node.
221   uses = CGUser();
222   DenseMap<Attribute, CallGraphNode *> resolvedRefs;
223   auto walkFn = [&](CallGraphNode *refNode, Operation *user) {
224     auto discardSymIt = discardableSymNodeUses.find(refNode);
225     if (discardSymIt == discardableSymNodeUses.end())
226       return;
227 
228     if (user != parentOp)
229       ++uses.innerUses[refNode];
230     else if (!uses.topLevelUses.insert(refNode).second)
231       return;
232     ++discardSymIt->second;
233   };
234   walkReferencedSymbolNodes(parentOp, cg, symbolTable, resolvedRefs, walkFn);
235 }
236 
237 void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) {
238   auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs];
239   for (auto &useIt : lhsUses.innerUses) {
240     rhsUses.innerUses[useIt.first] += useIt.second;
241     discardableSymNodeUses[useIt.first] += useIt.second;
242   }
243 }
244 
245 void CGUseList::decrementDiscardableUses(CGUser &uses) {
246   for (CallGraphNode *node : uses.topLevelUses)
247     --discardableSymNodeUses[node];
248   for (auto &it : uses.innerUses)
249     discardableSymNodeUses[it.first] -= it.second;
250 }
251 
252 //===----------------------------------------------------------------------===//
253 // CallGraph traversal
254 //===----------------------------------------------------------------------===//
255 
256 namespace {
257 /// This class represents a specific callgraph SCC.
258 class CallGraphSCC {
259 public:
260   CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator)
261       : parentIterator(parentIterator) {}
262   /// Return a range over the nodes within this SCC.
263   std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); }
264   std::vector<CallGraphNode *>::iterator end() { return nodes.end(); }
265 
266   /// Reset the nodes of this SCC with those provided.
267   void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; }
268 
269   /// Remove the given node from this SCC.
270   void remove(CallGraphNode *node) {
271     auto it = llvm::find(nodes, node);
272     if (it != nodes.end()) {
273       nodes.erase(it);
274       parentIterator.ReplaceNode(node, nullptr);
275     }
276   }
277 
278 private:
279   std::vector<CallGraphNode *> nodes;
280   llvm::scc_iterator<const CallGraph *> &parentIterator;
281 };
282 } // end anonymous namespace
283 
284 /// Run a given transformation over the SCCs of the callgraph in a bottom up
285 /// traversal.
286 static LogicalResult runTransformOnCGSCCs(
287     const CallGraph &cg,
288     function_ref<LogicalResult(CallGraphSCC &)> sccTransformer) {
289   llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg);
290   CallGraphSCC currentSCC(cgi);
291   while (!cgi.isAtEnd()) {
292     // Copy the current SCC and increment so that the transformer can modify the
293     // SCC without invalidating our iterator.
294     currentSCC.reset(*cgi);
295     ++cgi;
296     if (failed(sccTransformer(currentSCC)))
297       return failure();
298   }
299   return success();
300 }
301 
302 namespace {
303 /// This struct represents a resolved call to a given callgraph node. Given that
304 /// the call does not actually contain a direct reference to the
305 /// Region(CallGraphNode) that it is dispatching to, we need to resolve them
306 /// explicitly.
307 struct ResolvedCall {
308   ResolvedCall(CallOpInterface call, CallGraphNode *sourceNode,
309                CallGraphNode *targetNode)
310       : call(call), sourceNode(sourceNode), targetNode(targetNode) {}
311   CallOpInterface call;
312   CallGraphNode *sourceNode, *targetNode;
313 };
314 } // end anonymous namespace
315 
316 /// Collect all of the callable operations within the given range of blocks. If
317 /// `traverseNestedCGNodes` is true, this will also collect call operations
318 /// inside of nested callgraph nodes.
319 static void collectCallOps(iterator_range<Region::iterator> blocks,
320                            CallGraphNode *sourceNode, CallGraph &cg,
321                            SymbolTableCollection &symbolTable,
322                            SmallVectorImpl<ResolvedCall> &calls,
323                            bool traverseNestedCGNodes) {
324   SmallVector<std::pair<Block *, CallGraphNode *>, 8> worklist;
325   auto addToWorklist = [&](CallGraphNode *node,
326                            iterator_range<Region::iterator> blocks) {
327     for (Block &block : blocks)
328       worklist.emplace_back(&block, node);
329   };
330 
331   addToWorklist(sourceNode, blocks);
332   while (!worklist.empty()) {
333     Block *block;
334     std::tie(block, sourceNode) = worklist.pop_back_val();
335 
336     for (Operation &op : *block) {
337       if (auto call = dyn_cast<CallOpInterface>(op)) {
338         // TODO: Support inlining nested call references.
339         CallInterfaceCallable callable = call.getCallableForCallee();
340         if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) {
341           if (!symRef.isa<FlatSymbolRefAttr>())
342             continue;
343         }
344 
345         CallGraphNode *targetNode = cg.resolveCallable(call, symbolTable);
346         if (!targetNode->isExternal())
347           calls.emplace_back(call, sourceNode, targetNode);
348         continue;
349       }
350 
351       // If this is not a call, traverse the nested regions. If
352       // `traverseNestedCGNodes` is false, then don't traverse nested call graph
353       // regions.
354       for (auto &nestedRegion : op.getRegions()) {
355         CallGraphNode *nestedNode = cg.lookupNode(&nestedRegion);
356         if (traverseNestedCGNodes || !nestedNode)
357           addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion);
358       }
359     }
360   }
361 }
362 
363 //===----------------------------------------------------------------------===//
364 // Inliner
365 //===----------------------------------------------------------------------===//
366 namespace {
367 /// This class provides a specialization of the main inlining interface.
368 struct Inliner : public InlinerInterface {
369   Inliner(MLIRContext *context, CallGraph &cg,
370           SymbolTableCollection &symbolTable)
371       : InlinerInterface(context), cg(cg), symbolTable(symbolTable) {}
372 
373   /// Process a set of blocks that have been inlined. This callback is invoked
374   /// *before* inlined terminator operations have been processed.
375   void
376   processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final {
377     // Find the closest callgraph node from the first block.
378     CallGraphNode *node;
379     Region *region = inlinedBlocks.begin()->getParent();
380     while (!(node = cg.lookupNode(region))) {
381       region = region->getParentRegion();
382       assert(region && "expected valid parent node");
383     }
384 
385     collectCallOps(inlinedBlocks, node, cg, symbolTable, calls,
386                    /*traverseNestedCGNodes=*/true);
387   }
388 
389   /// Mark the given callgraph node for deletion.
390   void markForDeletion(CallGraphNode *node) { deadNodes.insert(node); }
391 
392   /// This method properly disposes of callables that became dead during
393   /// inlining. This should not be called while iterating over the SCCs.
394   void eraseDeadCallables() {
395     for (CallGraphNode *node : deadNodes)
396       node->getCallableRegion()->getParentOp()->erase();
397   }
398 
399   /// The set of callables known to be dead.
400   SmallPtrSet<CallGraphNode *, 8> deadNodes;
401 
402   /// The current set of call instructions to consider for inlining.
403   SmallVector<ResolvedCall, 8> calls;
404 
405   /// The callgraph being operated on.
406   CallGraph &cg;
407 
408   /// A symbol table to use when resolving call lookups.
409   SymbolTableCollection &symbolTable;
410 };
411 } // namespace
412 
413 /// Returns true if the given call should be inlined.
414 static bool shouldInline(ResolvedCall &resolvedCall) {
415   // Don't allow inlining terminator calls. We currently don't support this
416   // case.
417   if (resolvedCall.call->hasTrait<OpTrait::IsTerminator>())
418     return false;
419 
420   // Don't allow inlining if the target is an ancestor of the call. This
421   // prevents inlining recursively.
422   if (resolvedCall.targetNode->getCallableRegion()->isAncestor(
423           resolvedCall.call->getParentRegion()))
424     return false;
425 
426   // Otherwise, inline.
427   return true;
428 }
429 
430 /// Attempt to inline calls within the given scc. This function returns
431 /// success if any calls were inlined, failure otherwise.
432 static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
433                                       CallGraphSCC &currentSCC) {
434   CallGraph &cg = inliner.cg;
435   auto &calls = inliner.calls;
436 
437   // A set of dead nodes to remove after inlining.
438   SmallVector<CallGraphNode *, 1> deadNodes;
439 
440   // Collect all of the direct calls within the nodes of the current SCC. We
441   // don't traverse nested callgraph nodes, because they are handled separately
442   // likely within a different SCC.
443   for (CallGraphNode *node : currentSCC) {
444     if (node->isExternal())
445       continue;
446 
447     // Don't collect calls if the node is already dead.
448     if (useList.isDead(node)) {
449       deadNodes.push_back(node);
450     } else {
451       collectCallOps(*node->getCallableRegion(), node, cg, inliner.symbolTable,
452                      calls, /*traverseNestedCGNodes=*/false);
453     }
454   }
455 
456   // Try to inline each of the call operations. Don't cache the end iterator
457   // here as more calls may be added during inlining.
458   bool inlinedAnyCalls = false;
459   for (unsigned i = 0; i != calls.size(); ++i) {
460     ResolvedCall it = calls[i];
461     bool doInline = shouldInline(it);
462     CallOpInterface call = it.call;
463     LLVM_DEBUG({
464       if (doInline)
465         llvm::dbgs() << "* Inlining call: " << call << "\n";
466       else
467         llvm::dbgs() << "* Not inlining call: " << call << "\n";
468     });
469     if (!doInline)
470       continue;
471     Region *targetRegion = it.targetNode->getCallableRegion();
472 
473     // If this is the last call to the target node and the node is discardable,
474     // then inline it in-place and delete the node if successful.
475     bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode);
476 
477     LogicalResult inlineResult = inlineCall(
478         inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()),
479         targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace);
480     if (failed(inlineResult)) {
481       LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n");
482       continue;
483     }
484     inlinedAnyCalls = true;
485 
486     // If the inlining was successful, Merge the new uses into the source node.
487     useList.dropCallUses(it.sourceNode, call.getOperation(), cg);
488     useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode);
489 
490     // then erase the call.
491     call.erase();
492 
493     // If we inlined in place, mark the node for deletion.
494     if (inlineInPlace) {
495       useList.eraseNode(it.targetNode);
496       deadNodes.push_back(it.targetNode);
497     }
498   }
499 
500   for (CallGraphNode *node : deadNodes) {
501     currentSCC.remove(node);
502     inliner.markForDeletion(node);
503   }
504   calls.clear();
505   return success(inlinedAnyCalls);
506 }
507 
508 //===----------------------------------------------------------------------===//
509 // InlinerPass
510 //===----------------------------------------------------------------------===//
511 
512 namespace {
513 class InlinerPass : public InlinerBase<InlinerPass> {
514 public:
515   InlinerPass();
516   InlinerPass(const InlinerPass &) = default;
517   InlinerPass(std::function<void(OpPassManager &)> defaultPipeline);
518   InlinerPass(std::function<void(OpPassManager &)> defaultPipeline,
519               llvm::StringMap<OpPassManager> opPipelines);
520   void runOnOperation() override;
521 
522 private:
523   /// Attempt to inline calls within the given scc, and run simplifications,
524   /// until a fixed point is reached. This allows for the inlining of newly
525   /// devirtualized calls. Returns failure if there was a fatal error during
526   /// inlining.
527   LogicalResult inlineSCC(Inliner &inliner, CGUseList &useList,
528                           CallGraphSCC &currentSCC, MLIRContext *context);
529 
530   /// Optimize the nodes within the given SCC with one of the held optimization
531   /// pass pipelines. Returns failure if an error occurred during the
532   /// optimization of the SCC, success otherwise.
533   LogicalResult optimizeSCC(CallGraph &cg, CGUseList &useList,
534                             CallGraphSCC &currentSCC, MLIRContext *context);
535 
536   /// Optimize the nodes within the given SCC in parallel. Returns failure if an
537   /// error occurred during the optimization of the SCC, success otherwise.
538   LogicalResult optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
539                                  MLIRContext *context);
540 
541   /// Optimize the given callable node with one of the pass managers provided
542   /// with `pipelines`, or the default pipeline. Returns failure if an error
543   /// occurred during the optimization of the callable, success otherwise.
544   LogicalResult optimizeCallable(CallGraphNode *node,
545                                  llvm::StringMap<OpPassManager> &pipelines);
546 
547   /// Attempt to initialize the options of this pass from the given string.
548   /// Derived classes may override this method to hook into the point at which
549   /// options are initialized, but should generally always invoke this base
550   /// class variant.
551   LogicalResult initializeOptions(StringRef options) override;
552 
553   /// An optional function that constructs a default optimization pipeline for
554   /// a given operation.
555   std::function<void(OpPassManager &)> defaultPipeline;
556   /// A map of operation names to pass pipelines to use when optimizing
557   /// callable operations of these types. This provides a specialized pipeline
558   /// instead of the default. The vector size is the number of threads used
559   /// during optimization.
560   SmallVector<llvm::StringMap<OpPassManager>, 8> opPipelines;
561 };
562 } // end anonymous namespace
563 
564 InlinerPass::InlinerPass() : InlinerPass(defaultInlinerOptPipeline) {}
565 InlinerPass::InlinerPass(std::function<void(OpPassManager &)> defaultPipeline)
566     : defaultPipeline(defaultPipeline) {
567   opPipelines.push_back({});
568 
569   // Initialize the pass options with the provided arguments.
570   if (defaultPipeline) {
571     OpPassManager fakePM("__mlir_fake_pm_op");
572     defaultPipeline(fakePM);
573     llvm::raw_string_ostream strStream(defaultPipelineStr);
574     fakePM.printAsTextualPipeline(strStream);
575   }
576 }
577 
578 InlinerPass::InlinerPass(std::function<void(OpPassManager &)> defaultPipeline,
579                          llvm::StringMap<OpPassManager> opPipelines)
580     : InlinerPass(std::move(defaultPipeline)) {
581   if (opPipelines.empty())
582     return;
583 
584   // Update the option for the op specific optimization pipelines.
585   for (auto &it : opPipelines) {
586     std::string pipeline;
587     llvm::raw_string_ostream pipelineOS(pipeline);
588     pipelineOS << it.getKey() << "(";
589     it.second.printAsTextualPipeline(pipelineOS);
590     pipelineOS << ")";
591     opPipelineStrs.addValue(pipeline);
592   }
593   this->opPipelines.emplace_back(std::move(opPipelines));
594 }
595 
596 void InlinerPass::runOnOperation() {
597   CallGraph &cg = getAnalysis<CallGraph>();
598   auto *context = &getContext();
599 
600   // The inliner should only be run on operations that define a symbol table,
601   // as the callgraph will need to resolve references.
602   Operation *op = getOperation();
603   if (!op->hasTrait<OpTrait::SymbolTable>()) {
604     op->emitOpError() << " was scheduled to run under the inliner, but does "
605                          "not define a symbol table";
606     return signalPassFailure();
607   }
608 
609   // Run the inline transform in post-order over the SCCs in the callgraph.
610   SymbolTableCollection symbolTable;
611   Inliner inliner(context, cg, symbolTable);
612   CGUseList useList(getOperation(), cg, symbolTable);
613   LogicalResult result = runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) {
614     return inlineSCC(inliner, useList, scc, context);
615   });
616   if (failed(result))
617     return signalPassFailure();
618 
619   // After inlining, make sure to erase any callables proven to be dead.
620   inliner.eraseDeadCallables();
621 }
622 
623 LogicalResult InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList,
624                                      CallGraphSCC &currentSCC,
625                                      MLIRContext *context) {
626   // Continuously simplify and inline until we either reach a fixed point, or
627   // hit the maximum iteration count. Simplifying early helps to refine the cost
628   // model, and in future iterations may devirtualize new calls.
629   unsigned iterationCount = 0;
630   do {
631     if (failed(optimizeSCC(inliner.cg, useList, currentSCC, context)))
632       return failure();
633     if (failed(inlineCallsInSCC(inliner, useList, currentSCC)))
634       break;
635   } while (++iterationCount < maxInliningIterations);
636   return success();
637 }
638 
639 LogicalResult InlinerPass::optimizeSCC(CallGraph &cg, CGUseList &useList,
640                                        CallGraphSCC &currentSCC,
641                                        MLIRContext *context) {
642   // Collect the sets of nodes to simplify.
643   SmallVector<CallGraphNode *, 4> nodesToVisit;
644   for (auto *node : currentSCC) {
645     if (node->isExternal())
646       continue;
647 
648     // Don't simplify nodes with children. Nodes with children require special
649     // handling as we may remove the node during simplification. In the future,
650     // we should be able to handle this case with proper node deletion tracking.
651     if (node->hasChildren())
652       continue;
653 
654     // We also won't apply simplifications to nodes that can't have passes
655     // scheduled on them.
656     auto *region = node->getCallableRegion();
657     if (!region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
658       continue;
659     nodesToVisit.push_back(node);
660   }
661   if (nodesToVisit.empty())
662     return success();
663 
664   // Optimize each of the nodes within the SCC in parallel.
665   if (failed(optimizeSCCAsync(nodesToVisit, context)))
666       return failure();
667 
668   // Recompute the uses held by each of the nodes.
669   for (CallGraphNode *node : nodesToVisit)
670     useList.recomputeUses(node, cg);
671   return success();
672 }
673 
674 LogicalResult
675 InlinerPass::optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
676                               MLIRContext *ctx) {
677   // Ensure that there are enough pipeline maps for the optimizer to run in
678   // parallel. Note: The number of pass managers here needs to remain constant
679   // to prevent issues with pass instrumentations that rely on having the same
680   // pass manager for the main thread.
681   size_t numThreads = llvm::hardware_concurrency().compute_thread_count();
682   if (opPipelines.size() < numThreads) {
683     // Reserve before resizing so that we can use a reference to the first
684     // element.
685     opPipelines.reserve(numThreads);
686     opPipelines.resize(numThreads, opPipelines.front());
687   }
688 
689   // Ensure an analysis manager has been constructed for each of the nodes.
690   // This prevents thread races when running the nested pipelines.
691   for (CallGraphNode *node : nodesToVisit)
692     getAnalysisManager().nest(node->getCallableRegion()->getParentOp());
693 
694   // An atomic failure variable for the async executors.
695   std::vector<std::atomic<bool>> activePMs(opPipelines.size());
696   std::fill(activePMs.begin(), activePMs.end(), false);
697   return failableParallelForEach(ctx, nodesToVisit, [&](CallGraphNode *node) {
698     // Find a pass manager for this operation.
699     auto it = llvm::find_if(activePMs, [](std::atomic<bool> &isActive) {
700       bool expectedInactive = false;
701       return isActive.compare_exchange_strong(expectedInactive, true);
702     });
703     unsigned pmIndex = it - activePMs.begin();
704 
705     // Optimize this callable node.
706     LogicalResult result = optimizeCallable(node, opPipelines[pmIndex]);
707 
708     // Reset the active bit for this pass manager.
709     activePMs[pmIndex].store(false);
710     return result;
711   });
712 }
713 
714 LogicalResult
715 InlinerPass::optimizeCallable(CallGraphNode *node,
716                               llvm::StringMap<OpPassManager> &pipelines) {
717   Operation *callable = node->getCallableRegion()->getParentOp();
718   StringRef opName = callable->getName().getStringRef();
719   auto pipelineIt = pipelines.find(opName);
720   if (pipelineIt == pipelines.end()) {
721     // If a pipeline didn't exist, use the default if possible.
722     if (!defaultPipeline)
723       return success();
724 
725     OpPassManager defaultPM(opName);
726     defaultPipeline(defaultPM);
727     pipelineIt = pipelines.try_emplace(opName, std::move(defaultPM)).first;
728   }
729   return runPipeline(pipelineIt->second, callable);
730 }
731 
732 LogicalResult InlinerPass::initializeOptions(StringRef options) {
733   if (failed(Pass::initializeOptions(options)))
734     return failure();
735 
736   // Initialize the default pipeline builder to use the option string.
737   if (!defaultPipelineStr.empty()) {
738     std::string defaultPipelineCopy = defaultPipelineStr;
739     defaultPipeline = [=](OpPassManager &pm) {
740       (void)parsePassPipeline(defaultPipelineCopy, pm);
741     };
742   } else if (defaultPipelineStr.getNumOccurrences()) {
743     defaultPipeline = nullptr;
744   }
745 
746   // Initialize the op specific pass pipelines.
747   llvm::StringMap<OpPassManager> pipelines;
748   for (StringRef pipeline : opPipelineStrs) {
749     // Skip empty pipelines.
750     if (pipeline.empty())
751       continue;
752 
753     // Pipelines are expected to be of the form `<op-name>(<pipeline>)`.
754     size_t pipelineStart = pipeline.find_first_of('(');
755     if (pipelineStart == StringRef::npos || !pipeline.consume_back(")"))
756       return failure();
757     StringRef opName = pipeline.take_front(pipelineStart);
758     OpPassManager pm(opName);
759     if (failed(parsePassPipeline(pipeline.drop_front(1 + pipelineStart), pm)))
760       return failure();
761     pipelines.try_emplace(opName, std::move(pm));
762   }
763   opPipelines.assign({std::move(pipelines)});
764 
765   return success();
766 }
767 
768 std::unique_ptr<Pass> mlir::createInlinerPass() {
769   return std::make_unique<InlinerPass>();
770 }
771 std::unique_ptr<Pass>
772 mlir::createInlinerPass(llvm::StringMap<OpPassManager> opPipelines) {
773   return std::make_unique<InlinerPass>(defaultInlinerOptPipeline,
774                                        std::move(opPipelines));
775 }
776 std::unique_ptr<Pass> mlir::createInlinerPass(
777     llvm::StringMap<OpPassManager> opPipelines,
778     std::function<void(OpPassManager &)> defaultPipelineBuilder) {
779   return std::make_unique<InlinerPass>(std::move(defaultPipelineBuilder),
780                                        std::move(opPipelines));
781 }
782