1 //===- Inliner.cpp - Pass to inline function calls ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a basic inlining algorithm that operates bottom up over
10 // the Strongly Connect Components(SCCs) of the CallGraph. This enables a more
11 // incremental propagation of inlining decisions from the leafs to the roots of
12 // the callgraph.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "PassDetail.h"
17 #include "mlir/Analysis/CallGraph.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/Interfaces/SideEffectInterfaces.h"
20 #include "mlir/Transforms/InliningUtils.h"
21 #include "mlir/Transforms/Passes.h"
22 #include "llvm/ADT/SCCIterator.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/Parallel.h"
25 
26 #define DEBUG_TYPE "inlining"
27 
28 using namespace mlir;
29 
30 //===----------------------------------------------------------------------===//
31 // Symbol Use Tracking
32 //===----------------------------------------------------------------------===//
33 
34 /// Walk all of the used symbol callgraph nodes referenced with the given op.
35 static void walkReferencedSymbolNodes(
36     Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable,
37     DenseMap<Attribute, CallGraphNode *> &resolvedRefs,
38     function_ref<void(CallGraphNode *, Operation *)> callback) {
39   auto symbolUses = SymbolTable::getSymbolUses(op);
40   assert(symbolUses && "expected uses to be valid");
41 
42   Operation *symbolTableOp = op->getParentOp();
43   for (const SymbolTable::SymbolUse &use : *symbolUses) {
44     auto refIt = resolvedRefs.insert({use.getSymbolRef(), nullptr});
45     CallGraphNode *&node = refIt.first->second;
46 
47     // If this is the first instance of this reference, try to resolve a
48     // callgraph node for it.
49     if (refIt.second) {
50       auto *symbolOp = symbolTable.lookupNearestSymbolFrom(symbolTableOp,
51                                                            use.getSymbolRef());
52       auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp);
53       if (!callableOp)
54         continue;
55       node = cg.lookupNode(callableOp.getCallableRegion());
56     }
57     if (node)
58       callback(node, use.getUser());
59   }
60 }
61 
62 //===----------------------------------------------------------------------===//
63 // CGUseList
64 
65 namespace {
66 /// This struct tracks the uses of callgraph nodes that can be dropped when
67 /// use_empty. It directly tracks and manages a use-list for all of the
68 /// call-graph nodes. This is necessary because many callgraph nodes are
69 /// referenced by SymbolRefAttr, which has no mechanism akin to the SSA `Use`
70 /// class.
71 struct CGUseList {
72   /// This struct tracks the uses of callgraph nodes within a specific
73   /// operation.
74   struct CGUser {
75     /// Any nodes referenced in the top-level attribute list of this user. We
76     /// use a set here because the number of references does not matter.
77     DenseSet<CallGraphNode *> topLevelUses;
78 
79     /// Uses of nodes referenced by nested operations.
80     DenseMap<CallGraphNode *, int> innerUses;
81   };
82 
83   CGUseList(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable);
84 
85   /// Drop uses of nodes referred to by the given call operation that resides
86   /// within 'userNode'.
87   void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg);
88 
89   /// Remove the given node from the use list.
90   void eraseNode(CallGraphNode *node);
91 
92   /// Returns true if the given callgraph node has no uses and can be pruned.
93   bool isDead(CallGraphNode *node) const;
94 
95   /// Returns true if the given callgraph node has a single use and can be
96   /// discarded.
97   bool hasOneUseAndDiscardable(CallGraphNode *node) const;
98 
99   /// Recompute the uses held by the given callgraph node.
100   void recomputeUses(CallGraphNode *node, CallGraph &cg);
101 
102   /// Merge the uses of 'lhs' with the uses of the 'rhs' after inlining a copy
103   /// of 'lhs' into 'rhs'.
104   void mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs);
105 
106 private:
107   /// Decrement the uses of discardable nodes referenced by the given user.
108   void decrementDiscardableUses(CGUser &uses);
109 
110   /// A mapping between a discardable callgraph node (that is a symbol) and the
111   /// number of uses for this node.
112   DenseMap<CallGraphNode *, int> discardableSymNodeUses;
113 
114   /// A mapping between a callgraph node and the symbol callgraph nodes that it
115   /// uses.
116   DenseMap<CallGraphNode *, CGUser> nodeUses;
117 
118   /// A symbol table to use when resolving call lookups.
119   SymbolTableCollection &symbolTable;
120 };
121 } // end anonymous namespace
122 
123 CGUseList::CGUseList(Operation *op, CallGraph &cg,
124                      SymbolTableCollection &symbolTable)
125     : symbolTable(symbolTable) {
126   /// A set of callgraph nodes that are always known to be live during inlining.
127   DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes;
128 
129   // Walk each of the symbol tables looking for discardable callgraph nodes.
130   auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
131     for (Operation &op : symbolTableOp->getRegion(0).getOps()) {
132       // If this is a callgraph operation, check to see if it is discardable.
133       if (auto callable = dyn_cast<CallableOpInterface>(&op)) {
134         if (auto *node = cg.lookupNode(callable.getCallableRegion())) {
135           SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
136           if (symbol && (allUsesVisible || symbol.isPrivate()) &&
137               symbol.canDiscardOnUseEmpty()) {
138             discardableSymNodeUses.try_emplace(node, 0);
139           }
140           continue;
141         }
142       }
143       // Otherwise, check for any referenced nodes. These will be always-live.
144       walkReferencedSymbolNodes(&op, cg, symbolTable, alwaysLiveNodes,
145                                 [](CallGraphNode *, Operation *) {});
146     }
147   };
148   SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
149                                 walkFn);
150 
151   // Drop the use information for any discardable nodes that are always live.
152   for (auto &it : alwaysLiveNodes)
153     discardableSymNodeUses.erase(it.second);
154 
155   // Compute the uses for each of the callable nodes in the graph.
156   for (CallGraphNode *node : cg)
157     recomputeUses(node, cg);
158 }
159 
160 void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp,
161                              CallGraph &cg) {
162   auto &userRefs = nodeUses[userNode].innerUses;
163   auto walkFn = [&](CallGraphNode *node, Operation *user) {
164     auto parentIt = userRefs.find(node);
165     if (parentIt == userRefs.end())
166       return;
167     --parentIt->second;
168     --discardableSymNodeUses[node];
169   };
170   DenseMap<Attribute, CallGraphNode *> resolvedRefs;
171   walkReferencedSymbolNodes(callOp, cg, symbolTable, resolvedRefs, walkFn);
172 }
173 
174 void CGUseList::eraseNode(CallGraphNode *node) {
175   // Drop all child nodes.
176   for (auto &edge : *node)
177     if (edge.isChild())
178       eraseNode(edge.getTarget());
179 
180   // Drop the uses held by this node and erase it.
181   auto useIt = nodeUses.find(node);
182   assert(useIt != nodeUses.end() && "expected node to be valid");
183   decrementDiscardableUses(useIt->getSecond());
184   nodeUses.erase(useIt);
185   discardableSymNodeUses.erase(node);
186 }
187 
188 bool CGUseList::isDead(CallGraphNode *node) const {
189   // If the parent operation isn't a symbol, simply check normal SSA deadness.
190   Operation *nodeOp = node->getCallableRegion()->getParentOp();
191   if (!isa<SymbolOpInterface>(nodeOp))
192     return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->use_empty();
193 
194   // Otherwise, check the number of symbol uses.
195   auto symbolIt = discardableSymNodeUses.find(node);
196   return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0;
197 }
198 
199 bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const {
200   // If this isn't a symbol node, check for side-effects and SSA use count.
201   Operation *nodeOp = node->getCallableRegion()->getParentOp();
202   if (!isa<SymbolOpInterface>(nodeOp))
203     return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->hasOneUse();
204 
205   // Otherwise, check the number of symbol uses.
206   auto symbolIt = discardableSymNodeUses.find(node);
207   return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1;
208 }
209 
210 void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) {
211   Operation *parentOp = node->getCallableRegion()->getParentOp();
212   CGUser &uses = nodeUses[node];
213   decrementDiscardableUses(uses);
214 
215   // Collect the new discardable uses within this node.
216   uses = CGUser();
217   DenseMap<Attribute, CallGraphNode *> resolvedRefs;
218   auto walkFn = [&](CallGraphNode *refNode, Operation *user) {
219     auto discardSymIt = discardableSymNodeUses.find(refNode);
220     if (discardSymIt == discardableSymNodeUses.end())
221       return;
222 
223     if (user != parentOp)
224       ++uses.innerUses[refNode];
225     else if (!uses.topLevelUses.insert(refNode).second)
226       return;
227     ++discardSymIt->second;
228   };
229   walkReferencedSymbolNodes(parentOp, cg, symbolTable, resolvedRefs, walkFn);
230 }
231 
232 void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) {
233   auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs];
234   for (auto &useIt : lhsUses.innerUses) {
235     rhsUses.innerUses[useIt.first] += useIt.second;
236     discardableSymNodeUses[useIt.first] += useIt.second;
237   }
238 }
239 
240 void CGUseList::decrementDiscardableUses(CGUser &uses) {
241   for (CallGraphNode *node : uses.topLevelUses)
242     --discardableSymNodeUses[node];
243   for (auto &it : uses.innerUses)
244     discardableSymNodeUses[it.first] -= it.second;
245 }
246 
247 //===----------------------------------------------------------------------===//
248 // CallGraph traversal
249 //===----------------------------------------------------------------------===//
250 
251 namespace {
252 /// This class represents a specific callgraph SCC.
253 class CallGraphSCC {
254 public:
255   CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator)
256       : parentIterator(parentIterator) {}
257   /// Return a range over the nodes within this SCC.
258   std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); }
259   std::vector<CallGraphNode *>::iterator end() { return nodes.end(); }
260 
261   /// Reset the nodes of this SCC with those provided.
262   void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; }
263 
264   /// Remove the given node from this SCC.
265   void remove(CallGraphNode *node) {
266     auto it = llvm::find(nodes, node);
267     if (it != nodes.end()) {
268       nodes.erase(it);
269       parentIterator.ReplaceNode(node, nullptr);
270     }
271   }
272 
273 private:
274   std::vector<CallGraphNode *> nodes;
275   llvm::scc_iterator<const CallGraph *> &parentIterator;
276 };
277 } // end anonymous namespace
278 
279 /// Run a given transformation over the SCCs of the callgraph in a bottom up
280 /// traversal.
281 static void
282 runTransformOnCGSCCs(const CallGraph &cg,
283                      function_ref<void(CallGraphSCC &)> sccTransformer) {
284   llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg);
285   CallGraphSCC currentSCC(cgi);
286   while (!cgi.isAtEnd()) {
287     // Copy the current SCC and increment so that the transformer can modify the
288     // SCC without invalidating our iterator.
289     currentSCC.reset(*cgi);
290     ++cgi;
291     sccTransformer(currentSCC);
292   }
293 }
294 
295 namespace {
296 /// This struct represents a resolved call to a given callgraph node. Given that
297 /// the call does not actually contain a direct reference to the
298 /// Region(CallGraphNode) that it is dispatching to, we need to resolve them
299 /// explicitly.
300 struct ResolvedCall {
301   ResolvedCall(CallOpInterface call, CallGraphNode *sourceNode,
302                CallGraphNode *targetNode)
303       : call(call), sourceNode(sourceNode), targetNode(targetNode) {}
304   CallOpInterface call;
305   CallGraphNode *sourceNode, *targetNode;
306 };
307 } // end anonymous namespace
308 
309 /// Collect all of the callable operations within the given range of blocks. If
310 /// `traverseNestedCGNodes` is true, this will also collect call operations
311 /// inside of nested callgraph nodes.
312 static void collectCallOps(iterator_range<Region::iterator> blocks,
313                            CallGraphNode *sourceNode, CallGraph &cg,
314                            SymbolTableCollection &symbolTable,
315                            SmallVectorImpl<ResolvedCall> &calls,
316                            bool traverseNestedCGNodes) {
317   SmallVector<std::pair<Block *, CallGraphNode *>, 8> worklist;
318   auto addToWorklist = [&](CallGraphNode *node,
319                            iterator_range<Region::iterator> blocks) {
320     for (Block &block : blocks)
321       worklist.emplace_back(&block, node);
322   };
323 
324   addToWorklist(sourceNode, blocks);
325   while (!worklist.empty()) {
326     Block *block;
327     std::tie(block, sourceNode) = worklist.pop_back_val();
328 
329     for (Operation &op : *block) {
330       if (auto call = dyn_cast<CallOpInterface>(op)) {
331         // TODO: Support inlining nested call references.
332         CallInterfaceCallable callable = call.getCallableForCallee();
333         if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) {
334           if (!symRef.isa<FlatSymbolRefAttr>())
335             continue;
336         }
337 
338         CallGraphNode *targetNode = cg.resolveCallable(call, symbolTable);
339         if (!targetNode->isExternal())
340           calls.emplace_back(call, sourceNode, targetNode);
341         continue;
342       }
343 
344       // If this is not a call, traverse the nested regions. If
345       // `traverseNestedCGNodes` is false, then don't traverse nested call graph
346       // regions.
347       for (auto &nestedRegion : op.getRegions()) {
348         CallGraphNode *nestedNode = cg.lookupNode(&nestedRegion);
349         if (traverseNestedCGNodes || !nestedNode)
350           addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion);
351       }
352     }
353   }
354 }
355 
356 //===----------------------------------------------------------------------===//
357 // Inliner
358 //===----------------------------------------------------------------------===//
359 namespace {
360 /// This class provides a specialization of the main inlining interface.
361 struct Inliner : public InlinerInterface {
362   Inliner(MLIRContext *context, CallGraph &cg,
363           SymbolTableCollection &symbolTable)
364       : InlinerInterface(context), cg(cg), symbolTable(symbolTable) {}
365 
366   /// Process a set of blocks that have been inlined. This callback is invoked
367   /// *before* inlined terminator operations have been processed.
368   void
369   processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final {
370     // Find the closest callgraph node from the first block.
371     CallGraphNode *node;
372     Region *region = inlinedBlocks.begin()->getParent();
373     while (!(node = cg.lookupNode(region))) {
374       region = region->getParentRegion();
375       assert(region && "expected valid parent node");
376     }
377 
378     collectCallOps(inlinedBlocks, node, cg, symbolTable, calls,
379                    /*traverseNestedCGNodes=*/true);
380   }
381 
382   /// Mark the given callgraph node for deletion.
383   void markForDeletion(CallGraphNode *node) { deadNodes.insert(node); }
384 
385   /// This method properly disposes of callables that became dead during
386   /// inlining. This should not be called while iterating over the SCCs.
387   void eraseDeadCallables() {
388     for (CallGraphNode *node : deadNodes)
389       node->getCallableRegion()->getParentOp()->erase();
390   }
391 
392   /// The set of callables known to be dead.
393   SmallPtrSet<CallGraphNode *, 8> deadNodes;
394 
395   /// The current set of call instructions to consider for inlining.
396   SmallVector<ResolvedCall, 8> calls;
397 
398   /// The callgraph being operated on.
399   CallGraph &cg;
400 
401   /// A symbol table to use when resolving call lookups.
402   SymbolTableCollection &symbolTable;
403 };
404 } // namespace
405 
406 /// Returns true if the given call should be inlined.
407 static bool shouldInline(ResolvedCall &resolvedCall) {
408   // Don't allow inlining terminator calls. We currently don't support this
409   // case.
410   if (resolvedCall.call.getOperation()->isKnownTerminator())
411     return false;
412 
413   // Don't allow inlining if the target is an ancestor of the call. This
414   // prevents inlining recursively.
415   if (resolvedCall.targetNode->getCallableRegion()->isAncestor(
416           resolvedCall.call.getParentRegion()))
417     return false;
418 
419   // Otherwise, inline.
420   return true;
421 }
422 
423 /// Attempt to inline calls within the given scc. This function returns
424 /// success if any calls were inlined, failure otherwise.
425 static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
426                                       CallGraphSCC &currentSCC) {
427   CallGraph &cg = inliner.cg;
428   auto &calls = inliner.calls;
429 
430   // A set of dead nodes to remove after inlining.
431   SmallVector<CallGraphNode *, 1> deadNodes;
432 
433   // Collect all of the direct calls within the nodes of the current SCC. We
434   // don't traverse nested callgraph nodes, because they are handled separately
435   // likely within a different SCC.
436   for (CallGraphNode *node : currentSCC) {
437     if (node->isExternal())
438       continue;
439 
440     // Don't collect calls if the node is already dead.
441     if (useList.isDead(node)) {
442       deadNodes.push_back(node);
443     } else {
444       collectCallOps(*node->getCallableRegion(), node, cg, inliner.symbolTable,
445                      calls, /*traverseNestedCGNodes=*/false);
446     }
447   }
448 
449   // Try to inline each of the call operations. Don't cache the end iterator
450   // here as more calls may be added during inlining.
451   bool inlinedAnyCalls = false;
452   for (unsigned i = 0; i != calls.size(); ++i) {
453     ResolvedCall it = calls[i];
454     bool doInline = shouldInline(it);
455     CallOpInterface call = it.call;
456     LLVM_DEBUG({
457       if (doInline)
458         llvm::dbgs() << "* Inlining call: " << call << "\n";
459       else
460         llvm::dbgs() << "* Not inlining call: " << call << "\n";
461     });
462     if (!doInline)
463       continue;
464     Region *targetRegion = it.targetNode->getCallableRegion();
465 
466     // If this is the last call to the target node and the node is discardable,
467     // then inline it in-place and delete the node if successful.
468     bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode);
469 
470     LogicalResult inlineResult = inlineCall(
471         inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()),
472         targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace);
473     if (failed(inlineResult)) {
474       LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n");
475       continue;
476     }
477     inlinedAnyCalls = true;
478 
479     // If the inlining was successful, Merge the new uses into the source node.
480     useList.dropCallUses(it.sourceNode, call.getOperation(), cg);
481     useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode);
482 
483     // then erase the call.
484     call.erase();
485 
486     // If we inlined in place, mark the node for deletion.
487     if (inlineInPlace) {
488       useList.eraseNode(it.targetNode);
489       deadNodes.push_back(it.targetNode);
490     }
491   }
492 
493   for (CallGraphNode *node : deadNodes) {
494     currentSCC.remove(node);
495     inliner.markForDeletion(node);
496   }
497   calls.clear();
498   return success(inlinedAnyCalls);
499 }
500 
501 /// Canonicalize the nodes within the given SCC with the given set of
502 /// canonicalization patterns.
503 static void canonicalizeSCC(CallGraph &cg, CGUseList &useList,
504                             CallGraphSCC &currentSCC, MLIRContext *context,
505                             const OwningRewritePatternList &canonPatterns) {
506   // Collect the sets of nodes to canonicalize.
507   SmallVector<CallGraphNode *, 4> nodesToCanonicalize;
508   for (auto *node : currentSCC) {
509     // Don't canonicalize the external node, it has no valid callable region.
510     if (node->isExternal())
511       continue;
512 
513     // Don't canonicalize nodes with children. Nodes with children
514     // require special handling as we may remove the node during
515     // canonicalization. In the future, we should be able to handle this
516     // case with proper node deletion tracking.
517     if (node->hasChildren())
518       continue;
519 
520     // We also won't apply canonicalizations for nodes that are not
521     // isolated. This avoids potentially mutating the regions of nodes defined
522     // above, this is also a stipulation of the 'applyPatternsAndFoldGreedily'
523     // driver.
524     auto *region = node->getCallableRegion();
525     if (!region->getParentOp()->isKnownIsolatedFromAbove())
526       continue;
527     nodesToCanonicalize.push_back(node);
528   }
529   if (nodesToCanonicalize.empty())
530     return;
531 
532   // Canonicalize each of the nodes within the SCC in parallel.
533   // NOTE: This is simple now, because we don't enable canonicalizing nodes
534   // within children. When we remove this restriction, this logic will need to
535   // be reworked.
536   if (context->isMultithreadingEnabled()) {
537     ParallelDiagnosticHandler canonicalizationHandler(context);
538     llvm::parallelForEachN(
539         /*Begin=*/0, /*End=*/nodesToCanonicalize.size(), [&](size_t index) {
540           // Set the order for this thread so that diagnostics will be properly
541           // ordered.
542           canonicalizationHandler.setOrderIDForThread(index);
543 
544           // Apply the canonicalization patterns to this region.
545           auto *node = nodesToCanonicalize[index];
546           applyPatternsAndFoldGreedily(*node->getCallableRegion(),
547                                        canonPatterns);
548 
549           // Make sure to reset the order ID for the diagnostic handler, as this
550           // thread may be used in a different context.
551           canonicalizationHandler.eraseOrderIDForThread();
552         });
553   } else {
554     for (CallGraphNode *node : nodesToCanonicalize)
555       applyPatternsAndFoldGreedily(*node->getCallableRegion(), canonPatterns);
556   }
557 
558   // Recompute the uses held by each of the nodes.
559   for (CallGraphNode *node : nodesToCanonicalize)
560     useList.recomputeUses(node, cg);
561 }
562 
563 //===----------------------------------------------------------------------===//
564 // InlinerPass
565 //===----------------------------------------------------------------------===//
566 
567 namespace {
568 struct InlinerPass : public InlinerBase<InlinerPass> {
569   void runOnOperation() override;
570 
571   /// Attempt to inline calls within the given scc, and run canonicalizations
572   /// with the given patterns, until a fixed point is reached. This allows for
573   /// the inlining of newly devirtualized calls.
574   void inlineSCC(Inliner &inliner, CGUseList &useList, CallGraphSCC &currentSCC,
575                  MLIRContext *context,
576                  const OwningRewritePatternList &canonPatterns);
577 };
578 } // end anonymous namespace
579 
580 void InlinerPass::runOnOperation() {
581   CallGraph &cg = getAnalysis<CallGraph>();
582   auto *context = &getContext();
583 
584   // The inliner should only be run on operations that define a symbol table,
585   // as the callgraph will need to resolve references.
586   Operation *op = getOperation();
587   if (!op->hasTrait<OpTrait::SymbolTable>()) {
588     op->emitOpError() << " was scheduled to run under the inliner, but does "
589                          "not define a symbol table";
590     return signalPassFailure();
591   }
592 
593   // Collect a set of canonicalization patterns to use when simplifying
594   // callable regions within an SCC.
595   OwningRewritePatternList canonPatterns;
596   for (auto *op : context->getRegisteredOperations())
597     op->getCanonicalizationPatterns(canonPatterns, context);
598 
599   // Run the inline transform in post-order over the SCCs in the callgraph.
600   SymbolTableCollection symbolTable;
601   Inliner inliner(context, cg, symbolTable);
602   CGUseList useList(getOperation(), cg, symbolTable);
603   runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) {
604     inlineSCC(inliner, useList, scc, context, canonPatterns);
605   });
606 
607   // After inlining, make sure to erase any callables proven to be dead.
608   inliner.eraseDeadCallables();
609 }
610 
611 void InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList,
612                             CallGraphSCC &currentSCC, MLIRContext *context,
613                             const OwningRewritePatternList &canonPatterns) {
614   // If we successfully inlined any calls, run some simplifications on the
615   // nodes of the scc. Continue attempting to inline until we reach a fixed
616   // point, or a maximum iteration count. We canonicalize here as it may
617   // devirtualize new calls, as well as give us a better cost model.
618   unsigned iterationCount = 0;
619   while (succeeded(inlineCallsInSCC(inliner, useList, currentSCC))) {
620     // If we aren't allowing simplifications or the max iteration count was
621     // reached, then bail out early.
622     if (disableCanonicalization || ++iterationCount >= maxInliningIterations)
623       break;
624     canonicalizeSCC(inliner.cg, useList, currentSCC, context, canonPatterns);
625   }
626 }
627 
628 std::unique_ptr<Pass> mlir::createInlinerPass() {
629   return std::make_unique<InlinerPass>();
630 }
631