1 //===- Inliner.cpp - Pass to inline function calls ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a basic inlining algorithm that operates bottom up over
10 // the Strongly Connect Components(SCCs) of the CallGraph. This enables a more
11 // incremental propagation of inlining decisions from the leafs to the roots of
12 // the callgraph.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "PassDetail.h"
17 #include "mlir/Analysis/CallGraph.h"
18 #include "mlir/IR/Threading.h"
19 #include "mlir/Interfaces/CallInterfaces.h"
20 #include "mlir/Interfaces/SideEffectInterfaces.h"
21 #include "mlir/Pass/PassManager.h"
22 #include "mlir/Support/DebugStringHelper.h"
23 #include "mlir/Transforms/InliningUtils.h"
24 #include "mlir/Transforms/Passes.h"
25 #include "llvm/ADT/SCCIterator.h"
26 #include "llvm/Support/Debug.h"
27 
28 #define DEBUG_TYPE "inlining"
29 
30 using namespace mlir;
31 
32 /// This function implements the default inliner optimization pipeline.
defaultInlinerOptPipeline(OpPassManager & pm)33 static void defaultInlinerOptPipeline(OpPassManager &pm) {
34   pm.addPass(createCanonicalizerPass());
35 }
36 
37 //===----------------------------------------------------------------------===//
38 // Symbol Use Tracking
39 //===----------------------------------------------------------------------===//
40 
41 /// Walk all of the used symbol callgraph nodes referenced with the given op.
walkReferencedSymbolNodes(Operation * op,CallGraph & cg,SymbolTableCollection & symbolTable,DenseMap<Attribute,CallGraphNode * > & resolvedRefs,function_ref<void (CallGraphNode *,Operation *)> callback)42 static void walkReferencedSymbolNodes(
43     Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable,
44     DenseMap<Attribute, CallGraphNode *> &resolvedRefs,
45     function_ref<void(CallGraphNode *, Operation *)> callback) {
46   auto symbolUses = SymbolTable::getSymbolUses(op);
47   assert(symbolUses && "expected uses to be valid");
48 
49   Operation *symbolTableOp = op->getParentOp();
50   for (const SymbolTable::SymbolUse &use : *symbolUses) {
51     auto refIt = resolvedRefs.insert({use.getSymbolRef(), nullptr});
52     CallGraphNode *&node = refIt.first->second;
53 
54     // If this is the first instance of this reference, try to resolve a
55     // callgraph node for it.
56     if (refIt.second) {
57       auto *symbolOp = symbolTable.lookupNearestSymbolFrom(symbolTableOp,
58                                                            use.getSymbolRef());
59       auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp);
60       if (!callableOp)
61         continue;
62       node = cg.lookupNode(callableOp.getCallableRegion());
63     }
64     if (node)
65       callback(node, use.getUser());
66   }
67 }
68 
69 //===----------------------------------------------------------------------===//
70 // CGUseList
71 
72 namespace {
73 /// This struct tracks the uses of callgraph nodes that can be dropped when
74 /// use_empty. It directly tracks and manages a use-list for all of the
75 /// call-graph nodes. This is necessary because many callgraph nodes are
76 /// referenced by SymbolRefAttr, which has no mechanism akin to the SSA `Use`
77 /// class.
78 struct CGUseList {
79   /// This struct tracks the uses of callgraph nodes within a specific
80   /// operation.
81   struct CGUser {
82     /// Any nodes referenced in the top-level attribute list of this user. We
83     /// use a set here because the number of references does not matter.
84     DenseSet<CallGraphNode *> topLevelUses;
85 
86     /// Uses of nodes referenced by nested operations.
87     DenseMap<CallGraphNode *, int> innerUses;
88   };
89 
90   CGUseList(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable);
91 
92   /// Drop uses of nodes referred to by the given call operation that resides
93   /// within 'userNode'.
94   void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg);
95 
96   /// Remove the given node from the use list.
97   void eraseNode(CallGraphNode *node);
98 
99   /// Returns true if the given callgraph node has no uses and can be pruned.
100   bool isDead(CallGraphNode *node) const;
101 
102   /// Returns true if the given callgraph node has a single use and can be
103   /// discarded.
104   bool hasOneUseAndDiscardable(CallGraphNode *node) const;
105 
106   /// Recompute the uses held by the given callgraph node.
107   void recomputeUses(CallGraphNode *node, CallGraph &cg);
108 
109   /// Merge the uses of 'lhs' with the uses of the 'rhs' after inlining a copy
110   /// of 'lhs' into 'rhs'.
111   void mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs);
112 
113 private:
114   /// Decrement the uses of discardable nodes referenced by the given user.
115   void decrementDiscardableUses(CGUser &uses);
116 
117   /// A mapping between a discardable callgraph node (that is a symbol) and the
118   /// number of uses for this node.
119   DenseMap<CallGraphNode *, int> discardableSymNodeUses;
120 
121   /// A mapping between a callgraph node and the symbol callgraph nodes that it
122   /// uses.
123   DenseMap<CallGraphNode *, CGUser> nodeUses;
124 
125   /// A symbol table to use when resolving call lookups.
126   SymbolTableCollection &symbolTable;
127 };
128 } // namespace
129 
CGUseList(Operation * op,CallGraph & cg,SymbolTableCollection & symbolTable)130 CGUseList::CGUseList(Operation *op, CallGraph &cg,
131                      SymbolTableCollection &symbolTable)
132     : symbolTable(symbolTable) {
133   /// A set of callgraph nodes that are always known to be live during inlining.
134   DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes;
135 
136   // Walk each of the symbol tables looking for discardable callgraph nodes.
137   auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
138     for (Operation &op : symbolTableOp->getRegion(0).getOps()) {
139       // If this is a callgraph operation, check to see if it is discardable.
140       if (auto callable = dyn_cast<CallableOpInterface>(&op)) {
141         if (auto *node = cg.lookupNode(callable.getCallableRegion())) {
142           SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
143           if (symbol && (allUsesVisible || symbol.isPrivate()) &&
144               symbol.canDiscardOnUseEmpty()) {
145             discardableSymNodeUses.try_emplace(node, 0);
146           }
147           continue;
148         }
149       }
150       // Otherwise, check for any referenced nodes. These will be always-live.
151       walkReferencedSymbolNodes(&op, cg, symbolTable, alwaysLiveNodes,
152                                 [](CallGraphNode *, Operation *) {});
153     }
154   };
155   SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
156                                 walkFn);
157 
158   // Drop the use information for any discardable nodes that are always live.
159   for (auto &it : alwaysLiveNodes)
160     discardableSymNodeUses.erase(it.second);
161 
162   // Compute the uses for each of the callable nodes in the graph.
163   for (CallGraphNode *node : cg)
164     recomputeUses(node, cg);
165 }
166 
dropCallUses(CallGraphNode * userNode,Operation * callOp,CallGraph & cg)167 void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp,
168                              CallGraph &cg) {
169   auto &userRefs = nodeUses[userNode].innerUses;
170   auto walkFn = [&](CallGraphNode *node, Operation *user) {
171     auto parentIt = userRefs.find(node);
172     if (parentIt == userRefs.end())
173       return;
174     --parentIt->second;
175     --discardableSymNodeUses[node];
176   };
177   DenseMap<Attribute, CallGraphNode *> resolvedRefs;
178   walkReferencedSymbolNodes(callOp, cg, symbolTable, resolvedRefs, walkFn);
179 }
180 
eraseNode(CallGraphNode * node)181 void CGUseList::eraseNode(CallGraphNode *node) {
182   // Drop all child nodes.
183   for (auto &edge : *node)
184     if (edge.isChild())
185       eraseNode(edge.getTarget());
186 
187   // Drop the uses held by this node and erase it.
188   auto useIt = nodeUses.find(node);
189   assert(useIt != nodeUses.end() && "expected node to be valid");
190   decrementDiscardableUses(useIt->getSecond());
191   nodeUses.erase(useIt);
192   discardableSymNodeUses.erase(node);
193 }
194 
isDead(CallGraphNode * node) const195 bool CGUseList::isDead(CallGraphNode *node) const {
196   // If the parent operation isn't a symbol, simply check normal SSA deadness.
197   Operation *nodeOp = node->getCallableRegion()->getParentOp();
198   if (!isa<SymbolOpInterface>(nodeOp))
199     return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->use_empty();
200 
201   // Otherwise, check the number of symbol uses.
202   auto symbolIt = discardableSymNodeUses.find(node);
203   return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0;
204 }
205 
hasOneUseAndDiscardable(CallGraphNode * node) const206 bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const {
207   // If this isn't a symbol node, check for side-effects and SSA use count.
208   Operation *nodeOp = node->getCallableRegion()->getParentOp();
209   if (!isa<SymbolOpInterface>(nodeOp))
210     return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->hasOneUse();
211 
212   // Otherwise, check the number of symbol uses.
213   auto symbolIt = discardableSymNodeUses.find(node);
214   return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1;
215 }
216 
recomputeUses(CallGraphNode * node,CallGraph & cg)217 void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) {
218   Operation *parentOp = node->getCallableRegion()->getParentOp();
219   CGUser &uses = nodeUses[node];
220   decrementDiscardableUses(uses);
221 
222   // Collect the new discardable uses within this node.
223   uses = CGUser();
224   DenseMap<Attribute, CallGraphNode *> resolvedRefs;
225   auto walkFn = [&](CallGraphNode *refNode, Operation *user) {
226     auto discardSymIt = discardableSymNodeUses.find(refNode);
227     if (discardSymIt == discardableSymNodeUses.end())
228       return;
229 
230     if (user != parentOp)
231       ++uses.innerUses[refNode];
232     else if (!uses.topLevelUses.insert(refNode).second)
233       return;
234     ++discardSymIt->second;
235   };
236   walkReferencedSymbolNodes(parentOp, cg, symbolTable, resolvedRefs, walkFn);
237 }
238 
mergeUsesAfterInlining(CallGraphNode * lhs,CallGraphNode * rhs)239 void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) {
240   auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs];
241   for (auto &useIt : lhsUses.innerUses) {
242     rhsUses.innerUses[useIt.first] += useIt.second;
243     discardableSymNodeUses[useIt.first] += useIt.second;
244   }
245 }
246 
decrementDiscardableUses(CGUser & uses)247 void CGUseList::decrementDiscardableUses(CGUser &uses) {
248   for (CallGraphNode *node : uses.topLevelUses)
249     --discardableSymNodeUses[node];
250   for (auto &it : uses.innerUses)
251     discardableSymNodeUses[it.first] -= it.second;
252 }
253 
254 //===----------------------------------------------------------------------===//
255 // CallGraph traversal
256 //===----------------------------------------------------------------------===//
257 
258 namespace {
259 /// This class represents a specific callgraph SCC.
260 class CallGraphSCC {
261 public:
CallGraphSCC(llvm::scc_iterator<const CallGraph * > & parentIterator)262   CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator)
263       : parentIterator(parentIterator) {}
264   /// Return a range over the nodes within this SCC.
begin()265   std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); }
end()266   std::vector<CallGraphNode *>::iterator end() { return nodes.end(); }
267 
268   /// Reset the nodes of this SCC with those provided.
reset(const std::vector<CallGraphNode * > & newNodes)269   void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; }
270 
271   /// Remove the given node from this SCC.
remove(CallGraphNode * node)272   void remove(CallGraphNode *node) {
273     auto it = llvm::find(nodes, node);
274     if (it != nodes.end()) {
275       nodes.erase(it);
276       parentIterator.ReplaceNode(node, nullptr);
277     }
278   }
279 
280 private:
281   std::vector<CallGraphNode *> nodes;
282   llvm::scc_iterator<const CallGraph *> &parentIterator;
283 };
284 } // namespace
285 
286 /// Run a given transformation over the SCCs of the callgraph in a bottom up
287 /// traversal.
runTransformOnCGSCCs(const CallGraph & cg,function_ref<LogicalResult (CallGraphSCC &)> sccTransformer)288 static LogicalResult runTransformOnCGSCCs(
289     const CallGraph &cg,
290     function_ref<LogicalResult(CallGraphSCC &)> sccTransformer) {
291   llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg);
292   CallGraphSCC currentSCC(cgi);
293   while (!cgi.isAtEnd()) {
294     // Copy the current SCC and increment so that the transformer can modify the
295     // SCC without invalidating our iterator.
296     currentSCC.reset(*cgi);
297     ++cgi;
298     if (failed(sccTransformer(currentSCC)))
299       return failure();
300   }
301   return success();
302 }
303 
304 namespace {
305 /// This struct represents a resolved call to a given callgraph node. Given that
306 /// the call does not actually contain a direct reference to the
307 /// Region(CallGraphNode) that it is dispatching to, we need to resolve them
308 /// explicitly.
309 struct ResolvedCall {
ResolvedCall__anondfe352480711::ResolvedCall310   ResolvedCall(CallOpInterface call, CallGraphNode *sourceNode,
311                CallGraphNode *targetNode)
312       : call(call), sourceNode(sourceNode), targetNode(targetNode) {}
313   CallOpInterface call;
314   CallGraphNode *sourceNode, *targetNode;
315 };
316 } // namespace
317 
318 /// Collect all of the callable operations within the given range of blocks. If
319 /// `traverseNestedCGNodes` is true, this will also collect call operations
320 /// inside of nested callgraph nodes.
collectCallOps(iterator_range<Region::iterator> blocks,CallGraphNode * sourceNode,CallGraph & cg,SymbolTableCollection & symbolTable,SmallVectorImpl<ResolvedCall> & calls,bool traverseNestedCGNodes)321 static void collectCallOps(iterator_range<Region::iterator> blocks,
322                            CallGraphNode *sourceNode, CallGraph &cg,
323                            SymbolTableCollection &symbolTable,
324                            SmallVectorImpl<ResolvedCall> &calls,
325                            bool traverseNestedCGNodes) {
326   SmallVector<std::pair<Block *, CallGraphNode *>, 8> worklist;
327   auto addToWorklist = [&](CallGraphNode *node,
328                            iterator_range<Region::iterator> blocks) {
329     for (Block &block : blocks)
330       worklist.emplace_back(&block, node);
331   };
332 
333   addToWorklist(sourceNode, blocks);
334   while (!worklist.empty()) {
335     Block *block;
336     std::tie(block, sourceNode) = worklist.pop_back_val();
337 
338     for (Operation &op : *block) {
339       if (auto call = dyn_cast<CallOpInterface>(op)) {
340         // TODO: Support inlining nested call references.
341         CallInterfaceCallable callable = call.getCallableForCallee();
342         if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) {
343           if (!symRef.isa<FlatSymbolRefAttr>())
344             continue;
345         }
346 
347         CallGraphNode *targetNode = cg.resolveCallable(call, symbolTable);
348         if (!targetNode->isExternal())
349           calls.emplace_back(call, sourceNode, targetNode);
350         continue;
351       }
352 
353       // If this is not a call, traverse the nested regions. If
354       // `traverseNestedCGNodes` is false, then don't traverse nested call graph
355       // regions.
356       for (auto &nestedRegion : op.getRegions()) {
357         CallGraphNode *nestedNode = cg.lookupNode(&nestedRegion);
358         if (traverseNestedCGNodes || !nestedNode)
359           addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion);
360       }
361     }
362   }
363 }
364 
365 //===----------------------------------------------------------------------===//
366 // Inliner
367 //===----------------------------------------------------------------------===//
368 
369 #ifndef NDEBUG
getNodeName(CallOpInterface op)370 static std::string getNodeName(CallOpInterface op) {
371   if (auto sym = op.getCallableForCallee().dyn_cast<SymbolRefAttr>())
372     return debugString(op);
373   return "_unnamed_callee_";
374 }
375 #endif
376 
377 /// Return true if the specified `inlineHistoryID`  indicates an inline history
378 /// that already includes `node`.
inlineHistoryIncludes(CallGraphNode * node,Optional<size_t> inlineHistoryID,MutableArrayRef<std::pair<CallGraphNode *,Optional<size_t>>> inlineHistory)379 static bool inlineHistoryIncludes(
380     CallGraphNode *node, Optional<size_t> inlineHistoryID,
381     MutableArrayRef<std::pair<CallGraphNode *, Optional<size_t>>>
382         inlineHistory) {
383   while (inlineHistoryID.has_value()) {
384     assert(inlineHistoryID.value() < inlineHistory.size() &&
385            "Invalid inline history ID");
386     if (inlineHistory[inlineHistoryID.value()].first == node)
387       return true;
388     inlineHistoryID = inlineHistory[inlineHistoryID.value()].second;
389   }
390   return false;
391 }
392 
393 namespace {
394 /// This class provides a specialization of the main inlining interface.
395 struct Inliner : public InlinerInterface {
Inliner__anondfe352480911::Inliner396   Inliner(MLIRContext *context, CallGraph &cg,
397           SymbolTableCollection &symbolTable)
398       : InlinerInterface(context), cg(cg), symbolTable(symbolTable) {}
399 
400   /// Process a set of blocks that have been inlined. This callback is invoked
401   /// *before* inlined terminator operations have been processed.
402   void
processInlinedBlocks__anondfe352480911::Inliner403   processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final {
404     // Find the closest callgraph node from the first block.
405     CallGraphNode *node;
406     Region *region = inlinedBlocks.begin()->getParent();
407     while (!(node = cg.lookupNode(region))) {
408       region = region->getParentRegion();
409       assert(region && "expected valid parent node");
410     }
411 
412     collectCallOps(inlinedBlocks, node, cg, symbolTable, calls,
413                    /*traverseNestedCGNodes=*/true);
414   }
415 
416   /// Mark the given callgraph node for deletion.
markForDeletion__anondfe352480911::Inliner417   void markForDeletion(CallGraphNode *node) { deadNodes.insert(node); }
418 
419   /// This method properly disposes of callables that became dead during
420   /// inlining. This should not be called while iterating over the SCCs.
eraseDeadCallables__anondfe352480911::Inliner421   void eraseDeadCallables() {
422     for (CallGraphNode *node : deadNodes)
423       node->getCallableRegion()->getParentOp()->erase();
424   }
425 
426   /// The set of callables known to be dead.
427   SmallPtrSet<CallGraphNode *, 8> deadNodes;
428 
429   /// The current set of call instructions to consider for inlining.
430   SmallVector<ResolvedCall, 8> calls;
431 
432   /// The callgraph being operated on.
433   CallGraph &cg;
434 
435   /// A symbol table to use when resolving call lookups.
436   SymbolTableCollection &symbolTable;
437 };
438 } // namespace
439 
440 /// Returns true if the given call should be inlined.
shouldInline(ResolvedCall & resolvedCall)441 static bool shouldInline(ResolvedCall &resolvedCall) {
442   // Don't allow inlining terminator calls. We currently don't support this
443   // case.
444   if (resolvedCall.call->hasTrait<OpTrait::IsTerminator>())
445     return false;
446 
447   // Don't allow inlining if the target is an ancestor of the call. This
448   // prevents inlining recursively.
449   if (resolvedCall.targetNode->getCallableRegion()->isAncestor(
450           resolvedCall.call->getParentRegion()))
451     return false;
452 
453   // Otherwise, inline.
454   return true;
455 }
456 
457 /// Attempt to inline calls within the given scc. This function returns
458 /// success if any calls were inlined, failure otherwise.
inlineCallsInSCC(Inliner & inliner,CGUseList & useList,CallGraphSCC & currentSCC)459 static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
460                                       CallGraphSCC &currentSCC) {
461   CallGraph &cg = inliner.cg;
462   auto &calls = inliner.calls;
463 
464   // A set of dead nodes to remove after inlining.
465   llvm::SmallSetVector<CallGraphNode *, 1> deadNodes;
466 
467   // Collect all of the direct calls within the nodes of the current SCC. We
468   // don't traverse nested callgraph nodes, because they are handled separately
469   // likely within a different SCC.
470   for (CallGraphNode *node : currentSCC) {
471     if (node->isExternal())
472       continue;
473 
474     // Don't collect calls if the node is already dead.
475     if (useList.isDead(node)) {
476       deadNodes.insert(node);
477     } else {
478       collectCallOps(*node->getCallableRegion(), node, cg, inliner.symbolTable,
479                      calls, /*traverseNestedCGNodes=*/false);
480     }
481   }
482 
483   // When inlining a callee produces new call sites, we want to keep track of
484   // the fact that they were inlined from the callee. This allows us to avoid
485   // infinite inlining.
486   using InlineHistoryT = Optional<size_t>;
487   SmallVector<std::pair<CallGraphNode *, InlineHistoryT>, 8> inlineHistory;
488   std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{});
489 
490   LLVM_DEBUG({
491     llvm::dbgs() << "* Inliner: Initial calls in SCC are: {\n";
492     for (unsigned i = 0, e = calls.size(); i < e; ++i)
493       llvm::dbgs() << "  " << i << ". " << calls[i].call << ",\n";
494     llvm::dbgs() << "}\n";
495   });
496 
497   // Try to inline each of the call operations. Don't cache the end iterator
498   // here as more calls may be added during inlining.
499   bool inlinedAnyCalls = false;
500   for (unsigned i = 0; i < calls.size(); ++i) {
501     if (deadNodes.contains(calls[i].sourceNode))
502       continue;
503     ResolvedCall it = calls[i];
504 
505     InlineHistoryT inlineHistoryID = callHistory[i];
506     bool inHistory =
507         inlineHistoryIncludes(it.targetNode, inlineHistoryID, inlineHistory);
508     bool doInline = !inHistory && shouldInline(it);
509     CallOpInterface call = it.call;
510     LLVM_DEBUG({
511       if (doInline)
512         llvm::dbgs() << "* Inlining call: " << i << ". " << call << "\n";
513       else
514         llvm::dbgs() << "* Not inlining call: " << i << ". " << call << "\n";
515     });
516     if (!doInline)
517       continue;
518 
519     unsigned prevSize = calls.size();
520     Region *targetRegion = it.targetNode->getCallableRegion();
521 
522     // If this is the last call to the target node and the node is discardable,
523     // then inline it in-place and delete the node if successful.
524     bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode);
525 
526     LogicalResult inlineResult = inlineCall(
527         inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()),
528         targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace);
529     if (failed(inlineResult)) {
530       LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n");
531       continue;
532     }
533     inlinedAnyCalls = true;
534 
535     // Create a inline history entry for this inlined call, so that we remember
536     // that new callsites came about due to inlining Callee.
537     InlineHistoryT newInlineHistoryID{inlineHistory.size()};
538     inlineHistory.push_back(std::make_pair(it.targetNode, inlineHistoryID));
539 
540     auto historyToString = [](InlineHistoryT h) {
541       return h.has_value() ? std::to_string(h.value()) : "root";
542     };
543     (void)historyToString;
544     LLVM_DEBUG(llvm::dbgs()
545                << "* new inlineHistory entry: " << newInlineHistoryID << ". ["
546                << getNodeName(call) << ", " << historyToString(inlineHistoryID)
547                << "]\n");
548 
549     for (unsigned k = prevSize; k != calls.size(); ++k) {
550       callHistory.push_back(newInlineHistoryID);
551       LLVM_DEBUG(llvm::dbgs() << "* new call " << k << " {" << calls[i].call
552                               << "}\n   with historyID = " << newInlineHistoryID
553                               << ", added due to inlining of\n  call {" << call
554                               << "}\n with historyID = "
555                               << historyToString(inlineHistoryID) << "\n");
556     }
557 
558     // If the inlining was successful, Merge the new uses into the source node.
559     useList.dropCallUses(it.sourceNode, call.getOperation(), cg);
560     useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode);
561 
562     // then erase the call.
563     call.erase();
564 
565     // If we inlined in place, mark the node for deletion.
566     if (inlineInPlace) {
567       useList.eraseNode(it.targetNode);
568       deadNodes.insert(it.targetNode);
569     }
570   }
571 
572   for (CallGraphNode *node : deadNodes) {
573     currentSCC.remove(node);
574     inliner.markForDeletion(node);
575   }
576   calls.clear();
577   return success(inlinedAnyCalls);
578 }
579 
580 //===----------------------------------------------------------------------===//
581 // InlinerPass
582 //===----------------------------------------------------------------------===//
583 
584 namespace {
585 class InlinerPass : public InlinerBase<InlinerPass> {
586 public:
587   InlinerPass();
588   InlinerPass(const InlinerPass &) = default;
589   InlinerPass(std::function<void(OpPassManager &)> defaultPipeline);
590   InlinerPass(std::function<void(OpPassManager &)> defaultPipeline,
591               llvm::StringMap<OpPassManager> opPipelines);
592   void runOnOperation() override;
593 
594 private:
595   /// Attempt to inline calls within the given scc, and run simplifications,
596   /// until a fixed point is reached. This allows for the inlining of newly
597   /// devirtualized calls. Returns failure if there was a fatal error during
598   /// inlining.
599   LogicalResult inlineSCC(Inliner &inliner, CGUseList &useList,
600                           CallGraphSCC &currentSCC, MLIRContext *context);
601 
602   /// Optimize the nodes within the given SCC with one of the held optimization
603   /// pass pipelines. Returns failure if an error occurred during the
604   /// optimization of the SCC, success otherwise.
605   LogicalResult optimizeSCC(CallGraph &cg, CGUseList &useList,
606                             CallGraphSCC &currentSCC, MLIRContext *context);
607 
608   /// Optimize the nodes within the given SCC in parallel. Returns failure if an
609   /// error occurred during the optimization of the SCC, success otherwise.
610   LogicalResult optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
611                                  MLIRContext *context);
612 
613   /// Optimize the given callable node with one of the pass managers provided
614   /// with `pipelines`, or the default pipeline. Returns failure if an error
615   /// occurred during the optimization of the callable, success otherwise.
616   LogicalResult optimizeCallable(CallGraphNode *node,
617                                  llvm::StringMap<OpPassManager> &pipelines);
618 
619   /// Attempt to initialize the options of this pass from the given string.
620   /// Derived classes may override this method to hook into the point at which
621   /// options are initialized, but should generally always invoke this base
622   /// class variant.
623   LogicalResult initializeOptions(StringRef options) override;
624 
625   /// An optional function that constructs a default optimization pipeline for
626   /// a given operation.
627   std::function<void(OpPassManager &)> defaultPipeline;
628   /// A map of operation names to pass pipelines to use when optimizing
629   /// callable operations of these types. This provides a specialized pipeline
630   /// instead of the default. The vector size is the number of threads used
631   /// during optimization.
632   SmallVector<llvm::StringMap<OpPassManager>, 8> opPipelines;
633 };
634 } // namespace
635 
InlinerPass()636 InlinerPass::InlinerPass() : InlinerPass(defaultInlinerOptPipeline) {}
InlinerPass(std::function<void (OpPassManager &)> defaultPipeline)637 InlinerPass::InlinerPass(std::function<void(OpPassManager &)> defaultPipeline)
638     : defaultPipeline(std::move(defaultPipeline)) {
639   opPipelines.push_back({});
640 
641   // Initialize the pass options with the provided arguments.
642   if (defaultPipeline) {
643     OpPassManager fakePM("__mlir_fake_pm_op");
644     defaultPipeline(fakePM);
645     llvm::raw_string_ostream strStream(defaultPipelineStr);
646     fakePM.printAsTextualPipeline(strStream);
647   }
648 }
649 
InlinerPass(std::function<void (OpPassManager &)> defaultPipeline,llvm::StringMap<OpPassManager> opPipelines)650 InlinerPass::InlinerPass(std::function<void(OpPassManager &)> defaultPipeline,
651                          llvm::StringMap<OpPassManager> opPipelines)
652     : InlinerPass(std::move(defaultPipeline)) {
653   if (opPipelines.empty())
654     return;
655 
656   // Update the option for the op specific optimization pipelines.
657   for (auto &it : opPipelines)
658     opPipelineList.addValue(it.second);
659   this->opPipelines.emplace_back(std::move(opPipelines));
660 }
661 
runOnOperation()662 void InlinerPass::runOnOperation() {
663   CallGraph &cg = getAnalysis<CallGraph>();
664   auto *context = &getContext();
665 
666   // The inliner should only be run on operations that define a symbol table,
667   // as the callgraph will need to resolve references.
668   Operation *op = getOperation();
669   if (!op->hasTrait<OpTrait::SymbolTable>()) {
670     op->emitOpError() << " was scheduled to run under the inliner, but does "
671                          "not define a symbol table";
672     return signalPassFailure();
673   }
674 
675   // Run the inline transform in post-order over the SCCs in the callgraph.
676   SymbolTableCollection symbolTable;
677   Inliner inliner(context, cg, symbolTable);
678   CGUseList useList(getOperation(), cg, symbolTable);
679   LogicalResult result = runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) {
680     return inlineSCC(inliner, useList, scc, context);
681   });
682   if (failed(result))
683     return signalPassFailure();
684 
685   // After inlining, make sure to erase any callables proven to be dead.
686   inliner.eraseDeadCallables();
687 }
688 
inlineSCC(Inliner & inliner,CGUseList & useList,CallGraphSCC & currentSCC,MLIRContext * context)689 LogicalResult InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList,
690                                      CallGraphSCC &currentSCC,
691                                      MLIRContext *context) {
692   // Continuously simplify and inline until we either reach a fixed point, or
693   // hit the maximum iteration count. Simplifying early helps to refine the cost
694   // model, and in future iterations may devirtualize new calls.
695   unsigned iterationCount = 0;
696   do {
697     if (failed(optimizeSCC(inliner.cg, useList, currentSCC, context)))
698       return failure();
699     if (failed(inlineCallsInSCC(inliner, useList, currentSCC)))
700       break;
701   } while (++iterationCount < maxInliningIterations);
702   return success();
703 }
704 
optimizeSCC(CallGraph & cg,CGUseList & useList,CallGraphSCC & currentSCC,MLIRContext * context)705 LogicalResult InlinerPass::optimizeSCC(CallGraph &cg, CGUseList &useList,
706                                        CallGraphSCC &currentSCC,
707                                        MLIRContext *context) {
708   // Collect the sets of nodes to simplify.
709   SmallVector<CallGraphNode *, 4> nodesToVisit;
710   for (auto *node : currentSCC) {
711     if (node->isExternal())
712       continue;
713 
714     // Don't simplify nodes with children. Nodes with children require special
715     // handling as we may remove the node during simplification. In the future,
716     // we should be able to handle this case with proper node deletion tracking.
717     if (node->hasChildren())
718       continue;
719 
720     // We also won't apply simplifications to nodes that can't have passes
721     // scheduled on them.
722     auto *region = node->getCallableRegion();
723     if (!region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
724       continue;
725     nodesToVisit.push_back(node);
726   }
727   if (nodesToVisit.empty())
728     return success();
729 
730   // Optimize each of the nodes within the SCC in parallel.
731   if (failed(optimizeSCCAsync(nodesToVisit, context)))
732     return failure();
733 
734   // Recompute the uses held by each of the nodes.
735   for (CallGraphNode *node : nodesToVisit)
736     useList.recomputeUses(node, cg);
737   return success();
738 }
739 
740 LogicalResult
optimizeSCCAsync(MutableArrayRef<CallGraphNode * > nodesToVisit,MLIRContext * ctx)741 InlinerPass::optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
742                               MLIRContext *ctx) {
743   // We must maintain a fixed pool of pass managers which is at least as large
744   // as the maximum parallelism of the failableParallelForEach below.
745   // Note: The number of pass managers here needs to remain constant
746   // to prevent issues with pass instrumentations that rely on having the same
747   // pass manager for the main thread.
748   size_t numThreads = ctx->getNumThreads();
749   if (opPipelines.size() < numThreads) {
750     // Reserve before resizing so that we can use a reference to the first
751     // element.
752     opPipelines.reserve(numThreads);
753     opPipelines.resize(numThreads, opPipelines.front());
754   }
755 
756   // Ensure an analysis manager has been constructed for each of the nodes.
757   // This prevents thread races when running the nested pipelines.
758   for (CallGraphNode *node : nodesToVisit)
759     getAnalysisManager().nest(node->getCallableRegion()->getParentOp());
760 
761   // An atomic failure variable for the async executors.
762   std::vector<std::atomic<bool>> activePMs(opPipelines.size());
763   std::fill(activePMs.begin(), activePMs.end(), false);
764   return failableParallelForEach(ctx, nodesToVisit, [&](CallGraphNode *node) {
765     // Find a pass manager for this operation.
766     auto it = llvm::find_if(activePMs, [](std::atomic<bool> &isActive) {
767       bool expectedInactive = false;
768       return isActive.compare_exchange_strong(expectedInactive, true);
769     });
770     assert(it != activePMs.end() &&
771            "could not find inactive pass manager for thread");
772     unsigned pmIndex = it - activePMs.begin();
773 
774     // Optimize this callable node.
775     LogicalResult result = optimizeCallable(node, opPipelines[pmIndex]);
776 
777     // Reset the active bit for this pass manager.
778     activePMs[pmIndex].store(false);
779     return result;
780   });
781 }
782 
783 LogicalResult
optimizeCallable(CallGraphNode * node,llvm::StringMap<OpPassManager> & pipelines)784 InlinerPass::optimizeCallable(CallGraphNode *node,
785                               llvm::StringMap<OpPassManager> &pipelines) {
786   Operation *callable = node->getCallableRegion()->getParentOp();
787   StringRef opName = callable->getName().getStringRef();
788   auto pipelineIt = pipelines.find(opName);
789   if (pipelineIt == pipelines.end()) {
790     // If a pipeline didn't exist, use the default if possible.
791     if (!defaultPipeline)
792       return success();
793 
794     OpPassManager defaultPM(opName);
795     defaultPipeline(defaultPM);
796     pipelineIt = pipelines.try_emplace(opName, std::move(defaultPM)).first;
797   }
798   return runPipeline(pipelineIt->second, callable);
799 }
800 
initializeOptions(StringRef options)801 LogicalResult InlinerPass::initializeOptions(StringRef options) {
802   if (failed(Pass::initializeOptions(options)))
803     return failure();
804 
805   // Initialize the default pipeline builder to use the option string.
806   // TODO: Use a generic pass manager for default pipelines, and remove this.
807   if (!defaultPipelineStr.empty()) {
808     std::string defaultPipelineCopy = defaultPipelineStr;
809     defaultPipeline = [=](OpPassManager &pm) {
810       (void)parsePassPipeline(defaultPipelineCopy, pm);
811     };
812   } else if (defaultPipelineStr.getNumOccurrences()) {
813     defaultPipeline = nullptr;
814   }
815 
816   // Initialize the op specific pass pipelines.
817   llvm::StringMap<OpPassManager> pipelines;
818   for (OpPassManager pipeline : opPipelineList)
819     if (!pipeline.empty())
820       pipelines.try_emplace(pipeline.getOpAnchorName(), pipeline);
821   opPipelines.assign({std::move(pipelines)});
822 
823   return success();
824 }
825 
createInlinerPass()826 std::unique_ptr<Pass> mlir::createInlinerPass() {
827   return std::make_unique<InlinerPass>();
828 }
829 std::unique_ptr<Pass>
createInlinerPass(llvm::StringMap<OpPassManager> opPipelines)830 mlir::createInlinerPass(llvm::StringMap<OpPassManager> opPipelines) {
831   return std::make_unique<InlinerPass>(defaultInlinerOptPipeline,
832                                        std::move(opPipelines));
833 }
createInlinerPass(llvm::StringMap<OpPassManager> opPipelines,std::function<void (OpPassManager &)> defaultPipelineBuilder)834 std::unique_ptr<Pass> mlir::createInlinerPass(
835     llvm::StringMap<OpPassManager> opPipelines,
836     std::function<void(OpPassManager &)> defaultPipelineBuilder) {
837   return std::make_unique<InlinerPass>(std::move(defaultPipelineBuilder),
838                                        std::move(opPipelines));
839 }
840