1 //===- Inliner.cpp - Pass to inline function calls ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a basic inlining algorithm that operates bottom up over
10 // the Strongly Connect Components(SCCs) of the CallGraph. This enables a more
11 // incremental propagation of inlining decisions from the leafs to the roots of
12 // the callgraph.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "PassDetail.h"
17 #include "mlir/Analysis/CallGraph.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/Interfaces/SideEffectInterfaces.h"
20 #include "mlir/Transforms/InliningUtils.h"
21 #include "mlir/Transforms/Passes.h"
22 #include "llvm/ADT/SCCIterator.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/Parallel.h"
25 
26 #define DEBUG_TYPE "inlining"
27 
28 using namespace mlir;
29 
30 //===----------------------------------------------------------------------===//
31 // Symbol Use Tracking
32 //===----------------------------------------------------------------------===//
33 
34 /// Walk all of the used symbol callgraph nodes referenced with the given op.
35 static void walkReferencedSymbolNodes(
36     Operation *op, CallGraph &cg,
37     DenseMap<Attribute, CallGraphNode *> &resolvedRefs,
38     function_ref<void(CallGraphNode *, Operation *)> callback) {
39   auto symbolUses = SymbolTable::getSymbolUses(op);
40   assert(symbolUses && "expected uses to be valid");
41 
42   Operation *symbolTableOp = op->getParentOp();
43   for (const SymbolTable::SymbolUse &use : *symbolUses) {
44     auto refIt = resolvedRefs.insert({use.getSymbolRef(), nullptr});
45     CallGraphNode *&node = refIt.first->second;
46 
47     // If this is the first instance of this reference, try to resolve a
48     // callgraph node for it.
49     if (refIt.second) {
50       auto *symbolOp = SymbolTable::lookupNearestSymbolFrom(symbolTableOp,
51                                                             use.getSymbolRef());
52       auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp);
53       if (!callableOp)
54         continue;
55       node = cg.lookupNode(callableOp.getCallableRegion());
56     }
57     if (node)
58       callback(node, use.getUser());
59   }
60 }
61 
62 //===----------------------------------------------------------------------===//
63 // CGUseList
64 
65 namespace {
66 /// This struct tracks the uses of callgraph nodes that can be dropped when
67 /// use_empty. It directly tracks and manages a use-list for all of the
68 /// call-graph nodes. This is necessary because many callgraph nodes are
69 /// referenced by SymbolRefAttr, which has no mechanism akin to the SSA `Use`
70 /// class.
71 struct CGUseList {
72   /// This struct tracks the uses of callgraph nodes within a specific
73   /// operation.
74   struct CGUser {
75     /// Any nodes referenced in the top-level attribute list of this user. We
76     /// use a set here because the number of references does not matter.
77     DenseSet<CallGraphNode *> topLevelUses;
78 
79     /// Uses of nodes referenced by nested operations.
80     DenseMap<CallGraphNode *, int> innerUses;
81   };
82 
83   CGUseList(Operation *op, CallGraph &cg);
84 
85   /// Drop uses of nodes referred to by the given call operation that resides
86   /// within 'userNode'.
87   void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg);
88 
89   /// Remove the given node from the use list.
90   void eraseNode(CallGraphNode *node);
91 
92   /// Returns true if the given callgraph node has no uses and can be pruned.
93   bool isDead(CallGraphNode *node) const;
94 
95   /// Returns true if the given callgraph node has a single use and can be
96   /// discarded.
97   bool hasOneUseAndDiscardable(CallGraphNode *node) const;
98 
99   /// Recompute the uses held by the given callgraph node.
100   void recomputeUses(CallGraphNode *node, CallGraph &cg);
101 
102   /// Merge the uses of 'lhs' with the uses of the 'rhs' after inlining a copy
103   /// of 'lhs' into 'rhs'.
104   void mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs);
105 
106 private:
107   /// Decrement the uses of discardable nodes referenced by the given user.
108   void decrementDiscardableUses(CGUser &uses);
109 
110   /// A mapping between a discardable callgraph node (that is a symbol) and the
111   /// number of uses for this node.
112   DenseMap<CallGraphNode *, int> discardableSymNodeUses;
113   /// A mapping between a callgraph node and the symbol callgraph nodes that it
114   /// uses.
115   DenseMap<CallGraphNode *, CGUser> nodeUses;
116 };
117 } // end anonymous namespace
118 
119 CGUseList::CGUseList(Operation *op, CallGraph &cg) {
120   /// A set of callgraph nodes that are always known to be live during inlining.
121   DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes;
122 
123   // Walk each of the symbol tables looking for discardable callgraph nodes.
124   auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
125     for (Operation &op : symbolTableOp->getRegion(0).getOps()) {
126       // If this is a callgraph operation, check to see if it is discardable.
127       if (auto callable = dyn_cast<CallableOpInterface>(&op)) {
128         if (auto *node = cg.lookupNode(callable.getCallableRegion())) {
129           SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
130           if (symbol && (allUsesVisible || symbol.isPrivate()) &&
131               symbol.canDiscardOnUseEmpty()) {
132             discardableSymNodeUses.try_emplace(node, 0);
133           }
134           continue;
135         }
136       }
137       // Otherwise, check for any referenced nodes. These will be always-live.
138       walkReferencedSymbolNodes(&op, cg, alwaysLiveNodes,
139                                 [](CallGraphNode *, Operation *) {});
140     }
141   };
142   SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
143                                 walkFn);
144 
145   // Drop the use information for any discardable nodes that are always live.
146   for (auto &it : alwaysLiveNodes)
147     discardableSymNodeUses.erase(it.second);
148 
149   // Compute the uses for each of the callable nodes in the graph.
150   for (CallGraphNode *node : cg)
151     recomputeUses(node, cg);
152 }
153 
154 void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp,
155                              CallGraph &cg) {
156   auto &userRefs = nodeUses[userNode].innerUses;
157   auto walkFn = [&](CallGraphNode *node, Operation *user) {
158     auto parentIt = userRefs.find(node);
159     if (parentIt == userRefs.end())
160       return;
161     --parentIt->second;
162     --discardableSymNodeUses[node];
163   };
164   DenseMap<Attribute, CallGraphNode *> resolvedRefs;
165   walkReferencedSymbolNodes(callOp, cg, resolvedRefs, walkFn);
166 }
167 
168 void CGUseList::eraseNode(CallGraphNode *node) {
169   // Drop all child nodes.
170   for (auto &edge : *node)
171     if (edge.isChild())
172       eraseNode(edge.getTarget());
173 
174   // Drop the uses held by this node and erase it.
175   auto useIt = nodeUses.find(node);
176   assert(useIt != nodeUses.end() && "expected node to be valid");
177   decrementDiscardableUses(useIt->getSecond());
178   nodeUses.erase(useIt);
179   discardableSymNodeUses.erase(node);
180 }
181 
182 bool CGUseList::isDead(CallGraphNode *node) const {
183   // If the parent operation isn't a symbol, simply check normal SSA deadness.
184   Operation *nodeOp = node->getCallableRegion()->getParentOp();
185   if (!isa<SymbolOpInterface>(nodeOp))
186     return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->use_empty();
187 
188   // Otherwise, check the number of symbol uses.
189   auto symbolIt = discardableSymNodeUses.find(node);
190   return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0;
191 }
192 
193 bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const {
194   // If this isn't a symbol node, check for side-effects and SSA use count.
195   Operation *nodeOp = node->getCallableRegion()->getParentOp();
196   if (!isa<SymbolOpInterface>(nodeOp))
197     return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->hasOneUse();
198 
199   // Otherwise, check the number of symbol uses.
200   auto symbolIt = discardableSymNodeUses.find(node);
201   return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1;
202 }
203 
204 void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) {
205   Operation *parentOp = node->getCallableRegion()->getParentOp();
206   CGUser &uses = nodeUses[node];
207   decrementDiscardableUses(uses);
208 
209   // Collect the new discardable uses within this node.
210   uses = CGUser();
211   DenseMap<Attribute, CallGraphNode *> resolvedRefs;
212   auto walkFn = [&](CallGraphNode *refNode, Operation *user) {
213     auto discardSymIt = discardableSymNodeUses.find(refNode);
214     if (discardSymIt == discardableSymNodeUses.end())
215       return;
216 
217     if (user != parentOp)
218       ++uses.innerUses[refNode];
219     else if (!uses.topLevelUses.insert(refNode).second)
220       return;
221     ++discardSymIt->second;
222   };
223   walkReferencedSymbolNodes(parentOp, cg, resolvedRefs, walkFn);
224 }
225 
226 void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) {
227   auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs];
228   for (auto &useIt : lhsUses.innerUses) {
229     rhsUses.innerUses[useIt.first] += useIt.second;
230     discardableSymNodeUses[useIt.first] += useIt.second;
231   }
232 }
233 
234 void CGUseList::decrementDiscardableUses(CGUser &uses) {
235   for (CallGraphNode *node : uses.topLevelUses)
236     --discardableSymNodeUses[node];
237   for (auto &it : uses.innerUses)
238     discardableSymNodeUses[it.first] -= it.second;
239 }
240 
241 //===----------------------------------------------------------------------===//
242 // CallGraph traversal
243 //===----------------------------------------------------------------------===//
244 
245 namespace {
246 /// This class represents a specific callgraph SCC.
247 class CallGraphSCC {
248 public:
249   CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator)
250       : parentIterator(parentIterator) {}
251   /// Return a range over the nodes within this SCC.
252   std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); }
253   std::vector<CallGraphNode *>::iterator end() { return nodes.end(); }
254 
255   /// Reset the nodes of this SCC with those provided.
256   void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; }
257 
258   /// Remove the given node from this SCC.
259   void remove(CallGraphNode *node) {
260     auto it = llvm::find(nodes, node);
261     if (it != nodes.end()) {
262       nodes.erase(it);
263       parentIterator.ReplaceNode(node, nullptr);
264     }
265   }
266 
267 private:
268   std::vector<CallGraphNode *> nodes;
269   llvm::scc_iterator<const CallGraph *> &parentIterator;
270 };
271 } // end anonymous namespace
272 
273 /// Run a given transformation over the SCCs of the callgraph in a bottom up
274 /// traversal.
275 static void
276 runTransformOnCGSCCs(const CallGraph &cg,
277                      function_ref<void(CallGraphSCC &)> sccTransformer) {
278   llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg);
279   CallGraphSCC currentSCC(cgi);
280   while (!cgi.isAtEnd()) {
281     // Copy the current SCC and increment so that the transformer can modify the
282     // SCC without invalidating our iterator.
283     currentSCC.reset(*cgi);
284     ++cgi;
285     sccTransformer(currentSCC);
286   }
287 }
288 
289 namespace {
290 /// This struct represents a resolved call to a given callgraph node. Given that
291 /// the call does not actually contain a direct reference to the
292 /// Region(CallGraphNode) that it is dispatching to, we need to resolve them
293 /// explicitly.
294 struct ResolvedCall {
295   ResolvedCall(CallOpInterface call, CallGraphNode *sourceNode,
296                CallGraphNode *targetNode)
297       : call(call), sourceNode(sourceNode), targetNode(targetNode) {}
298   CallOpInterface call;
299   CallGraphNode *sourceNode, *targetNode;
300 };
301 } // end anonymous namespace
302 
303 /// Collect all of the callable operations within the given range of blocks. If
304 /// `traverseNestedCGNodes` is true, this will also collect call operations
305 /// inside of nested callgraph nodes.
306 static void collectCallOps(iterator_range<Region::iterator> blocks,
307                            CallGraphNode *sourceNode, CallGraph &cg,
308                            SmallVectorImpl<ResolvedCall> &calls,
309                            bool traverseNestedCGNodes) {
310   SmallVector<std::pair<Block *, CallGraphNode *>, 8> worklist;
311   auto addToWorklist = [&](CallGraphNode *node,
312                            iterator_range<Region::iterator> blocks) {
313     for (Block &block : blocks)
314       worklist.emplace_back(&block, node);
315   };
316 
317   addToWorklist(sourceNode, blocks);
318   while (!worklist.empty()) {
319     Block *block;
320     std::tie(block, sourceNode) = worklist.pop_back_val();
321 
322     for (Operation &op : *block) {
323       if (auto call = dyn_cast<CallOpInterface>(op)) {
324         // TODO: Support inlining nested call references.
325         CallInterfaceCallable callable = call.getCallableForCallee();
326         if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) {
327           if (!symRef.isa<FlatSymbolRefAttr>())
328             continue;
329         }
330 
331         CallGraphNode *targetNode = cg.resolveCallable(call);
332         if (!targetNode->isExternal())
333           calls.emplace_back(call, sourceNode, targetNode);
334         continue;
335       }
336 
337       // If this is not a call, traverse the nested regions. If
338       // `traverseNestedCGNodes` is false, then don't traverse nested call graph
339       // regions.
340       for (auto &nestedRegion : op.getRegions()) {
341         CallGraphNode *nestedNode = cg.lookupNode(&nestedRegion);
342         if (traverseNestedCGNodes || !nestedNode)
343           addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion);
344       }
345     }
346   }
347 }
348 
349 //===----------------------------------------------------------------------===//
350 // Inliner
351 //===----------------------------------------------------------------------===//
352 namespace {
353 /// This class provides a specialization of the main inlining interface.
354 struct Inliner : public InlinerInterface {
355   Inliner(MLIRContext *context, CallGraph &cg)
356       : InlinerInterface(context), cg(cg) {}
357 
358   /// Process a set of blocks that have been inlined. This callback is invoked
359   /// *before* inlined terminator operations have been processed.
360   void
361   processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final {
362     // Find the closest callgraph node from the first block.
363     CallGraphNode *node;
364     Region *region = inlinedBlocks.begin()->getParent();
365     while (!(node = cg.lookupNode(region))) {
366       region = region->getParentRegion();
367       assert(region && "expected valid parent node");
368     }
369 
370     collectCallOps(inlinedBlocks, node, cg, calls,
371                    /*traverseNestedCGNodes=*/true);
372   }
373 
374   /// Mark the given callgraph node for deletion.
375   void markForDeletion(CallGraphNode *node) { deadNodes.insert(node); }
376 
377   /// This method properly disposes of callables that became dead during
378   /// inlining. This should not be called while iterating over the SCCs.
379   void eraseDeadCallables() {
380     for (CallGraphNode *node : deadNodes)
381       node->getCallableRegion()->getParentOp()->erase();
382   }
383 
384   /// The set of callables known to be dead.
385   SmallPtrSet<CallGraphNode *, 8> deadNodes;
386 
387   /// The current set of call instructions to consider for inlining.
388   SmallVector<ResolvedCall, 8> calls;
389 
390   /// The callgraph being operated on.
391   CallGraph &cg;
392 };
393 } // namespace
394 
395 /// Returns true if the given call should be inlined.
396 static bool shouldInline(ResolvedCall &resolvedCall) {
397   // Don't allow inlining terminator calls. We currently don't support this
398   // case.
399   if (resolvedCall.call.getOperation()->isKnownTerminator())
400     return false;
401 
402   // Don't allow inlining if the target is an ancestor of the call. This
403   // prevents inlining recursively.
404   if (resolvedCall.targetNode->getCallableRegion()->isAncestor(
405           resolvedCall.call.getParentRegion()))
406     return false;
407 
408   // Otherwise, inline.
409   return true;
410 }
411 
412 /// Attempt to inline calls within the given scc. This function returns
413 /// success if any calls were inlined, failure otherwise.
414 static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
415                                       CallGraphSCC &currentSCC) {
416   CallGraph &cg = inliner.cg;
417   auto &calls = inliner.calls;
418 
419   // A set of dead nodes to remove after inlining.
420   SmallVector<CallGraphNode *, 1> deadNodes;
421 
422   // Collect all of the direct calls within the nodes of the current SCC. We
423   // don't traverse nested callgraph nodes, because they are handled separately
424   // likely within a different SCC.
425   for (CallGraphNode *node : currentSCC) {
426     if (node->isExternal())
427       continue;
428 
429     // Don't collect calls if the node is already dead.
430     if (useList.isDead(node))
431       deadNodes.push_back(node);
432     else
433       collectCallOps(*node->getCallableRegion(), node, cg, calls,
434                      /*traverseNestedCGNodes=*/false);
435   }
436 
437   // Try to inline each of the call operations. Don't cache the end iterator
438   // here as more calls may be added during inlining.
439   bool inlinedAnyCalls = false;
440   for (unsigned i = 0; i != calls.size(); ++i) {
441     ResolvedCall it = calls[i];
442     bool doInline = shouldInline(it);
443     CallOpInterface call = it.call;
444     LLVM_DEBUG({
445       if (doInline)
446         llvm::dbgs() << "* Inlining call: " << call << "\n";
447       else
448         llvm::dbgs() << "* Not inlining call: " << call << "\n";
449     });
450     if (!doInline)
451       continue;
452     Region *targetRegion = it.targetNode->getCallableRegion();
453 
454     // If this is the last call to the target node and the node is discardable,
455     // then inline it in-place and delete the node if successful.
456     bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode);
457 
458     LogicalResult inlineResult = inlineCall(
459         inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()),
460         targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace);
461     if (failed(inlineResult)) {
462       LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n");
463       continue;
464     }
465     inlinedAnyCalls = true;
466 
467     // If the inlining was successful, Merge the new uses into the source node.
468     useList.dropCallUses(it.sourceNode, call.getOperation(), cg);
469     useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode);
470 
471     // then erase the call.
472     call.erase();
473 
474     // If we inlined in place, mark the node for deletion.
475     if (inlineInPlace) {
476       useList.eraseNode(it.targetNode);
477       deadNodes.push_back(it.targetNode);
478     }
479   }
480 
481   for (CallGraphNode *node : deadNodes) {
482     currentSCC.remove(node);
483     inliner.markForDeletion(node);
484   }
485   calls.clear();
486   return success(inlinedAnyCalls);
487 }
488 
489 /// Canonicalize the nodes within the given SCC with the given set of
490 /// canonicalization patterns.
491 static void canonicalizeSCC(CallGraph &cg, CGUseList &useList,
492                             CallGraphSCC &currentSCC, MLIRContext *context,
493                             const OwningRewritePatternList &canonPatterns) {
494   // Collect the sets of nodes to canonicalize.
495   SmallVector<CallGraphNode *, 4> nodesToCanonicalize;
496   for (auto *node : currentSCC) {
497     // Don't canonicalize the external node, it has no valid callable region.
498     if (node->isExternal())
499       continue;
500 
501     // Don't canonicalize nodes with children. Nodes with children
502     // require special handling as we may remove the node during
503     // canonicalization. In the future, we should be able to handle this
504     // case with proper node deletion tracking.
505     if (node->hasChildren())
506       continue;
507 
508     // We also won't apply canonicalizations for nodes that are not
509     // isolated. This avoids potentially mutating the regions of nodes defined
510     // above, this is also a stipulation of the 'applyPatternsAndFoldGreedily'
511     // driver.
512     auto *region = node->getCallableRegion();
513     if (!region->getParentOp()->isKnownIsolatedFromAbove())
514       continue;
515     nodesToCanonicalize.push_back(node);
516   }
517   if (nodesToCanonicalize.empty())
518     return;
519 
520   // Canonicalize each of the nodes within the SCC in parallel.
521   // NOTE: This is simple now, because we don't enable canonicalizing nodes
522   // within children. When we remove this restriction, this logic will need to
523   // be reworked.
524   if (context->isMultithreadingEnabled()) {
525     ParallelDiagnosticHandler canonicalizationHandler(context);
526     llvm::parallelForEachN(
527         /*Begin=*/0, /*End=*/nodesToCanonicalize.size(), [&](size_t index) {
528           // Set the order for this thread so that diagnostics will be properly
529           // ordered.
530           canonicalizationHandler.setOrderIDForThread(index);
531 
532           // Apply the canonicalization patterns to this region.
533           auto *node = nodesToCanonicalize[index];
534           applyPatternsAndFoldGreedily(*node->getCallableRegion(),
535                                        canonPatterns);
536 
537           // Make sure to reset the order ID for the diagnostic handler, as this
538           // thread may be used in a different context.
539           canonicalizationHandler.eraseOrderIDForThread();
540         });
541   } else {
542     for (CallGraphNode *node : nodesToCanonicalize)
543       applyPatternsAndFoldGreedily(*node->getCallableRegion(), canonPatterns);
544   }
545 
546   // Recompute the uses held by each of the nodes.
547   for (CallGraphNode *node : nodesToCanonicalize)
548     useList.recomputeUses(node, cg);
549 }
550 
551 //===----------------------------------------------------------------------===//
552 // InlinerPass
553 //===----------------------------------------------------------------------===//
554 
555 namespace {
556 struct InlinerPass : public InlinerBase<InlinerPass> {
557   void runOnOperation() override;
558 
559   /// Attempt to inline calls within the given scc, and run canonicalizations
560   /// with the given patterns, until a fixed point is reached. This allows for
561   /// the inlining of newly devirtualized calls.
562   void inlineSCC(Inliner &inliner, CGUseList &useList, CallGraphSCC &currentSCC,
563                  MLIRContext *context,
564                  const OwningRewritePatternList &canonPatterns);
565 };
566 } // end anonymous namespace
567 
568 void InlinerPass::runOnOperation() {
569   CallGraph &cg = getAnalysis<CallGraph>();
570   auto *context = &getContext();
571 
572   // The inliner should only be run on operations that define a symbol table,
573   // as the callgraph will need to resolve references.
574   Operation *op = getOperation();
575   if (!op->hasTrait<OpTrait::SymbolTable>()) {
576     op->emitOpError() << " was scheduled to run under the inliner, but does "
577                          "not define a symbol table";
578     return signalPassFailure();
579   }
580 
581   // Collect a set of canonicalization patterns to use when simplifying
582   // callable regions within an SCC.
583   OwningRewritePatternList canonPatterns;
584   for (auto *op : context->getRegisteredOperations())
585     op->getCanonicalizationPatterns(canonPatterns, context);
586 
587   // Run the inline transform in post-order over the SCCs in the callgraph.
588   Inliner inliner(context, cg);
589   CGUseList useList(getOperation(), cg);
590   runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) {
591     inlineSCC(inliner, useList, scc, context, canonPatterns);
592   });
593 
594   // After inlining, make sure to erase any callables proven to be dead.
595   inliner.eraseDeadCallables();
596 }
597 
598 void InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList,
599                             CallGraphSCC &currentSCC, MLIRContext *context,
600                             const OwningRewritePatternList &canonPatterns) {
601   // If we successfully inlined any calls, run some simplifications on the
602   // nodes of the scc. Continue attempting to inline until we reach a fixed
603   // point, or a maximum iteration count. We canonicalize here as it may
604   // devirtualize new calls, as well as give us a better cost model.
605   unsigned iterationCount = 0;
606   while (succeeded(inlineCallsInSCC(inliner, useList, currentSCC))) {
607     // If we aren't allowing simplifications or the max iteration count was
608     // reached, then bail out early.
609     if (disableCanonicalization || ++iterationCount >= maxInliningIterations)
610       break;
611     canonicalizeSCC(inliner.cg, useList, currentSCC, context, canonPatterns);
612   }
613 }
614 
615 std::unique_ptr<Pass> mlir::createInlinerPass() {
616   return std::make_unique<InlinerPass>();
617 }
618