1 //===- CallPrinter.cpp - DOT printer for call graph -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines '-dot-callgraph', which emit a callgraph.<fnname>.dot
10 // containing the call graph of a module.
11 //
12 // There is also a pass available to directly call dotty ('-view-callgraph').
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "llvm/Analysis/CallPrinter.h"
17 #include "llvm/ADT/DenseMap.h"
18 #include "llvm/ADT/SmallSet.h"
19 #include "llvm/Analysis/BlockFrequencyInfo.h"
20 #include "llvm/Analysis/CallGraph.h"
21 #include "llvm/Analysis/HeatUtils.h"
22 #include "llvm/IR/Instructions.h"
23 #include "llvm/InitializePasses.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/DOTGraphTraits.h"
26 #include "llvm/Support/GraphWriter.h"
27 
28 using namespace llvm;
29 
30 namespace llvm {
31 template <class GraphType> struct GraphTraits;
32 }
33 
34 // This option shows static (relative) call counts.
35 // FIXME:
36 // Need to show real counts when profile data is available
37 static cl::opt<bool> ShowHeatColors("callgraph-heat-colors", cl::init(false),
38                                     cl::Hidden,
39                                     cl::desc("Show heat colors in call-graph"));
40 
41 static cl::opt<bool>
42     ShowEdgeWeight("callgraph-show-weights", cl::init(false), cl::Hidden,
43                        cl::desc("Show edges labeled with weights"));
44 
45 static cl::opt<bool>
46     CallMultiGraph("callgraph-multigraph", cl::init(false), cl::Hidden,
47             cl::desc("Show call-multigraph (do not remove parallel edges)"));
48 
49 static cl::opt<std::string> CallGraphDotFilenamePrefix(
50     "callgraph-dot-filename-prefix", cl::Hidden,
51     cl::desc("The prefix used for the CallGraph dot file names."));
52 
53 namespace llvm {
54 
55 class CallGraphDOTInfo {
56 private:
57   Module *M;
58   CallGraph *CG;
59   DenseMap<const Function *, uint64_t> Freq;
60   uint64_t MaxFreq;
61 
62 public:
63   std::function<BlockFrequencyInfo *(Function &)> LookupBFI;
64 
65   CallGraphDOTInfo(Module *M, CallGraph *CG,
66                    function_ref<BlockFrequencyInfo *(Function &)> LookupBFI)
67       : M(M), CG(CG), LookupBFI(LookupBFI) {
68     MaxFreq = 0;
69 
70     for (Function &F : M->getFunctionList()) {
71       uint64_t localSumFreq = 0;
72       SmallSet<Function *, 16> Callers;
73       for (User *U : F.users())
74         if (isa<CallInst>(U))
75           Callers.insert(cast<Instruction>(U)->getFunction());
76       for (Function *Caller : Callers)
77         localSumFreq += getNumOfCalls(*Caller, F);
78       if (localSumFreq >= MaxFreq)
79         MaxFreq = localSumFreq;
80       Freq[&F] = localSumFreq;
81     }
82     if (!CallMultiGraph)
83       removeParallelEdges();
84   }
85 
86   Module *getModule() const { return M; }
87 
88   CallGraph *getCallGraph() const { return CG; }
89 
90   uint64_t getFreq(const Function *F) { return Freq[F]; }
91 
92   uint64_t getMaxFreq() { return MaxFreq; }
93 
94 private:
95   void removeParallelEdges() {
96     for (auto &I : (*CG)) {
97       CallGraphNode *Node = I.second.get();
98 
99       bool FoundParallelEdge = true;
100       while (FoundParallelEdge) {
101         SmallSet<Function *, 16> Visited;
102         FoundParallelEdge = false;
103         for (auto CI = Node->begin(), CE = Node->end(); CI != CE; CI++) {
104           if (!(Visited.insert(CI->second->getFunction())).second) {
105             FoundParallelEdge = true;
106             Node->removeCallEdge(CI);
107             break;
108           }
109         }
110       }
111     }
112   }
113 };
114 
115 template <>
116 struct GraphTraits<CallGraphDOTInfo *>
117     : public GraphTraits<const CallGraphNode *> {
118   static NodeRef getEntryNode(CallGraphDOTInfo *CGInfo) {
119     // Start at the external node!
120     return CGInfo->getCallGraph()->getExternalCallingNode();
121   }
122 
123   typedef std::pair<const Function *const, std::unique_ptr<CallGraphNode>>
124       PairTy;
125   static const CallGraphNode *CGGetValuePtr(const PairTy &P) {
126     return P.second.get();
127   }
128 
129   // nodes_iterator/begin/end - Allow iteration over all nodes in the graph
130   typedef mapped_iterator<CallGraph::const_iterator, decltype(&CGGetValuePtr)>
131       nodes_iterator;
132 
133   static nodes_iterator nodes_begin(CallGraphDOTInfo *CGInfo) {
134     return nodes_iterator(CGInfo->getCallGraph()->begin(), &CGGetValuePtr);
135   }
136   static nodes_iterator nodes_end(CallGraphDOTInfo *CGInfo) {
137     return nodes_iterator(CGInfo->getCallGraph()->end(), &CGGetValuePtr);
138   }
139 };
140 
141 template <>
142 struct DOTGraphTraits<CallGraphDOTInfo *> : public DefaultDOTGraphTraits {
143 
144   DOTGraphTraits(bool isSimple = false) : DefaultDOTGraphTraits(isSimple) {}
145 
146   static std::string getGraphName(CallGraphDOTInfo *CGInfo) {
147     return "Call graph: " +
148            std::string(CGInfo->getModule()->getModuleIdentifier());
149   }
150 
151   static bool isNodeHidden(const CallGraphNode *Node,
152                            const CallGraphDOTInfo *CGInfo) {
153     if (CallMultiGraph || Node->getFunction())
154       return false;
155     return true;
156   }
157 
158   std::string getNodeLabel(const CallGraphNode *Node,
159                            CallGraphDOTInfo *CGInfo) {
160     if (Node == CGInfo->getCallGraph()->getExternalCallingNode())
161       return "external caller";
162     if (Node == CGInfo->getCallGraph()->getCallsExternalNode())
163       return "external callee";
164 
165     if (Function *Func = Node->getFunction())
166       return std::string(Func->getName());
167     return "external node";
168   }
169   static const CallGraphNode *CGGetValuePtr(CallGraphNode::CallRecord P) {
170     return P.second;
171   }
172 
173   // nodes_iterator/begin/end - Allow iteration over all nodes in the graph
174   typedef mapped_iterator<CallGraphNode::const_iterator,
175                           decltype(&CGGetValuePtr)>
176       nodes_iterator;
177 
178   std::string getEdgeAttributes(const CallGraphNode *Node, nodes_iterator I,
179                                 CallGraphDOTInfo *CGInfo) {
180     if (!ShowEdgeWeight)
181       return "";
182 
183     Function *Caller = Node->getFunction();
184     if (Caller == nullptr || Caller->isDeclaration())
185       return "";
186 
187     Function *Callee = (*I)->getFunction();
188     if (Callee == nullptr)
189       return "";
190 
191     uint64_t Counter = getNumOfCalls(*Caller, *Callee);
192     double Width =
193         1 + 2 * (double(Counter) / CGInfo->getMaxFreq());
194     std::string Attrs = "label=\"" + std::to_string(Counter) +
195                         "\" penwidth=" + std::to_string(Width);
196     return Attrs;
197   }
198 
199   std::string getNodeAttributes(const CallGraphNode *Node,
200                                 CallGraphDOTInfo *CGInfo) {
201     Function *F = Node->getFunction();
202     if (F == nullptr)
203       return "";
204     std::string attrs;
205     if (ShowHeatColors) {
206       uint64_t freq = CGInfo->getFreq(F);
207       std::string color = getHeatColor(freq, CGInfo->getMaxFreq());
208       std::string edgeColor = (freq <= (CGInfo->getMaxFreq() / 2))
209                                   ? getHeatColor(0)
210                                   : getHeatColor(1);
211       attrs = "color=\"" + edgeColor + "ff\", style=filled, fillcolor=\"" +
212               color + "80\"";
213     }
214     return attrs;
215   }
216 };
217 
218 } // end llvm namespace
219 
220 namespace {
221 // Viewer
222 class CallGraphViewer : public ModulePass {
223 public:
224   static char ID;
225   CallGraphViewer() : ModulePass(ID) {}
226 
227   void getAnalysisUsage(AnalysisUsage &AU) const override;
228   bool runOnModule(Module &M) override;
229 };
230 
231 void CallGraphViewer::getAnalysisUsage(AnalysisUsage &AU) const {
232   ModulePass::getAnalysisUsage(AU);
233   AU.addRequired<BlockFrequencyInfoWrapperPass>();
234   AU.setPreservesAll();
235 }
236 
237 bool CallGraphViewer::runOnModule(Module &M) {
238   auto LookupBFI = [this](Function &F) {
239     return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI();
240   };
241 
242   CallGraph CG(M);
243   CallGraphDOTInfo CFGInfo(&M, &CG, LookupBFI);
244 
245   std::string Title =
246       DOTGraphTraits<CallGraphDOTInfo *>::getGraphName(&CFGInfo);
247   ViewGraph(&CFGInfo, "callgraph", true, Title);
248 
249   return false;
250 }
251 
252 // DOT Printer
253 
254 class CallGraphDOTPrinter : public ModulePass {
255 public:
256   static char ID;
257   CallGraphDOTPrinter() : ModulePass(ID) {}
258 
259   void getAnalysisUsage(AnalysisUsage &AU) const override;
260   bool runOnModule(Module &M) override;
261 };
262 
263 void CallGraphDOTPrinter::getAnalysisUsage(AnalysisUsage &AU) const {
264   ModulePass::getAnalysisUsage(AU);
265   AU.addRequired<BlockFrequencyInfoWrapperPass>();
266   AU.setPreservesAll();
267 }
268 
269 bool CallGraphDOTPrinter::runOnModule(Module &M) {
270   auto LookupBFI = [this](Function &F) {
271     return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI();
272   };
273 
274   std::string Filename;
275   if (!CallGraphDotFilenamePrefix.empty())
276     Filename = (CallGraphDotFilenamePrefix + ".callgraph.dot");
277   else
278     Filename = (std::string(M.getModuleIdentifier()) + ".callgraph.dot");
279   errs() << "Writing '" << Filename << "'...";
280 
281   std::error_code EC;
282   raw_fd_ostream File(Filename, EC, sys::fs::OF_Text);
283 
284   CallGraph CG(M);
285   CallGraphDOTInfo CFGInfo(&M, &CG, LookupBFI);
286 
287   if (!EC)
288     WriteGraph(File, &CFGInfo);
289   else
290     errs() << "  error opening file for writing!";
291   errs() << "\n";
292 
293   return false;
294 }
295 
296 } // end anonymous namespace
297 
298 char CallGraphViewer::ID = 0;
299 INITIALIZE_PASS(CallGraphViewer, "view-callgraph", "View call graph", false,
300                 false)
301 
302 char CallGraphDOTPrinter::ID = 0;
303 INITIALIZE_PASS(CallGraphDOTPrinter, "dot-callgraph",
304                 "Print call graph to 'dot' file", false, false)
305 
306 // Create methods available outside of this file, to use them
307 // "include/llvm/LinkAllPasses.h". Otherwise the pass would be deleted by
308 // the link time optimization.
309 
310 ModulePass *llvm::createCallGraphViewerPass() { return new CallGraphViewer(); }
311 
312 ModulePass *llvm::createCallGraphDOTPrinterPass() {
313   return new CallGraphDOTPrinter();
314 }
315