1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Instrumentation-based profile-guided optimization 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "CodeGenPGO.h" 14 #include "CodeGenFunction.h" 15 #include "CoverageMappingGen.h" 16 #include "clang/AST/RecursiveASTVisitor.h" 17 #include "clang/AST/StmtVisitor.h" 18 #include "llvm/IR/Intrinsics.h" 19 #include "llvm/IR/MDBuilder.h" 20 #include "llvm/Support/CommandLine.h" 21 #include "llvm/Support/Endian.h" 22 #include "llvm/Support/FileSystem.h" 23 #include "llvm/Support/MD5.h" 24 25 static llvm::cl::opt<bool> 26 EnableValueProfiling("enable-value-profiling", llvm::cl::ZeroOrMore, 27 llvm::cl::desc("Enable value profiling"), 28 llvm::cl::Hidden, llvm::cl::init(false)); 29 30 using namespace clang; 31 using namespace CodeGen; 32 33 void CodeGenPGO::setFuncName(StringRef Name, 34 llvm::GlobalValue::LinkageTypes Linkage) { 35 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 36 FuncName = llvm::getPGOFuncName( 37 Name, Linkage, CGM.getCodeGenOpts().MainFileName, 38 PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version); 39 40 // If we're generating a profile, create a variable for the name. 41 if (CGM.getCodeGenOpts().hasProfileClangInstr()) 42 FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName); 43 } 44 45 void CodeGenPGO::setFuncName(llvm::Function *Fn) { 46 setFuncName(Fn->getName(), Fn->getLinkage()); 47 // Create PGOFuncName meta data. 48 llvm::createPGOFuncNameMetadata(*Fn, FuncName); 49 } 50 51 /// The version of the PGO hash algorithm. 52 enum PGOHashVersion : unsigned { 53 PGO_HASH_V1, 54 PGO_HASH_V2, 55 PGO_HASH_V3, 56 57 // Keep this set to the latest hash version. 58 PGO_HASH_LATEST = PGO_HASH_V3 59 }; 60 61 namespace { 62 /// Stable hasher for PGO region counters. 63 /// 64 /// PGOHash produces a stable hash of a given function's control flow. 65 /// 66 /// Changing the output of this hash will invalidate all previously generated 67 /// profiles -- i.e., don't do it. 68 /// 69 /// \note When this hash does eventually change (years?), we still need to 70 /// support old hashes. We'll need to pull in the version number from the 71 /// profile data format and use the matching hash function. 72 class PGOHash { 73 uint64_t Working; 74 unsigned Count; 75 PGOHashVersion HashVersion; 76 llvm::MD5 MD5; 77 78 static const int NumBitsPerType = 6; 79 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType; 80 static const unsigned TooBig = 1u << NumBitsPerType; 81 82 public: 83 /// Hash values for AST nodes. 84 /// 85 /// Distinct values for AST nodes that have region counters attached. 86 /// 87 /// These values must be stable. All new members must be added at the end, 88 /// and no members should be removed. Changing the enumeration value for an 89 /// AST node will affect the hash of every function that contains that node. 90 enum HashType : unsigned char { 91 None = 0, 92 LabelStmt = 1, 93 WhileStmt, 94 DoStmt, 95 ForStmt, 96 CXXForRangeStmt, 97 ObjCForCollectionStmt, 98 SwitchStmt, 99 CaseStmt, 100 DefaultStmt, 101 IfStmt, 102 CXXTryStmt, 103 CXXCatchStmt, 104 ConditionalOperator, 105 BinaryOperatorLAnd, 106 BinaryOperatorLOr, 107 BinaryConditionalOperator, 108 // The preceding values are available with PGO_HASH_V1. 109 110 EndOfScope, 111 IfThenBranch, 112 IfElseBranch, 113 GotoStmt, 114 IndirectGotoStmt, 115 BreakStmt, 116 ContinueStmt, 117 ReturnStmt, 118 ThrowExpr, 119 UnaryOperatorLNot, 120 BinaryOperatorLT, 121 BinaryOperatorGT, 122 BinaryOperatorLE, 123 BinaryOperatorGE, 124 BinaryOperatorEQ, 125 BinaryOperatorNE, 126 // The preceding values are available since PGO_HASH_V2. 127 128 // Keep this last. It's for the static assert that follows. 129 LastHashType 130 }; 131 static_assert(LastHashType <= TooBig, "Too many types in HashType"); 132 133 PGOHash(PGOHashVersion HashVersion) 134 : Working(0), Count(0), HashVersion(HashVersion), MD5() {} 135 void combine(HashType Type); 136 uint64_t finalize(); 137 PGOHashVersion getHashVersion() const { return HashVersion; } 138 }; 139 const int PGOHash::NumBitsPerType; 140 const unsigned PGOHash::NumTypesPerWord; 141 const unsigned PGOHash::TooBig; 142 143 /// Get the PGO hash version used in the given indexed profile. 144 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader, 145 CodeGenModule &CGM) { 146 if (PGOReader->getVersion() <= 4) 147 return PGO_HASH_V1; 148 if (PGOReader->getVersion() <= 5) 149 return PGO_HASH_V2; 150 return PGO_HASH_V3; 151 } 152 153 /// A RecursiveASTVisitor that fills a map of statements to PGO counters. 154 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> { 155 using Base = RecursiveASTVisitor<MapRegionCounters>; 156 157 /// The next counter value to assign. 158 unsigned NextCounter; 159 /// The function hash. 160 PGOHash Hash; 161 /// The map of statements to counters. 162 llvm::DenseMap<const Stmt *, unsigned> &CounterMap; 163 /// The profile version. 164 uint64_t ProfileVersion; 165 166 MapRegionCounters(PGOHashVersion HashVersion, uint64_t ProfileVersion, 167 llvm::DenseMap<const Stmt *, unsigned> &CounterMap) 168 : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap), 169 ProfileVersion(ProfileVersion) {} 170 171 // Blocks and lambdas are handled as separate functions, so we need not 172 // traverse them in the parent context. 173 bool TraverseBlockExpr(BlockExpr *BE) { return true; } 174 bool TraverseLambdaExpr(LambdaExpr *LE) { 175 // Traverse the captures, but not the body. 176 for (auto C : zip(LE->captures(), LE->capture_inits())) 177 TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C)); 178 return true; 179 } 180 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; } 181 182 bool VisitDecl(const Decl *D) { 183 switch (D->getKind()) { 184 default: 185 break; 186 case Decl::Function: 187 case Decl::CXXMethod: 188 case Decl::CXXConstructor: 189 case Decl::CXXDestructor: 190 case Decl::CXXConversion: 191 case Decl::ObjCMethod: 192 case Decl::Block: 193 case Decl::Captured: 194 CounterMap[D->getBody()] = NextCounter++; 195 break; 196 } 197 return true; 198 } 199 200 /// If \p S gets a fresh counter, update the counter mappings. Return the 201 /// V1 hash of \p S. 202 PGOHash::HashType updateCounterMappings(Stmt *S) { 203 auto Type = getHashType(PGO_HASH_V1, S); 204 if (Type != PGOHash::None) 205 CounterMap[S] = NextCounter++; 206 return Type; 207 } 208 209 /// The RHS of all logical operators gets a fresh counter in order to count 210 /// how many times the RHS evaluates to true or false, depending on the 211 /// semantics of the operator. This is only valid for ">= v7" of the profile 212 /// version so that we facilitate backward compatibility. 213 bool VisitBinaryOperator(BinaryOperator *S) { 214 if (ProfileVersion >= llvm::IndexedInstrProf::Version7) 215 if (S->isLogicalOp() && 216 CodeGenFunction::isInstrumentedCondition(S->getRHS())) 217 CounterMap[S->getRHS()] = NextCounter++; 218 return Base::VisitBinaryOperator(S); 219 } 220 221 /// Include \p S in the function hash. 222 bool VisitStmt(Stmt *S) { 223 auto Type = updateCounterMappings(S); 224 if (Hash.getHashVersion() != PGO_HASH_V1) 225 Type = getHashType(Hash.getHashVersion(), S); 226 if (Type != PGOHash::None) 227 Hash.combine(Type); 228 return true; 229 } 230 231 bool TraverseIfStmt(IfStmt *If) { 232 // If we used the V1 hash, use the default traversal. 233 if (Hash.getHashVersion() == PGO_HASH_V1) 234 return Base::TraverseIfStmt(If); 235 236 // Otherwise, keep track of which branch we're in while traversing. 237 VisitStmt(If); 238 for (Stmt *CS : If->children()) { 239 if (!CS) 240 continue; 241 if (CS == If->getThen()) 242 Hash.combine(PGOHash::IfThenBranch); 243 else if (CS == If->getElse()) 244 Hash.combine(PGOHash::IfElseBranch); 245 TraverseStmt(CS); 246 } 247 Hash.combine(PGOHash::EndOfScope); 248 return true; 249 } 250 251 // If the statement type \p N is nestable, and its nesting impacts profile 252 // stability, define a custom traversal which tracks the end of the statement 253 // in the hash (provided we're not using the V1 hash). 254 #define DEFINE_NESTABLE_TRAVERSAL(N) \ 255 bool Traverse##N(N *S) { \ 256 Base::Traverse##N(S); \ 257 if (Hash.getHashVersion() != PGO_HASH_V1) \ 258 Hash.combine(PGOHash::EndOfScope); \ 259 return true; \ 260 } 261 262 DEFINE_NESTABLE_TRAVERSAL(WhileStmt) 263 DEFINE_NESTABLE_TRAVERSAL(DoStmt) 264 DEFINE_NESTABLE_TRAVERSAL(ForStmt) 265 DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt) 266 DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt) 267 DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt) 268 DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt) 269 270 /// Get version \p HashVersion of the PGO hash for \p S. 271 PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) { 272 switch (S->getStmtClass()) { 273 default: 274 break; 275 case Stmt::LabelStmtClass: 276 return PGOHash::LabelStmt; 277 case Stmt::WhileStmtClass: 278 return PGOHash::WhileStmt; 279 case Stmt::DoStmtClass: 280 return PGOHash::DoStmt; 281 case Stmt::ForStmtClass: 282 return PGOHash::ForStmt; 283 case Stmt::CXXForRangeStmtClass: 284 return PGOHash::CXXForRangeStmt; 285 case Stmt::ObjCForCollectionStmtClass: 286 return PGOHash::ObjCForCollectionStmt; 287 case Stmt::SwitchStmtClass: 288 return PGOHash::SwitchStmt; 289 case Stmt::CaseStmtClass: 290 return PGOHash::CaseStmt; 291 case Stmt::DefaultStmtClass: 292 return PGOHash::DefaultStmt; 293 case Stmt::IfStmtClass: 294 return PGOHash::IfStmt; 295 case Stmt::CXXTryStmtClass: 296 return PGOHash::CXXTryStmt; 297 case Stmt::CXXCatchStmtClass: 298 return PGOHash::CXXCatchStmt; 299 case Stmt::ConditionalOperatorClass: 300 return PGOHash::ConditionalOperator; 301 case Stmt::BinaryConditionalOperatorClass: 302 return PGOHash::BinaryConditionalOperator; 303 case Stmt::BinaryOperatorClass: { 304 const BinaryOperator *BO = cast<BinaryOperator>(S); 305 if (BO->getOpcode() == BO_LAnd) 306 return PGOHash::BinaryOperatorLAnd; 307 if (BO->getOpcode() == BO_LOr) 308 return PGOHash::BinaryOperatorLOr; 309 if (HashVersion >= PGO_HASH_V2) { 310 switch (BO->getOpcode()) { 311 default: 312 break; 313 case BO_LT: 314 return PGOHash::BinaryOperatorLT; 315 case BO_GT: 316 return PGOHash::BinaryOperatorGT; 317 case BO_LE: 318 return PGOHash::BinaryOperatorLE; 319 case BO_GE: 320 return PGOHash::BinaryOperatorGE; 321 case BO_EQ: 322 return PGOHash::BinaryOperatorEQ; 323 case BO_NE: 324 return PGOHash::BinaryOperatorNE; 325 } 326 } 327 break; 328 } 329 } 330 331 if (HashVersion >= PGO_HASH_V2) { 332 switch (S->getStmtClass()) { 333 default: 334 break; 335 case Stmt::GotoStmtClass: 336 return PGOHash::GotoStmt; 337 case Stmt::IndirectGotoStmtClass: 338 return PGOHash::IndirectGotoStmt; 339 case Stmt::BreakStmtClass: 340 return PGOHash::BreakStmt; 341 case Stmt::ContinueStmtClass: 342 return PGOHash::ContinueStmt; 343 case Stmt::ReturnStmtClass: 344 return PGOHash::ReturnStmt; 345 case Stmt::CXXThrowExprClass: 346 return PGOHash::ThrowExpr; 347 case Stmt::UnaryOperatorClass: { 348 const UnaryOperator *UO = cast<UnaryOperator>(S); 349 if (UO->getOpcode() == UO_LNot) 350 return PGOHash::UnaryOperatorLNot; 351 break; 352 } 353 } 354 } 355 356 return PGOHash::None; 357 } 358 }; 359 360 /// A StmtVisitor that propagates the raw counts through the AST and 361 /// records the count at statements where the value may change. 362 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { 363 /// PGO state. 364 CodeGenPGO &PGO; 365 366 /// A flag that is set when the current count should be recorded on the 367 /// next statement, such as at the exit of a loop. 368 bool RecordNextStmtCount; 369 370 /// The count at the current location in the traversal. 371 uint64_t CurrentCount; 372 373 /// The map of statements to count values. 374 llvm::DenseMap<const Stmt *, uint64_t> &CountMap; 375 376 /// BreakContinueStack - Keep counts of breaks and continues inside loops. 377 struct BreakContinue { 378 uint64_t BreakCount; 379 uint64_t ContinueCount; 380 BreakContinue() : BreakCount(0), ContinueCount(0) {} 381 }; 382 SmallVector<BreakContinue, 8> BreakContinueStack; 383 384 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap, 385 CodeGenPGO &PGO) 386 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {} 387 388 void RecordStmtCount(const Stmt *S) { 389 if (RecordNextStmtCount) { 390 CountMap[S] = CurrentCount; 391 RecordNextStmtCount = false; 392 } 393 } 394 395 /// Set and return the current count. 396 uint64_t setCount(uint64_t Count) { 397 CurrentCount = Count; 398 return Count; 399 } 400 401 void VisitStmt(const Stmt *S) { 402 RecordStmtCount(S); 403 for (const Stmt *Child : S->children()) 404 if (Child) 405 this->Visit(Child); 406 } 407 408 void VisitFunctionDecl(const FunctionDecl *D) { 409 // Counter tracks entry to the function body. 410 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 411 CountMap[D->getBody()] = BodyCount; 412 Visit(D->getBody()); 413 } 414 415 // Skip lambda expressions. We visit these as FunctionDecls when we're 416 // generating them and aren't interested in the body when generating a 417 // parent context. 418 void VisitLambdaExpr(const LambdaExpr *LE) {} 419 420 void VisitCapturedDecl(const CapturedDecl *D) { 421 // Counter tracks entry to the capture body. 422 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 423 CountMap[D->getBody()] = BodyCount; 424 Visit(D->getBody()); 425 } 426 427 void VisitObjCMethodDecl(const ObjCMethodDecl *D) { 428 // Counter tracks entry to the method body. 429 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 430 CountMap[D->getBody()] = BodyCount; 431 Visit(D->getBody()); 432 } 433 434 void VisitBlockDecl(const BlockDecl *D) { 435 // Counter tracks entry to the block body. 436 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 437 CountMap[D->getBody()] = BodyCount; 438 Visit(D->getBody()); 439 } 440 441 void VisitReturnStmt(const ReturnStmt *S) { 442 RecordStmtCount(S); 443 if (S->getRetValue()) 444 Visit(S->getRetValue()); 445 CurrentCount = 0; 446 RecordNextStmtCount = true; 447 } 448 449 void VisitCXXThrowExpr(const CXXThrowExpr *E) { 450 RecordStmtCount(E); 451 if (E->getSubExpr()) 452 Visit(E->getSubExpr()); 453 CurrentCount = 0; 454 RecordNextStmtCount = true; 455 } 456 457 void VisitGotoStmt(const GotoStmt *S) { 458 RecordStmtCount(S); 459 CurrentCount = 0; 460 RecordNextStmtCount = true; 461 } 462 463 void VisitLabelStmt(const LabelStmt *S) { 464 RecordNextStmtCount = false; 465 // Counter tracks the block following the label. 466 uint64_t BlockCount = setCount(PGO.getRegionCount(S)); 467 CountMap[S] = BlockCount; 468 Visit(S->getSubStmt()); 469 } 470 471 void VisitBreakStmt(const BreakStmt *S) { 472 RecordStmtCount(S); 473 assert(!BreakContinueStack.empty() && "break not in a loop or switch!"); 474 BreakContinueStack.back().BreakCount += CurrentCount; 475 CurrentCount = 0; 476 RecordNextStmtCount = true; 477 } 478 479 void VisitContinueStmt(const ContinueStmt *S) { 480 RecordStmtCount(S); 481 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!"); 482 BreakContinueStack.back().ContinueCount += CurrentCount; 483 CurrentCount = 0; 484 RecordNextStmtCount = true; 485 } 486 487 void VisitWhileStmt(const WhileStmt *S) { 488 RecordStmtCount(S); 489 uint64_t ParentCount = CurrentCount; 490 491 BreakContinueStack.push_back(BreakContinue()); 492 // Visit the body region first so the break/continue adjustments can be 493 // included when visiting the condition. 494 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 495 CountMap[S->getBody()] = CurrentCount; 496 Visit(S->getBody()); 497 uint64_t BackedgeCount = CurrentCount; 498 499 // ...then go back and propagate counts through the condition. The count 500 // at the start of the condition is the sum of the incoming edges, 501 // the backedge from the end of the loop body, and the edges from 502 // continue statements. 503 BreakContinue BC = BreakContinueStack.pop_back_val(); 504 uint64_t CondCount = 505 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 506 CountMap[S->getCond()] = CondCount; 507 Visit(S->getCond()); 508 setCount(BC.BreakCount + CondCount - BodyCount); 509 RecordNextStmtCount = true; 510 } 511 512 void VisitDoStmt(const DoStmt *S) { 513 RecordStmtCount(S); 514 uint64_t LoopCount = PGO.getRegionCount(S); 515 516 BreakContinueStack.push_back(BreakContinue()); 517 // The count doesn't include the fallthrough from the parent scope. Add it. 518 uint64_t BodyCount = setCount(LoopCount + CurrentCount); 519 CountMap[S->getBody()] = BodyCount; 520 Visit(S->getBody()); 521 uint64_t BackedgeCount = CurrentCount; 522 523 BreakContinue BC = BreakContinueStack.pop_back_val(); 524 // The count at the start of the condition is equal to the count at the 525 // end of the body, plus any continues. 526 uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount); 527 CountMap[S->getCond()] = CondCount; 528 Visit(S->getCond()); 529 setCount(BC.BreakCount + CondCount - LoopCount); 530 RecordNextStmtCount = true; 531 } 532 533 void VisitForStmt(const ForStmt *S) { 534 RecordStmtCount(S); 535 if (S->getInit()) 536 Visit(S->getInit()); 537 538 uint64_t ParentCount = CurrentCount; 539 540 BreakContinueStack.push_back(BreakContinue()); 541 // Visit the body region first. (This is basically the same as a while 542 // loop; see further comments in VisitWhileStmt.) 543 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 544 CountMap[S->getBody()] = BodyCount; 545 Visit(S->getBody()); 546 uint64_t BackedgeCount = CurrentCount; 547 BreakContinue BC = BreakContinueStack.pop_back_val(); 548 549 // The increment is essentially part of the body but it needs to include 550 // the count for all the continue statements. 551 if (S->getInc()) { 552 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); 553 CountMap[S->getInc()] = IncCount; 554 Visit(S->getInc()); 555 } 556 557 // ...then go back and propagate counts through the condition. 558 uint64_t CondCount = 559 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 560 if (S->getCond()) { 561 CountMap[S->getCond()] = CondCount; 562 Visit(S->getCond()); 563 } 564 setCount(BC.BreakCount + CondCount - BodyCount); 565 RecordNextStmtCount = true; 566 } 567 568 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) { 569 RecordStmtCount(S); 570 if (S->getInit()) 571 Visit(S->getInit()); 572 Visit(S->getLoopVarStmt()); 573 Visit(S->getRangeStmt()); 574 Visit(S->getBeginStmt()); 575 Visit(S->getEndStmt()); 576 577 uint64_t ParentCount = CurrentCount; 578 BreakContinueStack.push_back(BreakContinue()); 579 // Visit the body region first. (This is basically the same as a while 580 // loop; see further comments in VisitWhileStmt.) 581 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 582 CountMap[S->getBody()] = BodyCount; 583 Visit(S->getBody()); 584 uint64_t BackedgeCount = CurrentCount; 585 BreakContinue BC = BreakContinueStack.pop_back_val(); 586 587 // The increment is essentially part of the body but it needs to include 588 // the count for all the continue statements. 589 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); 590 CountMap[S->getInc()] = IncCount; 591 Visit(S->getInc()); 592 593 // ...then go back and propagate counts through the condition. 594 uint64_t CondCount = 595 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 596 CountMap[S->getCond()] = CondCount; 597 Visit(S->getCond()); 598 setCount(BC.BreakCount + CondCount - BodyCount); 599 RecordNextStmtCount = true; 600 } 601 602 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) { 603 RecordStmtCount(S); 604 Visit(S->getElement()); 605 uint64_t ParentCount = CurrentCount; 606 BreakContinueStack.push_back(BreakContinue()); 607 // Counter tracks the body of the loop. 608 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 609 CountMap[S->getBody()] = BodyCount; 610 Visit(S->getBody()); 611 uint64_t BackedgeCount = CurrentCount; 612 BreakContinue BC = BreakContinueStack.pop_back_val(); 613 614 setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount - 615 BodyCount); 616 RecordNextStmtCount = true; 617 } 618 619 void VisitSwitchStmt(const SwitchStmt *S) { 620 RecordStmtCount(S); 621 if (S->getInit()) 622 Visit(S->getInit()); 623 Visit(S->getCond()); 624 CurrentCount = 0; 625 BreakContinueStack.push_back(BreakContinue()); 626 Visit(S->getBody()); 627 // If the switch is inside a loop, add the continue counts. 628 BreakContinue BC = BreakContinueStack.pop_back_val(); 629 if (!BreakContinueStack.empty()) 630 BreakContinueStack.back().ContinueCount += BC.ContinueCount; 631 // Counter tracks the exit block of the switch. 632 setCount(PGO.getRegionCount(S)); 633 RecordNextStmtCount = true; 634 } 635 636 void VisitSwitchCase(const SwitchCase *S) { 637 RecordNextStmtCount = false; 638 // Counter for this particular case. This counts only jumps from the 639 // switch header and does not include fallthrough from the case before 640 // this one. 641 uint64_t CaseCount = PGO.getRegionCount(S); 642 setCount(CurrentCount + CaseCount); 643 // We need the count without fallthrough in the mapping, so it's more useful 644 // for branch probabilities. 645 CountMap[S] = CaseCount; 646 RecordNextStmtCount = true; 647 Visit(S->getSubStmt()); 648 } 649 650 void VisitIfStmt(const IfStmt *S) { 651 RecordStmtCount(S); 652 uint64_t ParentCount = CurrentCount; 653 if (S->getInit()) 654 Visit(S->getInit()); 655 Visit(S->getCond()); 656 657 // Counter tracks the "then" part of an if statement. The count for 658 // the "else" part, if it exists, will be calculated from this counter. 659 uint64_t ThenCount = setCount(PGO.getRegionCount(S)); 660 CountMap[S->getThen()] = ThenCount; 661 Visit(S->getThen()); 662 uint64_t OutCount = CurrentCount; 663 664 uint64_t ElseCount = ParentCount - ThenCount; 665 if (S->getElse()) { 666 setCount(ElseCount); 667 CountMap[S->getElse()] = ElseCount; 668 Visit(S->getElse()); 669 OutCount += CurrentCount; 670 } else 671 OutCount += ElseCount; 672 setCount(OutCount); 673 RecordNextStmtCount = true; 674 } 675 676 void VisitCXXTryStmt(const CXXTryStmt *S) { 677 RecordStmtCount(S); 678 Visit(S->getTryBlock()); 679 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I) 680 Visit(S->getHandler(I)); 681 // Counter tracks the continuation block of the try statement. 682 setCount(PGO.getRegionCount(S)); 683 RecordNextStmtCount = true; 684 } 685 686 void VisitCXXCatchStmt(const CXXCatchStmt *S) { 687 RecordNextStmtCount = false; 688 // Counter tracks the catch statement's handler block. 689 uint64_t CatchCount = setCount(PGO.getRegionCount(S)); 690 CountMap[S] = CatchCount; 691 Visit(S->getHandlerBlock()); 692 } 693 694 void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) { 695 RecordStmtCount(E); 696 uint64_t ParentCount = CurrentCount; 697 Visit(E->getCond()); 698 699 // Counter tracks the "true" part of a conditional operator. The 700 // count in the "false" part will be calculated from this counter. 701 uint64_t TrueCount = setCount(PGO.getRegionCount(E)); 702 CountMap[E->getTrueExpr()] = TrueCount; 703 Visit(E->getTrueExpr()); 704 uint64_t OutCount = CurrentCount; 705 706 uint64_t FalseCount = setCount(ParentCount - TrueCount); 707 CountMap[E->getFalseExpr()] = FalseCount; 708 Visit(E->getFalseExpr()); 709 OutCount += CurrentCount; 710 711 setCount(OutCount); 712 RecordNextStmtCount = true; 713 } 714 715 void VisitBinLAnd(const BinaryOperator *E) { 716 RecordStmtCount(E); 717 uint64_t ParentCount = CurrentCount; 718 Visit(E->getLHS()); 719 // Counter tracks the right hand side of a logical and operator. 720 uint64_t RHSCount = setCount(PGO.getRegionCount(E)); 721 CountMap[E->getRHS()] = RHSCount; 722 Visit(E->getRHS()); 723 setCount(ParentCount + RHSCount - CurrentCount); 724 RecordNextStmtCount = true; 725 } 726 727 void VisitBinLOr(const BinaryOperator *E) { 728 RecordStmtCount(E); 729 uint64_t ParentCount = CurrentCount; 730 Visit(E->getLHS()); 731 // Counter tracks the right hand side of a logical or operator. 732 uint64_t RHSCount = setCount(PGO.getRegionCount(E)); 733 CountMap[E->getRHS()] = RHSCount; 734 Visit(E->getRHS()); 735 setCount(ParentCount + RHSCount - CurrentCount); 736 RecordNextStmtCount = true; 737 } 738 }; 739 } // end anonymous namespace 740 741 void PGOHash::combine(HashType Type) { 742 // Check that we never combine 0 and only have six bits. 743 assert(Type && "Hash is invalid: unexpected type 0"); 744 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types"); 745 746 // Pass through MD5 if enough work has built up. 747 if (Count && Count % NumTypesPerWord == 0) { 748 using namespace llvm::support; 749 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working); 750 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped))); 751 Working = 0; 752 } 753 754 // Accumulate the current type. 755 ++Count; 756 Working = Working << NumBitsPerType | Type; 757 } 758 759 uint64_t PGOHash::finalize() { 760 // Use Working as the hash directly if we never used MD5. 761 if (Count <= NumTypesPerWord) 762 // No need to byte swap here, since none of the math was endian-dependent. 763 // This number will be byte-swapped as required on endianness transitions, 764 // so we will see the same value on the other side. 765 return Working; 766 767 // Check for remaining work in Working. 768 if (Working) { 769 // Keep the buggy behavior from v1 and v2 for backward-compatibility. This 770 // is buggy because it converts a uint64_t into an array of uint8_t. 771 if (HashVersion < PGO_HASH_V3) { 772 MD5.update({(uint8_t)Working}); 773 } else { 774 using namespace llvm::support; 775 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working); 776 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped))); 777 } 778 } 779 780 // Finalize the MD5 and return the hash. 781 llvm::MD5::MD5Result Result; 782 MD5.final(Result); 783 return Result.low(); 784 } 785 786 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) { 787 const Decl *D = GD.getDecl(); 788 if (!D->hasBody()) 789 return; 790 791 // Skip CUDA/HIP kernel launch stub functions. 792 if (CGM.getLangOpts().CUDA && !CGM.getLangOpts().CUDAIsDevice && 793 D->hasAttr<CUDAGlobalAttr>()) 794 return; 795 796 bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr(); 797 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 798 if (!InstrumentRegions && !PGOReader) 799 return; 800 if (D->isImplicit()) 801 return; 802 // Constructors and destructors may be represented by several functions in IR. 803 // If so, instrument only base variant, others are implemented by delegation 804 // to the base one, it would be counted twice otherwise. 805 if (CGM.getTarget().getCXXABI().hasConstructorVariants()) { 806 if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D)) 807 if (GD.getCtorType() != Ctor_Base && 808 CodeGenFunction::IsConstructorDelegationValid(CCD)) 809 return; 810 } 811 if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base) 812 return; 813 814 CGM.ClearUnusedCoverageMapping(D); 815 setFuncName(Fn); 816 817 mapRegionCounters(D); 818 if (CGM.getCodeGenOpts().CoverageMapping) 819 emitCounterRegionMapping(D); 820 if (PGOReader) { 821 SourceManager &SM = CGM.getContext().getSourceManager(); 822 loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation())); 823 computeRegionCounts(D); 824 applyFunctionAttributes(PGOReader, Fn); 825 } 826 } 827 828 void CodeGenPGO::mapRegionCounters(const Decl *D) { 829 // Use the latest hash version when inserting instrumentation, but use the 830 // version in the indexed profile if we're reading PGO data. 831 PGOHashVersion HashVersion = PGO_HASH_LATEST; 832 uint64_t ProfileVersion = llvm::IndexedInstrProf::Version; 833 if (auto *PGOReader = CGM.getPGOReader()) { 834 HashVersion = getPGOHashVersion(PGOReader, CGM); 835 ProfileVersion = PGOReader->getVersion(); 836 } 837 838 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>); 839 MapRegionCounters Walker(HashVersion, ProfileVersion, *RegionCounterMap); 840 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 841 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD)); 842 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 843 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD)); 844 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 845 Walker.TraverseDecl(const_cast<BlockDecl *>(BD)); 846 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 847 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD)); 848 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl"); 849 NumRegionCounters = Walker.NextCounter; 850 FunctionHash = Walker.Hash.finalize(); 851 } 852 853 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) { 854 if (!D->getBody()) 855 return true; 856 857 // Skip host-only functions in the CUDA device compilation and device-only 858 // functions in the host compilation. Just roughly filter them out based on 859 // the function attributes. If there are effectively host-only or device-only 860 // ones, their coverage mapping may still be generated. 861 if (CGM.getLangOpts().CUDA && 862 ((CGM.getLangOpts().CUDAIsDevice && !D->hasAttr<CUDADeviceAttr>() && 863 !D->hasAttr<CUDAGlobalAttr>()) || 864 (!CGM.getLangOpts().CUDAIsDevice && 865 (D->hasAttr<CUDAGlobalAttr>() || 866 (!D->hasAttr<CUDAHostAttr>() && D->hasAttr<CUDADeviceAttr>()))))) 867 return true; 868 869 // Don't map the functions in system headers. 870 const auto &SM = CGM.getContext().getSourceManager(); 871 auto Loc = D->getBody()->getBeginLoc(); 872 return SM.isInSystemHeader(Loc); 873 } 874 875 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) { 876 if (skipRegionMappingForDecl(D)) 877 return; 878 879 std::string CoverageMapping; 880 llvm::raw_string_ostream OS(CoverageMapping); 881 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(), 882 CGM.getContext().getSourceManager(), 883 CGM.getLangOpts(), RegionCounterMap.get()); 884 MappingGen.emitCounterMapping(D, OS); 885 OS.flush(); 886 887 if (CoverageMapping.empty()) 888 return; 889 890 CGM.getCoverageMapping()->addFunctionMappingRecord( 891 FuncNameVar, FuncName, FunctionHash, CoverageMapping); 892 } 893 894 void 895 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name, 896 llvm::GlobalValue::LinkageTypes Linkage) { 897 if (skipRegionMappingForDecl(D)) 898 return; 899 900 std::string CoverageMapping; 901 llvm::raw_string_ostream OS(CoverageMapping); 902 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(), 903 CGM.getContext().getSourceManager(), 904 CGM.getLangOpts()); 905 MappingGen.emitEmptyMapping(D, OS); 906 OS.flush(); 907 908 if (CoverageMapping.empty()) 909 return; 910 911 setFuncName(Name, Linkage); 912 CGM.getCoverageMapping()->addFunctionMappingRecord( 913 FuncNameVar, FuncName, FunctionHash, CoverageMapping, false); 914 } 915 916 void CodeGenPGO::computeRegionCounts(const Decl *D) { 917 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>); 918 ComputeRegionCounts Walker(*StmtCountMap, *this); 919 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 920 Walker.VisitFunctionDecl(FD); 921 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 922 Walker.VisitObjCMethodDecl(MD); 923 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 924 Walker.VisitBlockDecl(BD); 925 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 926 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD)); 927 } 928 929 void 930 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader, 931 llvm::Function *Fn) { 932 if (!haveRegionCounts()) 933 return; 934 935 uint64_t FunctionCount = getRegionCount(nullptr); 936 Fn->setEntryCount(FunctionCount); 937 } 938 939 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S, 940 llvm::Value *StepV) { 941 if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap) 942 return; 943 if (!Builder.GetInsertBlock()) 944 return; 945 946 unsigned Counter = (*RegionCounterMap)[S]; 947 auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext()); 948 949 llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy), 950 Builder.getInt64(FunctionHash), 951 Builder.getInt32(NumRegionCounters), 952 Builder.getInt32(Counter), StepV}; 953 if (!StepV) 954 Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment), 955 makeArrayRef(Args, 4)); 956 else 957 Builder.CreateCall( 958 CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step), 959 makeArrayRef(Args)); 960 } 961 962 // This method either inserts a call to the profile run-time during 963 // instrumentation or puts profile data into metadata for PGO use. 964 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind, 965 llvm::Instruction *ValueSite, llvm::Value *ValuePtr) { 966 967 if (!EnableValueProfiling) 968 return; 969 970 if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock()) 971 return; 972 973 if (isa<llvm::Constant>(ValuePtr)) 974 return; 975 976 bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr(); 977 if (InstrumentValueSites && RegionCounterMap) { 978 auto BuilderInsertPoint = Builder.saveIP(); 979 Builder.SetInsertPoint(ValueSite); 980 llvm::Value *Args[5] = { 981 llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()), 982 Builder.getInt64(FunctionHash), 983 Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()), 984 Builder.getInt32(ValueKind), 985 Builder.getInt32(NumValueSites[ValueKind]++) 986 }; 987 Builder.CreateCall( 988 CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args); 989 Builder.restoreIP(BuilderInsertPoint); 990 return; 991 } 992 993 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 994 if (PGOReader && haveRegionCounts()) { 995 // We record the top most called three functions at each call site. 996 // Profile metadata contains "VP" string identifying this metadata 997 // as value profiling data, then a uint32_t value for the value profiling 998 // kind, a uint64_t value for the total number of times the call is 999 // executed, followed by the function hash and execution count (uint64_t) 1000 // pairs for each function. 1001 if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind)) 1002 return; 1003 1004 llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord, 1005 (llvm::InstrProfValueKind)ValueKind, 1006 NumValueSites[ValueKind]); 1007 1008 NumValueSites[ValueKind]++; 1009 } 1010 } 1011 1012 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader, 1013 bool IsInMainFile) { 1014 CGM.getPGOStats().addVisited(IsInMainFile); 1015 RegionCounts.clear(); 1016 llvm::Expected<llvm::InstrProfRecord> RecordExpected = 1017 PGOReader->getInstrProfRecord(FuncName, FunctionHash); 1018 if (auto E = RecordExpected.takeError()) { 1019 auto IPE = llvm::InstrProfError::take(std::move(E)); 1020 if (IPE == llvm::instrprof_error::unknown_function) 1021 CGM.getPGOStats().addMissing(IsInMainFile); 1022 else if (IPE == llvm::instrprof_error::hash_mismatch) 1023 CGM.getPGOStats().addMismatched(IsInMainFile); 1024 else if (IPE == llvm::instrprof_error::malformed) 1025 // TODO: Consider a more specific warning for this case. 1026 CGM.getPGOStats().addMismatched(IsInMainFile); 1027 return; 1028 } 1029 ProfRecord = 1030 std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get())); 1031 RegionCounts = ProfRecord->Counts; 1032 } 1033 1034 /// Calculate what to divide by to scale weights. 1035 /// 1036 /// Given the maximum weight, calculate a divisor that will scale all the 1037 /// weights to strictly less than UINT32_MAX. 1038 static uint64_t calculateWeightScale(uint64_t MaxWeight) { 1039 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1; 1040 } 1041 1042 /// Scale an individual branch weight (and add 1). 1043 /// 1044 /// Scale a 64-bit weight down to 32-bits using \c Scale. 1045 /// 1046 /// According to Laplace's Rule of Succession, it is better to compute the 1047 /// weight based on the count plus 1, so universally add 1 to the value. 1048 /// 1049 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no 1050 /// greater than \c Weight. 1051 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) { 1052 assert(Scale && "scale by 0?"); 1053 uint64_t Scaled = Weight / Scale + 1; 1054 assert(Scaled <= UINT32_MAX && "overflow 32-bits"); 1055 return Scaled; 1056 } 1057 1058 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount, 1059 uint64_t FalseCount) const { 1060 // Check for empty weights. 1061 if (!TrueCount && !FalseCount) 1062 return nullptr; 1063 1064 // Calculate how to scale down to 32-bits. 1065 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount)); 1066 1067 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 1068 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale), 1069 scaleBranchWeight(FalseCount, Scale)); 1070 } 1071 1072 llvm::MDNode * 1073 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) const { 1074 // We need at least two elements to create meaningful weights. 1075 if (Weights.size() < 2) 1076 return nullptr; 1077 1078 // Check for empty weights. 1079 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end()); 1080 if (MaxWeight == 0) 1081 return nullptr; 1082 1083 // Calculate how to scale down to 32-bits. 1084 uint64_t Scale = calculateWeightScale(MaxWeight); 1085 1086 SmallVector<uint32_t, 16> ScaledWeights; 1087 ScaledWeights.reserve(Weights.size()); 1088 for (uint64_t W : Weights) 1089 ScaledWeights.push_back(scaleBranchWeight(W, Scale)); 1090 1091 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 1092 return MDHelper.createBranchWeights(ScaledWeights); 1093 } 1094 1095 llvm::MDNode * 1096 CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond, 1097 uint64_t LoopCount) const { 1098 if (!PGO.haveRegionCounts()) 1099 return nullptr; 1100 Optional<uint64_t> CondCount = PGO.getStmtCount(Cond); 1101 if (!CondCount || *CondCount == 0) 1102 return nullptr; 1103 return createProfileWeights(LoopCount, 1104 std::max(*CondCount, LoopCount) - LoopCount); 1105 } 1106