1 //===- Inliner.cpp - Pass to inline function calls ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a basic inlining algorithm that operates bottom up over
10 // the Strongly Connect Components(SCCs) of the CallGraph. This enables a more
11 // incremental propagation of inlining decisions from the leafs to the roots of
12 // the callgraph.
13 //
14 //===----------------------------------------------------------------------===//
15
16 #include "PassDetail.h"
17 #include "mlir/Analysis/CallGraph.h"
18 #include "mlir/IR/Threading.h"
19 #include "mlir/Interfaces/CallInterfaces.h"
20 #include "mlir/Interfaces/SideEffectInterfaces.h"
21 #include "mlir/Pass/PassManager.h"
22 #include "mlir/Support/DebugStringHelper.h"
23 #include "mlir/Transforms/InliningUtils.h"
24 #include "mlir/Transforms/Passes.h"
25 #include "llvm/ADT/SCCIterator.h"
26 #include "llvm/Support/Debug.h"
27
28 #define DEBUG_TYPE "inlining"
29
30 using namespace mlir;
31
32 /// This function implements the default inliner optimization pipeline.
defaultInlinerOptPipeline(OpPassManager & pm)33 static void defaultInlinerOptPipeline(OpPassManager &pm) {
34 pm.addPass(createCanonicalizerPass());
35 }
36
37 //===----------------------------------------------------------------------===//
38 // Symbol Use Tracking
39 //===----------------------------------------------------------------------===//
40
41 /// Walk all of the used symbol callgraph nodes referenced with the given op.
walkReferencedSymbolNodes(Operation * op,CallGraph & cg,SymbolTableCollection & symbolTable,DenseMap<Attribute,CallGraphNode * > & resolvedRefs,function_ref<void (CallGraphNode *,Operation *)> callback)42 static void walkReferencedSymbolNodes(
43 Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable,
44 DenseMap<Attribute, CallGraphNode *> &resolvedRefs,
45 function_ref<void(CallGraphNode *, Operation *)> callback) {
46 auto symbolUses = SymbolTable::getSymbolUses(op);
47 assert(symbolUses && "expected uses to be valid");
48
49 Operation *symbolTableOp = op->getParentOp();
50 for (const SymbolTable::SymbolUse &use : *symbolUses) {
51 auto refIt = resolvedRefs.insert({use.getSymbolRef(), nullptr});
52 CallGraphNode *&node = refIt.first->second;
53
54 // If this is the first instance of this reference, try to resolve a
55 // callgraph node for it.
56 if (refIt.second) {
57 auto *symbolOp = symbolTable.lookupNearestSymbolFrom(symbolTableOp,
58 use.getSymbolRef());
59 auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp);
60 if (!callableOp)
61 continue;
62 node = cg.lookupNode(callableOp.getCallableRegion());
63 }
64 if (node)
65 callback(node, use.getUser());
66 }
67 }
68
69 //===----------------------------------------------------------------------===//
70 // CGUseList
71
72 namespace {
73 /// This struct tracks the uses of callgraph nodes that can be dropped when
74 /// use_empty. It directly tracks and manages a use-list for all of the
75 /// call-graph nodes. This is necessary because many callgraph nodes are
76 /// referenced by SymbolRefAttr, which has no mechanism akin to the SSA `Use`
77 /// class.
78 struct CGUseList {
79 /// This struct tracks the uses of callgraph nodes within a specific
80 /// operation.
81 struct CGUser {
82 /// Any nodes referenced in the top-level attribute list of this user. We
83 /// use a set here because the number of references does not matter.
84 DenseSet<CallGraphNode *> topLevelUses;
85
86 /// Uses of nodes referenced by nested operations.
87 DenseMap<CallGraphNode *, int> innerUses;
88 };
89
90 CGUseList(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable);
91
92 /// Drop uses of nodes referred to by the given call operation that resides
93 /// within 'userNode'.
94 void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg);
95
96 /// Remove the given node from the use list.
97 void eraseNode(CallGraphNode *node);
98
99 /// Returns true if the given callgraph node has no uses and can be pruned.
100 bool isDead(CallGraphNode *node) const;
101
102 /// Returns true if the given callgraph node has a single use and can be
103 /// discarded.
104 bool hasOneUseAndDiscardable(CallGraphNode *node) const;
105
106 /// Recompute the uses held by the given callgraph node.
107 void recomputeUses(CallGraphNode *node, CallGraph &cg);
108
109 /// Merge the uses of 'lhs' with the uses of the 'rhs' after inlining a copy
110 /// of 'lhs' into 'rhs'.
111 void mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs);
112
113 private:
114 /// Decrement the uses of discardable nodes referenced by the given user.
115 void decrementDiscardableUses(CGUser &uses);
116
117 /// A mapping between a discardable callgraph node (that is a symbol) and the
118 /// number of uses for this node.
119 DenseMap<CallGraphNode *, int> discardableSymNodeUses;
120
121 /// A mapping between a callgraph node and the symbol callgraph nodes that it
122 /// uses.
123 DenseMap<CallGraphNode *, CGUser> nodeUses;
124
125 /// A symbol table to use when resolving call lookups.
126 SymbolTableCollection &symbolTable;
127 };
128 } // namespace
129
CGUseList(Operation * op,CallGraph & cg,SymbolTableCollection & symbolTable)130 CGUseList::CGUseList(Operation *op, CallGraph &cg,
131 SymbolTableCollection &symbolTable)
132 : symbolTable(symbolTable) {
133 /// A set of callgraph nodes that are always known to be live during inlining.
134 DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes;
135
136 // Walk each of the symbol tables looking for discardable callgraph nodes.
137 auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
138 for (Operation &op : symbolTableOp->getRegion(0).getOps()) {
139 // If this is a callgraph operation, check to see if it is discardable.
140 if (auto callable = dyn_cast<CallableOpInterface>(&op)) {
141 if (auto *node = cg.lookupNode(callable.getCallableRegion())) {
142 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
143 if (symbol && (allUsesVisible || symbol.isPrivate()) &&
144 symbol.canDiscardOnUseEmpty()) {
145 discardableSymNodeUses.try_emplace(node, 0);
146 }
147 continue;
148 }
149 }
150 // Otherwise, check for any referenced nodes. These will be always-live.
151 walkReferencedSymbolNodes(&op, cg, symbolTable, alwaysLiveNodes,
152 [](CallGraphNode *, Operation *) {});
153 }
154 };
155 SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
156 walkFn);
157
158 // Drop the use information for any discardable nodes that are always live.
159 for (auto &it : alwaysLiveNodes)
160 discardableSymNodeUses.erase(it.second);
161
162 // Compute the uses for each of the callable nodes in the graph.
163 for (CallGraphNode *node : cg)
164 recomputeUses(node, cg);
165 }
166
dropCallUses(CallGraphNode * userNode,Operation * callOp,CallGraph & cg)167 void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp,
168 CallGraph &cg) {
169 auto &userRefs = nodeUses[userNode].innerUses;
170 auto walkFn = [&](CallGraphNode *node, Operation *user) {
171 auto parentIt = userRefs.find(node);
172 if (parentIt == userRefs.end())
173 return;
174 --parentIt->second;
175 --discardableSymNodeUses[node];
176 };
177 DenseMap<Attribute, CallGraphNode *> resolvedRefs;
178 walkReferencedSymbolNodes(callOp, cg, symbolTable, resolvedRefs, walkFn);
179 }
180
eraseNode(CallGraphNode * node)181 void CGUseList::eraseNode(CallGraphNode *node) {
182 // Drop all child nodes.
183 for (auto &edge : *node)
184 if (edge.isChild())
185 eraseNode(edge.getTarget());
186
187 // Drop the uses held by this node and erase it.
188 auto useIt = nodeUses.find(node);
189 assert(useIt != nodeUses.end() && "expected node to be valid");
190 decrementDiscardableUses(useIt->getSecond());
191 nodeUses.erase(useIt);
192 discardableSymNodeUses.erase(node);
193 }
194
isDead(CallGraphNode * node) const195 bool CGUseList::isDead(CallGraphNode *node) const {
196 // If the parent operation isn't a symbol, simply check normal SSA deadness.
197 Operation *nodeOp = node->getCallableRegion()->getParentOp();
198 if (!isa<SymbolOpInterface>(nodeOp))
199 return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->use_empty();
200
201 // Otherwise, check the number of symbol uses.
202 auto symbolIt = discardableSymNodeUses.find(node);
203 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0;
204 }
205
hasOneUseAndDiscardable(CallGraphNode * node) const206 bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const {
207 // If this isn't a symbol node, check for side-effects and SSA use count.
208 Operation *nodeOp = node->getCallableRegion()->getParentOp();
209 if (!isa<SymbolOpInterface>(nodeOp))
210 return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->hasOneUse();
211
212 // Otherwise, check the number of symbol uses.
213 auto symbolIt = discardableSymNodeUses.find(node);
214 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1;
215 }
216
recomputeUses(CallGraphNode * node,CallGraph & cg)217 void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) {
218 Operation *parentOp = node->getCallableRegion()->getParentOp();
219 CGUser &uses = nodeUses[node];
220 decrementDiscardableUses(uses);
221
222 // Collect the new discardable uses within this node.
223 uses = CGUser();
224 DenseMap<Attribute, CallGraphNode *> resolvedRefs;
225 auto walkFn = [&](CallGraphNode *refNode, Operation *user) {
226 auto discardSymIt = discardableSymNodeUses.find(refNode);
227 if (discardSymIt == discardableSymNodeUses.end())
228 return;
229
230 if (user != parentOp)
231 ++uses.innerUses[refNode];
232 else if (!uses.topLevelUses.insert(refNode).second)
233 return;
234 ++discardSymIt->second;
235 };
236 walkReferencedSymbolNodes(parentOp, cg, symbolTable, resolvedRefs, walkFn);
237 }
238
mergeUsesAfterInlining(CallGraphNode * lhs,CallGraphNode * rhs)239 void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) {
240 auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs];
241 for (auto &useIt : lhsUses.innerUses) {
242 rhsUses.innerUses[useIt.first] += useIt.second;
243 discardableSymNodeUses[useIt.first] += useIt.second;
244 }
245 }
246
decrementDiscardableUses(CGUser & uses)247 void CGUseList::decrementDiscardableUses(CGUser &uses) {
248 for (CallGraphNode *node : uses.topLevelUses)
249 --discardableSymNodeUses[node];
250 for (auto &it : uses.innerUses)
251 discardableSymNodeUses[it.first] -= it.second;
252 }
253
254 //===----------------------------------------------------------------------===//
255 // CallGraph traversal
256 //===----------------------------------------------------------------------===//
257
258 namespace {
259 /// This class represents a specific callgraph SCC.
260 class CallGraphSCC {
261 public:
CallGraphSCC(llvm::scc_iterator<const CallGraph * > & parentIterator)262 CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator)
263 : parentIterator(parentIterator) {}
264 /// Return a range over the nodes within this SCC.
begin()265 std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); }
end()266 std::vector<CallGraphNode *>::iterator end() { return nodes.end(); }
267
268 /// Reset the nodes of this SCC with those provided.
reset(const std::vector<CallGraphNode * > & newNodes)269 void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; }
270
271 /// Remove the given node from this SCC.
remove(CallGraphNode * node)272 void remove(CallGraphNode *node) {
273 auto it = llvm::find(nodes, node);
274 if (it != nodes.end()) {
275 nodes.erase(it);
276 parentIterator.ReplaceNode(node, nullptr);
277 }
278 }
279
280 private:
281 std::vector<CallGraphNode *> nodes;
282 llvm::scc_iterator<const CallGraph *> &parentIterator;
283 };
284 } // namespace
285
286 /// Run a given transformation over the SCCs of the callgraph in a bottom up
287 /// traversal.
runTransformOnCGSCCs(const CallGraph & cg,function_ref<LogicalResult (CallGraphSCC &)> sccTransformer)288 static LogicalResult runTransformOnCGSCCs(
289 const CallGraph &cg,
290 function_ref<LogicalResult(CallGraphSCC &)> sccTransformer) {
291 llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg);
292 CallGraphSCC currentSCC(cgi);
293 while (!cgi.isAtEnd()) {
294 // Copy the current SCC and increment so that the transformer can modify the
295 // SCC without invalidating our iterator.
296 currentSCC.reset(*cgi);
297 ++cgi;
298 if (failed(sccTransformer(currentSCC)))
299 return failure();
300 }
301 return success();
302 }
303
304 namespace {
305 /// This struct represents a resolved call to a given callgraph node. Given that
306 /// the call does not actually contain a direct reference to the
307 /// Region(CallGraphNode) that it is dispatching to, we need to resolve them
308 /// explicitly.
309 struct ResolvedCall {
ResolvedCall__anondfe352480711::ResolvedCall310 ResolvedCall(CallOpInterface call, CallGraphNode *sourceNode,
311 CallGraphNode *targetNode)
312 : call(call), sourceNode(sourceNode), targetNode(targetNode) {}
313 CallOpInterface call;
314 CallGraphNode *sourceNode, *targetNode;
315 };
316 } // namespace
317
318 /// Collect all of the callable operations within the given range of blocks. If
319 /// `traverseNestedCGNodes` is true, this will also collect call operations
320 /// inside of nested callgraph nodes.
collectCallOps(iterator_range<Region::iterator> blocks,CallGraphNode * sourceNode,CallGraph & cg,SymbolTableCollection & symbolTable,SmallVectorImpl<ResolvedCall> & calls,bool traverseNestedCGNodes)321 static void collectCallOps(iterator_range<Region::iterator> blocks,
322 CallGraphNode *sourceNode, CallGraph &cg,
323 SymbolTableCollection &symbolTable,
324 SmallVectorImpl<ResolvedCall> &calls,
325 bool traverseNestedCGNodes) {
326 SmallVector<std::pair<Block *, CallGraphNode *>, 8> worklist;
327 auto addToWorklist = [&](CallGraphNode *node,
328 iterator_range<Region::iterator> blocks) {
329 for (Block &block : blocks)
330 worklist.emplace_back(&block, node);
331 };
332
333 addToWorklist(sourceNode, blocks);
334 while (!worklist.empty()) {
335 Block *block;
336 std::tie(block, sourceNode) = worklist.pop_back_val();
337
338 for (Operation &op : *block) {
339 if (auto call = dyn_cast<CallOpInterface>(op)) {
340 // TODO: Support inlining nested call references.
341 CallInterfaceCallable callable = call.getCallableForCallee();
342 if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) {
343 if (!symRef.isa<FlatSymbolRefAttr>())
344 continue;
345 }
346
347 CallGraphNode *targetNode = cg.resolveCallable(call, symbolTable);
348 if (!targetNode->isExternal())
349 calls.emplace_back(call, sourceNode, targetNode);
350 continue;
351 }
352
353 // If this is not a call, traverse the nested regions. If
354 // `traverseNestedCGNodes` is false, then don't traverse nested call graph
355 // regions.
356 for (auto &nestedRegion : op.getRegions()) {
357 CallGraphNode *nestedNode = cg.lookupNode(&nestedRegion);
358 if (traverseNestedCGNodes || !nestedNode)
359 addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion);
360 }
361 }
362 }
363 }
364
365 //===----------------------------------------------------------------------===//
366 // Inliner
367 //===----------------------------------------------------------------------===//
368
369 #ifndef NDEBUG
getNodeName(CallOpInterface op)370 static std::string getNodeName(CallOpInterface op) {
371 if (auto sym = op.getCallableForCallee().dyn_cast<SymbolRefAttr>())
372 return debugString(op);
373 return "_unnamed_callee_";
374 }
375 #endif
376
377 /// Return true if the specified `inlineHistoryID` indicates an inline history
378 /// that already includes `node`.
inlineHistoryIncludes(CallGraphNode * node,Optional<size_t> inlineHistoryID,MutableArrayRef<std::pair<CallGraphNode *,Optional<size_t>>> inlineHistory)379 static bool inlineHistoryIncludes(
380 CallGraphNode *node, Optional<size_t> inlineHistoryID,
381 MutableArrayRef<std::pair<CallGraphNode *, Optional<size_t>>>
382 inlineHistory) {
383 while (inlineHistoryID.has_value()) {
384 assert(inlineHistoryID.value() < inlineHistory.size() &&
385 "Invalid inline history ID");
386 if (inlineHistory[inlineHistoryID.value()].first == node)
387 return true;
388 inlineHistoryID = inlineHistory[inlineHistoryID.value()].second;
389 }
390 return false;
391 }
392
393 namespace {
394 /// This class provides a specialization of the main inlining interface.
395 struct Inliner : public InlinerInterface {
Inliner__anondfe352480911::Inliner396 Inliner(MLIRContext *context, CallGraph &cg,
397 SymbolTableCollection &symbolTable)
398 : InlinerInterface(context), cg(cg), symbolTable(symbolTable) {}
399
400 /// Process a set of blocks that have been inlined. This callback is invoked
401 /// *before* inlined terminator operations have been processed.
402 void
processInlinedBlocks__anondfe352480911::Inliner403 processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final {
404 // Find the closest callgraph node from the first block.
405 CallGraphNode *node;
406 Region *region = inlinedBlocks.begin()->getParent();
407 while (!(node = cg.lookupNode(region))) {
408 region = region->getParentRegion();
409 assert(region && "expected valid parent node");
410 }
411
412 collectCallOps(inlinedBlocks, node, cg, symbolTable, calls,
413 /*traverseNestedCGNodes=*/true);
414 }
415
416 /// Mark the given callgraph node for deletion.
markForDeletion__anondfe352480911::Inliner417 void markForDeletion(CallGraphNode *node) { deadNodes.insert(node); }
418
419 /// This method properly disposes of callables that became dead during
420 /// inlining. This should not be called while iterating over the SCCs.
eraseDeadCallables__anondfe352480911::Inliner421 void eraseDeadCallables() {
422 for (CallGraphNode *node : deadNodes)
423 node->getCallableRegion()->getParentOp()->erase();
424 }
425
426 /// The set of callables known to be dead.
427 SmallPtrSet<CallGraphNode *, 8> deadNodes;
428
429 /// The current set of call instructions to consider for inlining.
430 SmallVector<ResolvedCall, 8> calls;
431
432 /// The callgraph being operated on.
433 CallGraph &cg;
434
435 /// A symbol table to use when resolving call lookups.
436 SymbolTableCollection &symbolTable;
437 };
438 } // namespace
439
440 /// Returns true if the given call should be inlined.
shouldInline(ResolvedCall & resolvedCall)441 static bool shouldInline(ResolvedCall &resolvedCall) {
442 // Don't allow inlining terminator calls. We currently don't support this
443 // case.
444 if (resolvedCall.call->hasTrait<OpTrait::IsTerminator>())
445 return false;
446
447 // Don't allow inlining if the target is an ancestor of the call. This
448 // prevents inlining recursively.
449 if (resolvedCall.targetNode->getCallableRegion()->isAncestor(
450 resolvedCall.call->getParentRegion()))
451 return false;
452
453 // Otherwise, inline.
454 return true;
455 }
456
457 /// Attempt to inline calls within the given scc. This function returns
458 /// success if any calls were inlined, failure otherwise.
inlineCallsInSCC(Inliner & inliner,CGUseList & useList,CallGraphSCC & currentSCC)459 static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
460 CallGraphSCC ¤tSCC) {
461 CallGraph &cg = inliner.cg;
462 auto &calls = inliner.calls;
463
464 // A set of dead nodes to remove after inlining.
465 llvm::SmallSetVector<CallGraphNode *, 1> deadNodes;
466
467 // Collect all of the direct calls within the nodes of the current SCC. We
468 // don't traverse nested callgraph nodes, because they are handled separately
469 // likely within a different SCC.
470 for (CallGraphNode *node : currentSCC) {
471 if (node->isExternal())
472 continue;
473
474 // Don't collect calls if the node is already dead.
475 if (useList.isDead(node)) {
476 deadNodes.insert(node);
477 } else {
478 collectCallOps(*node->getCallableRegion(), node, cg, inliner.symbolTable,
479 calls, /*traverseNestedCGNodes=*/false);
480 }
481 }
482
483 // When inlining a callee produces new call sites, we want to keep track of
484 // the fact that they were inlined from the callee. This allows us to avoid
485 // infinite inlining.
486 using InlineHistoryT = Optional<size_t>;
487 SmallVector<std::pair<CallGraphNode *, InlineHistoryT>, 8> inlineHistory;
488 std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{});
489
490 LLVM_DEBUG({
491 llvm::dbgs() << "* Inliner: Initial calls in SCC are: {\n";
492 for (unsigned i = 0, e = calls.size(); i < e; ++i)
493 llvm::dbgs() << " " << i << ". " << calls[i].call << ",\n";
494 llvm::dbgs() << "}\n";
495 });
496
497 // Try to inline each of the call operations. Don't cache the end iterator
498 // here as more calls may be added during inlining.
499 bool inlinedAnyCalls = false;
500 for (unsigned i = 0; i < calls.size(); ++i) {
501 if (deadNodes.contains(calls[i].sourceNode))
502 continue;
503 ResolvedCall it = calls[i];
504
505 InlineHistoryT inlineHistoryID = callHistory[i];
506 bool inHistory =
507 inlineHistoryIncludes(it.targetNode, inlineHistoryID, inlineHistory);
508 bool doInline = !inHistory && shouldInline(it);
509 CallOpInterface call = it.call;
510 LLVM_DEBUG({
511 if (doInline)
512 llvm::dbgs() << "* Inlining call: " << i << ". " << call << "\n";
513 else
514 llvm::dbgs() << "* Not inlining call: " << i << ". " << call << "\n";
515 });
516 if (!doInline)
517 continue;
518
519 unsigned prevSize = calls.size();
520 Region *targetRegion = it.targetNode->getCallableRegion();
521
522 // If this is the last call to the target node and the node is discardable,
523 // then inline it in-place and delete the node if successful.
524 bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode);
525
526 LogicalResult inlineResult = inlineCall(
527 inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()),
528 targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace);
529 if (failed(inlineResult)) {
530 LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n");
531 continue;
532 }
533 inlinedAnyCalls = true;
534
535 // Create a inline history entry for this inlined call, so that we remember
536 // that new callsites came about due to inlining Callee.
537 InlineHistoryT newInlineHistoryID{inlineHistory.size()};
538 inlineHistory.push_back(std::make_pair(it.targetNode, inlineHistoryID));
539
540 auto historyToString = [](InlineHistoryT h) {
541 return h.has_value() ? std::to_string(h.value()) : "root";
542 };
543 (void)historyToString;
544 LLVM_DEBUG(llvm::dbgs()
545 << "* new inlineHistory entry: " << newInlineHistoryID << ". ["
546 << getNodeName(call) << ", " << historyToString(inlineHistoryID)
547 << "]\n");
548
549 for (unsigned k = prevSize; k != calls.size(); ++k) {
550 callHistory.push_back(newInlineHistoryID);
551 LLVM_DEBUG(llvm::dbgs() << "* new call " << k << " {" << calls[i].call
552 << "}\n with historyID = " << newInlineHistoryID
553 << ", added due to inlining of\n call {" << call
554 << "}\n with historyID = "
555 << historyToString(inlineHistoryID) << "\n");
556 }
557
558 // If the inlining was successful, Merge the new uses into the source node.
559 useList.dropCallUses(it.sourceNode, call.getOperation(), cg);
560 useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode);
561
562 // then erase the call.
563 call.erase();
564
565 // If we inlined in place, mark the node for deletion.
566 if (inlineInPlace) {
567 useList.eraseNode(it.targetNode);
568 deadNodes.insert(it.targetNode);
569 }
570 }
571
572 for (CallGraphNode *node : deadNodes) {
573 currentSCC.remove(node);
574 inliner.markForDeletion(node);
575 }
576 calls.clear();
577 return success(inlinedAnyCalls);
578 }
579
580 //===----------------------------------------------------------------------===//
581 // InlinerPass
582 //===----------------------------------------------------------------------===//
583
584 namespace {
585 class InlinerPass : public InlinerBase<InlinerPass> {
586 public:
587 InlinerPass();
588 InlinerPass(const InlinerPass &) = default;
589 InlinerPass(std::function<void(OpPassManager &)> defaultPipeline);
590 InlinerPass(std::function<void(OpPassManager &)> defaultPipeline,
591 llvm::StringMap<OpPassManager> opPipelines);
592 void runOnOperation() override;
593
594 private:
595 /// Attempt to inline calls within the given scc, and run simplifications,
596 /// until a fixed point is reached. This allows for the inlining of newly
597 /// devirtualized calls. Returns failure if there was a fatal error during
598 /// inlining.
599 LogicalResult inlineSCC(Inliner &inliner, CGUseList &useList,
600 CallGraphSCC ¤tSCC, MLIRContext *context);
601
602 /// Optimize the nodes within the given SCC with one of the held optimization
603 /// pass pipelines. Returns failure if an error occurred during the
604 /// optimization of the SCC, success otherwise.
605 LogicalResult optimizeSCC(CallGraph &cg, CGUseList &useList,
606 CallGraphSCC ¤tSCC, MLIRContext *context);
607
608 /// Optimize the nodes within the given SCC in parallel. Returns failure if an
609 /// error occurred during the optimization of the SCC, success otherwise.
610 LogicalResult optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
611 MLIRContext *context);
612
613 /// Optimize the given callable node with one of the pass managers provided
614 /// with `pipelines`, or the default pipeline. Returns failure if an error
615 /// occurred during the optimization of the callable, success otherwise.
616 LogicalResult optimizeCallable(CallGraphNode *node,
617 llvm::StringMap<OpPassManager> &pipelines);
618
619 /// Attempt to initialize the options of this pass from the given string.
620 /// Derived classes may override this method to hook into the point at which
621 /// options are initialized, but should generally always invoke this base
622 /// class variant.
623 LogicalResult initializeOptions(StringRef options) override;
624
625 /// An optional function that constructs a default optimization pipeline for
626 /// a given operation.
627 std::function<void(OpPassManager &)> defaultPipeline;
628 /// A map of operation names to pass pipelines to use when optimizing
629 /// callable operations of these types. This provides a specialized pipeline
630 /// instead of the default. The vector size is the number of threads used
631 /// during optimization.
632 SmallVector<llvm::StringMap<OpPassManager>, 8> opPipelines;
633 };
634 } // namespace
635
InlinerPass()636 InlinerPass::InlinerPass() : InlinerPass(defaultInlinerOptPipeline) {}
InlinerPass(std::function<void (OpPassManager &)> defaultPipeline)637 InlinerPass::InlinerPass(std::function<void(OpPassManager &)> defaultPipeline)
638 : defaultPipeline(std::move(defaultPipeline)) {
639 opPipelines.push_back({});
640
641 // Initialize the pass options with the provided arguments.
642 if (defaultPipeline) {
643 OpPassManager fakePM("__mlir_fake_pm_op");
644 defaultPipeline(fakePM);
645 llvm::raw_string_ostream strStream(defaultPipelineStr);
646 fakePM.printAsTextualPipeline(strStream);
647 }
648 }
649
InlinerPass(std::function<void (OpPassManager &)> defaultPipeline,llvm::StringMap<OpPassManager> opPipelines)650 InlinerPass::InlinerPass(std::function<void(OpPassManager &)> defaultPipeline,
651 llvm::StringMap<OpPassManager> opPipelines)
652 : InlinerPass(std::move(defaultPipeline)) {
653 if (opPipelines.empty())
654 return;
655
656 // Update the option for the op specific optimization pipelines.
657 for (auto &it : opPipelines)
658 opPipelineList.addValue(it.second);
659 this->opPipelines.emplace_back(std::move(opPipelines));
660 }
661
runOnOperation()662 void InlinerPass::runOnOperation() {
663 CallGraph &cg = getAnalysis<CallGraph>();
664 auto *context = &getContext();
665
666 // The inliner should only be run on operations that define a symbol table,
667 // as the callgraph will need to resolve references.
668 Operation *op = getOperation();
669 if (!op->hasTrait<OpTrait::SymbolTable>()) {
670 op->emitOpError() << " was scheduled to run under the inliner, but does "
671 "not define a symbol table";
672 return signalPassFailure();
673 }
674
675 // Run the inline transform in post-order over the SCCs in the callgraph.
676 SymbolTableCollection symbolTable;
677 Inliner inliner(context, cg, symbolTable);
678 CGUseList useList(getOperation(), cg, symbolTable);
679 LogicalResult result = runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) {
680 return inlineSCC(inliner, useList, scc, context);
681 });
682 if (failed(result))
683 return signalPassFailure();
684
685 // After inlining, make sure to erase any callables proven to be dead.
686 inliner.eraseDeadCallables();
687 }
688
inlineSCC(Inliner & inliner,CGUseList & useList,CallGraphSCC & currentSCC,MLIRContext * context)689 LogicalResult InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList,
690 CallGraphSCC ¤tSCC,
691 MLIRContext *context) {
692 // Continuously simplify and inline until we either reach a fixed point, or
693 // hit the maximum iteration count. Simplifying early helps to refine the cost
694 // model, and in future iterations may devirtualize new calls.
695 unsigned iterationCount = 0;
696 do {
697 if (failed(optimizeSCC(inliner.cg, useList, currentSCC, context)))
698 return failure();
699 if (failed(inlineCallsInSCC(inliner, useList, currentSCC)))
700 break;
701 } while (++iterationCount < maxInliningIterations);
702 return success();
703 }
704
optimizeSCC(CallGraph & cg,CGUseList & useList,CallGraphSCC & currentSCC,MLIRContext * context)705 LogicalResult InlinerPass::optimizeSCC(CallGraph &cg, CGUseList &useList,
706 CallGraphSCC ¤tSCC,
707 MLIRContext *context) {
708 // Collect the sets of nodes to simplify.
709 SmallVector<CallGraphNode *, 4> nodesToVisit;
710 for (auto *node : currentSCC) {
711 if (node->isExternal())
712 continue;
713
714 // Don't simplify nodes with children. Nodes with children require special
715 // handling as we may remove the node during simplification. In the future,
716 // we should be able to handle this case with proper node deletion tracking.
717 if (node->hasChildren())
718 continue;
719
720 // We also won't apply simplifications to nodes that can't have passes
721 // scheduled on them.
722 auto *region = node->getCallableRegion();
723 if (!region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
724 continue;
725 nodesToVisit.push_back(node);
726 }
727 if (nodesToVisit.empty())
728 return success();
729
730 // Optimize each of the nodes within the SCC in parallel.
731 if (failed(optimizeSCCAsync(nodesToVisit, context)))
732 return failure();
733
734 // Recompute the uses held by each of the nodes.
735 for (CallGraphNode *node : nodesToVisit)
736 useList.recomputeUses(node, cg);
737 return success();
738 }
739
740 LogicalResult
optimizeSCCAsync(MutableArrayRef<CallGraphNode * > nodesToVisit,MLIRContext * ctx)741 InlinerPass::optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
742 MLIRContext *ctx) {
743 // We must maintain a fixed pool of pass managers which is at least as large
744 // as the maximum parallelism of the failableParallelForEach below.
745 // Note: The number of pass managers here needs to remain constant
746 // to prevent issues with pass instrumentations that rely on having the same
747 // pass manager for the main thread.
748 size_t numThreads = ctx->getNumThreads();
749 if (opPipelines.size() < numThreads) {
750 // Reserve before resizing so that we can use a reference to the first
751 // element.
752 opPipelines.reserve(numThreads);
753 opPipelines.resize(numThreads, opPipelines.front());
754 }
755
756 // Ensure an analysis manager has been constructed for each of the nodes.
757 // This prevents thread races when running the nested pipelines.
758 for (CallGraphNode *node : nodesToVisit)
759 getAnalysisManager().nest(node->getCallableRegion()->getParentOp());
760
761 // An atomic failure variable for the async executors.
762 std::vector<std::atomic<bool>> activePMs(opPipelines.size());
763 std::fill(activePMs.begin(), activePMs.end(), false);
764 return failableParallelForEach(ctx, nodesToVisit, [&](CallGraphNode *node) {
765 // Find a pass manager for this operation.
766 auto it = llvm::find_if(activePMs, [](std::atomic<bool> &isActive) {
767 bool expectedInactive = false;
768 return isActive.compare_exchange_strong(expectedInactive, true);
769 });
770 assert(it != activePMs.end() &&
771 "could not find inactive pass manager for thread");
772 unsigned pmIndex = it - activePMs.begin();
773
774 // Optimize this callable node.
775 LogicalResult result = optimizeCallable(node, opPipelines[pmIndex]);
776
777 // Reset the active bit for this pass manager.
778 activePMs[pmIndex].store(false);
779 return result;
780 });
781 }
782
783 LogicalResult
optimizeCallable(CallGraphNode * node,llvm::StringMap<OpPassManager> & pipelines)784 InlinerPass::optimizeCallable(CallGraphNode *node,
785 llvm::StringMap<OpPassManager> &pipelines) {
786 Operation *callable = node->getCallableRegion()->getParentOp();
787 StringRef opName = callable->getName().getStringRef();
788 auto pipelineIt = pipelines.find(opName);
789 if (pipelineIt == pipelines.end()) {
790 // If a pipeline didn't exist, use the default if possible.
791 if (!defaultPipeline)
792 return success();
793
794 OpPassManager defaultPM(opName);
795 defaultPipeline(defaultPM);
796 pipelineIt = pipelines.try_emplace(opName, std::move(defaultPM)).first;
797 }
798 return runPipeline(pipelineIt->second, callable);
799 }
800
initializeOptions(StringRef options)801 LogicalResult InlinerPass::initializeOptions(StringRef options) {
802 if (failed(Pass::initializeOptions(options)))
803 return failure();
804
805 // Initialize the default pipeline builder to use the option string.
806 // TODO: Use a generic pass manager for default pipelines, and remove this.
807 if (!defaultPipelineStr.empty()) {
808 std::string defaultPipelineCopy = defaultPipelineStr;
809 defaultPipeline = [=](OpPassManager &pm) {
810 (void)parsePassPipeline(defaultPipelineCopy, pm);
811 };
812 } else if (defaultPipelineStr.getNumOccurrences()) {
813 defaultPipeline = nullptr;
814 }
815
816 // Initialize the op specific pass pipelines.
817 llvm::StringMap<OpPassManager> pipelines;
818 for (OpPassManager pipeline : opPipelineList)
819 if (!pipeline.empty())
820 pipelines.try_emplace(pipeline.getOpAnchorName(), pipeline);
821 opPipelines.assign({std::move(pipelines)});
822
823 return success();
824 }
825
createInlinerPass()826 std::unique_ptr<Pass> mlir::createInlinerPass() {
827 return std::make_unique<InlinerPass>();
828 }
829 std::unique_ptr<Pass>
createInlinerPass(llvm::StringMap<OpPassManager> opPipelines)830 mlir::createInlinerPass(llvm::StringMap<OpPassManager> opPipelines) {
831 return std::make_unique<InlinerPass>(defaultInlinerOptPipeline,
832 std::move(opPipelines));
833 }
createInlinerPass(llvm::StringMap<OpPassManager> opPipelines,std::function<void (OpPassManager &)> defaultPipelineBuilder)834 std::unique_ptr<Pass> mlir::createInlinerPass(
835 llvm::StringMap<OpPassManager> opPipelines,
836 std::function<void(OpPassManager &)> defaultPipelineBuilder) {
837 return std::make_unique<InlinerPass>(std::move(defaultPipelineBuilder),
838 std::move(opPipelines));
839 }
840