1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // Instrumentation-based profile-guided optimization 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "CodeGenPGO.h" 15 #include "CodeGenFunction.h" 16 #include "CoverageMappingGen.h" 17 #include "clang/AST/RecursiveASTVisitor.h" 18 #include "clang/AST/StmtVisitor.h" 19 #include "llvm/IR/Intrinsics.h" 20 #include "llvm/IR/MDBuilder.h" 21 #include "llvm/Support/Endian.h" 22 #include "llvm/Support/FileSystem.h" 23 #include "llvm/Support/MD5.h" 24 25 static llvm::cl::opt<bool> EnableValueProfiling( 26 "enable-value-profiling", llvm::cl::ZeroOrMore, 27 llvm::cl::desc("Enable value profiling"), llvm::cl::init(false)); 28 29 using namespace clang; 30 using namespace CodeGen; 31 32 void CodeGenPGO::setFuncName(StringRef Name, 33 llvm::GlobalValue::LinkageTypes Linkage) { 34 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 35 FuncName = llvm::getPGOFuncName( 36 Name, Linkage, CGM.getCodeGenOpts().MainFileName, 37 PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version); 38 39 // If we're generating a profile, create a variable for the name. 40 if (CGM.getCodeGenOpts().hasProfileClangInstr()) 41 FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName); 42 } 43 44 void CodeGenPGO::setFuncName(llvm::Function *Fn) { 45 setFuncName(Fn->getName(), Fn->getLinkage()); 46 // Create PGOFuncName meta data. 47 llvm::createPGOFuncNameMetadata(*Fn, FuncName); 48 } 49 50 /// The version of the PGO hash algorithm. 51 enum PGOHashVersion : unsigned { 52 PGO_HASH_V1, 53 PGO_HASH_V2, 54 55 // Keep this set to the latest hash version. 56 PGO_HASH_LATEST = PGO_HASH_V2 57 }; 58 59 namespace { 60 /// \brief Stable hasher for PGO region counters. 61 /// 62 /// PGOHash produces a stable hash of a given function's control flow. 63 /// 64 /// Changing the output of this hash will invalidate all previously generated 65 /// profiles -- i.e., don't do it. 66 /// 67 /// \note When this hash does eventually change (years?), we still need to 68 /// support old hashes. We'll need to pull in the version number from the 69 /// profile data format and use the matching hash function. 70 class PGOHash { 71 uint64_t Working; 72 unsigned Count; 73 PGOHashVersion HashVersion; 74 llvm::MD5 MD5; 75 76 static const int NumBitsPerType = 6; 77 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType; 78 static const unsigned TooBig = 1u << NumBitsPerType; 79 80 public: 81 /// \brief Hash values for AST nodes. 82 /// 83 /// Distinct values for AST nodes that have region counters attached. 84 /// 85 /// These values must be stable. All new members must be added at the end, 86 /// and no members should be removed. Changing the enumeration value for an 87 /// AST node will affect the hash of every function that contains that node. 88 enum HashType : unsigned char { 89 None = 0, 90 LabelStmt = 1, 91 WhileStmt, 92 DoStmt, 93 ForStmt, 94 CXXForRangeStmt, 95 ObjCForCollectionStmt, 96 SwitchStmt, 97 CaseStmt, 98 DefaultStmt, 99 IfStmt, 100 CXXTryStmt, 101 CXXCatchStmt, 102 ConditionalOperator, 103 BinaryOperatorLAnd, 104 BinaryOperatorLOr, 105 BinaryConditionalOperator, 106 // The preceding values are available with PGO_HASH_V1. 107 108 EndOfScope, 109 IfThenBranch, 110 IfElseBranch, 111 GotoStmt, 112 IndirectGotoStmt, 113 BreakStmt, 114 ContinueStmt, 115 ReturnStmt, 116 ThrowExpr, 117 UnaryOperatorLNot, 118 BinaryOperatorLT, 119 BinaryOperatorGT, 120 BinaryOperatorLE, 121 BinaryOperatorGE, 122 BinaryOperatorEQ, 123 BinaryOperatorNE, 124 // The preceding values are available with PGO_HASH_V2. 125 126 // Keep this last. It's for the static assert that follows. 127 LastHashType 128 }; 129 static_assert(LastHashType <= TooBig, "Too many types in HashType"); 130 131 PGOHash(PGOHashVersion HashVersion) 132 : Working(0), Count(0), HashVersion(HashVersion), MD5() {} 133 void combine(HashType Type); 134 uint64_t finalize(); 135 PGOHashVersion getHashVersion() const { return HashVersion; } 136 }; 137 const int PGOHash::NumBitsPerType; 138 const unsigned PGOHash::NumTypesPerWord; 139 const unsigned PGOHash::TooBig; 140 141 /// Get the PGO hash version used in the given indexed profile. 142 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader, 143 CodeGenModule &CGM) { 144 if (PGOReader->getVersion() <= 4) 145 return PGO_HASH_V1; 146 return PGO_HASH_V2; 147 } 148 149 /// A RecursiveASTVisitor that fills a map of statements to PGO counters. 150 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> { 151 using Base = RecursiveASTVisitor<MapRegionCounters>; 152 153 /// The next counter value to assign. 154 unsigned NextCounter; 155 /// The function hash. 156 PGOHash Hash; 157 /// The map of statements to counters. 158 llvm::DenseMap<const Stmt *, unsigned> &CounterMap; 159 160 MapRegionCounters(PGOHashVersion HashVersion, 161 llvm::DenseMap<const Stmt *, unsigned> &CounterMap) 162 : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap) {} 163 164 // Blocks and lambdas are handled as separate functions, so we need not 165 // traverse them in the parent context. 166 bool TraverseBlockExpr(BlockExpr *BE) { return true; } 167 bool TraverseLambdaBody(LambdaExpr *LE) { return true; } 168 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; } 169 170 bool VisitDecl(const Decl *D) { 171 switch (D->getKind()) { 172 default: 173 break; 174 case Decl::Function: 175 case Decl::CXXMethod: 176 case Decl::CXXConstructor: 177 case Decl::CXXDestructor: 178 case Decl::CXXConversion: 179 case Decl::ObjCMethod: 180 case Decl::Block: 181 case Decl::Captured: 182 CounterMap[D->getBody()] = NextCounter++; 183 break; 184 } 185 return true; 186 } 187 188 /// If \p S gets a fresh counter, update the counter mappings. Return the 189 /// V1 hash of \p S. 190 PGOHash::HashType updateCounterMappings(Stmt *S) { 191 auto Type = getHashType(PGO_HASH_V1, S); 192 if (Type != PGOHash::None) 193 CounterMap[S] = NextCounter++; 194 return Type; 195 } 196 197 /// Include \p S in the function hash. 198 bool VisitStmt(Stmt *S) { 199 auto Type = updateCounterMappings(S); 200 if (Hash.getHashVersion() != PGO_HASH_V1) 201 Type = getHashType(Hash.getHashVersion(), S); 202 if (Type != PGOHash::None) 203 Hash.combine(Type); 204 return true; 205 } 206 207 bool TraverseIfStmt(IfStmt *If) { 208 // If we used the V1 hash, use the default traversal. 209 if (Hash.getHashVersion() == PGO_HASH_V1) 210 return Base::TraverseIfStmt(If); 211 212 // Otherwise, keep track of which branch we're in while traversing. 213 VisitStmt(If); 214 for (Stmt *CS : If->children()) { 215 if (!CS) 216 continue; 217 if (CS == If->getThen()) 218 Hash.combine(PGOHash::IfThenBranch); 219 else if (CS == If->getElse()) 220 Hash.combine(PGOHash::IfElseBranch); 221 TraverseStmt(CS); 222 } 223 Hash.combine(PGOHash::EndOfScope); 224 return true; 225 } 226 227 // If the statement type \p N is nestable, and its nesting impacts profile 228 // stability, define a custom traversal which tracks the end of the statement 229 // in the hash (provided we're not using the V1 hash). 230 #define DEFINE_NESTABLE_TRAVERSAL(N) \ 231 bool Traverse##N(N *S) { \ 232 Base::Traverse##N(S); \ 233 if (Hash.getHashVersion() != PGO_HASH_V1) \ 234 Hash.combine(PGOHash::EndOfScope); \ 235 return true; \ 236 } 237 238 DEFINE_NESTABLE_TRAVERSAL(WhileStmt) 239 DEFINE_NESTABLE_TRAVERSAL(DoStmt) 240 DEFINE_NESTABLE_TRAVERSAL(ForStmt) 241 DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt) 242 DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt) 243 DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt) 244 DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt) 245 246 /// Get version \p HashVersion of the PGO hash for \p S. 247 PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) { 248 switch (S->getStmtClass()) { 249 default: 250 break; 251 case Stmt::LabelStmtClass: 252 return PGOHash::LabelStmt; 253 case Stmt::WhileStmtClass: 254 return PGOHash::WhileStmt; 255 case Stmt::DoStmtClass: 256 return PGOHash::DoStmt; 257 case Stmt::ForStmtClass: 258 return PGOHash::ForStmt; 259 case Stmt::CXXForRangeStmtClass: 260 return PGOHash::CXXForRangeStmt; 261 case Stmt::ObjCForCollectionStmtClass: 262 return PGOHash::ObjCForCollectionStmt; 263 case Stmt::SwitchStmtClass: 264 return PGOHash::SwitchStmt; 265 case Stmt::CaseStmtClass: 266 return PGOHash::CaseStmt; 267 case Stmt::DefaultStmtClass: 268 return PGOHash::DefaultStmt; 269 case Stmt::IfStmtClass: 270 return PGOHash::IfStmt; 271 case Stmt::CXXTryStmtClass: 272 return PGOHash::CXXTryStmt; 273 case Stmt::CXXCatchStmtClass: 274 return PGOHash::CXXCatchStmt; 275 case Stmt::ConditionalOperatorClass: 276 return PGOHash::ConditionalOperator; 277 case Stmt::BinaryConditionalOperatorClass: 278 return PGOHash::BinaryConditionalOperator; 279 case Stmt::BinaryOperatorClass: { 280 const BinaryOperator *BO = cast<BinaryOperator>(S); 281 if (BO->getOpcode() == BO_LAnd) 282 return PGOHash::BinaryOperatorLAnd; 283 if (BO->getOpcode() == BO_LOr) 284 return PGOHash::BinaryOperatorLOr; 285 if (HashVersion == PGO_HASH_V2) { 286 switch (BO->getOpcode()) { 287 default: 288 break; 289 case BO_LT: 290 return PGOHash::BinaryOperatorLT; 291 case BO_GT: 292 return PGOHash::BinaryOperatorGT; 293 case BO_LE: 294 return PGOHash::BinaryOperatorLE; 295 case BO_GE: 296 return PGOHash::BinaryOperatorGE; 297 case BO_EQ: 298 return PGOHash::BinaryOperatorEQ; 299 case BO_NE: 300 return PGOHash::BinaryOperatorNE; 301 } 302 } 303 break; 304 } 305 } 306 307 if (HashVersion == PGO_HASH_V2) { 308 switch (S->getStmtClass()) { 309 default: 310 break; 311 case Stmt::GotoStmtClass: 312 return PGOHash::GotoStmt; 313 case Stmt::IndirectGotoStmtClass: 314 return PGOHash::IndirectGotoStmt; 315 case Stmt::BreakStmtClass: 316 return PGOHash::BreakStmt; 317 case Stmt::ContinueStmtClass: 318 return PGOHash::ContinueStmt; 319 case Stmt::ReturnStmtClass: 320 return PGOHash::ReturnStmt; 321 case Stmt::CXXThrowExprClass: 322 return PGOHash::ThrowExpr; 323 case Stmt::UnaryOperatorClass: { 324 const UnaryOperator *UO = cast<UnaryOperator>(S); 325 if (UO->getOpcode() == UO_LNot) 326 return PGOHash::UnaryOperatorLNot; 327 break; 328 } 329 } 330 } 331 332 return PGOHash::None; 333 } 334 }; 335 336 /// A StmtVisitor that propagates the raw counts through the AST and 337 /// records the count at statements where the value may change. 338 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { 339 /// PGO state. 340 CodeGenPGO &PGO; 341 342 /// A flag that is set when the current count should be recorded on the 343 /// next statement, such as at the exit of a loop. 344 bool RecordNextStmtCount; 345 346 /// The count at the current location in the traversal. 347 uint64_t CurrentCount; 348 349 /// The map of statements to count values. 350 llvm::DenseMap<const Stmt *, uint64_t> &CountMap; 351 352 /// BreakContinueStack - Keep counts of breaks and continues inside loops. 353 struct BreakContinue { 354 uint64_t BreakCount; 355 uint64_t ContinueCount; 356 BreakContinue() : BreakCount(0), ContinueCount(0) {} 357 }; 358 SmallVector<BreakContinue, 8> BreakContinueStack; 359 360 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap, 361 CodeGenPGO &PGO) 362 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {} 363 364 void RecordStmtCount(const Stmt *S) { 365 if (RecordNextStmtCount) { 366 CountMap[S] = CurrentCount; 367 RecordNextStmtCount = false; 368 } 369 } 370 371 /// Set and return the current count. 372 uint64_t setCount(uint64_t Count) { 373 CurrentCount = Count; 374 return Count; 375 } 376 377 void VisitStmt(const Stmt *S) { 378 RecordStmtCount(S); 379 for (const Stmt *Child : S->children()) 380 if (Child) 381 this->Visit(Child); 382 } 383 384 void VisitFunctionDecl(const FunctionDecl *D) { 385 // Counter tracks entry to the function body. 386 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 387 CountMap[D->getBody()] = BodyCount; 388 Visit(D->getBody()); 389 } 390 391 // Skip lambda expressions. We visit these as FunctionDecls when we're 392 // generating them and aren't interested in the body when generating a 393 // parent context. 394 void VisitLambdaExpr(const LambdaExpr *LE) {} 395 396 void VisitCapturedDecl(const CapturedDecl *D) { 397 // Counter tracks entry to the capture body. 398 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 399 CountMap[D->getBody()] = BodyCount; 400 Visit(D->getBody()); 401 } 402 403 void VisitObjCMethodDecl(const ObjCMethodDecl *D) { 404 // Counter tracks entry to the method body. 405 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 406 CountMap[D->getBody()] = BodyCount; 407 Visit(D->getBody()); 408 } 409 410 void VisitBlockDecl(const BlockDecl *D) { 411 // Counter tracks entry to the block body. 412 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 413 CountMap[D->getBody()] = BodyCount; 414 Visit(D->getBody()); 415 } 416 417 void VisitReturnStmt(const ReturnStmt *S) { 418 RecordStmtCount(S); 419 if (S->getRetValue()) 420 Visit(S->getRetValue()); 421 CurrentCount = 0; 422 RecordNextStmtCount = true; 423 } 424 425 void VisitCXXThrowExpr(const CXXThrowExpr *E) { 426 RecordStmtCount(E); 427 if (E->getSubExpr()) 428 Visit(E->getSubExpr()); 429 CurrentCount = 0; 430 RecordNextStmtCount = true; 431 } 432 433 void VisitGotoStmt(const GotoStmt *S) { 434 RecordStmtCount(S); 435 CurrentCount = 0; 436 RecordNextStmtCount = true; 437 } 438 439 void VisitLabelStmt(const LabelStmt *S) { 440 RecordNextStmtCount = false; 441 // Counter tracks the block following the label. 442 uint64_t BlockCount = setCount(PGO.getRegionCount(S)); 443 CountMap[S] = BlockCount; 444 Visit(S->getSubStmt()); 445 } 446 447 void VisitBreakStmt(const BreakStmt *S) { 448 RecordStmtCount(S); 449 assert(!BreakContinueStack.empty() && "break not in a loop or switch!"); 450 BreakContinueStack.back().BreakCount += CurrentCount; 451 CurrentCount = 0; 452 RecordNextStmtCount = true; 453 } 454 455 void VisitContinueStmt(const ContinueStmt *S) { 456 RecordStmtCount(S); 457 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!"); 458 BreakContinueStack.back().ContinueCount += CurrentCount; 459 CurrentCount = 0; 460 RecordNextStmtCount = true; 461 } 462 463 void VisitWhileStmt(const WhileStmt *S) { 464 RecordStmtCount(S); 465 uint64_t ParentCount = CurrentCount; 466 467 BreakContinueStack.push_back(BreakContinue()); 468 // Visit the body region first so the break/continue adjustments can be 469 // included when visiting the condition. 470 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 471 CountMap[S->getBody()] = CurrentCount; 472 Visit(S->getBody()); 473 uint64_t BackedgeCount = CurrentCount; 474 475 // ...then go back and propagate counts through the condition. The count 476 // at the start of the condition is the sum of the incoming edges, 477 // the backedge from the end of the loop body, and the edges from 478 // continue statements. 479 BreakContinue BC = BreakContinueStack.pop_back_val(); 480 uint64_t CondCount = 481 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 482 CountMap[S->getCond()] = CondCount; 483 Visit(S->getCond()); 484 setCount(BC.BreakCount + CondCount - BodyCount); 485 RecordNextStmtCount = true; 486 } 487 488 void VisitDoStmt(const DoStmt *S) { 489 RecordStmtCount(S); 490 uint64_t LoopCount = PGO.getRegionCount(S); 491 492 BreakContinueStack.push_back(BreakContinue()); 493 // The count doesn't include the fallthrough from the parent scope. Add it. 494 uint64_t BodyCount = setCount(LoopCount + CurrentCount); 495 CountMap[S->getBody()] = BodyCount; 496 Visit(S->getBody()); 497 uint64_t BackedgeCount = CurrentCount; 498 499 BreakContinue BC = BreakContinueStack.pop_back_val(); 500 // The count at the start of the condition is equal to the count at the 501 // end of the body, plus any continues. 502 uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount); 503 CountMap[S->getCond()] = CondCount; 504 Visit(S->getCond()); 505 setCount(BC.BreakCount + CondCount - LoopCount); 506 RecordNextStmtCount = true; 507 } 508 509 void VisitForStmt(const ForStmt *S) { 510 RecordStmtCount(S); 511 if (S->getInit()) 512 Visit(S->getInit()); 513 514 uint64_t ParentCount = CurrentCount; 515 516 BreakContinueStack.push_back(BreakContinue()); 517 // Visit the body region first. (This is basically the same as a while 518 // loop; see further comments in VisitWhileStmt.) 519 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 520 CountMap[S->getBody()] = BodyCount; 521 Visit(S->getBody()); 522 uint64_t BackedgeCount = CurrentCount; 523 BreakContinue BC = BreakContinueStack.pop_back_val(); 524 525 // The increment is essentially part of the body but it needs to include 526 // the count for all the continue statements. 527 if (S->getInc()) { 528 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); 529 CountMap[S->getInc()] = IncCount; 530 Visit(S->getInc()); 531 } 532 533 // ...then go back and propagate counts through the condition. 534 uint64_t CondCount = 535 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 536 if (S->getCond()) { 537 CountMap[S->getCond()] = CondCount; 538 Visit(S->getCond()); 539 } 540 setCount(BC.BreakCount + CondCount - BodyCount); 541 RecordNextStmtCount = true; 542 } 543 544 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) { 545 RecordStmtCount(S); 546 Visit(S->getLoopVarStmt()); 547 Visit(S->getRangeStmt()); 548 Visit(S->getBeginStmt()); 549 Visit(S->getEndStmt()); 550 551 uint64_t ParentCount = CurrentCount; 552 BreakContinueStack.push_back(BreakContinue()); 553 // Visit the body region first. (This is basically the same as a while 554 // loop; see further comments in VisitWhileStmt.) 555 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 556 CountMap[S->getBody()] = BodyCount; 557 Visit(S->getBody()); 558 uint64_t BackedgeCount = CurrentCount; 559 BreakContinue BC = BreakContinueStack.pop_back_val(); 560 561 // The increment is essentially part of the body but it needs to include 562 // the count for all the continue statements. 563 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); 564 CountMap[S->getInc()] = IncCount; 565 Visit(S->getInc()); 566 567 // ...then go back and propagate counts through the condition. 568 uint64_t CondCount = 569 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 570 CountMap[S->getCond()] = CondCount; 571 Visit(S->getCond()); 572 setCount(BC.BreakCount + CondCount - BodyCount); 573 RecordNextStmtCount = true; 574 } 575 576 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) { 577 RecordStmtCount(S); 578 Visit(S->getElement()); 579 uint64_t ParentCount = CurrentCount; 580 BreakContinueStack.push_back(BreakContinue()); 581 // Counter tracks the body of the loop. 582 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 583 CountMap[S->getBody()] = BodyCount; 584 Visit(S->getBody()); 585 uint64_t BackedgeCount = CurrentCount; 586 BreakContinue BC = BreakContinueStack.pop_back_val(); 587 588 setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount - 589 BodyCount); 590 RecordNextStmtCount = true; 591 } 592 593 void VisitSwitchStmt(const SwitchStmt *S) { 594 RecordStmtCount(S); 595 if (S->getInit()) 596 Visit(S->getInit()); 597 Visit(S->getCond()); 598 CurrentCount = 0; 599 BreakContinueStack.push_back(BreakContinue()); 600 Visit(S->getBody()); 601 // If the switch is inside a loop, add the continue counts. 602 BreakContinue BC = BreakContinueStack.pop_back_val(); 603 if (!BreakContinueStack.empty()) 604 BreakContinueStack.back().ContinueCount += BC.ContinueCount; 605 // Counter tracks the exit block of the switch. 606 setCount(PGO.getRegionCount(S)); 607 RecordNextStmtCount = true; 608 } 609 610 void VisitSwitchCase(const SwitchCase *S) { 611 RecordNextStmtCount = false; 612 // Counter for this particular case. This counts only jumps from the 613 // switch header and does not include fallthrough from the case before 614 // this one. 615 uint64_t CaseCount = PGO.getRegionCount(S); 616 setCount(CurrentCount + CaseCount); 617 // We need the count without fallthrough in the mapping, so it's more useful 618 // for branch probabilities. 619 CountMap[S] = CaseCount; 620 RecordNextStmtCount = true; 621 Visit(S->getSubStmt()); 622 } 623 624 void VisitIfStmt(const IfStmt *S) { 625 RecordStmtCount(S); 626 uint64_t ParentCount = CurrentCount; 627 if (S->getInit()) 628 Visit(S->getInit()); 629 Visit(S->getCond()); 630 631 // Counter tracks the "then" part of an if statement. The count for 632 // the "else" part, if it exists, will be calculated from this counter. 633 uint64_t ThenCount = setCount(PGO.getRegionCount(S)); 634 CountMap[S->getThen()] = ThenCount; 635 Visit(S->getThen()); 636 uint64_t OutCount = CurrentCount; 637 638 uint64_t ElseCount = ParentCount - ThenCount; 639 if (S->getElse()) { 640 setCount(ElseCount); 641 CountMap[S->getElse()] = ElseCount; 642 Visit(S->getElse()); 643 OutCount += CurrentCount; 644 } else 645 OutCount += ElseCount; 646 setCount(OutCount); 647 RecordNextStmtCount = true; 648 } 649 650 void VisitCXXTryStmt(const CXXTryStmt *S) { 651 RecordStmtCount(S); 652 Visit(S->getTryBlock()); 653 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I) 654 Visit(S->getHandler(I)); 655 // Counter tracks the continuation block of the try statement. 656 setCount(PGO.getRegionCount(S)); 657 RecordNextStmtCount = true; 658 } 659 660 void VisitCXXCatchStmt(const CXXCatchStmt *S) { 661 RecordNextStmtCount = false; 662 // Counter tracks the catch statement's handler block. 663 uint64_t CatchCount = setCount(PGO.getRegionCount(S)); 664 CountMap[S] = CatchCount; 665 Visit(S->getHandlerBlock()); 666 } 667 668 void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) { 669 RecordStmtCount(E); 670 uint64_t ParentCount = CurrentCount; 671 Visit(E->getCond()); 672 673 // Counter tracks the "true" part of a conditional operator. The 674 // count in the "false" part will be calculated from this counter. 675 uint64_t TrueCount = setCount(PGO.getRegionCount(E)); 676 CountMap[E->getTrueExpr()] = TrueCount; 677 Visit(E->getTrueExpr()); 678 uint64_t OutCount = CurrentCount; 679 680 uint64_t FalseCount = setCount(ParentCount - TrueCount); 681 CountMap[E->getFalseExpr()] = FalseCount; 682 Visit(E->getFalseExpr()); 683 OutCount += CurrentCount; 684 685 setCount(OutCount); 686 RecordNextStmtCount = true; 687 } 688 689 void VisitBinLAnd(const BinaryOperator *E) { 690 RecordStmtCount(E); 691 uint64_t ParentCount = CurrentCount; 692 Visit(E->getLHS()); 693 // Counter tracks the right hand side of a logical and operator. 694 uint64_t RHSCount = setCount(PGO.getRegionCount(E)); 695 CountMap[E->getRHS()] = RHSCount; 696 Visit(E->getRHS()); 697 setCount(ParentCount + RHSCount - CurrentCount); 698 RecordNextStmtCount = true; 699 } 700 701 void VisitBinLOr(const BinaryOperator *E) { 702 RecordStmtCount(E); 703 uint64_t ParentCount = CurrentCount; 704 Visit(E->getLHS()); 705 // Counter tracks the right hand side of a logical or operator. 706 uint64_t RHSCount = setCount(PGO.getRegionCount(E)); 707 CountMap[E->getRHS()] = RHSCount; 708 Visit(E->getRHS()); 709 setCount(ParentCount + RHSCount - CurrentCount); 710 RecordNextStmtCount = true; 711 } 712 }; 713 } // end anonymous namespace 714 715 void PGOHash::combine(HashType Type) { 716 // Check that we never combine 0 and only have six bits. 717 assert(Type && "Hash is invalid: unexpected type 0"); 718 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types"); 719 720 // Pass through MD5 if enough work has built up. 721 if (Count && Count % NumTypesPerWord == 0) { 722 using namespace llvm::support; 723 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working); 724 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped))); 725 Working = 0; 726 } 727 728 // Accumulate the current type. 729 ++Count; 730 Working = Working << NumBitsPerType | Type; 731 } 732 733 uint64_t PGOHash::finalize() { 734 // Use Working as the hash directly if we never used MD5. 735 if (Count <= NumTypesPerWord) 736 // No need to byte swap here, since none of the math was endian-dependent. 737 // This number will be byte-swapped as required on endianness transitions, 738 // so we will see the same value on the other side. 739 return Working; 740 741 // Check for remaining work in Working. 742 if (Working) 743 MD5.update(Working); 744 745 // Finalize the MD5 and return the hash. 746 llvm::MD5::MD5Result Result; 747 MD5.final(Result); 748 using namespace llvm::support; 749 return Result.low(); 750 } 751 752 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) { 753 const Decl *D = GD.getDecl(); 754 if (!D->hasBody()) 755 return; 756 757 bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr(); 758 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 759 if (!InstrumentRegions && !PGOReader) 760 return; 761 if (D->isImplicit()) 762 return; 763 // Constructors and destructors may be represented by several functions in IR. 764 // If so, instrument only base variant, others are implemented by delegation 765 // to the base one, it would be counted twice otherwise. 766 if (CGM.getTarget().getCXXABI().hasConstructorVariants()) { 767 if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base) 768 return; 769 770 if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D)) 771 if (GD.getCtorType() != Ctor_Base && 772 CodeGenFunction::IsConstructorDelegationValid(CCD)) 773 return; 774 } 775 CGM.ClearUnusedCoverageMapping(D); 776 setFuncName(Fn); 777 778 mapRegionCounters(D); 779 if (CGM.getCodeGenOpts().CoverageMapping) 780 emitCounterRegionMapping(D); 781 if (PGOReader) { 782 SourceManager &SM = CGM.getContext().getSourceManager(); 783 loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation())); 784 computeRegionCounts(D); 785 applyFunctionAttributes(PGOReader, Fn); 786 } 787 } 788 789 void CodeGenPGO::mapRegionCounters(const Decl *D) { 790 // Use the latest hash version when inserting instrumentation, but use the 791 // version in the indexed profile if we're reading PGO data. 792 PGOHashVersion HashVersion = PGO_HASH_LATEST; 793 if (auto *PGOReader = CGM.getPGOReader()) 794 HashVersion = getPGOHashVersion(PGOReader, CGM); 795 796 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>); 797 MapRegionCounters Walker(HashVersion, *RegionCounterMap); 798 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 799 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD)); 800 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 801 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD)); 802 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 803 Walker.TraverseDecl(const_cast<BlockDecl *>(BD)); 804 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 805 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD)); 806 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl"); 807 NumRegionCounters = Walker.NextCounter; 808 FunctionHash = Walker.Hash.finalize(); 809 } 810 811 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) { 812 if (!D->getBody()) 813 return true; 814 815 // Don't map the functions in system headers. 816 const auto &SM = CGM.getContext().getSourceManager(); 817 auto Loc = D->getBody()->getLocStart(); 818 return SM.isInSystemHeader(Loc); 819 } 820 821 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) { 822 if (skipRegionMappingForDecl(D)) 823 return; 824 825 std::string CoverageMapping; 826 llvm::raw_string_ostream OS(CoverageMapping); 827 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(), 828 CGM.getContext().getSourceManager(), 829 CGM.getLangOpts(), RegionCounterMap.get()); 830 MappingGen.emitCounterMapping(D, OS); 831 OS.flush(); 832 833 if (CoverageMapping.empty()) 834 return; 835 836 CGM.getCoverageMapping()->addFunctionMappingRecord( 837 FuncNameVar, FuncName, FunctionHash, CoverageMapping); 838 } 839 840 void 841 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name, 842 llvm::GlobalValue::LinkageTypes Linkage) { 843 if (skipRegionMappingForDecl(D)) 844 return; 845 846 std::string CoverageMapping; 847 llvm::raw_string_ostream OS(CoverageMapping); 848 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(), 849 CGM.getContext().getSourceManager(), 850 CGM.getLangOpts()); 851 MappingGen.emitEmptyMapping(D, OS); 852 OS.flush(); 853 854 if (CoverageMapping.empty()) 855 return; 856 857 setFuncName(Name, Linkage); 858 CGM.getCoverageMapping()->addFunctionMappingRecord( 859 FuncNameVar, FuncName, FunctionHash, CoverageMapping, false); 860 } 861 862 void CodeGenPGO::computeRegionCounts(const Decl *D) { 863 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>); 864 ComputeRegionCounts Walker(*StmtCountMap, *this); 865 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 866 Walker.VisitFunctionDecl(FD); 867 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 868 Walker.VisitObjCMethodDecl(MD); 869 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 870 Walker.VisitBlockDecl(BD); 871 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 872 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD)); 873 } 874 875 void 876 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader, 877 llvm::Function *Fn) { 878 if (!haveRegionCounts()) 879 return; 880 881 uint64_t FunctionCount = getRegionCount(nullptr); 882 Fn->setEntryCount(FunctionCount); 883 } 884 885 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S, 886 llvm::Value *StepV) { 887 if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap) 888 return; 889 if (!Builder.GetInsertBlock()) 890 return; 891 892 unsigned Counter = (*RegionCounterMap)[S]; 893 auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext()); 894 895 llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy), 896 Builder.getInt64(FunctionHash), 897 Builder.getInt32(NumRegionCounters), 898 Builder.getInt32(Counter), StepV}; 899 if (!StepV) 900 Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment), 901 makeArrayRef(Args, 4)); 902 else 903 Builder.CreateCall( 904 CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step), 905 makeArrayRef(Args)); 906 } 907 908 // This method either inserts a call to the profile run-time during 909 // instrumentation or puts profile data into metadata for PGO use. 910 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind, 911 llvm::Instruction *ValueSite, llvm::Value *ValuePtr) { 912 913 if (!EnableValueProfiling) 914 return; 915 916 if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock()) 917 return; 918 919 if (isa<llvm::Constant>(ValuePtr)) 920 return; 921 922 bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr(); 923 if (InstrumentValueSites && RegionCounterMap) { 924 auto BuilderInsertPoint = Builder.saveIP(); 925 Builder.SetInsertPoint(ValueSite); 926 llvm::Value *Args[5] = { 927 llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()), 928 Builder.getInt64(FunctionHash), 929 Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()), 930 Builder.getInt32(ValueKind), 931 Builder.getInt32(NumValueSites[ValueKind]++) 932 }; 933 Builder.CreateCall( 934 CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args); 935 Builder.restoreIP(BuilderInsertPoint); 936 return; 937 } 938 939 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 940 if (PGOReader && haveRegionCounts()) { 941 // We record the top most called three functions at each call site. 942 // Profile metadata contains "VP" string identifying this metadata 943 // as value profiling data, then a uint32_t value for the value profiling 944 // kind, a uint64_t value for the total number of times the call is 945 // executed, followed by the function hash and execution count (uint64_t) 946 // pairs for each function. 947 if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind)) 948 return; 949 950 llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord, 951 (llvm::InstrProfValueKind)ValueKind, 952 NumValueSites[ValueKind]); 953 954 NumValueSites[ValueKind]++; 955 } 956 } 957 958 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader, 959 bool IsInMainFile) { 960 CGM.getPGOStats().addVisited(IsInMainFile); 961 RegionCounts.clear(); 962 llvm::Expected<llvm::InstrProfRecord> RecordExpected = 963 PGOReader->getInstrProfRecord(FuncName, FunctionHash); 964 if (auto E = RecordExpected.takeError()) { 965 auto IPE = llvm::InstrProfError::take(std::move(E)); 966 if (IPE == llvm::instrprof_error::unknown_function) 967 CGM.getPGOStats().addMissing(IsInMainFile); 968 else if (IPE == llvm::instrprof_error::hash_mismatch) 969 CGM.getPGOStats().addMismatched(IsInMainFile); 970 else if (IPE == llvm::instrprof_error::malformed) 971 // TODO: Consider a more specific warning for this case. 972 CGM.getPGOStats().addMismatched(IsInMainFile); 973 return; 974 } 975 ProfRecord = 976 llvm::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get())); 977 RegionCounts = ProfRecord->Counts; 978 } 979 980 /// \brief Calculate what to divide by to scale weights. 981 /// 982 /// Given the maximum weight, calculate a divisor that will scale all the 983 /// weights to strictly less than UINT32_MAX. 984 static uint64_t calculateWeightScale(uint64_t MaxWeight) { 985 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1; 986 } 987 988 /// \brief Scale an individual branch weight (and add 1). 989 /// 990 /// Scale a 64-bit weight down to 32-bits using \c Scale. 991 /// 992 /// According to Laplace's Rule of Succession, it is better to compute the 993 /// weight based on the count plus 1, so universally add 1 to the value. 994 /// 995 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no 996 /// greater than \c Weight. 997 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) { 998 assert(Scale && "scale by 0?"); 999 uint64_t Scaled = Weight / Scale + 1; 1000 assert(Scaled <= UINT32_MAX && "overflow 32-bits"); 1001 return Scaled; 1002 } 1003 1004 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount, 1005 uint64_t FalseCount) { 1006 // Check for empty weights. 1007 if (!TrueCount && !FalseCount) 1008 return nullptr; 1009 1010 // Calculate how to scale down to 32-bits. 1011 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount)); 1012 1013 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 1014 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale), 1015 scaleBranchWeight(FalseCount, Scale)); 1016 } 1017 1018 llvm::MDNode * 1019 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) { 1020 // We need at least two elements to create meaningful weights. 1021 if (Weights.size() < 2) 1022 return nullptr; 1023 1024 // Check for empty weights. 1025 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end()); 1026 if (MaxWeight == 0) 1027 return nullptr; 1028 1029 // Calculate how to scale down to 32-bits. 1030 uint64_t Scale = calculateWeightScale(MaxWeight); 1031 1032 SmallVector<uint32_t, 16> ScaledWeights; 1033 ScaledWeights.reserve(Weights.size()); 1034 for (uint64_t W : Weights) 1035 ScaledWeights.push_back(scaleBranchWeight(W, Scale)); 1036 1037 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 1038 return MDHelper.createBranchWeights(ScaledWeights); 1039 } 1040 1041 llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond, 1042 uint64_t LoopCount) { 1043 if (!PGO.haveRegionCounts()) 1044 return nullptr; 1045 Optional<uint64_t> CondCount = PGO.getStmtCount(Cond); 1046 assert(CondCount.hasValue() && "missing expected loop condition count"); 1047 if (*CondCount == 0) 1048 return nullptr; 1049 return createProfileWeights(LoopCount, 1050 std::max(*CondCount, LoopCount) - LoopCount); 1051 } 1052