1 //===- SampleContextTracker.cpp - Context-sensitive Profile Tracker -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the SampleContextTracker used by CSSPGO.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "llvm/Transforms/IPO/SampleContextTracker.h"
14 #include "llvm/ADT/StringMap.h"
15 #include "llvm/ADT/StringRef.h"
16 #include "llvm/IR/DebugInfoMetadata.h"
17 #include "llvm/IR/Instructions.h"
18 #include "llvm/ProfileData/SampleProf.h"
19 #include <map>
20 #include <queue>
21 #include <vector>
22 
23 using namespace llvm;
24 using namespace sampleprof;
25 
26 #define DEBUG_TYPE "sample-context-tracker"
27 
28 namespace llvm {
29 
30 ContextTrieNode *ContextTrieNode::getChildContext(const LineLocation &CallSite,
31                                                   StringRef CalleeName) {
32   if (CalleeName.empty())
33     return getHottestChildContext(CallSite);
34 
35   uint32_t Hash = nodeHash(CalleeName, CallSite);
36   auto It = AllChildContext.find(Hash);
37   if (It != AllChildContext.end())
38     return &It->second;
39   return nullptr;
40 }
41 
42 ContextTrieNode *
43 ContextTrieNode::getHottestChildContext(const LineLocation &CallSite) {
44   // CSFDO-TODO: This could be slow, change AllChildContext so we can
45   // do point look up for child node by call site alone.
46   // Retrieve the child node with max count for indirect call
47   ContextTrieNode *ChildNodeRet = nullptr;
48   uint64_t MaxCalleeSamples = 0;
49   for (auto &It : AllChildContext) {
50     ContextTrieNode &ChildNode = It.second;
51     if (ChildNode.CallSiteLoc != CallSite)
52       continue;
53     FunctionSamples *Samples = ChildNode.getFunctionSamples();
54     if (!Samples)
55       continue;
56     if (Samples->getTotalSamples() > MaxCalleeSamples) {
57       ChildNodeRet = &ChildNode;
58       MaxCalleeSamples = Samples->getTotalSamples();
59     }
60   }
61 
62   return ChildNodeRet;
63 }
64 
65 ContextTrieNode &ContextTrieNode::moveToChildContext(
66     const LineLocation &CallSite, ContextTrieNode &&NodeToMove,
67     StringRef ContextStrToRemove, bool DeleteNode) {
68   uint32_t Hash = nodeHash(NodeToMove.getFuncName(), CallSite);
69   assert(!AllChildContext.count(Hash) && "Node to remove must exist");
70   LineLocation OldCallSite = NodeToMove.CallSiteLoc;
71   ContextTrieNode &OldParentContext = *NodeToMove.getParentContext();
72   AllChildContext[Hash] = NodeToMove;
73   ContextTrieNode &NewNode = AllChildContext[Hash];
74   NewNode.CallSiteLoc = CallSite;
75 
76   // Walk through nodes in the moved the subtree, and update
77   // FunctionSamples' context as for the context promotion.
78   // We also need to set new parant link for all children.
79   std::queue<ContextTrieNode *> NodeToUpdate;
80   NewNode.setParentContext(this);
81   NodeToUpdate.push(&NewNode);
82 
83   while (!NodeToUpdate.empty()) {
84     ContextTrieNode *Node = NodeToUpdate.front();
85     NodeToUpdate.pop();
86     FunctionSamples *FSamples = Node->getFunctionSamples();
87 
88     if (FSamples) {
89       FSamples->getContext().promoteOnPath(ContextStrToRemove);
90       FSamples->getContext().setState(SyntheticContext);
91       LLVM_DEBUG(dbgs() << "  Context promoted to: " << FSamples->getContext()
92                         << "\n");
93     }
94 
95     for (auto &It : Node->getAllChildContext()) {
96       ContextTrieNode *ChildNode = &It.second;
97       ChildNode->setParentContext(Node);
98       NodeToUpdate.push(ChildNode);
99     }
100   }
101 
102   // Original context no longer needed, destroy if requested.
103   if (DeleteNode)
104     OldParentContext.removeChildContext(OldCallSite, NewNode.getFuncName());
105 
106   return NewNode;
107 }
108 
109 void ContextTrieNode::removeChildContext(const LineLocation &CallSite,
110                                          StringRef CalleeName) {
111   uint32_t Hash = nodeHash(CalleeName, CallSite);
112   // Note this essentially calls dtor and destroys that child context
113   AllChildContext.erase(Hash);
114 }
115 
116 std::map<uint32_t, ContextTrieNode> &ContextTrieNode::getAllChildContext() {
117   return AllChildContext;
118 }
119 
120 StringRef ContextTrieNode::getFuncName() const { return FuncName; }
121 
122 FunctionSamples *ContextTrieNode::getFunctionSamples() const {
123   return FuncSamples;
124 }
125 
126 void ContextTrieNode::setFunctionSamples(FunctionSamples *FSamples) {
127   FuncSamples = FSamples;
128 }
129 
130 uint32_t ContextTrieNode::getFunctionSize() const { return FuncSize; }
131 
132 void ContextTrieNode::setFunctionSize(uint32_t FSize) { FuncSize = FSize; }
133 
134 LineLocation ContextTrieNode::getCallSiteLoc() const { return CallSiteLoc; }
135 
136 ContextTrieNode *ContextTrieNode::getParentContext() const {
137   return ParentContext;
138 }
139 
140 void ContextTrieNode::setParentContext(ContextTrieNode *Parent) {
141   ParentContext = Parent;
142 }
143 
144 void ContextTrieNode::dumpNode() {
145   dbgs() << "Node: " << FuncName << "\n"
146          << "  Callsite: " << CallSiteLoc << "\n"
147          << "  Size: " << FuncSize << "\n"
148          << "  Children:\n";
149 
150   for (auto &It : AllChildContext) {
151     dbgs() << "    Node: " << It.second.getFuncName() << "\n";
152   }
153 }
154 
155 void ContextTrieNode::dumpTree() {
156   dbgs() << "Context Profile Tree:\n";
157   std::queue<ContextTrieNode *> NodeQueue;
158   NodeQueue.push(this);
159 
160   while (!NodeQueue.empty()) {
161     ContextTrieNode *Node = NodeQueue.front();
162     NodeQueue.pop();
163     Node->dumpNode();
164 
165     for (auto &It : Node->getAllChildContext()) {
166       ContextTrieNode *ChildNode = &It.second;
167       NodeQueue.push(ChildNode);
168     }
169   }
170 }
171 
172 uint32_t ContextTrieNode::nodeHash(StringRef ChildName,
173                                    const LineLocation &Callsite) {
174   // We still use child's name for child hash, this is
175   // because for children of root node, we don't have
176   // different line/discriminator, and we'll rely on name
177   // to differentiate children.
178   uint32_t NameHash = std::hash<std::string>{}(ChildName.str());
179   uint32_t LocId = (Callsite.LineOffset << 16) | Callsite.Discriminator;
180   return NameHash + (LocId << 5) + LocId;
181 }
182 
183 ContextTrieNode *ContextTrieNode::getOrCreateChildContext(
184     const LineLocation &CallSite, StringRef CalleeName, bool AllowCreate) {
185   uint32_t Hash = nodeHash(CalleeName, CallSite);
186   auto It = AllChildContext.find(Hash);
187   if (It != AllChildContext.end()) {
188     assert(It->second.getFuncName() == CalleeName &&
189            "Hash collision for child context node");
190     return &It->second;
191   }
192 
193   if (!AllowCreate)
194     return nullptr;
195 
196   AllChildContext[Hash] =
197       ContextTrieNode(this, CalleeName, nullptr, 0, CallSite);
198   return &AllChildContext[Hash];
199 }
200 
201 // Profiler tracker than manages profiles and its associated context
202 SampleContextTracker::SampleContextTracker(
203     StringMap<FunctionSamples> &Profiles) {
204   for (auto &FuncSample : Profiles) {
205     FunctionSamples *FSamples = &FuncSample.second;
206     SampleContext Context(FuncSample.first(), RawContext);
207     LLVM_DEBUG(dbgs() << "Tracking Context for function: " << Context << "\n");
208     if (!Context.isBaseContext())
209       FuncToCtxtProfiles[Context.getNameWithoutContext()].push_back(FSamples);
210     ContextTrieNode *NewNode = getOrCreateContextPath(Context, true);
211     assert(!NewNode->getFunctionSamples() &&
212            "New node can't have sample profile");
213     NewNode->setFunctionSamples(FSamples);
214   }
215 }
216 
217 FunctionSamples *
218 SampleContextTracker::getCalleeContextSamplesFor(const CallBase &Inst,
219                                                  StringRef CalleeName) {
220   LLVM_DEBUG(dbgs() << "Getting callee context for instr: " << Inst << "\n");
221   DILocation *DIL = Inst.getDebugLoc();
222   if (!DIL)
223     return nullptr;
224 
225   CalleeName = FunctionSamples::getCanonicalFnName(CalleeName);
226 
227   // For indirect call, CalleeName will be empty, in which case the context
228   // profile for callee with largest total samples will be returned.
229   ContextTrieNode *CalleeContext = getCalleeContextFor(DIL, CalleeName);
230   if (CalleeContext) {
231     FunctionSamples *FSamples = CalleeContext->getFunctionSamples();
232     LLVM_DEBUG(if (FSamples) {
233       dbgs() << "  Callee context found: " << FSamples->getContext() << "\n";
234     });
235     return FSamples;
236   }
237 
238   return nullptr;
239 }
240 
241 std::vector<const FunctionSamples *>
242 SampleContextTracker::getIndirectCalleeContextSamplesFor(
243     const DILocation *DIL) {
244   std::vector<const FunctionSamples *> R;
245   if (!DIL)
246     return R;
247 
248   ContextTrieNode *CallerNode = getContextFor(DIL);
249   LineLocation CallSite = FunctionSamples::getCallSiteIdentifier(DIL);
250   for (auto &It : CallerNode->getAllChildContext()) {
251     ContextTrieNode &ChildNode = It.second;
252     if (ChildNode.getCallSiteLoc() != CallSite)
253       continue;
254     if (FunctionSamples *CalleeSamples = ChildNode.getFunctionSamples())
255       R.push_back(CalleeSamples);
256   }
257 
258   return R;
259 }
260 
261 FunctionSamples *
262 SampleContextTracker::getContextSamplesFor(const DILocation *DIL) {
263   assert(DIL && "Expect non-null location");
264 
265   ContextTrieNode *ContextNode = getContextFor(DIL);
266   if (!ContextNode)
267     return nullptr;
268 
269   // We may have inlined callees during pre-LTO compilation, in which case
270   // we need to rely on the inline stack from !dbg to mark context profile
271   // as inlined, instead of `MarkContextSamplesInlined` during inlining.
272   // Sample profile loader walks through all instructions to get profile,
273   // which calls this function. So once that is done, all previously inlined
274   // context profile should be marked properly.
275   FunctionSamples *Samples = ContextNode->getFunctionSamples();
276   if (Samples && ContextNode->getParentContext() != &RootContext)
277     Samples->getContext().setState(InlinedContext);
278 
279   return Samples;
280 }
281 
282 FunctionSamples *
283 SampleContextTracker::getContextSamplesFor(const SampleContext &Context) {
284   ContextTrieNode *Node = getContextFor(Context);
285   if (!Node)
286     return nullptr;
287 
288   return Node->getFunctionSamples();
289 }
290 
291 SampleContextTracker::ContextSamplesTy &
292 SampleContextTracker::getAllContextSamplesFor(const Function &Func) {
293   StringRef CanonName = FunctionSamples::getCanonicalFnName(Func);
294   return FuncToCtxtProfiles[CanonName];
295 }
296 
297 SampleContextTracker::ContextSamplesTy &
298 SampleContextTracker::getAllContextSamplesFor(StringRef Name) {
299   return FuncToCtxtProfiles[Name];
300 }
301 
302 FunctionSamples *SampleContextTracker::getBaseSamplesFor(const Function &Func,
303                                                          bool MergeContext) {
304   StringRef CanonName = FunctionSamples::getCanonicalFnName(Func);
305   return getBaseSamplesFor(CanonName, MergeContext);
306 }
307 
308 FunctionSamples *SampleContextTracker::getBaseSamplesFor(StringRef Name,
309                                                          bool MergeContext) {
310   LLVM_DEBUG(dbgs() << "Getting base profile for function: " << Name << "\n");
311   // Base profile is top-level node (child of root node), so try to retrieve
312   // existing top-level node for given function first. If it exists, it could be
313   // that we've merged base profile before, or there's actually context-less
314   // profile from the input (e.g. due to unreliable stack walking).
315   ContextTrieNode *Node = getTopLevelContextNode(Name);
316   if (MergeContext) {
317     LLVM_DEBUG(dbgs() << "  Merging context profile into base profile: " << Name
318                       << "\n");
319 
320     // We have profile for function under different contexts,
321     // create synthetic base profile and merge context profiles
322     // into base profile.
323     for (auto *CSamples : FuncToCtxtProfiles[Name]) {
324       SampleContext &Context = CSamples->getContext();
325       ContextTrieNode *FromNode = getContextFor(Context);
326       if (FromNode == Node)
327         continue;
328 
329       // Skip inlined context profile and also don't re-merge any context
330       if (Context.hasState(InlinedContext) || Context.hasState(MergedContext))
331         continue;
332 
333       ContextTrieNode &ToNode = promoteMergeContextSamplesTree(*FromNode);
334       assert((!Node || Node == &ToNode) && "Expect only one base profile");
335       Node = &ToNode;
336     }
337   }
338 
339   // Still no profile even after merge/promotion (if allowed)
340   if (!Node)
341     return nullptr;
342 
343   return Node->getFunctionSamples();
344 }
345 
346 void SampleContextTracker::markContextSamplesInlined(
347     const FunctionSamples *InlinedSamples) {
348   assert(InlinedSamples && "Expect non-null inlined samples");
349   LLVM_DEBUG(dbgs() << "Marking context profile as inlined: "
350                     << InlinedSamples->getContext() << "\n");
351   InlinedSamples->getContext().setState(InlinedContext);
352 }
353 
354 ContextTrieNode &SampleContextTracker::getRootContext() { return RootContext; }
355 
356 void SampleContextTracker::promoteMergeContextSamplesTree(
357     const Instruction &Inst, StringRef CalleeName) {
358   LLVM_DEBUG(dbgs() << "Promoting and merging context tree for instr: \n"
359                     << Inst << "\n");
360   // Get the caller context for the call instruction, we don't use callee
361   // name from call because there can be context from indirect calls too.
362   DILocation *DIL = Inst.getDebugLoc();
363   ContextTrieNode *CallerNode = getContextFor(DIL);
364   if (!CallerNode)
365     return;
366 
367   // Get the context that needs to be promoted
368   LineLocation CallSite = FunctionSamples::getCallSiteIdentifier(DIL);
369   // For indirect call, CalleeName will be empty, in which case we need to
370   // promote all non-inlined child context profiles.
371   if (CalleeName.empty()) {
372     for (auto &It : CallerNode->getAllChildContext()) {
373       ContextTrieNode *NodeToPromo = &It.second;
374       if (CallSite != NodeToPromo->getCallSiteLoc())
375         continue;
376       FunctionSamples *FromSamples = NodeToPromo->getFunctionSamples();
377       if (FromSamples && FromSamples->getContext().hasState(InlinedContext))
378         continue;
379       promoteMergeContextSamplesTree(*NodeToPromo);
380     }
381     return;
382   }
383 
384   // Get the context for the given callee that needs to be promoted
385   ContextTrieNode *NodeToPromo =
386       CallerNode->getChildContext(CallSite, CalleeName);
387   if (!NodeToPromo)
388     return;
389 
390   promoteMergeContextSamplesTree(*NodeToPromo);
391 }
392 
393 ContextTrieNode &SampleContextTracker::promoteMergeContextSamplesTree(
394     ContextTrieNode &NodeToPromo) {
395   // Promote the input node to be directly under root. This can happen
396   // when we decided to not inline a function under context represented
397   // by the input node. The promote and merge is then needed to reflect
398   // the context profile in the base (context-less) profile.
399   FunctionSamples *FromSamples = NodeToPromo.getFunctionSamples();
400   assert(FromSamples && "Shouldn't promote a context without profile");
401   LLVM_DEBUG(dbgs() << "  Found context tree root to promote: "
402                     << FromSamples->getContext() << "\n");
403 
404   assert(!FromSamples->getContext().hasState(InlinedContext) &&
405          "Shouldn't promote inlined context profile");
406   StringRef ContextStrToRemove = FromSamples->getContext().getCallingContext();
407   return promoteMergeContextSamplesTree(NodeToPromo, RootContext,
408                                         ContextStrToRemove);
409 }
410 
411 void SampleContextTracker::dump() { RootContext.dumpTree(); }
412 
413 ContextTrieNode *
414 SampleContextTracker::getContextFor(const SampleContext &Context) {
415   return getOrCreateContextPath(Context, false);
416 }
417 
418 ContextTrieNode *
419 SampleContextTracker::getCalleeContextFor(const DILocation *DIL,
420                                           StringRef CalleeName) {
421   assert(DIL && "Expect non-null location");
422 
423   ContextTrieNode *CallContext = getContextFor(DIL);
424   if (!CallContext)
425     return nullptr;
426 
427   // When CalleeName is empty, the child context profile with max
428   // total samples will be returned.
429   return CallContext->getChildContext(
430       FunctionSamples::getCallSiteIdentifier(DIL), CalleeName);
431 }
432 
433 ContextTrieNode *SampleContextTracker::getContextFor(const DILocation *DIL) {
434   assert(DIL && "Expect non-null location");
435   SmallVector<std::pair<LineLocation, StringRef>, 10> S;
436 
437   // Use C++ linkage name if possible.
438   const DILocation *PrevDIL = DIL;
439   for (DIL = DIL->getInlinedAt(); DIL; DIL = DIL->getInlinedAt()) {
440     StringRef Name = PrevDIL->getScope()->getSubprogram()->getLinkageName();
441     if (Name.empty())
442       Name = PrevDIL->getScope()->getSubprogram()->getName();
443     S.push_back(
444         std::make_pair(FunctionSamples::getCallSiteIdentifier(DIL), Name));
445     PrevDIL = DIL;
446   }
447 
448   // Push root node, note that root node like main may only
449   // a name, but not linkage name.
450   StringRef RootName = PrevDIL->getScope()->getSubprogram()->getLinkageName();
451   if (RootName.empty())
452     RootName = PrevDIL->getScope()->getSubprogram()->getName();
453   S.push_back(std::make_pair(LineLocation(0, 0), RootName));
454 
455   ContextTrieNode *ContextNode = &RootContext;
456   int I = S.size();
457   while (--I >= 0 && ContextNode) {
458     LineLocation &CallSite = S[I].first;
459     StringRef &CalleeName = S[I].second;
460     ContextNode = ContextNode->getChildContext(CallSite, CalleeName);
461   }
462 
463   if (I < 0)
464     return ContextNode;
465 
466   return nullptr;
467 }
468 
469 ContextTrieNode *
470 SampleContextTracker::getOrCreateContextPath(const SampleContext &Context,
471                                              bool AllowCreate) {
472   ContextTrieNode *ContextNode = &RootContext;
473   StringRef ContextRemain = Context;
474   StringRef ChildContext;
475   StringRef CalleeName;
476   LineLocation CallSiteLoc(0, 0);
477 
478   while (ContextNode && !ContextRemain.empty()) {
479     auto ContextSplit = SampleContext::splitContextString(ContextRemain);
480     ChildContext = ContextSplit.first;
481     ContextRemain = ContextSplit.second;
482     LineLocation NextCallSiteLoc(0, 0);
483     SampleContext::decodeContextString(ChildContext, CalleeName,
484                                        NextCallSiteLoc);
485 
486     // Create child node at parent line/disc location
487     if (AllowCreate) {
488       ContextNode =
489           ContextNode->getOrCreateChildContext(CallSiteLoc, CalleeName);
490     } else {
491       ContextNode = ContextNode->getChildContext(CallSiteLoc, CalleeName);
492     }
493     CallSiteLoc = NextCallSiteLoc;
494   }
495 
496   assert((!AllowCreate || ContextNode) &&
497          "Node must exist if creation is allowed");
498   return ContextNode;
499 }
500 
501 ContextTrieNode *SampleContextTracker::getTopLevelContextNode(StringRef FName) {
502   assert(!FName.empty() && "Top level node query must provide valid name");
503   return RootContext.getChildContext(LineLocation(0, 0), FName);
504 }
505 
506 ContextTrieNode &SampleContextTracker::addTopLevelContextNode(StringRef FName) {
507   assert(!getTopLevelContextNode(FName) && "Node to add must not exist");
508   return *RootContext.getOrCreateChildContext(LineLocation(0, 0), FName);
509 }
510 
511 void SampleContextTracker::mergeContextNode(ContextTrieNode &FromNode,
512                                             ContextTrieNode &ToNode,
513                                             StringRef ContextStrToRemove) {
514   FunctionSamples *FromSamples = FromNode.getFunctionSamples();
515   FunctionSamples *ToSamples = ToNode.getFunctionSamples();
516   if (FromSamples && ToSamples) {
517     // Merge/duplicate FromSamples into ToSamples
518     ToSamples->merge(*FromSamples);
519     ToSamples->getContext().setState(SyntheticContext);
520     FromSamples->getContext().setState(MergedContext);
521   } else if (FromSamples) {
522     // Transfer FromSamples from FromNode to ToNode
523     ToNode.setFunctionSamples(FromSamples);
524     FromSamples->getContext().setState(SyntheticContext);
525     FromSamples->getContext().promoteOnPath(ContextStrToRemove);
526     FromNode.setFunctionSamples(nullptr);
527   }
528 }
529 
530 ContextTrieNode &SampleContextTracker::promoteMergeContextSamplesTree(
531     ContextTrieNode &FromNode, ContextTrieNode &ToNodeParent,
532     StringRef ContextStrToRemove) {
533   assert(!ContextStrToRemove.empty() && "Context to remove can't be empty");
534 
535   // Ignore call site location if destination is top level under root
536   LineLocation NewCallSiteLoc = LineLocation(0, 0);
537   LineLocation OldCallSiteLoc = FromNode.getCallSiteLoc();
538   ContextTrieNode &FromNodeParent = *FromNode.getParentContext();
539   ContextTrieNode *ToNode = nullptr;
540   bool MoveToRoot = (&ToNodeParent == &RootContext);
541   if (!MoveToRoot) {
542     NewCallSiteLoc = OldCallSiteLoc;
543   }
544 
545   // Locate destination node, create/move if not existing
546   ToNode = ToNodeParent.getChildContext(NewCallSiteLoc, FromNode.getFuncName());
547   if (!ToNode) {
548     // Do not delete node to move from its parent here because
549     // caller is iterating over children of that parent node.
550     ToNode = &ToNodeParent.moveToChildContext(
551         NewCallSiteLoc, std::move(FromNode), ContextStrToRemove, false);
552   } else {
553     // Destination node exists, merge samples for the context tree
554     mergeContextNode(FromNode, *ToNode, ContextStrToRemove);
555     LLVM_DEBUG({
556       if (ToNode->getFunctionSamples())
557         dbgs() << "  Context promoted and merged to: "
558                << ToNode->getFunctionSamples()->getContext() << "\n";
559     });
560 
561     // Recursively promote and merge children
562     for (auto &It : FromNode.getAllChildContext()) {
563       ContextTrieNode &FromChildNode = It.second;
564       promoteMergeContextSamplesTree(FromChildNode, *ToNode,
565                                      ContextStrToRemove);
566     }
567 
568     // Remove children once they're all merged
569     FromNode.getAllChildContext().clear();
570   }
571 
572   // For root of subtree, remove itself from old parent too
573   if (MoveToRoot)
574     FromNodeParent.removeChildContext(OldCallSiteLoc, ToNode->getFuncName());
575 
576   return *ToNode;
577 }
578 } // namespace llvm
579