10ba00878SRiver Riddle //===- Inliner.cpp - Pass to inline function calls ------------------------===//
20ba00878SRiver Riddle //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60ba00878SRiver Riddle //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
8a20d96e4SRiver Riddle //
9a20d96e4SRiver Riddle // This file implements a basic inlining algorithm that operates bottom up over
10a20d96e4SRiver Riddle // the Strongly Connect Components(SCCs) of the CallGraph. This enables a more
11a20d96e4SRiver Riddle // incremental propagation of inlining decisions from the leafs to the roots of
12a20d96e4SRiver Riddle // the callgraph.
13a20d96e4SRiver Riddle //
14a20d96e4SRiver Riddle //===----------------------------------------------------------------------===//
150ba00878SRiver Riddle
161834ad4aSRiver Riddle #include "PassDetail.h"
17a20d96e4SRiver Riddle #include "mlir/Analysis/CallGraph.h"
186569cf2aSRiver Riddle #include "mlir/IR/Threading.h"
1906057248SRiver Riddle #include "mlir/Interfaces/CallInterfaces.h"
20eb623ae8SStephen Neuendorffer #include "mlir/Interfaces/SideEffectInterfaces.h"
21d7eba200SRiver Riddle #include "mlir/Pass/PassManager.h"
22*c2ecf162SJaved Absar #include "mlir/Support/DebugStringHelper.h"
230ba00878SRiver Riddle #include "mlir/Transforms/InliningUtils.h"
240ba00878SRiver Riddle #include "mlir/Transforms/Passes.h"
25a20d96e4SRiver Riddle #include "llvm/ADT/SCCIterator.h"
26553f794bSSean Silva #include "llvm/Support/Debug.h"
270ba00878SRiver Riddle
28553f794bSSean Silva #define DEBUG_TYPE "inlining"
29553f794bSSean Silva
300ba00878SRiver Riddle using namespace mlir;
310ba00878SRiver Riddle
32d7eba200SRiver Riddle /// This function implements the default inliner optimization pipeline.
defaultInlinerOptPipeline(OpPassManager & pm)33d7eba200SRiver Riddle static void defaultInlinerOptPipeline(OpPassManager &pm) {
34d7eba200SRiver Riddle pm.addPass(createCanonicalizerPass());
35d7eba200SRiver Riddle }
36d7eba200SRiver Riddle
37a20d96e4SRiver Riddle //===----------------------------------------------------------------------===//
384be504a9SRiver Riddle // Symbol Use Tracking
394be504a9SRiver Riddle //===----------------------------------------------------------------------===//
404be504a9SRiver Riddle
414be504a9SRiver Riddle /// Walk all of the used symbol callgraph nodes referenced with the given op.
walkReferencedSymbolNodes(Operation * op,CallGraph & cg,SymbolTableCollection & symbolTable,DenseMap<Attribute,CallGraphNode * > & resolvedRefs,function_ref<void (CallGraphNode *,Operation *)> callback)424be504a9SRiver Riddle static void walkReferencedSymbolNodes(
43a5ea6045SRiver Riddle Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable,
444be504a9SRiver Riddle DenseMap<Attribute, CallGraphNode *> &resolvedRefs,
454be504a9SRiver Riddle function_ref<void(CallGraphNode *, Operation *)> callback) {
464be504a9SRiver Riddle auto symbolUses = SymbolTable::getSymbolUses(op);
474be504a9SRiver Riddle assert(symbolUses && "expected uses to be valid");
484be504a9SRiver Riddle
494be504a9SRiver Riddle Operation *symbolTableOp = op->getParentOp();
504be504a9SRiver Riddle for (const SymbolTable::SymbolUse &use : *symbolUses) {
514be504a9SRiver Riddle auto refIt = resolvedRefs.insert({use.getSymbolRef(), nullptr});
524be504a9SRiver Riddle CallGraphNode *&node = refIt.first->second;
534be504a9SRiver Riddle
544be504a9SRiver Riddle // If this is the first instance of this reference, try to resolve a
554be504a9SRiver Riddle // callgraph node for it.
564be504a9SRiver Riddle if (refIt.second) {
57a5ea6045SRiver Riddle auto *symbolOp = symbolTable.lookupNearestSymbolFrom(symbolTableOp,
584be504a9SRiver Riddle use.getSymbolRef());
594be504a9SRiver Riddle auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp);
604be504a9SRiver Riddle if (!callableOp)
614be504a9SRiver Riddle continue;
624be504a9SRiver Riddle node = cg.lookupNode(callableOp.getCallableRegion());
634be504a9SRiver Riddle }
644be504a9SRiver Riddle if (node)
654be504a9SRiver Riddle callback(node, use.getUser());
664be504a9SRiver Riddle }
674be504a9SRiver Riddle }
684be504a9SRiver Riddle
694be504a9SRiver Riddle //===----------------------------------------------------------------------===//
704be504a9SRiver Riddle // CGUseList
714be504a9SRiver Riddle
724be504a9SRiver Riddle namespace {
734be504a9SRiver Riddle /// This struct tracks the uses of callgraph nodes that can be dropped when
744be504a9SRiver Riddle /// use_empty. It directly tracks and manages a use-list for all of the
754be504a9SRiver Riddle /// call-graph nodes. This is necessary because many callgraph nodes are
764be504a9SRiver Riddle /// referenced by SymbolRefAttr, which has no mechanism akin to the SSA `Use`
774be504a9SRiver Riddle /// class.
784be504a9SRiver Riddle struct CGUseList {
794be504a9SRiver Riddle /// This struct tracks the uses of callgraph nodes within a specific
804be504a9SRiver Riddle /// operation.
814be504a9SRiver Riddle struct CGUser {
824be504a9SRiver Riddle /// Any nodes referenced in the top-level attribute list of this user. We
834be504a9SRiver Riddle /// use a set here because the number of references does not matter.
844be504a9SRiver Riddle DenseSet<CallGraphNode *> topLevelUses;
854be504a9SRiver Riddle
864be504a9SRiver Riddle /// Uses of nodes referenced by nested operations.
874be504a9SRiver Riddle DenseMap<CallGraphNode *, int> innerUses;
884be504a9SRiver Riddle };
894be504a9SRiver Riddle
90a5ea6045SRiver Riddle CGUseList(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable);
914be504a9SRiver Riddle
924be504a9SRiver Riddle /// Drop uses of nodes referred to by the given call operation that resides
934be504a9SRiver Riddle /// within 'userNode'.
944be504a9SRiver Riddle void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg);
954be504a9SRiver Riddle
964be504a9SRiver Riddle /// Remove the given node from the use list.
974be504a9SRiver Riddle void eraseNode(CallGraphNode *node);
984be504a9SRiver Riddle
994be504a9SRiver Riddle /// Returns true if the given callgraph node has no uses and can be pruned.
1004be504a9SRiver Riddle bool isDead(CallGraphNode *node) const;
1014be504a9SRiver Riddle
1024be504a9SRiver Riddle /// Returns true if the given callgraph node has a single use and can be
1034be504a9SRiver Riddle /// discarded.
1044be504a9SRiver Riddle bool hasOneUseAndDiscardable(CallGraphNode *node) const;
1054be504a9SRiver Riddle
1064be504a9SRiver Riddle /// Recompute the uses held by the given callgraph node.
1074be504a9SRiver Riddle void recomputeUses(CallGraphNode *node, CallGraph &cg);
1084be504a9SRiver Riddle
1094be504a9SRiver Riddle /// Merge the uses of 'lhs' with the uses of the 'rhs' after inlining a copy
1104be504a9SRiver Riddle /// of 'lhs' into 'rhs'.
1114be504a9SRiver Riddle void mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs);
1124be504a9SRiver Riddle
1134be504a9SRiver Riddle private:
1144be504a9SRiver Riddle /// Decrement the uses of discardable nodes referenced by the given user.
1154be504a9SRiver Riddle void decrementDiscardableUses(CGUser &uses);
1164be504a9SRiver Riddle
1174be504a9SRiver Riddle /// A mapping between a discardable callgraph node (that is a symbol) and the
1184be504a9SRiver Riddle /// number of uses for this node.
1194be504a9SRiver Riddle DenseMap<CallGraphNode *, int> discardableSymNodeUses;
120a5ea6045SRiver Riddle
1214be504a9SRiver Riddle /// A mapping between a callgraph node and the symbol callgraph nodes that it
1224be504a9SRiver Riddle /// uses.
1234be504a9SRiver Riddle DenseMap<CallGraphNode *, CGUser> nodeUses;
124a5ea6045SRiver Riddle
125a5ea6045SRiver Riddle /// A symbol table to use when resolving call lookups.
126a5ea6045SRiver Riddle SymbolTableCollection &symbolTable;
1274be504a9SRiver Riddle };
128be0a7e9fSMehdi Amini } // namespace
1294be504a9SRiver Riddle
CGUseList(Operation * op,CallGraph & cg,SymbolTableCollection & symbolTable)130a5ea6045SRiver Riddle CGUseList::CGUseList(Operation *op, CallGraph &cg,
131a5ea6045SRiver Riddle SymbolTableCollection &symbolTable)
132a5ea6045SRiver Riddle : symbolTable(symbolTable) {
1334be504a9SRiver Riddle /// A set of callgraph nodes that are always known to be live during inlining.
1344be504a9SRiver Riddle DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes;
1354be504a9SRiver Riddle
1364be504a9SRiver Riddle // Walk each of the symbol tables looking for discardable callgraph nodes.
1374be504a9SRiver Riddle auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
1381e4faf23SRiver Riddle for (Operation &op : symbolTableOp->getRegion(0).getOps()) {
1394be504a9SRiver Riddle // If this is a callgraph operation, check to see if it is discardable.
1404be504a9SRiver Riddle if (auto callable = dyn_cast<CallableOpInterface>(&op)) {
1414be504a9SRiver Riddle if (auto *node = cg.lookupNode(callable.getCallableRegion())) {
1427c221a7dSRiver Riddle SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
1437c221a7dSRiver Riddle if (symbol && (allUsesVisible || symbol.isPrivate()) &&
1447c221a7dSRiver Riddle symbol.canDiscardOnUseEmpty()) {
1454be504a9SRiver Riddle discardableSymNodeUses.try_emplace(node, 0);
1467c221a7dSRiver Riddle }
1474be504a9SRiver Riddle continue;
1484be504a9SRiver Riddle }
1494be504a9SRiver Riddle }
1504be504a9SRiver Riddle // Otherwise, check for any referenced nodes. These will be always-live.
151a5ea6045SRiver Riddle walkReferencedSymbolNodes(&op, cg, symbolTable, alwaysLiveNodes,
1524be504a9SRiver Riddle [](CallGraphNode *, Operation *) {});
1534be504a9SRiver Riddle }
1544be504a9SRiver Riddle };
155a90151d6SRiver Riddle SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
156a90151d6SRiver Riddle walkFn);
1574be504a9SRiver Riddle
1584be504a9SRiver Riddle // Drop the use information for any discardable nodes that are always live.
1594be504a9SRiver Riddle for (auto &it : alwaysLiveNodes)
1604be504a9SRiver Riddle discardableSymNodeUses.erase(it.second);
1614be504a9SRiver Riddle
1624be504a9SRiver Riddle // Compute the uses for each of the callable nodes in the graph.
1634be504a9SRiver Riddle for (CallGraphNode *node : cg)
1644be504a9SRiver Riddle recomputeUses(node, cg);
1654be504a9SRiver Riddle }
1664be504a9SRiver Riddle
dropCallUses(CallGraphNode * userNode,Operation * callOp,CallGraph & cg)1674be504a9SRiver Riddle void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp,
1684be504a9SRiver Riddle CallGraph &cg) {
1694be504a9SRiver Riddle auto &userRefs = nodeUses[userNode].innerUses;
1704be504a9SRiver Riddle auto walkFn = [&](CallGraphNode *node, Operation *user) {
1714be504a9SRiver Riddle auto parentIt = userRefs.find(node);
1724be504a9SRiver Riddle if (parentIt == userRefs.end())
1734be504a9SRiver Riddle return;
1744be504a9SRiver Riddle --parentIt->second;
1754be504a9SRiver Riddle --discardableSymNodeUses[node];
1764be504a9SRiver Riddle };
1774be504a9SRiver Riddle DenseMap<Attribute, CallGraphNode *> resolvedRefs;
178a5ea6045SRiver Riddle walkReferencedSymbolNodes(callOp, cg, symbolTable, resolvedRefs, walkFn);
1794be504a9SRiver Riddle }
1804be504a9SRiver Riddle
eraseNode(CallGraphNode * node)1814be504a9SRiver Riddle void CGUseList::eraseNode(CallGraphNode *node) {
1824be504a9SRiver Riddle // Drop all child nodes.
1834be504a9SRiver Riddle for (auto &edge : *node)
1844be504a9SRiver Riddle if (edge.isChild())
1854be504a9SRiver Riddle eraseNode(edge.getTarget());
1864be504a9SRiver Riddle
1874be504a9SRiver Riddle // Drop the uses held by this node and erase it.
1884be504a9SRiver Riddle auto useIt = nodeUses.find(node);
1894be504a9SRiver Riddle assert(useIt != nodeUses.end() && "expected node to be valid");
1904be504a9SRiver Riddle decrementDiscardableUses(useIt->getSecond());
1914be504a9SRiver Riddle nodeUses.erase(useIt);
1924be504a9SRiver Riddle discardableSymNodeUses.erase(node);
1934be504a9SRiver Riddle }
1944be504a9SRiver Riddle
isDead(CallGraphNode * node) const1954be504a9SRiver Riddle bool CGUseList::isDead(CallGraphNode *node) const {
1964be504a9SRiver Riddle // If the parent operation isn't a symbol, simply check normal SSA deadness.
1974be504a9SRiver Riddle Operation *nodeOp = node->getCallableRegion()->getParentOp();
1987c221a7dSRiver Riddle if (!isa<SymbolOpInterface>(nodeOp))
1994be504a9SRiver Riddle return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->use_empty();
2004be504a9SRiver Riddle
2014be504a9SRiver Riddle // Otherwise, check the number of symbol uses.
2024be504a9SRiver Riddle auto symbolIt = discardableSymNodeUses.find(node);
2034be504a9SRiver Riddle return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0;
2044be504a9SRiver Riddle }
2054be504a9SRiver Riddle
hasOneUseAndDiscardable(CallGraphNode * node) const2064be504a9SRiver Riddle bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const {
2074be504a9SRiver Riddle // If this isn't a symbol node, check for side-effects and SSA use count.
2084be504a9SRiver Riddle Operation *nodeOp = node->getCallableRegion()->getParentOp();
2097c221a7dSRiver Riddle if (!isa<SymbolOpInterface>(nodeOp))
2104be504a9SRiver Riddle return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->hasOneUse();
2114be504a9SRiver Riddle
2124be504a9SRiver Riddle // Otherwise, check the number of symbol uses.
2134be504a9SRiver Riddle auto symbolIt = discardableSymNodeUses.find(node);
2144be504a9SRiver Riddle return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1;
2154be504a9SRiver Riddle }
2164be504a9SRiver Riddle
recomputeUses(CallGraphNode * node,CallGraph & cg)2174be504a9SRiver Riddle void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) {
2184be504a9SRiver Riddle Operation *parentOp = node->getCallableRegion()->getParentOp();
2194be504a9SRiver Riddle CGUser &uses = nodeUses[node];
2204be504a9SRiver Riddle decrementDiscardableUses(uses);
2214be504a9SRiver Riddle
2224be504a9SRiver Riddle // Collect the new discardable uses within this node.
2234be504a9SRiver Riddle uses = CGUser();
2244be504a9SRiver Riddle DenseMap<Attribute, CallGraphNode *> resolvedRefs;
2254be504a9SRiver Riddle auto walkFn = [&](CallGraphNode *refNode, Operation *user) {
2264be504a9SRiver Riddle auto discardSymIt = discardableSymNodeUses.find(refNode);
2274be504a9SRiver Riddle if (discardSymIt == discardableSymNodeUses.end())
2284be504a9SRiver Riddle return;
2294be504a9SRiver Riddle
2304be504a9SRiver Riddle if (user != parentOp)
2314be504a9SRiver Riddle ++uses.innerUses[refNode];
2324be504a9SRiver Riddle else if (!uses.topLevelUses.insert(refNode).second)
2334be504a9SRiver Riddle return;
2344be504a9SRiver Riddle ++discardSymIt->second;
2354be504a9SRiver Riddle };
236a5ea6045SRiver Riddle walkReferencedSymbolNodes(parentOp, cg, symbolTable, resolvedRefs, walkFn);
2374be504a9SRiver Riddle }
2384be504a9SRiver Riddle
mergeUsesAfterInlining(CallGraphNode * lhs,CallGraphNode * rhs)2394be504a9SRiver Riddle void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) {
2404be504a9SRiver Riddle auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs];
2414be504a9SRiver Riddle for (auto &useIt : lhsUses.innerUses) {
2424be504a9SRiver Riddle rhsUses.innerUses[useIt.first] += useIt.second;
2434be504a9SRiver Riddle discardableSymNodeUses[useIt.first] += useIt.second;
2444be504a9SRiver Riddle }
2454be504a9SRiver Riddle }
2464be504a9SRiver Riddle
decrementDiscardableUses(CGUser & uses)2474be504a9SRiver Riddle void CGUseList::decrementDiscardableUses(CGUser &uses) {
2484be504a9SRiver Riddle for (CallGraphNode *node : uses.topLevelUses)
2494be504a9SRiver Riddle --discardableSymNodeUses[node];
2504be504a9SRiver Riddle for (auto &it : uses.innerUses)
2514be504a9SRiver Riddle discardableSymNodeUses[it.first] -= it.second;
2524be504a9SRiver Riddle }
2534be504a9SRiver Riddle
2544be504a9SRiver Riddle //===----------------------------------------------------------------------===//
255a20d96e4SRiver Riddle // CallGraph traversal
256a20d96e4SRiver Riddle //===----------------------------------------------------------------------===//
257a20d96e4SRiver Riddle
258f4ef77cbSRiver Riddle namespace {
259f4ef77cbSRiver Riddle /// This class represents a specific callgraph SCC.
260f4ef77cbSRiver Riddle class CallGraphSCC {
261f4ef77cbSRiver Riddle public:
CallGraphSCC(llvm::scc_iterator<const CallGraph * > & parentIterator)262f4ef77cbSRiver Riddle CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator)
263f4ef77cbSRiver Riddle : parentIterator(parentIterator) {}
264f4ef77cbSRiver Riddle /// Return a range over the nodes within this SCC.
begin()265f4ef77cbSRiver Riddle std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); }
end()266f4ef77cbSRiver Riddle std::vector<CallGraphNode *>::iterator end() { return nodes.end(); }
267f4ef77cbSRiver Riddle
268f4ef77cbSRiver Riddle /// Reset the nodes of this SCC with those provided.
reset(const std::vector<CallGraphNode * > & newNodes)269f4ef77cbSRiver Riddle void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; }
270f4ef77cbSRiver Riddle
271f4ef77cbSRiver Riddle /// Remove the given node from this SCC.
remove(CallGraphNode * node)272f4ef77cbSRiver Riddle void remove(CallGraphNode *node) {
273f4ef77cbSRiver Riddle auto it = llvm::find(nodes, node);
274f4ef77cbSRiver Riddle if (it != nodes.end()) {
275f4ef77cbSRiver Riddle nodes.erase(it);
276f4ef77cbSRiver Riddle parentIterator.ReplaceNode(node, nullptr);
277f4ef77cbSRiver Riddle }
278f4ef77cbSRiver Riddle }
279f4ef77cbSRiver Riddle
280f4ef77cbSRiver Riddle private:
281f4ef77cbSRiver Riddle std::vector<CallGraphNode *> nodes;
282f4ef77cbSRiver Riddle llvm::scc_iterator<const CallGraph *> &parentIterator;
283f4ef77cbSRiver Riddle };
284be0a7e9fSMehdi Amini } // namespace
285f4ef77cbSRiver Riddle
286a20d96e4SRiver Riddle /// Run a given transformation over the SCCs of the callgraph in a bottom up
287a20d96e4SRiver Riddle /// traversal.
runTransformOnCGSCCs(const CallGraph & cg,function_ref<LogicalResult (CallGraphSCC &)> sccTransformer)288d7eba200SRiver Riddle static LogicalResult runTransformOnCGSCCs(
289d7eba200SRiver Riddle const CallGraph &cg,
290d7eba200SRiver Riddle function_ref<LogicalResult(CallGraphSCC &)> sccTransformer) {
291f4ef77cbSRiver Riddle llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg);
292f4ef77cbSRiver Riddle CallGraphSCC currentSCC(cgi);
2936b1cc3c6SRiver Riddle while (!cgi.isAtEnd()) {
2946b1cc3c6SRiver Riddle // Copy the current SCC and increment so that the transformer can modify the
2956b1cc3c6SRiver Riddle // SCC without invalidating our iterator.
296f4ef77cbSRiver Riddle currentSCC.reset(*cgi);
2976b1cc3c6SRiver Riddle ++cgi;
298d7eba200SRiver Riddle if (failed(sccTransformer(currentSCC)))
299d7eba200SRiver Riddle return failure();
3006b1cc3c6SRiver Riddle }
301d7eba200SRiver Riddle return success();
302a20d96e4SRiver Riddle }
303a20d96e4SRiver Riddle
304a20d96e4SRiver Riddle namespace {
305a20d96e4SRiver Riddle /// This struct represents a resolved call to a given callgraph node. Given that
306a20d96e4SRiver Riddle /// the call does not actually contain a direct reference to the
307a20d96e4SRiver Riddle /// Region(CallGraphNode) that it is dispatching to, we need to resolve them
308a20d96e4SRiver Riddle /// explicitly.
309a20d96e4SRiver Riddle struct ResolvedCall {
ResolvedCall__anondfe352480711::ResolvedCall3104be504a9SRiver Riddle ResolvedCall(CallOpInterface call, CallGraphNode *sourceNode,
3114be504a9SRiver Riddle CallGraphNode *targetNode)
3124be504a9SRiver Riddle : call(call), sourceNode(sourceNode), targetNode(targetNode) {}
313a20d96e4SRiver Riddle CallOpInterface call;
3144be504a9SRiver Riddle CallGraphNode *sourceNode, *targetNode;
315a20d96e4SRiver Riddle };
316be0a7e9fSMehdi Amini } // namespace
317a20d96e4SRiver Riddle
318a20d96e4SRiver Riddle /// Collect all of the callable operations within the given range of blocks. If
319a20d96e4SRiver Riddle /// `traverseNestedCGNodes` is true, this will also collect call operations
320a20d96e4SRiver Riddle /// inside of nested callgraph nodes.
collectCallOps(iterator_range<Region::iterator> blocks,CallGraphNode * sourceNode,CallGraph & cg,SymbolTableCollection & symbolTable,SmallVectorImpl<ResolvedCall> & calls,bool traverseNestedCGNodes)3214562e389SRiver Riddle static void collectCallOps(iterator_range<Region::iterator> blocks,
3224be504a9SRiver Riddle CallGraphNode *sourceNode, CallGraph &cg,
323a5ea6045SRiver Riddle SymbolTableCollection &symbolTable,
3244be504a9SRiver Riddle SmallVectorImpl<ResolvedCall> &calls,
325a20d96e4SRiver Riddle bool traverseNestedCGNodes) {
3264be504a9SRiver Riddle SmallVector<std::pair<Block *, CallGraphNode *>, 8> worklist;
3274be504a9SRiver Riddle auto addToWorklist = [&](CallGraphNode *node,
3284be504a9SRiver Riddle iterator_range<Region::iterator> blocks) {
329a20d96e4SRiver Riddle for (Block &block : blocks)
3304be504a9SRiver Riddle worklist.emplace_back(&block, node);
331a20d96e4SRiver Riddle };
332a20d96e4SRiver Riddle
3334be504a9SRiver Riddle addToWorklist(sourceNode, blocks);
334a20d96e4SRiver Riddle while (!worklist.empty()) {
3354be504a9SRiver Riddle Block *block;
3364be504a9SRiver Riddle std::tie(block, sourceNode) = worklist.pop_back_val();
3374be504a9SRiver Riddle
3384be504a9SRiver Riddle for (Operation &op : *block) {
339a20d96e4SRiver Riddle if (auto call = dyn_cast<CallOpInterface>(op)) {
3409db53a18SRiver Riddle // TODO: Support inlining nested call references.
3415c159b91SRiver Riddle CallInterfaceCallable callable = call.getCallableForCallee();
342c7748404SRiver Riddle if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) {
343c7748404SRiver Riddle if (!symRef.isa<FlatSymbolRefAttr>())
344c7748404SRiver Riddle continue;
345c7748404SRiver Riddle }
346c7748404SRiver Riddle
347a5ea6045SRiver Riddle CallGraphNode *targetNode = cg.resolveCallable(call, symbolTable);
3484be504a9SRiver Riddle if (!targetNode->isExternal())
3494be504a9SRiver Riddle calls.emplace_back(call, sourceNode, targetNode);
350a20d96e4SRiver Riddle continue;
351a20d96e4SRiver Riddle }
352a20d96e4SRiver Riddle
353a20d96e4SRiver Riddle // If this is not a call, traverse the nested regions. If
354a20d96e4SRiver Riddle // `traverseNestedCGNodes` is false, then don't traverse nested call graph
355a20d96e4SRiver Riddle // regions.
3564be504a9SRiver Riddle for (auto &nestedRegion : op.getRegions()) {
3574be504a9SRiver Riddle CallGraphNode *nestedNode = cg.lookupNode(&nestedRegion);
3584be504a9SRiver Riddle if (traverseNestedCGNodes || !nestedNode)
3594be504a9SRiver Riddle addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion);
3604be504a9SRiver Riddle }
361a20d96e4SRiver Riddle }
362a20d96e4SRiver Riddle }
363a20d96e4SRiver Riddle }
364a20d96e4SRiver Riddle
365a20d96e4SRiver Riddle //===----------------------------------------------------------------------===//
366a20d96e4SRiver Riddle // Inliner
367a20d96e4SRiver Riddle //===----------------------------------------------------------------------===//
368*c2ecf162SJaved Absar
369*c2ecf162SJaved Absar #ifndef NDEBUG
getNodeName(CallOpInterface op)370*c2ecf162SJaved Absar static std::string getNodeName(CallOpInterface op) {
371*c2ecf162SJaved Absar if (auto sym = op.getCallableForCallee().dyn_cast<SymbolRefAttr>())
372*c2ecf162SJaved Absar return debugString(op);
373*c2ecf162SJaved Absar return "_unnamed_callee_";
374*c2ecf162SJaved Absar }
375*c2ecf162SJaved Absar #endif
376*c2ecf162SJaved Absar
377*c2ecf162SJaved Absar /// Return true if the specified `inlineHistoryID` indicates an inline history
378*c2ecf162SJaved Absar /// that already includes `node`.
inlineHistoryIncludes(CallGraphNode * node,Optional<size_t> inlineHistoryID,MutableArrayRef<std::pair<CallGraphNode *,Optional<size_t>>> inlineHistory)379*c2ecf162SJaved Absar static bool inlineHistoryIncludes(
380*c2ecf162SJaved Absar CallGraphNode *node, Optional<size_t> inlineHistoryID,
381*c2ecf162SJaved Absar MutableArrayRef<std::pair<CallGraphNode *, Optional<size_t>>>
382*c2ecf162SJaved Absar inlineHistory) {
383*c2ecf162SJaved Absar while (inlineHistoryID.has_value()) {
384*c2ecf162SJaved Absar assert(inlineHistoryID.value() < inlineHistory.size() &&
385*c2ecf162SJaved Absar "Invalid inline history ID");
386*c2ecf162SJaved Absar if (inlineHistory[inlineHistoryID.value()].first == node)
387*c2ecf162SJaved Absar return true;
388*c2ecf162SJaved Absar inlineHistoryID = inlineHistory[inlineHistoryID.value()].second;
389*c2ecf162SJaved Absar }
390*c2ecf162SJaved Absar return false;
391*c2ecf162SJaved Absar }
392*c2ecf162SJaved Absar
393a20d96e4SRiver Riddle namespace {
394a20d96e4SRiver Riddle /// This class provides a specialization of the main inlining interface.
395a20d96e4SRiver Riddle struct Inliner : public InlinerInterface {
Inliner__anondfe352480911::Inliner396a5ea6045SRiver Riddle Inliner(MLIRContext *context, CallGraph &cg,
397a5ea6045SRiver Riddle SymbolTableCollection &symbolTable)
398a5ea6045SRiver Riddle : InlinerInterface(context), cg(cg), symbolTable(symbolTable) {}
399a20d96e4SRiver Riddle
400a20d96e4SRiver Riddle /// Process a set of blocks that have been inlined. This callback is invoked
401a20d96e4SRiver Riddle /// *before* inlined terminator operations have been processed.
4024562e389SRiver Riddle void
processInlinedBlocks__anondfe352480911::Inliner4034562e389SRiver Riddle processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final {
4044be504a9SRiver Riddle // Find the closest callgraph node from the first block.
4054be504a9SRiver Riddle CallGraphNode *node;
4064be504a9SRiver Riddle Region *region = inlinedBlocks.begin()->getParent();
4074be504a9SRiver Riddle while (!(node = cg.lookupNode(region))) {
4084be504a9SRiver Riddle region = region->getParentRegion();
4094be504a9SRiver Riddle assert(region && "expected valid parent node");
4104be504a9SRiver Riddle }
4114be504a9SRiver Riddle
412a5ea6045SRiver Riddle collectCallOps(inlinedBlocks, node, cg, symbolTable, calls,
4134be504a9SRiver Riddle /*traverseNestedCGNodes=*/true);
414a20d96e4SRiver Riddle }
415a20d96e4SRiver Riddle
416f4ef77cbSRiver Riddle /// Mark the given callgraph node for deletion.
markForDeletion__anondfe352480911::Inliner417f4ef77cbSRiver Riddle void markForDeletion(CallGraphNode *node) { deadNodes.insert(node); }
418f4ef77cbSRiver Riddle
419f4ef77cbSRiver Riddle /// This method properly disposes of callables that became dead during
420f4ef77cbSRiver Riddle /// inlining. This should not be called while iterating over the SCCs.
eraseDeadCallables__anondfe352480911::Inliner421f4ef77cbSRiver Riddle void eraseDeadCallables() {
422f4ef77cbSRiver Riddle for (CallGraphNode *node : deadNodes)
423f4ef77cbSRiver Riddle node->getCallableRegion()->getParentOp()->erase();
424f4ef77cbSRiver Riddle }
425f4ef77cbSRiver Riddle
426f4ef77cbSRiver Riddle /// The set of callables known to be dead.
427f4ef77cbSRiver Riddle SmallPtrSet<CallGraphNode *, 8> deadNodes;
428f4ef77cbSRiver Riddle
429a20d96e4SRiver Riddle /// The current set of call instructions to consider for inlining.
430a20d96e4SRiver Riddle SmallVector<ResolvedCall, 8> calls;
431a20d96e4SRiver Riddle
432a20d96e4SRiver Riddle /// The callgraph being operated on.
433a20d96e4SRiver Riddle CallGraph &cg;
434a5ea6045SRiver Riddle
435a5ea6045SRiver Riddle /// A symbol table to use when resolving call lookups.
436a5ea6045SRiver Riddle SymbolTableCollection &symbolTable;
437a20d96e4SRiver Riddle };
438a20d96e4SRiver Riddle } // namespace
439a20d96e4SRiver Riddle
440a20d96e4SRiver Riddle /// Returns true if the given call should be inlined.
shouldInline(ResolvedCall & resolvedCall)441a20d96e4SRiver Riddle static bool shouldInline(ResolvedCall &resolvedCall) {
442a20d96e4SRiver Riddle // Don't allow inlining terminator calls. We currently don't support this
443a20d96e4SRiver Riddle // case.
444fe7c0d90SRiver Riddle if (resolvedCall.call->hasTrait<OpTrait::IsTerminator>())
445a20d96e4SRiver Riddle return false;
446a20d96e4SRiver Riddle
447a20d96e4SRiver Riddle // Don't allow inlining if the target is an ancestor of the call. This
448a20d96e4SRiver Riddle // prevents inlining recursively.
449a20d96e4SRiver Riddle if (resolvedCall.targetNode->getCallableRegion()->isAncestor(
4500bf4a82aSChristian Sigg resolvedCall.call->getParentRegion()))
451a20d96e4SRiver Riddle return false;
452a20d96e4SRiver Riddle
453a20d96e4SRiver Riddle // Otherwise, inline.
454a20d96e4SRiver Riddle return true;
455a20d96e4SRiver Riddle }
456a20d96e4SRiver Riddle
4576b1cc3c6SRiver Riddle /// Attempt to inline calls within the given scc. This function returns
4586b1cc3c6SRiver Riddle /// success if any calls were inlined, failure otherwise.
inlineCallsInSCC(Inliner & inliner,CGUseList & useList,CallGraphSCC & currentSCC)459f4ef77cbSRiver Riddle static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
460f4ef77cbSRiver Riddle CallGraphSCC ¤tSCC) {
461a20d96e4SRiver Riddle CallGraph &cg = inliner.cg;
462a20d96e4SRiver Riddle auto &calls = inliner.calls;
463a20d96e4SRiver Riddle
464f4ef77cbSRiver Riddle // A set of dead nodes to remove after inlining.
4652f672e2fSAlex Zinenko llvm::SmallSetVector<CallGraphNode *, 1> deadNodes;
466f4ef77cbSRiver Riddle
467a20d96e4SRiver Riddle // Collect all of the direct calls within the nodes of the current SCC. We
468a20d96e4SRiver Riddle // don't traverse nested callgraph nodes, because they are handled separately
469a20d96e4SRiver Riddle // likely within a different SCC.
4704be504a9SRiver Riddle for (CallGraphNode *node : currentSCC) {
4714be504a9SRiver Riddle if (node->isExternal())
4724be504a9SRiver Riddle continue;
4734be504a9SRiver Riddle
474f4ef77cbSRiver Riddle // Don't collect calls if the node is already dead.
475a5ea6045SRiver Riddle if (useList.isDead(node)) {
4762f672e2fSAlex Zinenko deadNodes.insert(node);
477a5ea6045SRiver Riddle } else {
478a5ea6045SRiver Riddle collectCallOps(*node->getCallableRegion(), node, cg, inliner.symbolTable,
479a5ea6045SRiver Riddle calls, /*traverseNestedCGNodes=*/false);
480a5ea6045SRiver Riddle }
481a20d96e4SRiver Riddle }
4824be504a9SRiver Riddle
483*c2ecf162SJaved Absar // When inlining a callee produces new call sites, we want to keep track of
484*c2ecf162SJaved Absar // the fact that they were inlined from the callee. This allows us to avoid
485*c2ecf162SJaved Absar // infinite inlining.
486*c2ecf162SJaved Absar using InlineHistoryT = Optional<size_t>;
487*c2ecf162SJaved Absar SmallVector<std::pair<CallGraphNode *, InlineHistoryT>, 8> inlineHistory;
488*c2ecf162SJaved Absar std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{});
489*c2ecf162SJaved Absar
490*c2ecf162SJaved Absar LLVM_DEBUG({
491*c2ecf162SJaved Absar llvm::dbgs() << "* Inliner: Initial calls in SCC are: {\n";
492*c2ecf162SJaved Absar for (unsigned i = 0, e = calls.size(); i < e; ++i)
493*c2ecf162SJaved Absar llvm::dbgs() << " " << i << ". " << calls[i].call << ",\n";
494*c2ecf162SJaved Absar llvm::dbgs() << "}\n";
495*c2ecf162SJaved Absar });
496*c2ecf162SJaved Absar
497a20d96e4SRiver Riddle // Try to inline each of the call operations. Don't cache the end iterator
498a20d96e4SRiver Riddle // here as more calls may be added during inlining.
4996b1cc3c6SRiver Riddle bool inlinedAnyCalls = false;
500*c2ecf162SJaved Absar for (unsigned i = 0; i < calls.size(); ++i) {
5012f672e2fSAlex Zinenko if (deadNodes.contains(calls[i].sourceNode))
5022f672e2fSAlex Zinenko continue;
5034f37450bSRiver Riddle ResolvedCall it = calls[i];
504*c2ecf162SJaved Absar
505*c2ecf162SJaved Absar InlineHistoryT inlineHistoryID = callHistory[i];
506*c2ecf162SJaved Absar bool inHistory =
507*c2ecf162SJaved Absar inlineHistoryIncludes(it.targetNode, inlineHistoryID, inlineHistory);
508*c2ecf162SJaved Absar bool doInline = !inHistory && shouldInline(it);
50947593511SRahul Joshi CallOpInterface call = it.call;
510553f794bSSean Silva LLVM_DEBUG({
51122219cfcSSean Silva if (doInline)
512*c2ecf162SJaved Absar llvm::dbgs() << "* Inlining call: " << i << ". " << call << "\n";
51322219cfcSSean Silva else
514*c2ecf162SJaved Absar llvm::dbgs() << "* Not inlining call: " << i << ". " << call << "\n";
515553f794bSSean Silva });
51622219cfcSSean Silva if (!doInline)
517a20d96e4SRiver Riddle continue;
518*c2ecf162SJaved Absar
519*c2ecf162SJaved Absar unsigned prevSize = calls.size();
5205830f71aSRiver Riddle Region *targetRegion = it.targetNode->getCallableRegion();
5214be504a9SRiver Riddle
5224be504a9SRiver Riddle // If this is the last call to the target node and the node is discardable,
5234be504a9SRiver Riddle // then inline it in-place and delete the node if successful.
5244be504a9SRiver Riddle bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode);
5254be504a9SRiver Riddle
5265830f71aSRiver Riddle LogicalResult inlineResult = inlineCall(
5275830f71aSRiver Riddle inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()),
5284be504a9SRiver Riddle targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace);
52947593511SRahul Joshi if (failed(inlineResult)) {
53047593511SRahul Joshi LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n");
531a20d96e4SRiver Riddle continue;
53247593511SRahul Joshi }
5336b1cc3c6SRiver Riddle inlinedAnyCalls = true;
5344be504a9SRiver Riddle
535*c2ecf162SJaved Absar // Create a inline history entry for this inlined call, so that we remember
536*c2ecf162SJaved Absar // that new callsites came about due to inlining Callee.
537*c2ecf162SJaved Absar InlineHistoryT newInlineHistoryID{inlineHistory.size()};
538*c2ecf162SJaved Absar inlineHistory.push_back(std::make_pair(it.targetNode, inlineHistoryID));
539*c2ecf162SJaved Absar
540*c2ecf162SJaved Absar auto historyToString = [](InlineHistoryT h) {
541*c2ecf162SJaved Absar return h.has_value() ? std::to_string(h.value()) : "root";
542*c2ecf162SJaved Absar };
543*c2ecf162SJaved Absar (void)historyToString;
544*c2ecf162SJaved Absar LLVM_DEBUG(llvm::dbgs()
545*c2ecf162SJaved Absar << "* new inlineHistory entry: " << newInlineHistoryID << ". ["
546*c2ecf162SJaved Absar << getNodeName(call) << ", " << historyToString(inlineHistoryID)
547*c2ecf162SJaved Absar << "]\n");
548*c2ecf162SJaved Absar
549*c2ecf162SJaved Absar for (unsigned k = prevSize; k != calls.size(); ++k) {
550*c2ecf162SJaved Absar callHistory.push_back(newInlineHistoryID);
551*c2ecf162SJaved Absar LLVM_DEBUG(llvm::dbgs() << "* new call " << k << " {" << calls[i].call
552*c2ecf162SJaved Absar << "}\n with historyID = " << newInlineHistoryID
553*c2ecf162SJaved Absar << ", added due to inlining of\n call {" << call
554*c2ecf162SJaved Absar << "}\n with historyID = "
555*c2ecf162SJaved Absar << historyToString(inlineHistoryID) << "\n");
556*c2ecf162SJaved Absar }
557*c2ecf162SJaved Absar
5584be504a9SRiver Riddle // If the inlining was successful, Merge the new uses into the source node.
5594be504a9SRiver Riddle useList.dropCallUses(it.sourceNode, call.getOperation(), cg);
5604be504a9SRiver Riddle useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode);
5614be504a9SRiver Riddle
5624be504a9SRiver Riddle // then erase the call.
5634be504a9SRiver Riddle call.erase();
5644be504a9SRiver Riddle
5654be504a9SRiver Riddle // If we inlined in place, mark the node for deletion.
5664be504a9SRiver Riddle if (inlineInPlace) {
5674be504a9SRiver Riddle useList.eraseNode(it.targetNode);
5682f672e2fSAlex Zinenko deadNodes.insert(it.targetNode);
569a20d96e4SRiver Riddle }
5704be504a9SRiver Riddle }
5714be504a9SRiver Riddle
572f4ef77cbSRiver Riddle for (CallGraphNode *node : deadNodes) {
573f4ef77cbSRiver Riddle currentSCC.remove(node);
574f4ef77cbSRiver Riddle inliner.markForDeletion(node);
575f4ef77cbSRiver Riddle }
576a20d96e4SRiver Riddle calls.clear();
5776b1cc3c6SRiver Riddle return success(inlinedAnyCalls);
5786b1cc3c6SRiver Riddle }
5796b1cc3c6SRiver Riddle
580a20d96e4SRiver Riddle //===----------------------------------------------------------------------===//
581a20d96e4SRiver Riddle // InlinerPass
582a20d96e4SRiver Riddle //===----------------------------------------------------------------------===//
583a20d96e4SRiver Riddle
5840ba00878SRiver Riddle namespace {
585d7eba200SRiver Riddle class InlinerPass : public InlinerBase<InlinerPass> {
586d7eba200SRiver Riddle public:
587d7eba200SRiver Riddle InlinerPass();
588d7eba200SRiver Riddle InlinerPass(const InlinerPass &) = default;
589d7eba200SRiver Riddle InlinerPass(std::function<void(OpPassManager &)> defaultPipeline);
590d7eba200SRiver Riddle InlinerPass(std::function<void(OpPassManager &)> defaultPipeline,
591d7eba200SRiver Riddle llvm::StringMap<OpPassManager> opPipelines);
592400ad6f9SRiver Riddle void runOnOperation() override;
593400ad6f9SRiver Riddle
594d7eba200SRiver Riddle private:
595d7eba200SRiver Riddle /// Attempt to inline calls within the given scc, and run simplifications,
596d7eba200SRiver Riddle /// until a fixed point is reached. This allows for the inlining of newly
597d7eba200SRiver Riddle /// devirtualized calls. Returns failure if there was a fatal error during
598d7eba200SRiver Riddle /// inlining.
599d7eba200SRiver Riddle LogicalResult inlineSCC(Inliner &inliner, CGUseList &useList,
600d7eba200SRiver Riddle CallGraphSCC ¤tSCC, MLIRContext *context);
601d7eba200SRiver Riddle
602d7eba200SRiver Riddle /// Optimize the nodes within the given SCC with one of the held optimization
603d7eba200SRiver Riddle /// pass pipelines. Returns failure if an error occurred during the
604d7eba200SRiver Riddle /// optimization of the SCC, success otherwise.
605d7eba200SRiver Riddle LogicalResult optimizeSCC(CallGraph &cg, CGUseList &useList,
606d7eba200SRiver Riddle CallGraphSCC ¤tSCC, MLIRContext *context);
607d7eba200SRiver Riddle
608d7eba200SRiver Riddle /// Optimize the nodes within the given SCC in parallel. Returns failure if an
609d7eba200SRiver Riddle /// error occurred during the optimization of the SCC, success otherwise.
610d7eba200SRiver Riddle LogicalResult optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
611d7eba200SRiver Riddle MLIRContext *context);
612d7eba200SRiver Riddle
613d7eba200SRiver Riddle /// Optimize the given callable node with one of the pass managers provided
614d7eba200SRiver Riddle /// with `pipelines`, or the default pipeline. Returns failure if an error
615d7eba200SRiver Riddle /// occurred during the optimization of the callable, success otherwise.
616d7eba200SRiver Riddle LogicalResult optimizeCallable(CallGraphNode *node,
617d7eba200SRiver Riddle llvm::StringMap<OpPassManager> &pipelines);
618d7eba200SRiver Riddle
619d7eba200SRiver Riddle /// Attempt to initialize the options of this pass from the given string.
620d7eba200SRiver Riddle /// Derived classes may override this method to hook into the point at which
621d7eba200SRiver Riddle /// options are initialized, but should generally always invoke this base
622d7eba200SRiver Riddle /// class variant.
623d7eba200SRiver Riddle LogicalResult initializeOptions(StringRef options) override;
624d7eba200SRiver Riddle
625d7eba200SRiver Riddle /// An optional function that constructs a default optimization pipeline for
626d7eba200SRiver Riddle /// a given operation.
627d7eba200SRiver Riddle std::function<void(OpPassManager &)> defaultPipeline;
628d7eba200SRiver Riddle /// A map of operation names to pass pipelines to use when optimizing
629d7eba200SRiver Riddle /// callable operations of these types. This provides a specialized pipeline
630d7eba200SRiver Riddle /// instead of the default. The vector size is the number of threads used
631d7eba200SRiver Riddle /// during optimization.
632d7eba200SRiver Riddle SmallVector<llvm::StringMap<OpPassManager>, 8> opPipelines;
633400ad6f9SRiver Riddle };
634be0a7e9fSMehdi Amini } // namespace
635400ad6f9SRiver Riddle
InlinerPass()636d7eba200SRiver Riddle InlinerPass::InlinerPass() : InlinerPass(defaultInlinerOptPipeline) {}
InlinerPass(std::function<void (OpPassManager &)> defaultPipeline)637d7eba200SRiver Riddle InlinerPass::InlinerPass(std::function<void(OpPassManager &)> defaultPipeline)
6381fc096afSMehdi Amini : defaultPipeline(std::move(defaultPipeline)) {
639d7eba200SRiver Riddle opPipelines.push_back({});
640d7eba200SRiver Riddle
641d7eba200SRiver Riddle // Initialize the pass options with the provided arguments.
642d7eba200SRiver Riddle if (defaultPipeline) {
643d7eba200SRiver Riddle OpPassManager fakePM("__mlir_fake_pm_op");
644d7eba200SRiver Riddle defaultPipeline(fakePM);
645d7eba200SRiver Riddle llvm::raw_string_ostream strStream(defaultPipelineStr);
646d7eba200SRiver Riddle fakePM.printAsTextualPipeline(strStream);
647d7eba200SRiver Riddle }
648d7eba200SRiver Riddle }
649d7eba200SRiver Riddle
InlinerPass(std::function<void (OpPassManager &)> defaultPipeline,llvm::StringMap<OpPassManager> opPipelines)650d7eba200SRiver Riddle InlinerPass::InlinerPass(std::function<void(OpPassManager &)> defaultPipeline,
651d7eba200SRiver Riddle llvm::StringMap<OpPassManager> opPipelines)
652d7eba200SRiver Riddle : InlinerPass(std::move(defaultPipeline)) {
653d7eba200SRiver Riddle if (opPipelines.empty())
654d7eba200SRiver Riddle return;
655d7eba200SRiver Riddle
656d7eba200SRiver Riddle // Update the option for the op specific optimization pipelines.
6570d8df980SRiver Riddle for (auto &it : opPipelines)
6580d8df980SRiver Riddle opPipelineList.addValue(it.second);
659d7eba200SRiver Riddle this->opPipelines.emplace_back(std::move(opPipelines));
660d7eba200SRiver Riddle }
661d7eba200SRiver Riddle
runOnOperation()662400ad6f9SRiver Riddle void InlinerPass::runOnOperation() {
663a20d96e4SRiver Riddle CallGraph &cg = getAnalysis<CallGraph>();
6646b1cc3c6SRiver Riddle auto *context = &getContext();
6656b1cc3c6SRiver Riddle
666c7748404SRiver Riddle // The inliner should only be run on operations that define a symbol table,
667c7748404SRiver Riddle // as the callgraph will need to resolve references.
668c7748404SRiver Riddle Operation *op = getOperation();
669c7748404SRiver Riddle if (!op->hasTrait<OpTrait::SymbolTable>()) {
670c7748404SRiver Riddle op->emitOpError() << " was scheduled to run under the inliner, but does "
671c7748404SRiver Riddle "not define a symbol table";
672c7748404SRiver Riddle return signalPassFailure();
673c7748404SRiver Riddle }
674c7748404SRiver Riddle
675a20d96e4SRiver Riddle // Run the inline transform in post-order over the SCCs in the callgraph.
676a5ea6045SRiver Riddle SymbolTableCollection symbolTable;
677a5ea6045SRiver Riddle Inliner inliner(context, cg, symbolTable);
678a5ea6045SRiver Riddle CGUseList useList(getOperation(), cg, symbolTable);
679d7eba200SRiver Riddle LogicalResult result = runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) {
680d7eba200SRiver Riddle return inlineSCC(inliner, useList, scc, context);
681a20d96e4SRiver Riddle });
682d7eba200SRiver Riddle if (failed(result))
683d7eba200SRiver Riddle return signalPassFailure();
684f4ef77cbSRiver Riddle
685f4ef77cbSRiver Riddle // After inlining, make sure to erase any callables proven to be dead.
686f4ef77cbSRiver Riddle inliner.eraseDeadCallables();
6870ba00878SRiver Riddle }
688400ad6f9SRiver Riddle
inlineSCC(Inliner & inliner,CGUseList & useList,CallGraphSCC & currentSCC,MLIRContext * context)689d7eba200SRiver Riddle LogicalResult InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList,
690d7eba200SRiver Riddle CallGraphSCC ¤tSCC,
691d7eba200SRiver Riddle MLIRContext *context) {
692d7eba200SRiver Riddle // Continuously simplify and inline until we either reach a fixed point, or
693d7eba200SRiver Riddle // hit the maximum iteration count. Simplifying early helps to refine the cost
694d7eba200SRiver Riddle // model, and in future iterations may devirtualize new calls.
695400ad6f9SRiver Riddle unsigned iterationCount = 0;
696d7eba200SRiver Riddle do {
697d7eba200SRiver Riddle if (failed(optimizeSCC(inliner.cg, useList, currentSCC, context)))
698d7eba200SRiver Riddle return failure();
699d7eba200SRiver Riddle if (failed(inlineCallsInSCC(inliner, useList, currentSCC)))
700400ad6f9SRiver Riddle break;
701d7eba200SRiver Riddle } while (++iterationCount < maxInliningIterations);
702d7eba200SRiver Riddle return success();
703400ad6f9SRiver Riddle }
704d7eba200SRiver Riddle
optimizeSCC(CallGraph & cg,CGUseList & useList,CallGraphSCC & currentSCC,MLIRContext * context)705d7eba200SRiver Riddle LogicalResult InlinerPass::optimizeSCC(CallGraph &cg, CGUseList &useList,
706d7eba200SRiver Riddle CallGraphSCC ¤tSCC,
707d7eba200SRiver Riddle MLIRContext *context) {
708d7eba200SRiver Riddle // Collect the sets of nodes to simplify.
709d7eba200SRiver Riddle SmallVector<CallGraphNode *, 4> nodesToVisit;
710d7eba200SRiver Riddle for (auto *node : currentSCC) {
711d7eba200SRiver Riddle if (node->isExternal())
712d7eba200SRiver Riddle continue;
713d7eba200SRiver Riddle
714d7eba200SRiver Riddle // Don't simplify nodes with children. Nodes with children require special
715d7eba200SRiver Riddle // handling as we may remove the node during simplification. In the future,
716d7eba200SRiver Riddle // we should be able to handle this case with proper node deletion tracking.
717d7eba200SRiver Riddle if (node->hasChildren())
718d7eba200SRiver Riddle continue;
719d7eba200SRiver Riddle
720d7eba200SRiver Riddle // We also won't apply simplifications to nodes that can't have passes
721d7eba200SRiver Riddle // scheduled on them.
722d7eba200SRiver Riddle auto *region = node->getCallableRegion();
723fe7c0d90SRiver Riddle if (!region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
724d7eba200SRiver Riddle continue;
725d7eba200SRiver Riddle nodesToVisit.push_back(node);
726d7eba200SRiver Riddle }
727d7eba200SRiver Riddle if (nodesToVisit.empty())
728d7eba200SRiver Riddle return success();
729d7eba200SRiver Riddle
730d7eba200SRiver Riddle // Optimize each of the nodes within the SCC in parallel.
731d7eba200SRiver Riddle if (failed(optimizeSCCAsync(nodesToVisit, context)))
732d7eba200SRiver Riddle return failure();
733d7eba200SRiver Riddle
734d7eba200SRiver Riddle // Recompute the uses held by each of the nodes.
735d7eba200SRiver Riddle for (CallGraphNode *node : nodesToVisit)
736d7eba200SRiver Riddle useList.recomputeUses(node, cg);
737d7eba200SRiver Riddle return success();
738d7eba200SRiver Riddle }
739d7eba200SRiver Riddle
740d7eba200SRiver Riddle LogicalResult
optimizeSCCAsync(MutableArrayRef<CallGraphNode * > nodesToVisit,MLIRContext * ctx)741d7eba200SRiver Riddle InlinerPass::optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
7426569cf2aSRiver Riddle MLIRContext *ctx) {
7438ff42766SStella Laurenzo // We must maintain a fixed pool of pass managers which is at least as large
7448ff42766SStella Laurenzo // as the maximum parallelism of the failableParallelForEach below.
7458ff42766SStella Laurenzo // Note: The number of pass managers here needs to remain constant
74616a50c9eSRiver Riddle // to prevent issues with pass instrumentations that rely on having the same
74716a50c9eSRiver Riddle // pass manager for the main thread.
74841a64338SMogball size_t numThreads = ctx->getNumThreads();
749abd3c6f2SRiver Riddle if (opPipelines.size() < numThreads) {
750d7eba200SRiver Riddle // Reserve before resizing so that we can use a reference to the first
751d7eba200SRiver Riddle // element.
752d7eba200SRiver Riddle opPipelines.reserve(numThreads);
753d7eba200SRiver Riddle opPipelines.resize(numThreads, opPipelines.front());
754d7eba200SRiver Riddle }
755d7eba200SRiver Riddle
756d7eba200SRiver Riddle // Ensure an analysis manager has been constructed for each of the nodes.
757d7eba200SRiver Riddle // This prevents thread races when running the nested pipelines.
758d7eba200SRiver Riddle for (CallGraphNode *node : nodesToVisit)
759d7eba200SRiver Riddle getAnalysisManager().nest(node->getCallableRegion()->getParentOp());
760d7eba200SRiver Riddle
7616569cf2aSRiver Riddle // An atomic failure variable for the async executors.
7626569cf2aSRiver Riddle std::vector<std::atomic<bool>> activePMs(opPipelines.size());
7636569cf2aSRiver Riddle std::fill(activePMs.begin(), activePMs.end(), false);
7646569cf2aSRiver Riddle return failableParallelForEach(ctx, nodesToVisit, [&](CallGraphNode *node) {
7656569cf2aSRiver Riddle // Find a pass manager for this operation.
7666569cf2aSRiver Riddle auto it = llvm::find_if(activePMs, [](std::atomic<bool> &isActive) {
7676569cf2aSRiver Riddle bool expectedInactive = false;
7686569cf2aSRiver Riddle return isActive.compare_exchange_strong(expectedInactive, true);
769d7eba200SRiver Riddle });
7708ff42766SStella Laurenzo assert(it != activePMs.end() &&
7718ff42766SStella Laurenzo "could not find inactive pass manager for thread");
7726569cf2aSRiver Riddle unsigned pmIndex = it - activePMs.begin();
7736569cf2aSRiver Riddle
7746569cf2aSRiver Riddle // Optimize this callable node.
7756569cf2aSRiver Riddle LogicalResult result = optimizeCallable(node, opPipelines[pmIndex]);
7766569cf2aSRiver Riddle
7776569cf2aSRiver Riddle // Reset the active bit for this pass manager.
7786569cf2aSRiver Riddle activePMs[pmIndex].store(false);
7796569cf2aSRiver Riddle return result;
7806569cf2aSRiver Riddle });
781d7eba200SRiver Riddle }
782d7eba200SRiver Riddle
783d7eba200SRiver Riddle LogicalResult
optimizeCallable(CallGraphNode * node,llvm::StringMap<OpPassManager> & pipelines)784d7eba200SRiver Riddle InlinerPass::optimizeCallable(CallGraphNode *node,
785d7eba200SRiver Riddle llvm::StringMap<OpPassManager> &pipelines) {
786d7eba200SRiver Riddle Operation *callable = node->getCallableRegion()->getParentOp();
787d7eba200SRiver Riddle StringRef opName = callable->getName().getStringRef();
788d7eba200SRiver Riddle auto pipelineIt = pipelines.find(opName);
789d7eba200SRiver Riddle if (pipelineIt == pipelines.end()) {
790d7eba200SRiver Riddle // If a pipeline didn't exist, use the default if possible.
791d7eba200SRiver Riddle if (!defaultPipeline)
792d7eba200SRiver Riddle return success();
793d7eba200SRiver Riddle
794d7eba200SRiver Riddle OpPassManager defaultPM(opName);
795d7eba200SRiver Riddle defaultPipeline(defaultPM);
796d7eba200SRiver Riddle pipelineIt = pipelines.try_emplace(opName, std::move(defaultPM)).first;
797d7eba200SRiver Riddle }
798d7eba200SRiver Riddle return runPipeline(pipelineIt->second, callable);
799d7eba200SRiver Riddle }
800d7eba200SRiver Riddle
initializeOptions(StringRef options)801d7eba200SRiver Riddle LogicalResult InlinerPass::initializeOptions(StringRef options) {
802d7eba200SRiver Riddle if (failed(Pass::initializeOptions(options)))
803d7eba200SRiver Riddle return failure();
804d7eba200SRiver Riddle
805d7eba200SRiver Riddle // Initialize the default pipeline builder to use the option string.
806c2fb9c29SRiver Riddle // TODO: Use a generic pass manager for default pipelines, and remove this.
807d7eba200SRiver Riddle if (!defaultPipelineStr.empty()) {
808d7eba200SRiver Riddle std::string defaultPipelineCopy = defaultPipelineStr;
809d7eba200SRiver Riddle defaultPipeline = [=](OpPassManager &pm) {
810e21adfa3SRiver Riddle (void)parsePassPipeline(defaultPipelineCopy, pm);
811d7eba200SRiver Riddle };
812d7eba200SRiver Riddle } else if (defaultPipelineStr.getNumOccurrences()) {
813d7eba200SRiver Riddle defaultPipeline = nullptr;
814d7eba200SRiver Riddle }
815d7eba200SRiver Riddle
816d7eba200SRiver Riddle // Initialize the op specific pass pipelines.
817d7eba200SRiver Riddle llvm::StringMap<OpPassManager> pipelines;
8180d8df980SRiver Riddle for (OpPassManager pipeline : opPipelineList)
8190d8df980SRiver Riddle if (!pipeline.empty())
820c2fb9c29SRiver Riddle pipelines.try_emplace(pipeline.getOpAnchorName(), pipeline);
821d7eba200SRiver Riddle opPipelines.assign({std::move(pipelines)});
822d7eba200SRiver Riddle
823d7eba200SRiver Riddle return success();
824400ad6f9SRiver Riddle }
8250ba00878SRiver Riddle
createInlinerPass()8263940b90dSSana Damani std::unique_ptr<Pass> mlir::createInlinerPass() {
8273940b90dSSana Damani return std::make_unique<InlinerPass>();
8283940b90dSSana Damani }
829d7eba200SRiver Riddle std::unique_ptr<Pass>
createInlinerPass(llvm::StringMap<OpPassManager> opPipelines)830d7eba200SRiver Riddle mlir::createInlinerPass(llvm::StringMap<OpPassManager> opPipelines) {
831d7eba200SRiver Riddle return std::make_unique<InlinerPass>(defaultInlinerOptPipeline,
832d7eba200SRiver Riddle std::move(opPipelines));
833d7eba200SRiver Riddle }
createInlinerPass(llvm::StringMap<OpPassManager> opPipelines,std::function<void (OpPassManager &)> defaultPipelineBuilder)83457bf8560SValentin Clement std::unique_ptr<Pass> mlir::createInlinerPass(
83557bf8560SValentin Clement llvm::StringMap<OpPassManager> opPipelines,
836d7eba200SRiver Riddle std::function<void(OpPassManager &)> defaultPipelineBuilder) {
837d7eba200SRiver Riddle return std::make_unique<InlinerPass>(std::move(defaultPipelineBuilder),
838d7eba200SRiver Riddle std::move(opPipelines));
839d7eba200SRiver Riddle }
840