1 //===- SampleContextTracker.cpp - Context-sensitive Profile Tracker -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the SampleContextTracker used by CSSPGO.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "llvm/Transforms/IPO/SampleContextTracker.h"
14 #include "llvm/ADT/StringMap.h"
15 #include "llvm/ADT/StringRef.h"
16 #include "llvm/IR/DebugInfoMetadata.h"
17 #include "llvm/IR/Instructions.h"
18 #include "llvm/ProfileData/SampleProf.h"
19 #include <map>
20 #include <queue>
21 #include <vector>
22 
23 using namespace llvm;
24 using namespace sampleprof;
25 
26 #define DEBUG_TYPE "sample-context-tracker"
27 
28 namespace llvm {
29 
30 ContextTrieNode *ContextTrieNode::getChildContext(const LineLocation &CallSite,
31                                                   StringRef CalleeName) {
32   if (CalleeName.empty())
33     return getHottestChildContext(CallSite);
34 
35   uint32_t Hash = nodeHash(CalleeName, CallSite);
36   auto It = AllChildContext.find(Hash);
37   if (It != AllChildContext.end())
38     return &It->second;
39   return nullptr;
40 }
41 
42 ContextTrieNode *
43 ContextTrieNode::getHottestChildContext(const LineLocation &CallSite) {
44   // CSFDO-TODO: This could be slow, change AllChildContext so we can
45   // do point look up for child node by call site alone.
46   // Retrieve the child node with max count for indirect call
47   ContextTrieNode *ChildNodeRet = nullptr;
48   uint64_t MaxCalleeSamples = 0;
49   for (auto &It : AllChildContext) {
50     ContextTrieNode &ChildNode = It.second;
51     if (ChildNode.CallSiteLoc != CallSite)
52       continue;
53     FunctionSamples *Samples = ChildNode.getFunctionSamples();
54     if (!Samples)
55       continue;
56     if (Samples->getTotalSamples() > MaxCalleeSamples) {
57       ChildNodeRet = &ChildNode;
58       MaxCalleeSamples = Samples->getTotalSamples();
59     }
60   }
61 
62   return ChildNodeRet;
63 }
64 
65 ContextTrieNode &ContextTrieNode::moveToChildContext(
66     const LineLocation &CallSite, ContextTrieNode &&NodeToMove,
67     StringRef ContextStrToRemove, bool DeleteNode) {
68   uint32_t Hash = nodeHash(NodeToMove.getFuncName(), CallSite);
69   assert(!AllChildContext.count(Hash) && "Node to remove must exist");
70   LineLocation OldCallSite = NodeToMove.CallSiteLoc;
71   ContextTrieNode &OldParentContext = *NodeToMove.getParentContext();
72   AllChildContext[Hash] = NodeToMove;
73   ContextTrieNode &NewNode = AllChildContext[Hash];
74   NewNode.CallSiteLoc = CallSite;
75 
76   // Walk through nodes in the moved the subtree, and update
77   // FunctionSamples' context as for the context promotion.
78   // We also need to set new parant link for all children.
79   std::queue<ContextTrieNode *> NodeToUpdate;
80   NewNode.setParentContext(this);
81   NodeToUpdate.push(&NewNode);
82 
83   while (!NodeToUpdate.empty()) {
84     ContextTrieNode *Node = NodeToUpdate.front();
85     NodeToUpdate.pop();
86     FunctionSamples *FSamples = Node->getFunctionSamples();
87 
88     if (FSamples) {
89       FSamples->getContext().promoteOnPath(ContextStrToRemove);
90       FSamples->getContext().setState(SyntheticContext);
91       LLVM_DEBUG(dbgs() << "  Context promoted to: " << FSamples->getContext()
92                         << "\n");
93     }
94 
95     for (auto &It : Node->getAllChildContext()) {
96       ContextTrieNode *ChildNode = &It.second;
97       ChildNode->setParentContext(Node);
98       NodeToUpdate.push(ChildNode);
99     }
100   }
101 
102   // Original context no longer needed, destroy if requested.
103   if (DeleteNode)
104     OldParentContext.removeChildContext(OldCallSite, NewNode.getFuncName());
105 
106   return NewNode;
107 }
108 
109 void ContextTrieNode::removeChildContext(const LineLocation &CallSite,
110                                          StringRef CalleeName) {
111   uint32_t Hash = nodeHash(CalleeName, CallSite);
112   // Note this essentially calls dtor and destroys that child context
113   AllChildContext.erase(Hash);
114 }
115 
116 std::map<uint32_t, ContextTrieNode> &ContextTrieNode::getAllChildContext() {
117   return AllChildContext;
118 }
119 
120 StringRef ContextTrieNode::getFuncName() const { return FuncName; }
121 
122 FunctionSamples *ContextTrieNode::getFunctionSamples() const {
123   return FuncSamples;
124 }
125 
126 void ContextTrieNode::setFunctionSamples(FunctionSamples *FSamples) {
127   FuncSamples = FSamples;
128 }
129 
130 LineLocation ContextTrieNode::getCallSiteLoc() const { return CallSiteLoc; }
131 
132 ContextTrieNode *ContextTrieNode::getParentContext() const {
133   return ParentContext;
134 }
135 
136 void ContextTrieNode::setParentContext(ContextTrieNode *Parent) {
137   ParentContext = Parent;
138 }
139 
140 void ContextTrieNode::dump() {
141   dbgs() << "Node: " << FuncName << "\n"
142          << "  Callsite: " << CallSiteLoc << "\n"
143          << "  Children:\n";
144 
145   for (auto &It : AllChildContext) {
146     dbgs() << "    Node: " << It.second.getFuncName() << "\n";
147   }
148 }
149 
150 uint32_t ContextTrieNode::nodeHash(StringRef ChildName,
151                                    const LineLocation &Callsite) {
152   // We still use child's name for child hash, this is
153   // because for children of root node, we don't have
154   // different line/discriminator, and we'll rely on name
155   // to differentiate children.
156   uint32_t NameHash = std::hash<std::string>{}(ChildName.str());
157   uint32_t LocId = (Callsite.LineOffset << 16) | Callsite.Discriminator;
158   return NameHash + (LocId << 5) + LocId;
159 }
160 
161 ContextTrieNode *ContextTrieNode::getOrCreateChildContext(
162     const LineLocation &CallSite, StringRef CalleeName, bool AllowCreate) {
163   uint32_t Hash = nodeHash(CalleeName, CallSite);
164   auto It = AllChildContext.find(Hash);
165   if (It != AllChildContext.end()) {
166     assert(It->second.getFuncName() == CalleeName &&
167            "Hash collision for child context node");
168     return &It->second;
169   }
170 
171   if (!AllowCreate)
172     return nullptr;
173 
174   AllChildContext[Hash] = ContextTrieNode(this, CalleeName, nullptr, CallSite);
175   return &AllChildContext[Hash];
176 }
177 
178 // Profiler tracker than manages profiles and its associated context
179 SampleContextTracker::SampleContextTracker(
180     StringMap<FunctionSamples> &Profiles) {
181   for (auto &FuncSample : Profiles) {
182     FunctionSamples *FSamples = &FuncSample.second;
183     SampleContext Context(FuncSample.first(), RawContext);
184     LLVM_DEBUG(dbgs() << "Tracking Context for function: " << Context << "\n");
185     if (!Context.isBaseContext())
186       FuncToCtxtProfileSet[Context.getNameWithoutContext()].insert(FSamples);
187     ContextTrieNode *NewNode = getOrCreateContextPath(Context, true);
188     assert(!NewNode->getFunctionSamples() &&
189            "New node can't have sample profile");
190     NewNode->setFunctionSamples(FSamples);
191   }
192 }
193 
194 FunctionSamples *
195 SampleContextTracker::getCalleeContextSamplesFor(const CallBase &Inst,
196                                                  StringRef CalleeName) {
197   LLVM_DEBUG(dbgs() << "Getting callee context for instr: " << Inst << "\n");
198   DILocation *DIL = Inst.getDebugLoc();
199   if (!DIL)
200     return nullptr;
201 
202   CalleeName = FunctionSamples::getCanonicalFnName(CalleeName);
203 
204   // For indirect call, CalleeName will be empty, in which case the context
205   // profile for callee with largest total samples will be returned.
206   ContextTrieNode *CalleeContext = getCalleeContextFor(DIL, CalleeName);
207   if (CalleeContext) {
208     FunctionSamples *FSamples = CalleeContext->getFunctionSamples();
209     LLVM_DEBUG(if (FSamples) {
210       dbgs() << "  Callee context found: " << FSamples->getContext() << "\n";
211     });
212     return FSamples;
213   }
214 
215   return nullptr;
216 }
217 
218 std::vector<const FunctionSamples *>
219 SampleContextTracker::getIndirectCalleeContextSamplesFor(
220     const DILocation *DIL) {
221   std::vector<const FunctionSamples *> R;
222   if (!DIL)
223     return R;
224 
225   ContextTrieNode *CallerNode = getContextFor(DIL);
226   LineLocation CallSite = FunctionSamples::getCallSiteIdentifier(DIL);
227   for (auto &It : CallerNode->getAllChildContext()) {
228     ContextTrieNode &ChildNode = It.second;
229     if (ChildNode.getCallSiteLoc() != CallSite)
230       continue;
231     if (FunctionSamples *CalleeSamples = ChildNode.getFunctionSamples())
232       R.push_back(CalleeSamples);
233   }
234 
235   return R;
236 }
237 
238 FunctionSamples *
239 SampleContextTracker::getContextSamplesFor(const DILocation *DIL) {
240   assert(DIL && "Expect non-null location");
241 
242   ContextTrieNode *ContextNode = getContextFor(DIL);
243   if (!ContextNode)
244     return nullptr;
245 
246   // We may have inlined callees during pre-LTO compilation, in which case
247   // we need to rely on the inline stack from !dbg to mark context profile
248   // as inlined, instead of `MarkContextSamplesInlined` during inlining.
249   // Sample profile loader walks through all instructions to get profile,
250   // which calls this function. So once that is done, all previously inlined
251   // context profile should be marked properly.
252   FunctionSamples *Samples = ContextNode->getFunctionSamples();
253   if (Samples && ContextNode->getParentContext() != &RootContext)
254     Samples->getContext().setState(InlinedContext);
255 
256   return Samples;
257 }
258 
259 FunctionSamples *
260 SampleContextTracker::getContextSamplesFor(const SampleContext &Context) {
261   ContextTrieNode *Node = getContextFor(Context);
262   if (!Node)
263     return nullptr;
264 
265   return Node->getFunctionSamples();
266 }
267 
268 SampleContextTracker::ContextSamplesTy &
269 SampleContextTracker::getAllContextSamplesFor(const Function &Func) {
270   StringRef CanonName = FunctionSamples::getCanonicalFnName(Func);
271   return FuncToCtxtProfileSet[CanonName];
272 }
273 
274 SampleContextTracker::ContextSamplesTy &
275 SampleContextTracker::getAllContextSamplesFor(StringRef Name) {
276   return FuncToCtxtProfileSet[Name];
277 }
278 
279 FunctionSamples *SampleContextTracker::getBaseSamplesFor(const Function &Func,
280                                                          bool MergeContext) {
281   StringRef CanonName = FunctionSamples::getCanonicalFnName(Func);
282   return getBaseSamplesFor(CanonName, MergeContext);
283 }
284 
285 FunctionSamples *SampleContextTracker::getBaseSamplesFor(StringRef Name,
286                                                          bool MergeContext) {
287   LLVM_DEBUG(dbgs() << "Getting base profile for function: " << Name << "\n");
288   // Base profile is top-level node (child of root node), so try to retrieve
289   // existing top-level node for given function first. If it exists, it could be
290   // that we've merged base profile before, or there's actually context-less
291   // profile from the input (e.g. due to unreliable stack walking).
292   ContextTrieNode *Node = getTopLevelContextNode(Name);
293   if (MergeContext) {
294     LLVM_DEBUG(dbgs() << "  Merging context profile into base profile: " << Name
295                       << "\n");
296 
297     // We have profile for function under different contexts,
298     // create synthetic base profile and merge context profiles
299     // into base profile.
300     for (auto *CSamples : FuncToCtxtProfileSet[Name]) {
301       SampleContext &Context = CSamples->getContext();
302       ContextTrieNode *FromNode = getContextFor(Context);
303       if (FromNode == Node)
304         continue;
305 
306       // Skip inlined context profile and also don't re-merge any context
307       if (Context.hasState(InlinedContext) || Context.hasState(MergedContext))
308         continue;
309 
310       ContextTrieNode &ToNode = promoteMergeContextSamplesTree(*FromNode);
311       assert((!Node || Node == &ToNode) && "Expect only one base profile");
312       Node = &ToNode;
313     }
314   }
315 
316   // Still no profile even after merge/promotion (if allowed)
317   if (!Node)
318     return nullptr;
319 
320   return Node->getFunctionSamples();
321 }
322 
323 void SampleContextTracker::markContextSamplesInlined(
324     const FunctionSamples *InlinedSamples) {
325   assert(InlinedSamples && "Expect non-null inlined samples");
326   LLVM_DEBUG(dbgs() << "Marking context profile as inlined: "
327                     << InlinedSamples->getContext() << "\n");
328   InlinedSamples->getContext().setState(InlinedContext);
329 }
330 
331 void SampleContextTracker::promoteMergeContextSamplesTree(
332     const Instruction &Inst, StringRef CalleeName) {
333   LLVM_DEBUG(dbgs() << "Promoting and merging context tree for instr: \n"
334                     << Inst << "\n");
335   // Get the caller context for the call instruction, we don't use callee
336   // name from call because there can be context from indirect calls too.
337   DILocation *DIL = Inst.getDebugLoc();
338   ContextTrieNode *CallerNode = getContextFor(DIL);
339   if (!CallerNode)
340     return;
341 
342   // Get the context that needs to be promoted
343   LineLocation CallSite = FunctionSamples::getCallSiteIdentifier(DIL);
344   // For indirect call, CalleeName will be empty, in which case we need to
345   // promote all non-inlined child context profiles.
346   if (CalleeName.empty()) {
347     for (auto &It : CallerNode->getAllChildContext()) {
348       ContextTrieNode *NodeToPromo = &It.second;
349       if (CallSite != NodeToPromo->getCallSiteLoc())
350         continue;
351       FunctionSamples *FromSamples = NodeToPromo->getFunctionSamples();
352       if (FromSamples && FromSamples->getContext().hasState(InlinedContext))
353         continue;
354       promoteMergeContextSamplesTree(*NodeToPromo);
355     }
356     return;
357   }
358 
359   // Get the context for the given callee that needs to be promoted
360   ContextTrieNode *NodeToPromo =
361       CallerNode->getChildContext(CallSite, CalleeName);
362   if (!NodeToPromo)
363     return;
364 
365   promoteMergeContextSamplesTree(*NodeToPromo);
366 }
367 
368 ContextTrieNode &SampleContextTracker::promoteMergeContextSamplesTree(
369     ContextTrieNode &NodeToPromo) {
370   // Promote the input node to be directly under root. This can happen
371   // when we decided to not inline a function under context represented
372   // by the input node. The promote and merge is then needed to reflect
373   // the context profile in the base (context-less) profile.
374   FunctionSamples *FromSamples = NodeToPromo.getFunctionSamples();
375   assert(FromSamples && "Shouldn't promote a context without profile");
376   LLVM_DEBUG(dbgs() << "  Found context tree root to promote: "
377                     << FromSamples->getContext() << "\n");
378 
379   assert(!FromSamples->getContext().hasState(InlinedContext) &&
380          "Shouldn't promote inlined context profile");
381   StringRef ContextStrToRemove = FromSamples->getContext().getCallingContext();
382   return promoteMergeContextSamplesTree(NodeToPromo, RootContext,
383                                         ContextStrToRemove);
384 }
385 
386 void SampleContextTracker::dump() {
387   dbgs() << "Context Profile Tree:\n";
388   std::queue<ContextTrieNode *> NodeQueue;
389   NodeQueue.push(&RootContext);
390 
391   while (!NodeQueue.empty()) {
392     ContextTrieNode *Node = NodeQueue.front();
393     NodeQueue.pop();
394     Node->dump();
395 
396     for (auto &It : Node->getAllChildContext()) {
397       ContextTrieNode *ChildNode = &It.second;
398       NodeQueue.push(ChildNode);
399     }
400   }
401 }
402 
403 ContextTrieNode *
404 SampleContextTracker::getContextFor(const SampleContext &Context) {
405   return getOrCreateContextPath(Context, false);
406 }
407 
408 ContextTrieNode *
409 SampleContextTracker::getCalleeContextFor(const DILocation *DIL,
410                                           StringRef CalleeName) {
411   assert(DIL && "Expect non-null location");
412 
413   ContextTrieNode *CallContext = getContextFor(DIL);
414   if (!CallContext)
415     return nullptr;
416 
417   // When CalleeName is empty, the child context profile with max
418   // total samples will be returned.
419   return CallContext->getChildContext(
420       FunctionSamples::getCallSiteIdentifier(DIL), CalleeName);
421 }
422 
423 ContextTrieNode *SampleContextTracker::getContextFor(const DILocation *DIL) {
424   assert(DIL && "Expect non-null location");
425   SmallVector<std::pair<LineLocation, StringRef>, 10> S;
426 
427   // Use C++ linkage name if possible.
428   const DILocation *PrevDIL = DIL;
429   for (DIL = DIL->getInlinedAt(); DIL; DIL = DIL->getInlinedAt()) {
430     StringRef Name = PrevDIL->getScope()->getSubprogram()->getLinkageName();
431     if (Name.empty())
432       Name = PrevDIL->getScope()->getSubprogram()->getName();
433     S.push_back(
434         std::make_pair(FunctionSamples::getCallSiteIdentifier(DIL),
435                        PrevDIL->getScope()->getSubprogram()->getLinkageName()));
436     PrevDIL = DIL;
437   }
438 
439   // Push root node, note that root node like main may only
440   // a name, but not linkage name.
441   StringRef RootName = PrevDIL->getScope()->getSubprogram()->getLinkageName();
442   if (RootName.empty())
443     RootName = PrevDIL->getScope()->getSubprogram()->getName();
444   S.push_back(std::make_pair(LineLocation(0, 0), RootName));
445 
446   ContextTrieNode *ContextNode = &RootContext;
447   int I = S.size();
448   while (--I >= 0 && ContextNode) {
449     LineLocation &CallSite = S[I].first;
450     StringRef &CalleeName = S[I].second;
451     ContextNode = ContextNode->getChildContext(CallSite, CalleeName);
452   }
453 
454   if (I < 0)
455     return ContextNode;
456 
457   return nullptr;
458 }
459 
460 ContextTrieNode *
461 SampleContextTracker::getOrCreateContextPath(const SampleContext &Context,
462                                              bool AllowCreate) {
463   ContextTrieNode *ContextNode = &RootContext;
464   StringRef ContextRemain = Context;
465   StringRef ChildContext;
466   StringRef CalleeName;
467   LineLocation CallSiteLoc(0, 0);
468 
469   while (ContextNode && !ContextRemain.empty()) {
470     auto ContextSplit = SampleContext::splitContextString(ContextRemain);
471     ChildContext = ContextSplit.first;
472     ContextRemain = ContextSplit.second;
473     LineLocation NextCallSiteLoc(0, 0);
474     SampleContext::decodeContextString(ChildContext, CalleeName,
475                                        NextCallSiteLoc);
476 
477     // Create child node at parent line/disc location
478     if (AllowCreate) {
479       ContextNode =
480           ContextNode->getOrCreateChildContext(CallSiteLoc, CalleeName);
481     } else {
482       ContextNode = ContextNode->getChildContext(CallSiteLoc, CalleeName);
483     }
484     CallSiteLoc = NextCallSiteLoc;
485   }
486 
487   assert((!AllowCreate || ContextNode) &&
488          "Node must exist if creation is allowed");
489   return ContextNode;
490 }
491 
492 ContextTrieNode *SampleContextTracker::getTopLevelContextNode(StringRef FName) {
493   return RootContext.getChildContext(LineLocation(0, 0), FName);
494 }
495 
496 ContextTrieNode &SampleContextTracker::addTopLevelContextNode(StringRef FName) {
497   assert(!getTopLevelContextNode(FName) && "Node to add must not exist");
498   return *RootContext.getOrCreateChildContext(LineLocation(0, 0), FName);
499 }
500 
501 void SampleContextTracker::mergeContextNode(ContextTrieNode &FromNode,
502                                             ContextTrieNode &ToNode,
503                                             StringRef ContextStrToRemove) {
504   FunctionSamples *FromSamples = FromNode.getFunctionSamples();
505   FunctionSamples *ToSamples = ToNode.getFunctionSamples();
506   if (FromSamples && ToSamples) {
507     // Merge/duplicate FromSamples into ToSamples
508     ToSamples->merge(*FromSamples);
509     ToSamples->getContext().setState(SyntheticContext);
510     FromSamples->getContext().setState(MergedContext);
511   } else if (FromSamples) {
512     // Transfer FromSamples from FromNode to ToNode
513     ToNode.setFunctionSamples(FromSamples);
514     FromSamples->getContext().setState(SyntheticContext);
515     FromSamples->getContext().promoteOnPath(ContextStrToRemove);
516     FromNode.setFunctionSamples(nullptr);
517   }
518 }
519 
520 ContextTrieNode &SampleContextTracker::promoteMergeContextSamplesTree(
521     ContextTrieNode &FromNode, ContextTrieNode &ToNodeParent,
522     StringRef ContextStrToRemove) {
523   assert(!ContextStrToRemove.empty() && "Context to remove can't be empty");
524 
525   // Ignore call site location if destination is top level under root
526   LineLocation NewCallSiteLoc = LineLocation(0, 0);
527   LineLocation OldCallSiteLoc = FromNode.getCallSiteLoc();
528   ContextTrieNode &FromNodeParent = *FromNode.getParentContext();
529   ContextTrieNode *ToNode = nullptr;
530   bool MoveToRoot = (&ToNodeParent == &RootContext);
531   if (!MoveToRoot) {
532     NewCallSiteLoc = OldCallSiteLoc;
533   }
534 
535   // Locate destination node, create/move if not existing
536   ToNode = ToNodeParent.getChildContext(NewCallSiteLoc, FromNode.getFuncName());
537   if (!ToNode) {
538     // Do not delete node to move from its parent here because
539     // caller is iterating over children of that parent node.
540     ToNode = &ToNodeParent.moveToChildContext(
541         NewCallSiteLoc, std::move(FromNode), ContextStrToRemove, false);
542   } else {
543     // Destination node exists, merge samples for the context tree
544     mergeContextNode(FromNode, *ToNode, ContextStrToRemove);
545     LLVM_DEBUG(dbgs() << "  Context promoted and merged to: "
546                       << ToNode->getFunctionSamples()->getContext() << "\n");
547 
548     // Recursively promote and merge children
549     for (auto &It : FromNode.getAllChildContext()) {
550       ContextTrieNode &FromChildNode = It.second;
551       promoteMergeContextSamplesTree(FromChildNode, *ToNode,
552                                      ContextStrToRemove);
553     }
554 
555     // Remove children once they're all merged
556     FromNode.getAllChildContext().clear();
557   }
558 
559   // For root of subtree, remove itself from old parent too
560   if (MoveToRoot)
561     FromNodeParent.removeChildContext(OldCallSiteLoc, ToNode->getFuncName());
562 
563   return *ToNode;
564 }
565 
566 // Replace call graph edges with dynamic call edges from the profile.
567 void SampleContextTracker::addCallGraphEdges(CallGraph &CG,
568                                              StringMap<Function *> &SymbolMap) {
569   // Add profile call edges to the call graph.
570   std::queue<ContextTrieNode *> NodeQueue;
571   NodeQueue.push(&RootContext);
572   while (!NodeQueue.empty()) {
573     ContextTrieNode *Node = NodeQueue.front();
574     NodeQueue.pop();
575     Function *F = SymbolMap.lookup(Node->getFuncName());
576     for (auto &I : Node->getAllChildContext()) {
577       ContextTrieNode *ChildNode = &I.second;
578       NodeQueue.push(ChildNode);
579       if (F && !F->isDeclaration()) {
580         Function *Callee = SymbolMap.lookup(ChildNode->getFuncName());
581         if (Callee && !Callee->isDeclaration())
582           CG[F]->addCalledFunction(nullptr, CG[Callee]);
583       }
584     }
585   }
586 }
587 } // namespace llvm
588