1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // Instrumentation-based profile-guided optimization 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "CodeGenPGO.h" 15 #include "CodeGenFunction.h" 16 #include "CoverageMappingGen.h" 17 #include "clang/AST/RecursiveASTVisitor.h" 18 #include "clang/AST/StmtVisitor.h" 19 #include "llvm/IR/Intrinsics.h" 20 #include "llvm/IR/MDBuilder.h" 21 #include "llvm/Support/Endian.h" 22 #include "llvm/Support/FileSystem.h" 23 #include "llvm/Support/MD5.h" 24 25 static llvm::cl::opt<bool> 26 EnableValueProfiling("enable-value-profiling", llvm::cl::ZeroOrMore, 27 llvm::cl::desc("Enable value profiling"), 28 llvm::cl::Hidden, llvm::cl::init(false)); 29 30 using namespace clang; 31 using namespace CodeGen; 32 33 void CodeGenPGO::setFuncName(StringRef Name, 34 llvm::GlobalValue::LinkageTypes Linkage) { 35 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 36 FuncName = llvm::getPGOFuncName( 37 Name, Linkage, CGM.getCodeGenOpts().MainFileName, 38 PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version); 39 40 // If we're generating a profile, create a variable for the name. 41 if (CGM.getCodeGenOpts().hasProfileClangInstr()) 42 FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName); 43 } 44 45 void CodeGenPGO::setFuncName(llvm::Function *Fn) { 46 setFuncName(Fn->getName(), Fn->getLinkage()); 47 // Create PGOFuncName meta data. 48 llvm::createPGOFuncNameMetadata(*Fn, FuncName); 49 } 50 51 /// The version of the PGO hash algorithm. 52 enum PGOHashVersion : unsigned { 53 PGO_HASH_V1, 54 PGO_HASH_V2, 55 56 // Keep this set to the latest hash version. 57 PGO_HASH_LATEST = PGO_HASH_V2 58 }; 59 60 namespace { 61 /// \brief Stable hasher for PGO region counters. 62 /// 63 /// PGOHash produces a stable hash of a given function's control flow. 64 /// 65 /// Changing the output of this hash will invalidate all previously generated 66 /// profiles -- i.e., don't do it. 67 /// 68 /// \note When this hash does eventually change (years?), we still need to 69 /// support old hashes. We'll need to pull in the version number from the 70 /// profile data format and use the matching hash function. 71 class PGOHash { 72 uint64_t Working; 73 unsigned Count; 74 PGOHashVersion HashVersion; 75 llvm::MD5 MD5; 76 77 static const int NumBitsPerType = 6; 78 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType; 79 static const unsigned TooBig = 1u << NumBitsPerType; 80 81 public: 82 /// \brief Hash values for AST nodes. 83 /// 84 /// Distinct values for AST nodes that have region counters attached. 85 /// 86 /// These values must be stable. All new members must be added at the end, 87 /// and no members should be removed. Changing the enumeration value for an 88 /// AST node will affect the hash of every function that contains that node. 89 enum HashType : unsigned char { 90 None = 0, 91 LabelStmt = 1, 92 WhileStmt, 93 DoStmt, 94 ForStmt, 95 CXXForRangeStmt, 96 ObjCForCollectionStmt, 97 SwitchStmt, 98 CaseStmt, 99 DefaultStmt, 100 IfStmt, 101 CXXTryStmt, 102 CXXCatchStmt, 103 ConditionalOperator, 104 BinaryOperatorLAnd, 105 BinaryOperatorLOr, 106 BinaryConditionalOperator, 107 // The preceding values are available with PGO_HASH_V1. 108 109 EndOfScope, 110 IfThenBranch, 111 IfElseBranch, 112 GotoStmt, 113 IndirectGotoStmt, 114 BreakStmt, 115 ContinueStmt, 116 ReturnStmt, 117 ThrowExpr, 118 UnaryOperatorLNot, 119 BinaryOperatorLT, 120 BinaryOperatorGT, 121 BinaryOperatorLE, 122 BinaryOperatorGE, 123 BinaryOperatorEQ, 124 BinaryOperatorNE, 125 // The preceding values are available with PGO_HASH_V2. 126 127 // Keep this last. It's for the static assert that follows. 128 LastHashType 129 }; 130 static_assert(LastHashType <= TooBig, "Too many types in HashType"); 131 132 PGOHash(PGOHashVersion HashVersion) 133 : Working(0), Count(0), HashVersion(HashVersion), MD5() {} 134 void combine(HashType Type); 135 uint64_t finalize(); 136 PGOHashVersion getHashVersion() const { return HashVersion; } 137 }; 138 const int PGOHash::NumBitsPerType; 139 const unsigned PGOHash::NumTypesPerWord; 140 const unsigned PGOHash::TooBig; 141 142 /// Get the PGO hash version used in the given indexed profile. 143 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader, 144 CodeGenModule &CGM) { 145 if (PGOReader->getVersion() <= 4) 146 return PGO_HASH_V1; 147 return PGO_HASH_V2; 148 } 149 150 /// A RecursiveASTVisitor that fills a map of statements to PGO counters. 151 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> { 152 using Base = RecursiveASTVisitor<MapRegionCounters>; 153 154 /// The next counter value to assign. 155 unsigned NextCounter; 156 /// The function hash. 157 PGOHash Hash; 158 /// The map of statements to counters. 159 llvm::DenseMap<const Stmt *, unsigned> &CounterMap; 160 161 MapRegionCounters(PGOHashVersion HashVersion, 162 llvm::DenseMap<const Stmt *, unsigned> &CounterMap) 163 : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap) {} 164 165 // Blocks and lambdas are handled as separate functions, so we need not 166 // traverse them in the parent context. 167 bool TraverseBlockExpr(BlockExpr *BE) { return true; } 168 bool TraverseLambdaBody(LambdaExpr *LE) { return true; } 169 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; } 170 171 bool VisitDecl(const Decl *D) { 172 switch (D->getKind()) { 173 default: 174 break; 175 case Decl::Function: 176 case Decl::CXXMethod: 177 case Decl::CXXConstructor: 178 case Decl::CXXDestructor: 179 case Decl::CXXConversion: 180 case Decl::ObjCMethod: 181 case Decl::Block: 182 case Decl::Captured: 183 CounterMap[D->getBody()] = NextCounter++; 184 break; 185 } 186 return true; 187 } 188 189 /// If \p S gets a fresh counter, update the counter mappings. Return the 190 /// V1 hash of \p S. 191 PGOHash::HashType updateCounterMappings(Stmt *S) { 192 auto Type = getHashType(PGO_HASH_V1, S); 193 if (Type != PGOHash::None) 194 CounterMap[S] = NextCounter++; 195 return Type; 196 } 197 198 /// Include \p S in the function hash. 199 bool VisitStmt(Stmt *S) { 200 auto Type = updateCounterMappings(S); 201 if (Hash.getHashVersion() != PGO_HASH_V1) 202 Type = getHashType(Hash.getHashVersion(), S); 203 if (Type != PGOHash::None) 204 Hash.combine(Type); 205 return true; 206 } 207 208 bool TraverseIfStmt(IfStmt *If) { 209 // If we used the V1 hash, use the default traversal. 210 if (Hash.getHashVersion() == PGO_HASH_V1) 211 return Base::TraverseIfStmt(If); 212 213 // Otherwise, keep track of which branch we're in while traversing. 214 VisitStmt(If); 215 for (Stmt *CS : If->children()) { 216 if (!CS) 217 continue; 218 if (CS == If->getThen()) 219 Hash.combine(PGOHash::IfThenBranch); 220 else if (CS == If->getElse()) 221 Hash.combine(PGOHash::IfElseBranch); 222 TraverseStmt(CS); 223 } 224 Hash.combine(PGOHash::EndOfScope); 225 return true; 226 } 227 228 // If the statement type \p N is nestable, and its nesting impacts profile 229 // stability, define a custom traversal which tracks the end of the statement 230 // in the hash (provided we're not using the V1 hash). 231 #define DEFINE_NESTABLE_TRAVERSAL(N) \ 232 bool Traverse##N(N *S) { \ 233 Base::Traverse##N(S); \ 234 if (Hash.getHashVersion() != PGO_HASH_V1) \ 235 Hash.combine(PGOHash::EndOfScope); \ 236 return true; \ 237 } 238 239 DEFINE_NESTABLE_TRAVERSAL(WhileStmt) 240 DEFINE_NESTABLE_TRAVERSAL(DoStmt) 241 DEFINE_NESTABLE_TRAVERSAL(ForStmt) 242 DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt) 243 DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt) 244 DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt) 245 DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt) 246 247 /// Get version \p HashVersion of the PGO hash for \p S. 248 PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) { 249 switch (S->getStmtClass()) { 250 default: 251 break; 252 case Stmt::LabelStmtClass: 253 return PGOHash::LabelStmt; 254 case Stmt::WhileStmtClass: 255 return PGOHash::WhileStmt; 256 case Stmt::DoStmtClass: 257 return PGOHash::DoStmt; 258 case Stmt::ForStmtClass: 259 return PGOHash::ForStmt; 260 case Stmt::CXXForRangeStmtClass: 261 return PGOHash::CXXForRangeStmt; 262 case Stmt::ObjCForCollectionStmtClass: 263 return PGOHash::ObjCForCollectionStmt; 264 case Stmt::SwitchStmtClass: 265 return PGOHash::SwitchStmt; 266 case Stmt::CaseStmtClass: 267 return PGOHash::CaseStmt; 268 case Stmt::DefaultStmtClass: 269 return PGOHash::DefaultStmt; 270 case Stmt::IfStmtClass: 271 return PGOHash::IfStmt; 272 case Stmt::CXXTryStmtClass: 273 return PGOHash::CXXTryStmt; 274 case Stmt::CXXCatchStmtClass: 275 return PGOHash::CXXCatchStmt; 276 case Stmt::ConditionalOperatorClass: 277 return PGOHash::ConditionalOperator; 278 case Stmt::BinaryConditionalOperatorClass: 279 return PGOHash::BinaryConditionalOperator; 280 case Stmt::BinaryOperatorClass: { 281 const BinaryOperator *BO = cast<BinaryOperator>(S); 282 if (BO->getOpcode() == BO_LAnd) 283 return PGOHash::BinaryOperatorLAnd; 284 if (BO->getOpcode() == BO_LOr) 285 return PGOHash::BinaryOperatorLOr; 286 if (HashVersion == PGO_HASH_V2) { 287 switch (BO->getOpcode()) { 288 default: 289 break; 290 case BO_LT: 291 return PGOHash::BinaryOperatorLT; 292 case BO_GT: 293 return PGOHash::BinaryOperatorGT; 294 case BO_LE: 295 return PGOHash::BinaryOperatorLE; 296 case BO_GE: 297 return PGOHash::BinaryOperatorGE; 298 case BO_EQ: 299 return PGOHash::BinaryOperatorEQ; 300 case BO_NE: 301 return PGOHash::BinaryOperatorNE; 302 } 303 } 304 break; 305 } 306 } 307 308 if (HashVersion == PGO_HASH_V2) { 309 switch (S->getStmtClass()) { 310 default: 311 break; 312 case Stmt::GotoStmtClass: 313 return PGOHash::GotoStmt; 314 case Stmt::IndirectGotoStmtClass: 315 return PGOHash::IndirectGotoStmt; 316 case Stmt::BreakStmtClass: 317 return PGOHash::BreakStmt; 318 case Stmt::ContinueStmtClass: 319 return PGOHash::ContinueStmt; 320 case Stmt::ReturnStmtClass: 321 return PGOHash::ReturnStmt; 322 case Stmt::CXXThrowExprClass: 323 return PGOHash::ThrowExpr; 324 case Stmt::UnaryOperatorClass: { 325 const UnaryOperator *UO = cast<UnaryOperator>(S); 326 if (UO->getOpcode() == UO_LNot) 327 return PGOHash::UnaryOperatorLNot; 328 break; 329 } 330 } 331 } 332 333 return PGOHash::None; 334 } 335 }; 336 337 /// A StmtVisitor that propagates the raw counts through the AST and 338 /// records the count at statements where the value may change. 339 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { 340 /// PGO state. 341 CodeGenPGO &PGO; 342 343 /// A flag that is set when the current count should be recorded on the 344 /// next statement, such as at the exit of a loop. 345 bool RecordNextStmtCount; 346 347 /// The count at the current location in the traversal. 348 uint64_t CurrentCount; 349 350 /// The map of statements to count values. 351 llvm::DenseMap<const Stmt *, uint64_t> &CountMap; 352 353 /// BreakContinueStack - Keep counts of breaks and continues inside loops. 354 struct BreakContinue { 355 uint64_t BreakCount; 356 uint64_t ContinueCount; 357 BreakContinue() : BreakCount(0), ContinueCount(0) {} 358 }; 359 SmallVector<BreakContinue, 8> BreakContinueStack; 360 361 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap, 362 CodeGenPGO &PGO) 363 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {} 364 365 void RecordStmtCount(const Stmt *S) { 366 if (RecordNextStmtCount) { 367 CountMap[S] = CurrentCount; 368 RecordNextStmtCount = false; 369 } 370 } 371 372 /// Set and return the current count. 373 uint64_t setCount(uint64_t Count) { 374 CurrentCount = Count; 375 return Count; 376 } 377 378 void VisitStmt(const Stmt *S) { 379 RecordStmtCount(S); 380 for (const Stmt *Child : S->children()) 381 if (Child) 382 this->Visit(Child); 383 } 384 385 void VisitFunctionDecl(const FunctionDecl *D) { 386 // Counter tracks entry to the function body. 387 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 388 CountMap[D->getBody()] = BodyCount; 389 Visit(D->getBody()); 390 } 391 392 // Skip lambda expressions. We visit these as FunctionDecls when we're 393 // generating them and aren't interested in the body when generating a 394 // parent context. 395 void VisitLambdaExpr(const LambdaExpr *LE) {} 396 397 void VisitCapturedDecl(const CapturedDecl *D) { 398 // Counter tracks entry to the capture body. 399 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 400 CountMap[D->getBody()] = BodyCount; 401 Visit(D->getBody()); 402 } 403 404 void VisitObjCMethodDecl(const ObjCMethodDecl *D) { 405 // Counter tracks entry to the method body. 406 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 407 CountMap[D->getBody()] = BodyCount; 408 Visit(D->getBody()); 409 } 410 411 void VisitBlockDecl(const BlockDecl *D) { 412 // Counter tracks entry to the block body. 413 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 414 CountMap[D->getBody()] = BodyCount; 415 Visit(D->getBody()); 416 } 417 418 void VisitReturnStmt(const ReturnStmt *S) { 419 RecordStmtCount(S); 420 if (S->getRetValue()) 421 Visit(S->getRetValue()); 422 CurrentCount = 0; 423 RecordNextStmtCount = true; 424 } 425 426 void VisitCXXThrowExpr(const CXXThrowExpr *E) { 427 RecordStmtCount(E); 428 if (E->getSubExpr()) 429 Visit(E->getSubExpr()); 430 CurrentCount = 0; 431 RecordNextStmtCount = true; 432 } 433 434 void VisitGotoStmt(const GotoStmt *S) { 435 RecordStmtCount(S); 436 CurrentCount = 0; 437 RecordNextStmtCount = true; 438 } 439 440 void VisitLabelStmt(const LabelStmt *S) { 441 RecordNextStmtCount = false; 442 // Counter tracks the block following the label. 443 uint64_t BlockCount = setCount(PGO.getRegionCount(S)); 444 CountMap[S] = BlockCount; 445 Visit(S->getSubStmt()); 446 } 447 448 void VisitBreakStmt(const BreakStmt *S) { 449 RecordStmtCount(S); 450 assert(!BreakContinueStack.empty() && "break not in a loop or switch!"); 451 BreakContinueStack.back().BreakCount += CurrentCount; 452 CurrentCount = 0; 453 RecordNextStmtCount = true; 454 } 455 456 void VisitContinueStmt(const ContinueStmt *S) { 457 RecordStmtCount(S); 458 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!"); 459 BreakContinueStack.back().ContinueCount += CurrentCount; 460 CurrentCount = 0; 461 RecordNextStmtCount = true; 462 } 463 464 void VisitWhileStmt(const WhileStmt *S) { 465 RecordStmtCount(S); 466 uint64_t ParentCount = CurrentCount; 467 468 BreakContinueStack.push_back(BreakContinue()); 469 // Visit the body region first so the break/continue adjustments can be 470 // included when visiting the condition. 471 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 472 CountMap[S->getBody()] = CurrentCount; 473 Visit(S->getBody()); 474 uint64_t BackedgeCount = CurrentCount; 475 476 // ...then go back and propagate counts through the condition. The count 477 // at the start of the condition is the sum of the incoming edges, 478 // the backedge from the end of the loop body, and the edges from 479 // continue statements. 480 BreakContinue BC = BreakContinueStack.pop_back_val(); 481 uint64_t CondCount = 482 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 483 CountMap[S->getCond()] = CondCount; 484 Visit(S->getCond()); 485 setCount(BC.BreakCount + CondCount - BodyCount); 486 RecordNextStmtCount = true; 487 } 488 489 void VisitDoStmt(const DoStmt *S) { 490 RecordStmtCount(S); 491 uint64_t LoopCount = PGO.getRegionCount(S); 492 493 BreakContinueStack.push_back(BreakContinue()); 494 // The count doesn't include the fallthrough from the parent scope. Add it. 495 uint64_t BodyCount = setCount(LoopCount + CurrentCount); 496 CountMap[S->getBody()] = BodyCount; 497 Visit(S->getBody()); 498 uint64_t BackedgeCount = CurrentCount; 499 500 BreakContinue BC = BreakContinueStack.pop_back_val(); 501 // The count at the start of the condition is equal to the count at the 502 // end of the body, plus any continues. 503 uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount); 504 CountMap[S->getCond()] = CondCount; 505 Visit(S->getCond()); 506 setCount(BC.BreakCount + CondCount - LoopCount); 507 RecordNextStmtCount = true; 508 } 509 510 void VisitForStmt(const ForStmt *S) { 511 RecordStmtCount(S); 512 if (S->getInit()) 513 Visit(S->getInit()); 514 515 uint64_t ParentCount = CurrentCount; 516 517 BreakContinueStack.push_back(BreakContinue()); 518 // Visit the body region first. (This is basically the same as a while 519 // loop; see further comments in VisitWhileStmt.) 520 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 521 CountMap[S->getBody()] = BodyCount; 522 Visit(S->getBody()); 523 uint64_t BackedgeCount = CurrentCount; 524 BreakContinue BC = BreakContinueStack.pop_back_val(); 525 526 // The increment is essentially part of the body but it needs to include 527 // the count for all the continue statements. 528 if (S->getInc()) { 529 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); 530 CountMap[S->getInc()] = IncCount; 531 Visit(S->getInc()); 532 } 533 534 // ...then go back and propagate counts through the condition. 535 uint64_t CondCount = 536 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 537 if (S->getCond()) { 538 CountMap[S->getCond()] = CondCount; 539 Visit(S->getCond()); 540 } 541 setCount(BC.BreakCount + CondCount - BodyCount); 542 RecordNextStmtCount = true; 543 } 544 545 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) { 546 RecordStmtCount(S); 547 Visit(S->getLoopVarStmt()); 548 Visit(S->getRangeStmt()); 549 Visit(S->getBeginStmt()); 550 Visit(S->getEndStmt()); 551 552 uint64_t ParentCount = CurrentCount; 553 BreakContinueStack.push_back(BreakContinue()); 554 // Visit the body region first. (This is basically the same as a while 555 // loop; see further comments in VisitWhileStmt.) 556 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 557 CountMap[S->getBody()] = BodyCount; 558 Visit(S->getBody()); 559 uint64_t BackedgeCount = CurrentCount; 560 BreakContinue BC = BreakContinueStack.pop_back_val(); 561 562 // The increment is essentially part of the body but it needs to include 563 // the count for all the continue statements. 564 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); 565 CountMap[S->getInc()] = IncCount; 566 Visit(S->getInc()); 567 568 // ...then go back and propagate counts through the condition. 569 uint64_t CondCount = 570 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 571 CountMap[S->getCond()] = CondCount; 572 Visit(S->getCond()); 573 setCount(BC.BreakCount + CondCount - BodyCount); 574 RecordNextStmtCount = true; 575 } 576 577 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) { 578 RecordStmtCount(S); 579 Visit(S->getElement()); 580 uint64_t ParentCount = CurrentCount; 581 BreakContinueStack.push_back(BreakContinue()); 582 // Counter tracks the body of the loop. 583 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 584 CountMap[S->getBody()] = BodyCount; 585 Visit(S->getBody()); 586 uint64_t BackedgeCount = CurrentCount; 587 BreakContinue BC = BreakContinueStack.pop_back_val(); 588 589 setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount - 590 BodyCount); 591 RecordNextStmtCount = true; 592 } 593 594 void VisitSwitchStmt(const SwitchStmt *S) { 595 RecordStmtCount(S); 596 if (S->getInit()) 597 Visit(S->getInit()); 598 Visit(S->getCond()); 599 CurrentCount = 0; 600 BreakContinueStack.push_back(BreakContinue()); 601 Visit(S->getBody()); 602 // If the switch is inside a loop, add the continue counts. 603 BreakContinue BC = BreakContinueStack.pop_back_val(); 604 if (!BreakContinueStack.empty()) 605 BreakContinueStack.back().ContinueCount += BC.ContinueCount; 606 // Counter tracks the exit block of the switch. 607 setCount(PGO.getRegionCount(S)); 608 RecordNextStmtCount = true; 609 } 610 611 void VisitSwitchCase(const SwitchCase *S) { 612 RecordNextStmtCount = false; 613 // Counter for this particular case. This counts only jumps from the 614 // switch header and does not include fallthrough from the case before 615 // this one. 616 uint64_t CaseCount = PGO.getRegionCount(S); 617 setCount(CurrentCount + CaseCount); 618 // We need the count without fallthrough in the mapping, so it's more useful 619 // for branch probabilities. 620 CountMap[S] = CaseCount; 621 RecordNextStmtCount = true; 622 Visit(S->getSubStmt()); 623 } 624 625 void VisitIfStmt(const IfStmt *S) { 626 RecordStmtCount(S); 627 uint64_t ParentCount = CurrentCount; 628 if (S->getInit()) 629 Visit(S->getInit()); 630 Visit(S->getCond()); 631 632 // Counter tracks the "then" part of an if statement. The count for 633 // the "else" part, if it exists, will be calculated from this counter. 634 uint64_t ThenCount = setCount(PGO.getRegionCount(S)); 635 CountMap[S->getThen()] = ThenCount; 636 Visit(S->getThen()); 637 uint64_t OutCount = CurrentCount; 638 639 uint64_t ElseCount = ParentCount - ThenCount; 640 if (S->getElse()) { 641 setCount(ElseCount); 642 CountMap[S->getElse()] = ElseCount; 643 Visit(S->getElse()); 644 OutCount += CurrentCount; 645 } else 646 OutCount += ElseCount; 647 setCount(OutCount); 648 RecordNextStmtCount = true; 649 } 650 651 void VisitCXXTryStmt(const CXXTryStmt *S) { 652 RecordStmtCount(S); 653 Visit(S->getTryBlock()); 654 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I) 655 Visit(S->getHandler(I)); 656 // Counter tracks the continuation block of the try statement. 657 setCount(PGO.getRegionCount(S)); 658 RecordNextStmtCount = true; 659 } 660 661 void VisitCXXCatchStmt(const CXXCatchStmt *S) { 662 RecordNextStmtCount = false; 663 // Counter tracks the catch statement's handler block. 664 uint64_t CatchCount = setCount(PGO.getRegionCount(S)); 665 CountMap[S] = CatchCount; 666 Visit(S->getHandlerBlock()); 667 } 668 669 void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) { 670 RecordStmtCount(E); 671 uint64_t ParentCount = CurrentCount; 672 Visit(E->getCond()); 673 674 // Counter tracks the "true" part of a conditional operator. The 675 // count in the "false" part will be calculated from this counter. 676 uint64_t TrueCount = setCount(PGO.getRegionCount(E)); 677 CountMap[E->getTrueExpr()] = TrueCount; 678 Visit(E->getTrueExpr()); 679 uint64_t OutCount = CurrentCount; 680 681 uint64_t FalseCount = setCount(ParentCount - TrueCount); 682 CountMap[E->getFalseExpr()] = FalseCount; 683 Visit(E->getFalseExpr()); 684 OutCount += CurrentCount; 685 686 setCount(OutCount); 687 RecordNextStmtCount = true; 688 } 689 690 void VisitBinLAnd(const BinaryOperator *E) { 691 RecordStmtCount(E); 692 uint64_t ParentCount = CurrentCount; 693 Visit(E->getLHS()); 694 // Counter tracks the right hand side of a logical and operator. 695 uint64_t RHSCount = setCount(PGO.getRegionCount(E)); 696 CountMap[E->getRHS()] = RHSCount; 697 Visit(E->getRHS()); 698 setCount(ParentCount + RHSCount - CurrentCount); 699 RecordNextStmtCount = true; 700 } 701 702 void VisitBinLOr(const BinaryOperator *E) { 703 RecordStmtCount(E); 704 uint64_t ParentCount = CurrentCount; 705 Visit(E->getLHS()); 706 // Counter tracks the right hand side of a logical or operator. 707 uint64_t RHSCount = setCount(PGO.getRegionCount(E)); 708 CountMap[E->getRHS()] = RHSCount; 709 Visit(E->getRHS()); 710 setCount(ParentCount + RHSCount - CurrentCount); 711 RecordNextStmtCount = true; 712 } 713 }; 714 } // end anonymous namespace 715 716 void PGOHash::combine(HashType Type) { 717 // Check that we never combine 0 and only have six bits. 718 assert(Type && "Hash is invalid: unexpected type 0"); 719 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types"); 720 721 // Pass through MD5 if enough work has built up. 722 if (Count && Count % NumTypesPerWord == 0) { 723 using namespace llvm::support; 724 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working); 725 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped))); 726 Working = 0; 727 } 728 729 // Accumulate the current type. 730 ++Count; 731 Working = Working << NumBitsPerType | Type; 732 } 733 734 uint64_t PGOHash::finalize() { 735 // Use Working as the hash directly if we never used MD5. 736 if (Count <= NumTypesPerWord) 737 // No need to byte swap here, since none of the math was endian-dependent. 738 // This number will be byte-swapped as required on endianness transitions, 739 // so we will see the same value on the other side. 740 return Working; 741 742 // Check for remaining work in Working. 743 if (Working) 744 MD5.update(Working); 745 746 // Finalize the MD5 and return the hash. 747 llvm::MD5::MD5Result Result; 748 MD5.final(Result); 749 using namespace llvm::support; 750 return Result.low(); 751 } 752 753 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) { 754 const Decl *D = GD.getDecl(); 755 if (!D->hasBody()) 756 return; 757 758 bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr(); 759 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 760 if (!InstrumentRegions && !PGOReader) 761 return; 762 if (D->isImplicit()) 763 return; 764 // Constructors and destructors may be represented by several functions in IR. 765 // If so, instrument only base variant, others are implemented by delegation 766 // to the base one, it would be counted twice otherwise. 767 if (CGM.getTarget().getCXXABI().hasConstructorVariants()) { 768 if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base) 769 return; 770 771 if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D)) 772 if (GD.getCtorType() != Ctor_Base && 773 CodeGenFunction::IsConstructorDelegationValid(CCD)) 774 return; 775 } 776 CGM.ClearUnusedCoverageMapping(D); 777 setFuncName(Fn); 778 779 mapRegionCounters(D); 780 if (CGM.getCodeGenOpts().CoverageMapping) 781 emitCounterRegionMapping(D); 782 if (PGOReader) { 783 SourceManager &SM = CGM.getContext().getSourceManager(); 784 loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation())); 785 computeRegionCounts(D); 786 applyFunctionAttributes(PGOReader, Fn); 787 } 788 } 789 790 void CodeGenPGO::mapRegionCounters(const Decl *D) { 791 // Use the latest hash version when inserting instrumentation, but use the 792 // version in the indexed profile if we're reading PGO data. 793 PGOHashVersion HashVersion = PGO_HASH_LATEST; 794 if (auto *PGOReader = CGM.getPGOReader()) 795 HashVersion = getPGOHashVersion(PGOReader, CGM); 796 797 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>); 798 MapRegionCounters Walker(HashVersion, *RegionCounterMap); 799 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 800 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD)); 801 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 802 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD)); 803 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 804 Walker.TraverseDecl(const_cast<BlockDecl *>(BD)); 805 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 806 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD)); 807 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl"); 808 NumRegionCounters = Walker.NextCounter; 809 FunctionHash = Walker.Hash.finalize(); 810 } 811 812 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) { 813 if (!D->getBody()) 814 return true; 815 816 // Don't map the functions in system headers. 817 const auto &SM = CGM.getContext().getSourceManager(); 818 auto Loc = D->getBody()->getLocStart(); 819 return SM.isInSystemHeader(Loc); 820 } 821 822 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) { 823 if (skipRegionMappingForDecl(D)) 824 return; 825 826 std::string CoverageMapping; 827 llvm::raw_string_ostream OS(CoverageMapping); 828 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(), 829 CGM.getContext().getSourceManager(), 830 CGM.getLangOpts(), RegionCounterMap.get()); 831 MappingGen.emitCounterMapping(D, OS); 832 OS.flush(); 833 834 if (CoverageMapping.empty()) 835 return; 836 837 CGM.getCoverageMapping()->addFunctionMappingRecord( 838 FuncNameVar, FuncName, FunctionHash, CoverageMapping); 839 } 840 841 void 842 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name, 843 llvm::GlobalValue::LinkageTypes Linkage) { 844 if (skipRegionMappingForDecl(D)) 845 return; 846 847 std::string CoverageMapping; 848 llvm::raw_string_ostream OS(CoverageMapping); 849 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(), 850 CGM.getContext().getSourceManager(), 851 CGM.getLangOpts()); 852 MappingGen.emitEmptyMapping(D, OS); 853 OS.flush(); 854 855 if (CoverageMapping.empty()) 856 return; 857 858 setFuncName(Name, Linkage); 859 CGM.getCoverageMapping()->addFunctionMappingRecord( 860 FuncNameVar, FuncName, FunctionHash, CoverageMapping, false); 861 } 862 863 void CodeGenPGO::computeRegionCounts(const Decl *D) { 864 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>); 865 ComputeRegionCounts Walker(*StmtCountMap, *this); 866 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 867 Walker.VisitFunctionDecl(FD); 868 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 869 Walker.VisitObjCMethodDecl(MD); 870 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 871 Walker.VisitBlockDecl(BD); 872 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 873 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD)); 874 } 875 876 void 877 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader, 878 llvm::Function *Fn) { 879 if (!haveRegionCounts()) 880 return; 881 882 uint64_t FunctionCount = getRegionCount(nullptr); 883 Fn->setEntryCount(FunctionCount); 884 } 885 886 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S, 887 llvm::Value *StepV) { 888 if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap) 889 return; 890 if (!Builder.GetInsertBlock()) 891 return; 892 893 unsigned Counter = (*RegionCounterMap)[S]; 894 auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext()); 895 896 llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy), 897 Builder.getInt64(FunctionHash), 898 Builder.getInt32(NumRegionCounters), 899 Builder.getInt32(Counter), StepV}; 900 if (!StepV) 901 Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment), 902 makeArrayRef(Args, 4)); 903 else 904 Builder.CreateCall( 905 CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step), 906 makeArrayRef(Args)); 907 } 908 909 // This method either inserts a call to the profile run-time during 910 // instrumentation or puts profile data into metadata for PGO use. 911 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind, 912 llvm::Instruction *ValueSite, llvm::Value *ValuePtr) { 913 914 if (!EnableValueProfiling) 915 return; 916 917 if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock()) 918 return; 919 920 if (isa<llvm::Constant>(ValuePtr)) 921 return; 922 923 bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr(); 924 if (InstrumentValueSites && RegionCounterMap) { 925 auto BuilderInsertPoint = Builder.saveIP(); 926 Builder.SetInsertPoint(ValueSite); 927 llvm::Value *Args[5] = { 928 llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()), 929 Builder.getInt64(FunctionHash), 930 Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()), 931 Builder.getInt32(ValueKind), 932 Builder.getInt32(NumValueSites[ValueKind]++) 933 }; 934 Builder.CreateCall( 935 CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args); 936 Builder.restoreIP(BuilderInsertPoint); 937 return; 938 } 939 940 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 941 if (PGOReader && haveRegionCounts()) { 942 // We record the top most called three functions at each call site. 943 // Profile metadata contains "VP" string identifying this metadata 944 // as value profiling data, then a uint32_t value for the value profiling 945 // kind, a uint64_t value for the total number of times the call is 946 // executed, followed by the function hash and execution count (uint64_t) 947 // pairs for each function. 948 if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind)) 949 return; 950 951 llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord, 952 (llvm::InstrProfValueKind)ValueKind, 953 NumValueSites[ValueKind]); 954 955 NumValueSites[ValueKind]++; 956 } 957 } 958 959 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader, 960 bool IsInMainFile) { 961 CGM.getPGOStats().addVisited(IsInMainFile); 962 RegionCounts.clear(); 963 llvm::Expected<llvm::InstrProfRecord> RecordExpected = 964 PGOReader->getInstrProfRecord(FuncName, FunctionHash); 965 if (auto E = RecordExpected.takeError()) { 966 auto IPE = llvm::InstrProfError::take(std::move(E)); 967 if (IPE == llvm::instrprof_error::unknown_function) 968 CGM.getPGOStats().addMissing(IsInMainFile); 969 else if (IPE == llvm::instrprof_error::hash_mismatch) 970 CGM.getPGOStats().addMismatched(IsInMainFile); 971 else if (IPE == llvm::instrprof_error::malformed) 972 // TODO: Consider a more specific warning for this case. 973 CGM.getPGOStats().addMismatched(IsInMainFile); 974 return; 975 } 976 ProfRecord = 977 llvm::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get())); 978 RegionCounts = ProfRecord->Counts; 979 } 980 981 /// \brief Calculate what to divide by to scale weights. 982 /// 983 /// Given the maximum weight, calculate a divisor that will scale all the 984 /// weights to strictly less than UINT32_MAX. 985 static uint64_t calculateWeightScale(uint64_t MaxWeight) { 986 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1; 987 } 988 989 /// \brief Scale an individual branch weight (and add 1). 990 /// 991 /// Scale a 64-bit weight down to 32-bits using \c Scale. 992 /// 993 /// According to Laplace's Rule of Succession, it is better to compute the 994 /// weight based on the count plus 1, so universally add 1 to the value. 995 /// 996 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no 997 /// greater than \c Weight. 998 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) { 999 assert(Scale && "scale by 0?"); 1000 uint64_t Scaled = Weight / Scale + 1; 1001 assert(Scaled <= UINT32_MAX && "overflow 32-bits"); 1002 return Scaled; 1003 } 1004 1005 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount, 1006 uint64_t FalseCount) { 1007 // Check for empty weights. 1008 if (!TrueCount && !FalseCount) 1009 return nullptr; 1010 1011 // Calculate how to scale down to 32-bits. 1012 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount)); 1013 1014 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 1015 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale), 1016 scaleBranchWeight(FalseCount, Scale)); 1017 } 1018 1019 llvm::MDNode * 1020 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) { 1021 // We need at least two elements to create meaningful weights. 1022 if (Weights.size() < 2) 1023 return nullptr; 1024 1025 // Check for empty weights. 1026 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end()); 1027 if (MaxWeight == 0) 1028 return nullptr; 1029 1030 // Calculate how to scale down to 32-bits. 1031 uint64_t Scale = calculateWeightScale(MaxWeight); 1032 1033 SmallVector<uint32_t, 16> ScaledWeights; 1034 ScaledWeights.reserve(Weights.size()); 1035 for (uint64_t W : Weights) 1036 ScaledWeights.push_back(scaleBranchWeight(W, Scale)); 1037 1038 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 1039 return MDHelper.createBranchWeights(ScaledWeights); 1040 } 1041 1042 llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond, 1043 uint64_t LoopCount) { 1044 if (!PGO.haveRegionCounts()) 1045 return nullptr; 1046 Optional<uint64_t> CondCount = PGO.getStmtCount(Cond); 1047 assert(CondCount.hasValue() && "missing expected loop condition count"); 1048 if (*CondCount == 0) 1049 return nullptr; 1050 return createProfileWeights(LoopCount, 1051 std::max(*CondCount, LoopCount) - LoopCount); 1052 } 1053