1 //===- SampleContextTracker.cpp - Context-sensitive Profile Tracker -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the SampleContextTracker used by CSSPGO. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "llvm/Transforms/IPO/SampleContextTracker.h" 14 #include "llvm/ADT/StringMap.h" 15 #include "llvm/ADT/StringRef.h" 16 #include "llvm/IR/DebugInfoMetadata.h" 17 #include "llvm/IR/Instructions.h" 18 #include "llvm/ProfileData/SampleProf.h" 19 #include <map> 20 #include <queue> 21 #include <vector> 22 23 using namespace llvm; 24 using namespace sampleprof; 25 26 #define DEBUG_TYPE "sample-context-tracker" 27 28 namespace llvm { 29 30 ContextTrieNode *ContextTrieNode::getChildContext(const LineLocation &CallSite, 31 StringRef CalleeName) { 32 if (CalleeName.empty()) 33 return getHottestChildContext(CallSite); 34 35 uint32_t Hash = nodeHash(CalleeName, CallSite); 36 auto It = AllChildContext.find(Hash); 37 if (It != AllChildContext.end()) 38 return &It->second; 39 return nullptr; 40 } 41 42 ContextTrieNode * 43 ContextTrieNode::getHottestChildContext(const LineLocation &CallSite) { 44 // CSFDO-TODO: This could be slow, change AllChildContext so we can 45 // do point look up for child node by call site alone. 46 // Retrieve the child node with max count for indirect call 47 ContextTrieNode *ChildNodeRet = nullptr; 48 uint64_t MaxCalleeSamples = 0; 49 for (auto &It : AllChildContext) { 50 ContextTrieNode &ChildNode = It.second; 51 if (ChildNode.CallSiteLoc != CallSite) 52 continue; 53 FunctionSamples *Samples = ChildNode.getFunctionSamples(); 54 if (!Samples) 55 continue; 56 if (Samples->getTotalSamples() > MaxCalleeSamples) { 57 ChildNodeRet = &ChildNode; 58 MaxCalleeSamples = Samples->getTotalSamples(); 59 } 60 } 61 62 return ChildNodeRet; 63 } 64 65 ContextTrieNode &ContextTrieNode::moveToChildContext( 66 const LineLocation &CallSite, ContextTrieNode &&NodeToMove, 67 StringRef ContextStrToRemove, bool DeleteNode) { 68 uint32_t Hash = nodeHash(NodeToMove.getFuncName(), CallSite); 69 assert(!AllChildContext.count(Hash) && "Node to remove must exist"); 70 LineLocation OldCallSite = NodeToMove.CallSiteLoc; 71 ContextTrieNode &OldParentContext = *NodeToMove.getParentContext(); 72 AllChildContext[Hash] = NodeToMove; 73 ContextTrieNode &NewNode = AllChildContext[Hash]; 74 NewNode.CallSiteLoc = CallSite; 75 76 // Walk through nodes in the moved the subtree, and update 77 // FunctionSamples' context as for the context promotion. 78 // We also need to set new parant link for all children. 79 std::queue<ContextTrieNode *> NodeToUpdate; 80 NewNode.setParentContext(this); 81 NodeToUpdate.push(&NewNode); 82 83 while (!NodeToUpdate.empty()) { 84 ContextTrieNode *Node = NodeToUpdate.front(); 85 NodeToUpdate.pop(); 86 FunctionSamples *FSamples = Node->getFunctionSamples(); 87 88 if (FSamples) { 89 FSamples->getContext().promoteOnPath(ContextStrToRemove); 90 FSamples->getContext().setState(SyntheticContext); 91 LLVM_DEBUG(dbgs() << " Context promoted to: " << FSamples->getContext() 92 << "\n"); 93 } 94 95 for (auto &It : Node->getAllChildContext()) { 96 ContextTrieNode *ChildNode = &It.second; 97 ChildNode->setParentContext(Node); 98 NodeToUpdate.push(ChildNode); 99 } 100 } 101 102 // Original context no longer needed, destroy if requested. 103 if (DeleteNode) 104 OldParentContext.removeChildContext(OldCallSite, NewNode.getFuncName()); 105 106 return NewNode; 107 } 108 109 void ContextTrieNode::removeChildContext(const LineLocation &CallSite, 110 StringRef CalleeName) { 111 uint32_t Hash = nodeHash(CalleeName, CallSite); 112 // Note this essentially calls dtor and destroys that child context 113 AllChildContext.erase(Hash); 114 } 115 116 std::map<uint32_t, ContextTrieNode> &ContextTrieNode::getAllChildContext() { 117 return AllChildContext; 118 } 119 120 StringRef ContextTrieNode::getFuncName() const { return FuncName; } 121 122 FunctionSamples *ContextTrieNode::getFunctionSamples() const { 123 return FuncSamples; 124 } 125 126 void ContextTrieNode::setFunctionSamples(FunctionSamples *FSamples) { 127 FuncSamples = FSamples; 128 } 129 130 uint32_t ContextTrieNode::getFunctionSize() const { return FuncSize; } 131 132 void ContextTrieNode::setFunctionSize(uint32_t FSize) { FuncSize = FSize; } 133 134 LineLocation ContextTrieNode::getCallSiteLoc() const { return CallSiteLoc; } 135 136 ContextTrieNode *ContextTrieNode::getParentContext() const { 137 return ParentContext; 138 } 139 140 void ContextTrieNode::setParentContext(ContextTrieNode *Parent) { 141 ParentContext = Parent; 142 } 143 144 void ContextTrieNode::dumpNode() { 145 dbgs() << "Node: " << FuncName << "\n" 146 << " Callsite: " << CallSiteLoc << "\n" 147 << " Size: " << FuncSize << "\n" 148 << " Children:\n"; 149 150 for (auto &It : AllChildContext) { 151 dbgs() << " Node: " << It.second.getFuncName() << "\n"; 152 } 153 } 154 155 void ContextTrieNode::dumpTree() { 156 dbgs() << "Context Profile Tree:\n"; 157 std::queue<ContextTrieNode *> NodeQueue; 158 NodeQueue.push(this); 159 160 while (!NodeQueue.empty()) { 161 ContextTrieNode *Node = NodeQueue.front(); 162 NodeQueue.pop(); 163 Node->dumpNode(); 164 165 for (auto &It : Node->getAllChildContext()) { 166 ContextTrieNode *ChildNode = &It.second; 167 NodeQueue.push(ChildNode); 168 } 169 } 170 } 171 172 uint32_t ContextTrieNode::nodeHash(StringRef ChildName, 173 const LineLocation &Callsite) { 174 // We still use child's name for child hash, this is 175 // because for children of root node, we don't have 176 // different line/discriminator, and we'll rely on name 177 // to differentiate children. 178 uint32_t NameHash = std::hash<std::string>{}(ChildName.str()); 179 uint32_t LocId = (Callsite.LineOffset << 16) | Callsite.Discriminator; 180 return NameHash + (LocId << 5) + LocId; 181 } 182 183 ContextTrieNode *ContextTrieNode::getOrCreateChildContext( 184 const LineLocation &CallSite, StringRef CalleeName, bool AllowCreate) { 185 uint32_t Hash = nodeHash(CalleeName, CallSite); 186 auto It = AllChildContext.find(Hash); 187 if (It != AllChildContext.end()) { 188 assert(It->second.getFuncName() == CalleeName && 189 "Hash collision for child context node"); 190 return &It->second; 191 } 192 193 if (!AllowCreate) 194 return nullptr; 195 196 AllChildContext[Hash] = 197 ContextTrieNode(this, CalleeName, nullptr, 0, CallSite); 198 return &AllChildContext[Hash]; 199 } 200 201 // Profiler tracker than manages profiles and its associated context 202 SampleContextTracker::SampleContextTracker( 203 StringMap<FunctionSamples> &Profiles) { 204 for (auto &FuncSample : Profiles) { 205 FunctionSamples *FSamples = &FuncSample.second; 206 SampleContext Context(FuncSample.first(), RawContext); 207 LLVM_DEBUG(dbgs() << "Tracking Context for function: " << Context << "\n"); 208 if (!Context.isBaseContext()) 209 FuncToCtxtProfiles[Context.getNameWithoutContext()].push_back(FSamples); 210 ContextTrieNode *NewNode = getOrCreateContextPath(Context, true); 211 assert(!NewNode->getFunctionSamples() && 212 "New node can't have sample profile"); 213 NewNode->setFunctionSamples(FSamples); 214 } 215 } 216 217 FunctionSamples * 218 SampleContextTracker::getCalleeContextSamplesFor(const CallBase &Inst, 219 StringRef CalleeName) { 220 LLVM_DEBUG(dbgs() << "Getting callee context for instr: " << Inst << "\n"); 221 DILocation *DIL = Inst.getDebugLoc(); 222 if (!DIL) 223 return nullptr; 224 225 CalleeName = FunctionSamples::getCanonicalFnName(CalleeName); 226 227 // For indirect call, CalleeName will be empty, in which case the context 228 // profile for callee with largest total samples will be returned. 229 ContextTrieNode *CalleeContext = getCalleeContextFor(DIL, CalleeName); 230 if (CalleeContext) { 231 FunctionSamples *FSamples = CalleeContext->getFunctionSamples(); 232 LLVM_DEBUG(if (FSamples) { 233 dbgs() << " Callee context found: " << FSamples->getContext() << "\n"; 234 }); 235 return FSamples; 236 } 237 238 return nullptr; 239 } 240 241 std::vector<const FunctionSamples *> 242 SampleContextTracker::getIndirectCalleeContextSamplesFor( 243 const DILocation *DIL) { 244 std::vector<const FunctionSamples *> R; 245 if (!DIL) 246 return R; 247 248 ContextTrieNode *CallerNode = getContextFor(DIL); 249 LineLocation CallSite = FunctionSamples::getCallSiteIdentifier(DIL); 250 for (auto &It : CallerNode->getAllChildContext()) { 251 ContextTrieNode &ChildNode = It.second; 252 if (ChildNode.getCallSiteLoc() != CallSite) 253 continue; 254 if (FunctionSamples *CalleeSamples = ChildNode.getFunctionSamples()) 255 R.push_back(CalleeSamples); 256 } 257 258 return R; 259 } 260 261 FunctionSamples * 262 SampleContextTracker::getContextSamplesFor(const DILocation *DIL) { 263 assert(DIL && "Expect non-null location"); 264 265 ContextTrieNode *ContextNode = getContextFor(DIL); 266 if (!ContextNode) 267 return nullptr; 268 269 // We may have inlined callees during pre-LTO compilation, in which case 270 // we need to rely on the inline stack from !dbg to mark context profile 271 // as inlined, instead of `MarkContextSamplesInlined` during inlining. 272 // Sample profile loader walks through all instructions to get profile, 273 // which calls this function. So once that is done, all previously inlined 274 // context profile should be marked properly. 275 FunctionSamples *Samples = ContextNode->getFunctionSamples(); 276 if (Samples && ContextNode->getParentContext() != &RootContext) 277 Samples->getContext().setState(InlinedContext); 278 279 return Samples; 280 } 281 282 FunctionSamples * 283 SampleContextTracker::getContextSamplesFor(const SampleContext &Context) { 284 ContextTrieNode *Node = getContextFor(Context); 285 if (!Node) 286 return nullptr; 287 288 return Node->getFunctionSamples(); 289 } 290 291 SampleContextTracker::ContextSamplesTy & 292 SampleContextTracker::getAllContextSamplesFor(const Function &Func) { 293 StringRef CanonName = FunctionSamples::getCanonicalFnName(Func); 294 return FuncToCtxtProfiles[CanonName]; 295 } 296 297 SampleContextTracker::ContextSamplesTy & 298 SampleContextTracker::getAllContextSamplesFor(StringRef Name) { 299 return FuncToCtxtProfiles[Name]; 300 } 301 302 FunctionSamples *SampleContextTracker::getBaseSamplesFor(const Function &Func, 303 bool MergeContext) { 304 StringRef CanonName = FunctionSamples::getCanonicalFnName(Func); 305 return getBaseSamplesFor(CanonName, MergeContext); 306 } 307 308 FunctionSamples *SampleContextTracker::getBaseSamplesFor(StringRef Name, 309 bool MergeContext) { 310 LLVM_DEBUG(dbgs() << "Getting base profile for function: " << Name << "\n"); 311 // Base profile is top-level node (child of root node), so try to retrieve 312 // existing top-level node for given function first. If it exists, it could be 313 // that we've merged base profile before, or there's actually context-less 314 // profile from the input (e.g. due to unreliable stack walking). 315 ContextTrieNode *Node = getTopLevelContextNode(Name); 316 if (MergeContext) { 317 LLVM_DEBUG(dbgs() << " Merging context profile into base profile: " << Name 318 << "\n"); 319 320 // We have profile for function under different contexts, 321 // create synthetic base profile and merge context profiles 322 // into base profile. 323 for (auto *CSamples : FuncToCtxtProfiles[Name]) { 324 SampleContext &Context = CSamples->getContext(); 325 ContextTrieNode *FromNode = getContextFor(Context); 326 if (FromNode == Node) 327 continue; 328 329 // Skip inlined context profile and also don't re-merge any context 330 if (Context.hasState(InlinedContext) || Context.hasState(MergedContext)) 331 continue; 332 333 ContextTrieNode &ToNode = promoteMergeContextSamplesTree(*FromNode); 334 assert((!Node || Node == &ToNode) && "Expect only one base profile"); 335 Node = &ToNode; 336 } 337 } 338 339 // Still no profile even after merge/promotion (if allowed) 340 if (!Node) 341 return nullptr; 342 343 return Node->getFunctionSamples(); 344 } 345 346 void SampleContextTracker::markContextSamplesInlined( 347 const FunctionSamples *InlinedSamples) { 348 assert(InlinedSamples && "Expect non-null inlined samples"); 349 LLVM_DEBUG(dbgs() << "Marking context profile as inlined: " 350 << InlinedSamples->getContext() << "\n"); 351 InlinedSamples->getContext().setState(InlinedContext); 352 } 353 354 ContextTrieNode &SampleContextTracker::getRootContext() { return RootContext; } 355 356 void SampleContextTracker::promoteMergeContextSamplesTree( 357 const Instruction &Inst, StringRef CalleeName) { 358 LLVM_DEBUG(dbgs() << "Promoting and merging context tree for instr: \n" 359 << Inst << "\n"); 360 // Get the caller context for the call instruction, we don't use callee 361 // name from call because there can be context from indirect calls too. 362 DILocation *DIL = Inst.getDebugLoc(); 363 ContextTrieNode *CallerNode = getContextFor(DIL); 364 if (!CallerNode) 365 return; 366 367 // Get the context that needs to be promoted 368 LineLocation CallSite = FunctionSamples::getCallSiteIdentifier(DIL); 369 // For indirect call, CalleeName will be empty, in which case we need to 370 // promote all non-inlined child context profiles. 371 if (CalleeName.empty()) { 372 for (auto &It : CallerNode->getAllChildContext()) { 373 ContextTrieNode *NodeToPromo = &It.second; 374 if (CallSite != NodeToPromo->getCallSiteLoc()) 375 continue; 376 FunctionSamples *FromSamples = NodeToPromo->getFunctionSamples(); 377 if (FromSamples && FromSamples->getContext().hasState(InlinedContext)) 378 continue; 379 promoteMergeContextSamplesTree(*NodeToPromo); 380 } 381 return; 382 } 383 384 // Get the context for the given callee that needs to be promoted 385 ContextTrieNode *NodeToPromo = 386 CallerNode->getChildContext(CallSite, CalleeName); 387 if (!NodeToPromo) 388 return; 389 390 promoteMergeContextSamplesTree(*NodeToPromo); 391 } 392 393 ContextTrieNode &SampleContextTracker::promoteMergeContextSamplesTree( 394 ContextTrieNode &NodeToPromo) { 395 // Promote the input node to be directly under root. This can happen 396 // when we decided to not inline a function under context represented 397 // by the input node. The promote and merge is then needed to reflect 398 // the context profile in the base (context-less) profile. 399 FunctionSamples *FromSamples = NodeToPromo.getFunctionSamples(); 400 assert(FromSamples && "Shouldn't promote a context without profile"); 401 LLVM_DEBUG(dbgs() << " Found context tree root to promote: " 402 << FromSamples->getContext() << "\n"); 403 404 assert(!FromSamples->getContext().hasState(InlinedContext) && 405 "Shouldn't promote inlined context profile"); 406 StringRef ContextStrToRemove = FromSamples->getContext().getCallingContext(); 407 return promoteMergeContextSamplesTree(NodeToPromo, RootContext, 408 ContextStrToRemove); 409 } 410 411 void SampleContextTracker::dump() { RootContext.dumpTree(); } 412 413 ContextTrieNode * 414 SampleContextTracker::getContextFor(const SampleContext &Context) { 415 return getOrCreateContextPath(Context, false); 416 } 417 418 ContextTrieNode * 419 SampleContextTracker::getCalleeContextFor(const DILocation *DIL, 420 StringRef CalleeName) { 421 assert(DIL && "Expect non-null location"); 422 423 ContextTrieNode *CallContext = getContextFor(DIL); 424 if (!CallContext) 425 return nullptr; 426 427 // When CalleeName is empty, the child context profile with max 428 // total samples will be returned. 429 return CallContext->getChildContext( 430 FunctionSamples::getCallSiteIdentifier(DIL), CalleeName); 431 } 432 433 ContextTrieNode *SampleContextTracker::getContextFor(const DILocation *DIL) { 434 assert(DIL && "Expect non-null location"); 435 SmallVector<std::pair<LineLocation, StringRef>, 10> S; 436 437 // Use C++ linkage name if possible. 438 const DILocation *PrevDIL = DIL; 439 for (DIL = DIL->getInlinedAt(); DIL; DIL = DIL->getInlinedAt()) { 440 StringRef Name = PrevDIL->getScope()->getSubprogram()->getLinkageName(); 441 if (Name.empty()) 442 Name = PrevDIL->getScope()->getSubprogram()->getName(); 443 S.push_back( 444 std::make_pair(FunctionSamples::getCallSiteIdentifier(DIL), Name)); 445 PrevDIL = DIL; 446 } 447 448 // Push root node, note that root node like main may only 449 // a name, but not linkage name. 450 StringRef RootName = PrevDIL->getScope()->getSubprogram()->getLinkageName(); 451 if (RootName.empty()) 452 RootName = PrevDIL->getScope()->getSubprogram()->getName(); 453 S.push_back(std::make_pair(LineLocation(0, 0), RootName)); 454 455 ContextTrieNode *ContextNode = &RootContext; 456 int I = S.size(); 457 while (--I >= 0 && ContextNode) { 458 LineLocation &CallSite = S[I].first; 459 StringRef &CalleeName = S[I].second; 460 ContextNode = ContextNode->getChildContext(CallSite, CalleeName); 461 } 462 463 if (I < 0) 464 return ContextNode; 465 466 return nullptr; 467 } 468 469 ContextTrieNode * 470 SampleContextTracker::getOrCreateContextPath(const SampleContext &Context, 471 bool AllowCreate) { 472 ContextTrieNode *ContextNode = &RootContext; 473 StringRef ContextRemain = Context; 474 StringRef ChildContext; 475 StringRef CalleeName; 476 LineLocation CallSiteLoc(0, 0); 477 478 while (ContextNode && !ContextRemain.empty()) { 479 auto ContextSplit = SampleContext::splitContextString(ContextRemain); 480 ChildContext = ContextSplit.first; 481 ContextRemain = ContextSplit.second; 482 LineLocation NextCallSiteLoc(0, 0); 483 SampleContext::decodeContextString(ChildContext, CalleeName, 484 NextCallSiteLoc); 485 486 // Create child node at parent line/disc location 487 if (AllowCreate) { 488 ContextNode = 489 ContextNode->getOrCreateChildContext(CallSiteLoc, CalleeName); 490 } else { 491 ContextNode = ContextNode->getChildContext(CallSiteLoc, CalleeName); 492 } 493 CallSiteLoc = NextCallSiteLoc; 494 } 495 496 assert((!AllowCreate || ContextNode) && 497 "Node must exist if creation is allowed"); 498 return ContextNode; 499 } 500 501 ContextTrieNode *SampleContextTracker::getTopLevelContextNode(StringRef FName) { 502 assert(!FName.empty() && "Top level node query must provide valid name"); 503 return RootContext.getChildContext(LineLocation(0, 0), FName); 504 } 505 506 ContextTrieNode &SampleContextTracker::addTopLevelContextNode(StringRef FName) { 507 assert(!getTopLevelContextNode(FName) && "Node to add must not exist"); 508 return *RootContext.getOrCreateChildContext(LineLocation(0, 0), FName); 509 } 510 511 void SampleContextTracker::mergeContextNode(ContextTrieNode &FromNode, 512 ContextTrieNode &ToNode, 513 StringRef ContextStrToRemove) { 514 FunctionSamples *FromSamples = FromNode.getFunctionSamples(); 515 FunctionSamples *ToSamples = ToNode.getFunctionSamples(); 516 if (FromSamples && ToSamples) { 517 // Merge/duplicate FromSamples into ToSamples 518 ToSamples->merge(*FromSamples); 519 ToSamples->getContext().setState(SyntheticContext); 520 FromSamples->getContext().setState(MergedContext); 521 } else if (FromSamples) { 522 // Transfer FromSamples from FromNode to ToNode 523 ToNode.setFunctionSamples(FromSamples); 524 FromSamples->getContext().setState(SyntheticContext); 525 FromSamples->getContext().promoteOnPath(ContextStrToRemove); 526 FromNode.setFunctionSamples(nullptr); 527 } 528 } 529 530 ContextTrieNode &SampleContextTracker::promoteMergeContextSamplesTree( 531 ContextTrieNode &FromNode, ContextTrieNode &ToNodeParent, 532 StringRef ContextStrToRemove) { 533 assert(!ContextStrToRemove.empty() && "Context to remove can't be empty"); 534 535 // Ignore call site location if destination is top level under root 536 LineLocation NewCallSiteLoc = LineLocation(0, 0); 537 LineLocation OldCallSiteLoc = FromNode.getCallSiteLoc(); 538 ContextTrieNode &FromNodeParent = *FromNode.getParentContext(); 539 ContextTrieNode *ToNode = nullptr; 540 bool MoveToRoot = (&ToNodeParent == &RootContext); 541 if (!MoveToRoot) { 542 NewCallSiteLoc = OldCallSiteLoc; 543 } 544 545 // Locate destination node, create/move if not existing 546 ToNode = ToNodeParent.getChildContext(NewCallSiteLoc, FromNode.getFuncName()); 547 if (!ToNode) { 548 // Do not delete node to move from its parent here because 549 // caller is iterating over children of that parent node. 550 ToNode = &ToNodeParent.moveToChildContext( 551 NewCallSiteLoc, std::move(FromNode), ContextStrToRemove, false); 552 } else { 553 // Destination node exists, merge samples for the context tree 554 mergeContextNode(FromNode, *ToNode, ContextStrToRemove); 555 LLVM_DEBUG({ 556 if (ToNode->getFunctionSamples()) 557 dbgs() << " Context promoted and merged to: " 558 << ToNode->getFunctionSamples()->getContext() << "\n"; 559 }); 560 561 // Recursively promote and merge children 562 for (auto &It : FromNode.getAllChildContext()) { 563 ContextTrieNode &FromChildNode = It.second; 564 promoteMergeContextSamplesTree(FromChildNode, *ToNode, 565 ContextStrToRemove); 566 } 567 568 // Remove children once they're all merged 569 FromNode.getAllChildContext().clear(); 570 } 571 572 // For root of subtree, remove itself from old parent too 573 if (MoveToRoot) 574 FromNodeParent.removeChildContext(OldCallSiteLoc, ToNode->getFuncName()); 575 576 return *ToNode; 577 } 578 } // namespace llvm 579