1 //===- Inliner.cpp - Pass to inline function calls ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a basic inlining algorithm that operates bottom up over
10 // the Strongly Connect Components(SCCs) of the CallGraph. This enables a more
11 // incremental propagation of inlining decisions from the leafs to the roots of
12 // the callgraph.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "PassDetail.h"
17 #include "mlir/Analysis/CallGraph.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/Interfaces/SideEffects.h"
20 #include "mlir/Transforms/InliningUtils.h"
21 #include "mlir/Transforms/Passes.h"
22 #include "llvm/ADT/SCCIterator.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/Parallel.h"
25 
26 #define DEBUG_TYPE "inlining"
27 
28 using namespace mlir;
29 
30 //===----------------------------------------------------------------------===//
31 // Symbol Use Tracking
32 //===----------------------------------------------------------------------===//
33 
34 /// Walk all of the used symbol callgraph nodes referenced with the given op.
35 static void walkReferencedSymbolNodes(
36     Operation *op, CallGraph &cg,
37     DenseMap<Attribute, CallGraphNode *> &resolvedRefs,
38     function_ref<void(CallGraphNode *, Operation *)> callback) {
39   auto symbolUses = SymbolTable::getSymbolUses(op);
40   assert(symbolUses && "expected uses to be valid");
41 
42   Operation *symbolTableOp = op->getParentOp();
43   for (const SymbolTable::SymbolUse &use : *symbolUses) {
44     auto refIt = resolvedRefs.insert({use.getSymbolRef(), nullptr});
45     CallGraphNode *&node = refIt.first->second;
46 
47     // If this is the first instance of this reference, try to resolve a
48     // callgraph node for it.
49     if (refIt.second) {
50       auto *symbolOp = SymbolTable::lookupNearestSymbolFrom(symbolTableOp,
51                                                             use.getSymbolRef());
52       auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp);
53       if (!callableOp)
54         continue;
55       node = cg.lookupNode(callableOp.getCallableRegion());
56     }
57     if (node)
58       callback(node, use.getUser());
59   }
60 }
61 
62 //===----------------------------------------------------------------------===//
63 // CGUseList
64 
65 namespace {
66 /// This struct tracks the uses of callgraph nodes that can be dropped when
67 /// use_empty. It directly tracks and manages a use-list for all of the
68 /// call-graph nodes. This is necessary because many callgraph nodes are
69 /// referenced by SymbolRefAttr, which has no mechanism akin to the SSA `Use`
70 /// class.
71 struct CGUseList {
72   /// This struct tracks the uses of callgraph nodes within a specific
73   /// operation.
74   struct CGUser {
75     /// Any nodes referenced in the top-level attribute list of this user. We
76     /// use a set here because the number of references does not matter.
77     DenseSet<CallGraphNode *> topLevelUses;
78 
79     /// Uses of nodes referenced by nested operations.
80     DenseMap<CallGraphNode *, int> innerUses;
81   };
82 
83   CGUseList(Operation *op, CallGraph &cg);
84 
85   /// Drop uses of nodes referred to by the given call operation that resides
86   /// within 'userNode'.
87   void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg);
88 
89   /// Remove the given node from the use list.
90   void eraseNode(CallGraphNode *node);
91 
92   /// Returns true if the given callgraph node has no uses and can be pruned.
93   bool isDead(CallGraphNode *node) const;
94 
95   /// Returns true if the given callgraph node has a single use and can be
96   /// discarded.
97   bool hasOneUseAndDiscardable(CallGraphNode *node) const;
98 
99   /// Recompute the uses held by the given callgraph node.
100   void recomputeUses(CallGraphNode *node, CallGraph &cg);
101 
102   /// Merge the uses of 'lhs' with the uses of the 'rhs' after inlining a copy
103   /// of 'lhs' into 'rhs'.
104   void mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs);
105 
106 private:
107   /// Decrement the uses of discardable nodes referenced by the given user.
108   void decrementDiscardableUses(CGUser &uses);
109 
110   /// A mapping between a discardable callgraph node (that is a symbol) and the
111   /// number of uses for this node.
112   DenseMap<CallGraphNode *, int> discardableSymNodeUses;
113   /// A mapping between a callgraph node and the symbol callgraph nodes that it
114   /// uses.
115   DenseMap<CallGraphNode *, CGUser> nodeUses;
116 };
117 } // end anonymous namespace
118 
119 CGUseList::CGUseList(Operation *op, CallGraph &cg) {
120   /// A set of callgraph nodes that are always known to be live during inlining.
121   DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes;
122 
123   // Walk each of the symbol tables looking for discardable callgraph nodes.
124   auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
125     for (Block &block : symbolTableOp->getRegion(0)) {
126       for (Operation &op : block) {
127         // If this is a callgraph operation, check to see if it is discardable.
128         if (auto callable = dyn_cast<CallableOpInterface>(&op)) {
129           if (auto *node = cg.lookupNode(callable.getCallableRegion())) {
130             SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
131             if (symbol && (allUsesVisible || symbol.isPrivate()) &&
132                 symbol.canDiscardOnUseEmpty()) {
133               discardableSymNodeUses.try_emplace(node, 0);
134             }
135             continue;
136           }
137         }
138         // Otherwise, check for any referenced nodes. These will be always-live.
139         walkReferencedSymbolNodes(&op, cg, alwaysLiveNodes,
140                                   [](CallGraphNode *, Operation *) {});
141       }
142     }
143   };
144   SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
145                                 walkFn);
146 
147   // Drop the use information for any discardable nodes that are always live.
148   for (auto &it : alwaysLiveNodes)
149     discardableSymNodeUses.erase(it.second);
150 
151   // Compute the uses for each of the callable nodes in the graph.
152   for (CallGraphNode *node : cg)
153     recomputeUses(node, cg);
154 }
155 
156 void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp,
157                              CallGraph &cg) {
158   auto &userRefs = nodeUses[userNode].innerUses;
159   auto walkFn = [&](CallGraphNode *node, Operation *user) {
160     auto parentIt = userRefs.find(node);
161     if (parentIt == userRefs.end())
162       return;
163     --parentIt->second;
164     --discardableSymNodeUses[node];
165   };
166   DenseMap<Attribute, CallGraphNode *> resolvedRefs;
167   walkReferencedSymbolNodes(callOp, cg, resolvedRefs, walkFn);
168 }
169 
170 void CGUseList::eraseNode(CallGraphNode *node) {
171   // Drop all child nodes.
172   for (auto &edge : *node)
173     if (edge.isChild())
174       eraseNode(edge.getTarget());
175 
176   // Drop the uses held by this node and erase it.
177   auto useIt = nodeUses.find(node);
178   assert(useIt != nodeUses.end() && "expected node to be valid");
179   decrementDiscardableUses(useIt->getSecond());
180   nodeUses.erase(useIt);
181   discardableSymNodeUses.erase(node);
182 }
183 
184 bool CGUseList::isDead(CallGraphNode *node) const {
185   // If the parent operation isn't a symbol, simply check normal SSA deadness.
186   Operation *nodeOp = node->getCallableRegion()->getParentOp();
187   if (!isa<SymbolOpInterface>(nodeOp))
188     return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->use_empty();
189 
190   // Otherwise, check the number of symbol uses.
191   auto symbolIt = discardableSymNodeUses.find(node);
192   return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0;
193 }
194 
195 bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const {
196   // If this isn't a symbol node, check for side-effects and SSA use count.
197   Operation *nodeOp = node->getCallableRegion()->getParentOp();
198   if (!isa<SymbolOpInterface>(nodeOp))
199     return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->hasOneUse();
200 
201   // Otherwise, check the number of symbol uses.
202   auto symbolIt = discardableSymNodeUses.find(node);
203   return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1;
204 }
205 
206 void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) {
207   Operation *parentOp = node->getCallableRegion()->getParentOp();
208   CGUser &uses = nodeUses[node];
209   decrementDiscardableUses(uses);
210 
211   // Collect the new discardable uses within this node.
212   uses = CGUser();
213   DenseMap<Attribute, CallGraphNode *> resolvedRefs;
214   auto walkFn = [&](CallGraphNode *refNode, Operation *user) {
215     auto discardSymIt = discardableSymNodeUses.find(refNode);
216     if (discardSymIt == discardableSymNodeUses.end())
217       return;
218 
219     if (user != parentOp)
220       ++uses.innerUses[refNode];
221     else if (!uses.topLevelUses.insert(refNode).second)
222       return;
223     ++discardSymIt->second;
224   };
225   walkReferencedSymbolNodes(parentOp, cg, resolvedRefs, walkFn);
226 }
227 
228 void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) {
229   auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs];
230   for (auto &useIt : lhsUses.innerUses) {
231     rhsUses.innerUses[useIt.first] += useIt.second;
232     discardableSymNodeUses[useIt.first] += useIt.second;
233   }
234 }
235 
236 void CGUseList::decrementDiscardableUses(CGUser &uses) {
237   for (CallGraphNode *node : uses.topLevelUses)
238     --discardableSymNodeUses[node];
239   for (auto &it : uses.innerUses)
240     discardableSymNodeUses[it.first] -= it.second;
241 }
242 
243 //===----------------------------------------------------------------------===//
244 // CallGraph traversal
245 //===----------------------------------------------------------------------===//
246 
247 /// Run a given transformation over the SCCs of the callgraph in a bottom up
248 /// traversal.
249 static void runTransformOnCGSCCs(
250     const CallGraph &cg,
251     function_ref<void(MutableArrayRef<CallGraphNode *>)> sccTransformer) {
252   std::vector<CallGraphNode *> currentSCCVec;
253   auto cgi = llvm::scc_begin(&cg);
254   while (!cgi.isAtEnd()) {
255     // Copy the current SCC and increment so that the transformer can modify the
256     // SCC without invalidating our iterator.
257     currentSCCVec = *cgi;
258     ++cgi;
259     sccTransformer(currentSCCVec);
260   }
261 }
262 
263 namespace {
264 /// This struct represents a resolved call to a given callgraph node. Given that
265 /// the call does not actually contain a direct reference to the
266 /// Region(CallGraphNode) that it is dispatching to, we need to resolve them
267 /// explicitly.
268 struct ResolvedCall {
269   ResolvedCall(CallOpInterface call, CallGraphNode *sourceNode,
270                CallGraphNode *targetNode)
271       : call(call), sourceNode(sourceNode), targetNode(targetNode) {}
272   CallOpInterface call;
273   CallGraphNode *sourceNode, *targetNode;
274 };
275 } // end anonymous namespace
276 
277 /// Collect all of the callable operations within the given range of blocks. If
278 /// `traverseNestedCGNodes` is true, this will also collect call operations
279 /// inside of nested callgraph nodes.
280 static void collectCallOps(iterator_range<Region::iterator> blocks,
281                            CallGraphNode *sourceNode, CallGraph &cg,
282                            SmallVectorImpl<ResolvedCall> &calls,
283                            bool traverseNestedCGNodes) {
284   SmallVector<std::pair<Block *, CallGraphNode *>, 8> worklist;
285   auto addToWorklist = [&](CallGraphNode *node,
286                            iterator_range<Region::iterator> blocks) {
287     for (Block &block : blocks)
288       worklist.emplace_back(&block, node);
289   };
290 
291   addToWorklist(sourceNode, blocks);
292   while (!worklist.empty()) {
293     Block *block;
294     std::tie(block, sourceNode) = worklist.pop_back_val();
295 
296     for (Operation &op : *block) {
297       if (auto call = dyn_cast<CallOpInterface>(op)) {
298         // TODO(riverriddle) Support inlining nested call references.
299         CallInterfaceCallable callable = call.getCallableForCallee();
300         if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) {
301           if (!symRef.isa<FlatSymbolRefAttr>())
302             continue;
303         }
304 
305         CallGraphNode *targetNode = cg.resolveCallable(call);
306         if (!targetNode->isExternal())
307           calls.emplace_back(call, sourceNode, targetNode);
308         continue;
309       }
310 
311       // If this is not a call, traverse the nested regions. If
312       // `traverseNestedCGNodes` is false, then don't traverse nested call graph
313       // regions.
314       for (auto &nestedRegion : op.getRegions()) {
315         CallGraphNode *nestedNode = cg.lookupNode(&nestedRegion);
316         if (traverseNestedCGNodes || !nestedNode)
317           addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion);
318       }
319     }
320   }
321 }
322 
323 //===----------------------------------------------------------------------===//
324 // Inliner
325 //===----------------------------------------------------------------------===//
326 namespace {
327 /// This class provides a specialization of the main inlining interface.
328 struct Inliner : public InlinerInterface {
329   Inliner(MLIRContext *context, CallGraph &cg)
330       : InlinerInterface(context), cg(cg) {}
331 
332   /// Process a set of blocks that have been inlined. This callback is invoked
333   /// *before* inlined terminator operations have been processed.
334   void
335   processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final {
336     // Find the closest callgraph node from the first block.
337     CallGraphNode *node;
338     Region *region = inlinedBlocks.begin()->getParent();
339     while (!(node = cg.lookupNode(region))) {
340       region = region->getParentRegion();
341       assert(region && "expected valid parent node");
342     }
343 
344     collectCallOps(inlinedBlocks, node, cg, calls,
345                    /*traverseNestedCGNodes=*/true);
346   }
347 
348   /// The current set of call instructions to consider for inlining.
349   SmallVector<ResolvedCall, 8> calls;
350 
351   /// The callgraph being operated on.
352   CallGraph &cg;
353 };
354 } // namespace
355 
356 /// Returns true if the given call should be inlined.
357 static bool shouldInline(ResolvedCall &resolvedCall) {
358   // Don't allow inlining terminator calls. We currently don't support this
359   // case.
360   if (resolvedCall.call.getOperation()->isKnownTerminator())
361     return false;
362 
363   // Don't allow inlining if the target is an ancestor of the call. This
364   // prevents inlining recursively.
365   if (resolvedCall.targetNode->getCallableRegion()->isAncestor(
366           resolvedCall.call.getParentRegion()))
367     return false;
368 
369   // Otherwise, inline.
370   return true;
371 }
372 
373 /// Delete the given node and remove it from the current scc and the callgraph.
374 static void deleteNode(CallGraphNode *node, CGUseList &useList, CallGraph &cg,
375                        MutableArrayRef<CallGraphNode *> currentSCC) {
376   // Erase the parent operation and remove it from the various lists.
377   node->getCallableRegion()->getParentOp()->erase();
378   cg.eraseNode(node);
379 
380   // Replace this node in the currentSCC with the external node.
381   auto it = llvm::find(currentSCC, node);
382   if (it != currentSCC.end())
383     *it = cg.getExternalNode();
384 }
385 
386 /// Attempt to inline calls within the given scc. This function returns
387 /// success if any calls were inlined, failure otherwise.
388 static LogicalResult
389 inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
390                  MutableArrayRef<CallGraphNode *> currentSCC) {
391   CallGraph &cg = inliner.cg;
392   auto &calls = inliner.calls;
393 
394   // Collect all of the direct calls within the nodes of the current SCC. We
395   // don't traverse nested callgraph nodes, because they are handled separately
396   // likely within a different SCC.
397   for (CallGraphNode *node : currentSCC) {
398     if (node->isExternal())
399       continue;
400 
401     // If this node is dead, just delete it now.
402     if (useList.isDead(node))
403       deleteNode(node, useList, cg, currentSCC);
404     else
405       collectCallOps(*node->getCallableRegion(), node, cg, calls,
406                      /*traverseNestedCGNodes=*/false);
407   }
408   if (calls.empty())
409     return failure();
410 
411   // A set of dead nodes to remove after inlining.
412   SmallVector<CallGraphNode *, 1> deadNodes;
413 
414   // Try to inline each of the call operations. Don't cache the end iterator
415   // here as more calls may be added during inlining.
416   bool inlinedAnyCalls = false;
417   for (unsigned i = 0; i != calls.size(); ++i) {
418     ResolvedCall it = calls[i];
419     bool doInline = shouldInline(it);
420     LLVM_DEBUG({
421       if (doInline)
422         llvm::dbgs() << "* Inlining call: ";
423       else
424         llvm::dbgs() << "* Not inlining call: ";
425       it.call.dump();
426     });
427     if (!doInline)
428       continue;
429     CallOpInterface call = it.call;
430     Region *targetRegion = it.targetNode->getCallableRegion();
431 
432     // If this is the last call to the target node and the node is discardable,
433     // then inline it in-place and delete the node if successful.
434     bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode);
435 
436     LogicalResult inlineResult = inlineCall(
437         inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()),
438         targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace);
439     if (failed(inlineResult))
440       continue;
441     inlinedAnyCalls = true;
442 
443     // If the inlining was successful, Merge the new uses into the source node.
444     useList.dropCallUses(it.sourceNode, call.getOperation(), cg);
445     useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode);
446 
447     // then erase the call.
448     call.erase();
449 
450     // If we inlined in place, mark the node for deletion.
451     if (inlineInPlace) {
452       useList.eraseNode(it.targetNode);
453       deadNodes.push_back(it.targetNode);
454     }
455   }
456 
457   for (CallGraphNode *node : deadNodes)
458     deleteNode(node, useList, cg, currentSCC);
459   calls.clear();
460   return success(inlinedAnyCalls);
461 }
462 
463 /// Canonicalize the nodes within the given SCC with the given set of
464 /// canonicalization patterns.
465 static void canonicalizeSCC(CallGraph &cg, CGUseList &useList,
466                             MutableArrayRef<CallGraphNode *> currentSCC,
467                             MLIRContext *context,
468                             const OwningRewritePatternList &canonPatterns) {
469   // Collect the sets of nodes to canonicalize.
470   SmallVector<CallGraphNode *, 4> nodesToCanonicalize;
471   for (auto *node : currentSCC) {
472     // Don't canonicalize the external node, it has no valid callable region.
473     if (node->isExternal())
474       continue;
475 
476     // Don't canonicalize nodes with children. Nodes with children
477     // require special handling as we may remove the node during
478     // canonicalization. In the future, we should be able to handle this
479     // case with proper node deletion tracking.
480     if (node->hasChildren())
481       continue;
482 
483     // We also won't apply canonicalizations for nodes that are not
484     // isolated. This avoids potentially mutating the regions of nodes defined
485     // above, this is also a stipulation of the 'applyPatternsAndFoldGreedily'
486     // driver.
487     auto *region = node->getCallableRegion();
488     if (!region->getParentOp()->isKnownIsolatedFromAbove())
489       continue;
490     nodesToCanonicalize.push_back(node);
491   }
492   if (nodesToCanonicalize.empty())
493     return;
494 
495   // Canonicalize each of the nodes within the SCC in parallel.
496   // NOTE: This is simple now, because we don't enable canonicalizing nodes
497   // within children. When we remove this restriction, this logic will need to
498   // be reworked.
499   ParallelDiagnosticHandler canonicalizationHandler(context);
500   llvm::parallel::for_each_n(
501       llvm::parallel::par, /*Begin=*/size_t(0),
502       /*End=*/nodesToCanonicalize.size(), [&](size_t index) {
503         // Set the order for this thread so that diagnostics will be properly
504         // ordered.
505         canonicalizationHandler.setOrderIDForThread(index);
506 
507         // Apply the canonicalization patterns to this region.
508         auto *node = nodesToCanonicalize[index];
509         applyPatternsAndFoldGreedily(*node->getCallableRegion(), canonPatterns);
510 
511         // Make sure to reset the order ID for the diagnostic handler, as this
512         // thread may be used in a different context.
513         canonicalizationHandler.eraseOrderIDForThread();
514       });
515 
516   // Recompute the uses held by each of the nodes.
517   for (CallGraphNode *node : nodesToCanonicalize)
518     useList.recomputeUses(node, cg);
519 }
520 
521 //===----------------------------------------------------------------------===//
522 // InlinerPass
523 //===----------------------------------------------------------------------===//
524 
525 namespace {
526 struct InlinerPass : public InlinerBase<InlinerPass> {
527   void runOnOperation() override;
528 
529   /// Attempt to inline calls within the given scc, and run canonicalizations
530   /// with the given patterns, until a fixed point is reached. This allows for
531   /// the inlining of newly devirtualized calls.
532   void inlineSCC(Inliner &inliner, CGUseList &useList,
533                  MutableArrayRef<CallGraphNode *> currentSCC,
534                  MLIRContext *context,
535                  const OwningRewritePatternList &canonPatterns);
536 };
537 } // end anonymous namespace
538 
539 void InlinerPass::runOnOperation() {
540   CallGraph &cg = getAnalysis<CallGraph>();
541   auto *context = &getContext();
542 
543   // The inliner should only be run on operations that define a symbol table,
544   // as the callgraph will need to resolve references.
545   Operation *op = getOperation();
546   if (!op->hasTrait<OpTrait::SymbolTable>()) {
547     op->emitOpError() << " was scheduled to run under the inliner, but does "
548                          "not define a symbol table";
549     return signalPassFailure();
550   }
551 
552   // Collect a set of canonicalization patterns to use when simplifying
553   // callable regions within an SCC.
554   OwningRewritePatternList canonPatterns;
555   for (auto *op : context->getRegisteredOperations())
556     op->getCanonicalizationPatterns(canonPatterns, context);
557 
558   // Run the inline transform in post-order over the SCCs in the callgraph.
559   Inliner inliner(context, cg);
560   CGUseList useList(getOperation(), cg);
561   runTransformOnCGSCCs(cg, [&](MutableArrayRef<CallGraphNode *> scc) {
562     inlineSCC(inliner, useList, scc, context, canonPatterns);
563   });
564 }
565 
566 void InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList,
567                             MutableArrayRef<CallGraphNode *> currentSCC,
568                             MLIRContext *context,
569                             const OwningRewritePatternList &canonPatterns) {
570   // If we successfully inlined any calls, run some simplifications on the
571   // nodes of the scc. Continue attempting to inline until we reach a fixed
572   // point, or a maximum iteration count. We canonicalize here as it may
573   // devirtualize new calls, as well as give us a better cost model.
574   unsigned iterationCount = 0;
575   while (succeeded(inlineCallsInSCC(inliner, useList, currentSCC))) {
576     // If we aren't allowing simplifications or the max iteration count was
577     // reached, then bail out early.
578     if (disableCanonicalization || ++iterationCount >= maxInliningIterations)
579       break;
580     canonicalizeSCC(inliner.cg, useList, currentSCC, context, canonPatterns);
581   }
582 }
583 
584 std::unique_ptr<Pass> mlir::createInlinerPass() {
585   return std::make_unique<InlinerPass>();
586 }
587