1 //===--- SyntheticCountsUtils.cpp - synthetic counts propagation utils ---===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file defines utilities for propagating synthetic counts.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/Analysis/SyntheticCountsUtils.h"
15 #include "llvm/ADT/DenseSet.h"
16 #include "llvm/ADT/SCCIterator.h"
17 #include "llvm/ADT/SmallPtrSet.h"
18 #include "llvm/Analysis/CallGraph.h"
19 #include "llvm/IR/CallSite.h"
20 #include "llvm/IR/Function.h"
21 #include "llvm/IR/InstIterator.h"
22 #include "llvm/IR/Instructions.h"
23 
24 using namespace llvm;
25 
26 // Given a set of functions in an SCC, propagate entry counts to functions
27 // called by the SCC.
28 static void
29 propagateFromSCC(const SmallPtrSetImpl<Function *> &SCCFunctions,
30                  function_ref<Scaled64(CallSite CS)> GetCallSiteRelFreq,
31                  function_ref<uint64_t(Function *F)> GetCount,
32                  function_ref<void(Function *F, uint64_t)> AddToCount) {
33 
34   SmallVector<CallSite, 16> CallSites;
35 
36   // Gather all callsites in the SCC.
37   auto GatherCallSites = [&]() {
38     for (auto *F : SCCFunctions) {
39       assert(F && !F->isDeclaration());
40       for (auto &I : instructions(F)) {
41         if (auto CS = CallSite(&I)) {
42           CallSites.push_back(CS);
43         }
44       }
45     }
46   };
47 
48   GatherCallSites();
49 
50   // Partition callsites so that the callsites that call functions in the same
51   // SCC come first.
52   auto Mid = partition(CallSites, [&](CallSite &CS) {
53     auto *Callee = CS.getCalledFunction();
54     if (Callee)
55       return SCCFunctions.count(Callee);
56     // FIXME: Use the !callees metadata to propagate counts through indirect
57     // calls.
58     return 0U;
59   });
60 
61   // For functions in the same SCC, update the counts in two steps:
62   // 1. Compute the additional count for each function by propagating the counts
63   // along all incoming edges to the function that originate from the same SCC
64   // and summing them up.
65   // 2. Add the additional counts to the functions in the SCC.
66   // This ensures that the order of
67   // traversal of functions within the SCC doesn't change the final result.
68 
69   DenseMap<Function *, uint64_t> AdditionalCounts;
70   for (auto It = CallSites.begin(); It != Mid; It++) {
71     auto &CS = *It;
72     auto RelFreq = GetCallSiteRelFreq(CS);
73     Function *Callee = CS.getCalledFunction();
74     Function *Caller = CS.getCaller();
75     RelFreq *= Scaled64(GetCount(Caller), 0);
76     uint64_t AdditionalCount = RelFreq.toInt<uint64_t>();
77     AdditionalCounts[Callee] += AdditionalCount;
78   }
79 
80   // Update the counts for the functions in the SCC.
81   for (auto &Entry : AdditionalCounts)
82     AddToCount(Entry.first, Entry.second);
83 
84   // Now update the counts for functions not in SCC.
85   for (auto It = Mid; It != CallSites.end(); It++) {
86     auto &CS = *It;
87     auto Weight = GetCallSiteRelFreq(CS);
88     Function *Callee = CS.getCalledFunction();
89     Function *Caller = CS.getCaller();
90     Weight *= Scaled64(GetCount(Caller), 0);
91     AddToCount(Callee, Weight.toInt<uint64_t>());
92   }
93 }
94 
95 /// Propgate synthetic entry counts on a callgraph.
96 ///
97 /// This performs a reverse post-order traversal of the callgraph SCC. For each
98 /// SCC, it first propagates the entry counts to the functions within the SCC
99 /// through call edges and updates them in one shot. Then the entry counts are
100 /// propagated to functions outside the SCC.
101 void llvm::propagateSyntheticCounts(
102     const CallGraph &CG, function_ref<Scaled64(CallSite CS)> GetCallSiteRelFreq,
103     function_ref<uint64_t(Function *F)> GetCount,
104     function_ref<void(Function *F, uint64_t)> AddToCount) {
105 
106   SmallVector<SmallPtrSet<Function *, 8>, 16> SCCs;
107   for (auto I = scc_begin(&CG); !I.isAtEnd(); ++I) {
108     auto SCC = *I;
109 
110     SmallPtrSet<Function *, 8> SCCFunctions;
111     for (auto *Node : SCC) {
112       Function *F = Node->getFunction();
113       if (F && !F->isDeclaration()) {
114         SCCFunctions.insert(F);
115       }
116     }
117     SCCs.push_back(SCCFunctions);
118   }
119 
120   for (auto &SCCFunctions : reverse(SCCs))
121     propagateFromSCC(SCCFunctions, GetCallSiteRelFreq, GetCount, AddToCount);
122 }
123