1 //===--- SyntheticCountsUtils.cpp - synthetic counts propagation utils ---===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This file defines utilities for propagating synthetic counts. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/Analysis/SyntheticCountsUtils.h" 15 #include "llvm/ADT/DenseSet.h" 16 #include "llvm/ADT/SCCIterator.h" 17 #include "llvm/ADT/SmallPtrSet.h" 18 #include "llvm/Analysis/CallGraph.h" 19 #include "llvm/IR/CallSite.h" 20 #include "llvm/IR/Function.h" 21 #include "llvm/IR/InstIterator.h" 22 #include "llvm/IR/Instructions.h" 23 24 using namespace llvm; 25 26 // Given a set of functions in an SCC, propagate entry counts to functions 27 // called by the SCC. 28 static void 29 propagateFromSCC(const SmallPtrSetImpl<Function *> &SCCFunctions, 30 function_ref<Scaled64(CallSite CS)> GetCallSiteRelFreq, 31 function_ref<uint64_t(Function *F)> GetCount, 32 function_ref<void(Function *F, uint64_t)> AddToCount) { 33 34 SmallVector<CallSite, 16> CallSites; 35 36 // Gather all callsites in the SCC. 37 auto GatherCallSites = [&]() { 38 for (auto *F : SCCFunctions) { 39 assert(F && !F->isDeclaration()); 40 for (auto &I : instructions(F)) { 41 if (auto CS = CallSite(&I)) { 42 CallSites.push_back(CS); 43 } 44 } 45 } 46 }; 47 48 GatherCallSites(); 49 50 // Partition callsites so that the callsites that call functions in the same 51 // SCC come first. 52 auto Mid = partition(CallSites, [&](CallSite &CS) { 53 auto *Callee = CS.getCalledFunction(); 54 if (Callee) 55 return SCCFunctions.count(Callee); 56 // FIXME: Use the !callees metadata to propagate counts through indirect 57 // calls. 58 return 0U; 59 }); 60 61 // For functions in the same SCC, update the counts in two steps: 62 // 1. Compute the additional count for each function by propagating the counts 63 // along all incoming edges to the function that originate from the same SCC 64 // and summing them up. 65 // 2. Add the additional counts to the functions in the SCC. 66 // This ensures that the order of 67 // traversal of functions within the SCC doesn't change the final result. 68 69 DenseMap<Function *, uint64_t> AdditionalCounts; 70 for (auto It = CallSites.begin(); It != Mid; It++) { 71 auto &CS = *It; 72 auto RelFreq = GetCallSiteRelFreq(CS); 73 Function *Callee = CS.getCalledFunction(); 74 Function *Caller = CS.getCaller(); 75 RelFreq *= Scaled64(GetCount(Caller), 0); 76 uint64_t AdditionalCount = RelFreq.toInt<uint64_t>(); 77 AdditionalCounts[Callee] += AdditionalCount; 78 } 79 80 // Update the counts for the functions in the SCC. 81 for (auto &Entry : AdditionalCounts) 82 AddToCount(Entry.first, Entry.second); 83 84 // Now update the counts for functions not in SCC. 85 for (auto It = Mid; It != CallSites.end(); It++) { 86 auto &CS = *It; 87 auto Weight = GetCallSiteRelFreq(CS); 88 Function *Callee = CS.getCalledFunction(); 89 Function *Caller = CS.getCaller(); 90 Weight *= Scaled64(GetCount(Caller), 0); 91 AddToCount(Callee, Weight.toInt<uint64_t>()); 92 } 93 } 94 95 /// Propgate synthetic entry counts on a callgraph. 96 /// 97 /// This performs a reverse post-order traversal of the callgraph SCC. For each 98 /// SCC, it first propagates the entry counts to the functions within the SCC 99 /// through call edges and updates them in one shot. Then the entry counts are 100 /// propagated to functions outside the SCC. 101 void llvm::propagateSyntheticCounts( 102 const CallGraph &CG, function_ref<Scaled64(CallSite CS)> GetCallSiteRelFreq, 103 function_ref<uint64_t(Function *F)> GetCount, 104 function_ref<void(Function *F, uint64_t)> AddToCount) { 105 106 SmallVector<SmallPtrSet<Function *, 8>, 16> SCCs; 107 for (auto I = scc_begin(&CG); !I.isAtEnd(); ++I) { 108 auto SCC = *I; 109 110 SmallPtrSet<Function *, 8> SCCFunctions; 111 for (auto *Node : SCC) { 112 Function *F = Node->getFunction(); 113 if (F && !F->isDeclaration()) { 114 SCCFunctions.insert(F); 115 } 116 } 117 SCCs.push_back(SCCFunctions); 118 } 119 120 for (auto &SCCFunctions : reverse(SCCs)) 121 propagateFromSCC(SCCFunctions, GetCallSiteRelFreq, GetCount, AddToCount); 122 } 123