1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // Instrumentation-based profile-guided optimization 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "CodeGenPGO.h" 15 #include "CodeGenFunction.h" 16 #include "CoverageMappingGen.h" 17 #include "clang/AST/RecursiveASTVisitor.h" 18 #include "clang/AST/StmtVisitor.h" 19 #include "llvm/IR/Intrinsics.h" 20 #include "llvm/IR/MDBuilder.h" 21 #include "llvm/Support/Endian.h" 22 #include "llvm/Support/FileSystem.h" 23 #include "llvm/Support/MD5.h" 24 25 static llvm::cl::opt<bool> 26 EnableValueProfiling("enable-value-profiling", llvm::cl::ZeroOrMore, 27 llvm::cl::desc("Enable value profiling"), 28 llvm::cl::Hidden, llvm::cl::init(false)); 29 30 using namespace clang; 31 using namespace CodeGen; 32 33 void CodeGenPGO::setFuncName(StringRef Name, 34 llvm::GlobalValue::LinkageTypes Linkage) { 35 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 36 FuncName = llvm::getPGOFuncName( 37 Name, Linkage, CGM.getCodeGenOpts().MainFileName, 38 PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version); 39 40 // If we're generating a profile, create a variable for the name. 41 if (CGM.getCodeGenOpts().hasProfileClangInstr()) 42 FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName); 43 } 44 45 void CodeGenPGO::setFuncName(llvm::Function *Fn) { 46 setFuncName(Fn->getName(), Fn->getLinkage()); 47 // Create PGOFuncName meta data. 48 llvm::createPGOFuncNameMetadata(*Fn, FuncName); 49 } 50 51 /// The version of the PGO hash algorithm. 52 enum PGOHashVersion : unsigned { 53 PGO_HASH_V1, 54 PGO_HASH_V2, 55 56 // Keep this set to the latest hash version. 57 PGO_HASH_LATEST = PGO_HASH_V2 58 }; 59 60 namespace { 61 /// Stable hasher for PGO region counters. 62 /// 63 /// PGOHash produces a stable hash of a given function's control flow. 64 /// 65 /// Changing the output of this hash will invalidate all previously generated 66 /// profiles -- i.e., don't do it. 67 /// 68 /// \note When this hash does eventually change (years?), we still need to 69 /// support old hashes. We'll need to pull in the version number from the 70 /// profile data format and use the matching hash function. 71 class PGOHash { 72 uint64_t Working; 73 unsigned Count; 74 PGOHashVersion HashVersion; 75 llvm::MD5 MD5; 76 77 static const int NumBitsPerType = 6; 78 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType; 79 static const unsigned TooBig = 1u << NumBitsPerType; 80 81 public: 82 /// Hash values for AST nodes. 83 /// 84 /// Distinct values for AST nodes that have region counters attached. 85 /// 86 /// These values must be stable. All new members must be added at the end, 87 /// and no members should be removed. Changing the enumeration value for an 88 /// AST node will affect the hash of every function that contains that node. 89 enum HashType : unsigned char { 90 None = 0, 91 LabelStmt = 1, 92 WhileStmt, 93 DoStmt, 94 ForStmt, 95 CXXForRangeStmt, 96 ObjCForCollectionStmt, 97 SwitchStmt, 98 CaseStmt, 99 DefaultStmt, 100 IfStmt, 101 CXXTryStmt, 102 CXXCatchStmt, 103 ConditionalOperator, 104 BinaryOperatorLAnd, 105 BinaryOperatorLOr, 106 BinaryConditionalOperator, 107 // The preceding values are available with PGO_HASH_V1. 108 109 EndOfScope, 110 IfThenBranch, 111 IfElseBranch, 112 GotoStmt, 113 IndirectGotoStmt, 114 BreakStmt, 115 ContinueStmt, 116 ReturnStmt, 117 ThrowExpr, 118 UnaryOperatorLNot, 119 BinaryOperatorLT, 120 BinaryOperatorGT, 121 BinaryOperatorLE, 122 BinaryOperatorGE, 123 BinaryOperatorEQ, 124 BinaryOperatorNE, 125 // The preceding values are available with PGO_HASH_V2. 126 127 // Keep this last. It's for the static assert that follows. 128 LastHashType 129 }; 130 static_assert(LastHashType <= TooBig, "Too many types in HashType"); 131 132 PGOHash(PGOHashVersion HashVersion) 133 : Working(0), Count(0), HashVersion(HashVersion), MD5() {} 134 void combine(HashType Type); 135 uint64_t finalize(); 136 PGOHashVersion getHashVersion() const { return HashVersion; } 137 }; 138 const int PGOHash::NumBitsPerType; 139 const unsigned PGOHash::NumTypesPerWord; 140 const unsigned PGOHash::TooBig; 141 142 /// Get the PGO hash version used in the given indexed profile. 143 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader, 144 CodeGenModule &CGM) { 145 if (PGOReader->getVersion() <= 4) 146 return PGO_HASH_V1; 147 return PGO_HASH_V2; 148 } 149 150 /// A RecursiveASTVisitor that fills a map of statements to PGO counters. 151 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> { 152 using Base = RecursiveASTVisitor<MapRegionCounters>; 153 154 /// The next counter value to assign. 155 unsigned NextCounter; 156 /// The function hash. 157 PGOHash Hash; 158 /// The map of statements to counters. 159 llvm::DenseMap<const Stmt *, unsigned> &CounterMap; 160 161 MapRegionCounters(PGOHashVersion HashVersion, 162 llvm::DenseMap<const Stmt *, unsigned> &CounterMap) 163 : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap) {} 164 165 // Blocks and lambdas are handled as separate functions, so we need not 166 // traverse them in the parent context. 167 bool TraverseBlockExpr(BlockExpr *BE) { return true; } 168 bool TraverseLambdaBody(LambdaExpr *LE) { return true; } 169 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; } 170 171 bool VisitDecl(const Decl *D) { 172 switch (D->getKind()) { 173 default: 174 break; 175 case Decl::Function: 176 case Decl::CXXMethod: 177 case Decl::CXXConstructor: 178 case Decl::CXXDestructor: 179 case Decl::CXXConversion: 180 case Decl::ObjCMethod: 181 case Decl::Block: 182 case Decl::Captured: 183 CounterMap[D->getBody()] = NextCounter++; 184 break; 185 } 186 return true; 187 } 188 189 /// If \p S gets a fresh counter, update the counter mappings. Return the 190 /// V1 hash of \p S. 191 PGOHash::HashType updateCounterMappings(Stmt *S) { 192 auto Type = getHashType(PGO_HASH_V1, S); 193 if (Type != PGOHash::None) 194 CounterMap[S] = NextCounter++; 195 return Type; 196 } 197 198 /// Include \p S in the function hash. 199 bool VisitStmt(Stmt *S) { 200 auto Type = updateCounterMappings(S); 201 if (Hash.getHashVersion() != PGO_HASH_V1) 202 Type = getHashType(Hash.getHashVersion(), S); 203 if (Type != PGOHash::None) 204 Hash.combine(Type); 205 return true; 206 } 207 208 bool TraverseIfStmt(IfStmt *If) { 209 // If we used the V1 hash, use the default traversal. 210 if (Hash.getHashVersion() == PGO_HASH_V1) 211 return Base::TraverseIfStmt(If); 212 213 // Otherwise, keep track of which branch we're in while traversing. 214 VisitStmt(If); 215 for (Stmt *CS : If->children()) { 216 if (!CS) 217 continue; 218 if (CS == If->getThen()) 219 Hash.combine(PGOHash::IfThenBranch); 220 else if (CS == If->getElse()) 221 Hash.combine(PGOHash::IfElseBranch); 222 TraverseStmt(CS); 223 } 224 Hash.combine(PGOHash::EndOfScope); 225 return true; 226 } 227 228 // If the statement type \p N is nestable, and its nesting impacts profile 229 // stability, define a custom traversal which tracks the end of the statement 230 // in the hash (provided we're not using the V1 hash). 231 #define DEFINE_NESTABLE_TRAVERSAL(N) \ 232 bool Traverse##N(N *S) { \ 233 Base::Traverse##N(S); \ 234 if (Hash.getHashVersion() != PGO_HASH_V1) \ 235 Hash.combine(PGOHash::EndOfScope); \ 236 return true; \ 237 } 238 239 DEFINE_NESTABLE_TRAVERSAL(WhileStmt) 240 DEFINE_NESTABLE_TRAVERSAL(DoStmt) 241 DEFINE_NESTABLE_TRAVERSAL(ForStmt) 242 DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt) 243 DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt) 244 DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt) 245 DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt) 246 247 /// Get version \p HashVersion of the PGO hash for \p S. 248 PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) { 249 switch (S->getStmtClass()) { 250 default: 251 break; 252 case Stmt::LabelStmtClass: 253 return PGOHash::LabelStmt; 254 case Stmt::WhileStmtClass: 255 return PGOHash::WhileStmt; 256 case Stmt::DoStmtClass: 257 return PGOHash::DoStmt; 258 case Stmt::ForStmtClass: 259 return PGOHash::ForStmt; 260 case Stmt::CXXForRangeStmtClass: 261 return PGOHash::CXXForRangeStmt; 262 case Stmt::ObjCForCollectionStmtClass: 263 return PGOHash::ObjCForCollectionStmt; 264 case Stmt::SwitchStmtClass: 265 return PGOHash::SwitchStmt; 266 case Stmt::CaseStmtClass: 267 return PGOHash::CaseStmt; 268 case Stmt::DefaultStmtClass: 269 return PGOHash::DefaultStmt; 270 case Stmt::IfStmtClass: 271 return PGOHash::IfStmt; 272 case Stmt::CXXTryStmtClass: 273 return PGOHash::CXXTryStmt; 274 case Stmt::CXXCatchStmtClass: 275 return PGOHash::CXXCatchStmt; 276 case Stmt::ConditionalOperatorClass: 277 return PGOHash::ConditionalOperator; 278 case Stmt::BinaryConditionalOperatorClass: 279 return PGOHash::BinaryConditionalOperator; 280 case Stmt::BinaryOperatorClass: { 281 const BinaryOperator *BO = cast<BinaryOperator>(S); 282 if (BO->getOpcode() == BO_LAnd) 283 return PGOHash::BinaryOperatorLAnd; 284 if (BO->getOpcode() == BO_LOr) 285 return PGOHash::BinaryOperatorLOr; 286 if (HashVersion == PGO_HASH_V2) { 287 switch (BO->getOpcode()) { 288 default: 289 break; 290 case BO_LT: 291 return PGOHash::BinaryOperatorLT; 292 case BO_GT: 293 return PGOHash::BinaryOperatorGT; 294 case BO_LE: 295 return PGOHash::BinaryOperatorLE; 296 case BO_GE: 297 return PGOHash::BinaryOperatorGE; 298 case BO_EQ: 299 return PGOHash::BinaryOperatorEQ; 300 case BO_NE: 301 return PGOHash::BinaryOperatorNE; 302 } 303 } 304 break; 305 } 306 } 307 308 if (HashVersion == PGO_HASH_V2) { 309 switch (S->getStmtClass()) { 310 default: 311 break; 312 case Stmt::GotoStmtClass: 313 return PGOHash::GotoStmt; 314 case Stmt::IndirectGotoStmtClass: 315 return PGOHash::IndirectGotoStmt; 316 case Stmt::BreakStmtClass: 317 return PGOHash::BreakStmt; 318 case Stmt::ContinueStmtClass: 319 return PGOHash::ContinueStmt; 320 case Stmt::ReturnStmtClass: 321 return PGOHash::ReturnStmt; 322 case Stmt::CXXThrowExprClass: 323 return PGOHash::ThrowExpr; 324 case Stmt::UnaryOperatorClass: { 325 const UnaryOperator *UO = cast<UnaryOperator>(S); 326 if (UO->getOpcode() == UO_LNot) 327 return PGOHash::UnaryOperatorLNot; 328 break; 329 } 330 } 331 } 332 333 return PGOHash::None; 334 } 335 }; 336 337 /// A StmtVisitor that propagates the raw counts through the AST and 338 /// records the count at statements where the value may change. 339 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { 340 /// PGO state. 341 CodeGenPGO &PGO; 342 343 /// A flag that is set when the current count should be recorded on the 344 /// next statement, such as at the exit of a loop. 345 bool RecordNextStmtCount; 346 347 /// The count at the current location in the traversal. 348 uint64_t CurrentCount; 349 350 /// The map of statements to count values. 351 llvm::DenseMap<const Stmt *, uint64_t> &CountMap; 352 353 /// BreakContinueStack - Keep counts of breaks and continues inside loops. 354 struct BreakContinue { 355 uint64_t BreakCount; 356 uint64_t ContinueCount; 357 BreakContinue() : BreakCount(0), ContinueCount(0) {} 358 }; 359 SmallVector<BreakContinue, 8> BreakContinueStack; 360 361 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap, 362 CodeGenPGO &PGO) 363 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {} 364 365 void RecordStmtCount(const Stmt *S) { 366 if (RecordNextStmtCount) { 367 CountMap[S] = CurrentCount; 368 RecordNextStmtCount = false; 369 } 370 } 371 372 /// Set and return the current count. 373 uint64_t setCount(uint64_t Count) { 374 CurrentCount = Count; 375 return Count; 376 } 377 378 void VisitStmt(const Stmt *S) { 379 RecordStmtCount(S); 380 for (const Stmt *Child : S->children()) 381 if (Child) 382 this->Visit(Child); 383 } 384 385 void VisitFunctionDecl(const FunctionDecl *D) { 386 // Counter tracks entry to the function body. 387 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 388 CountMap[D->getBody()] = BodyCount; 389 Visit(D->getBody()); 390 } 391 392 // Skip lambda expressions. We visit these as FunctionDecls when we're 393 // generating them and aren't interested in the body when generating a 394 // parent context. 395 void VisitLambdaExpr(const LambdaExpr *LE) {} 396 397 void VisitCapturedDecl(const CapturedDecl *D) { 398 // Counter tracks entry to the capture body. 399 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 400 CountMap[D->getBody()] = BodyCount; 401 Visit(D->getBody()); 402 } 403 404 void VisitObjCMethodDecl(const ObjCMethodDecl *D) { 405 // Counter tracks entry to the method body. 406 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 407 CountMap[D->getBody()] = BodyCount; 408 Visit(D->getBody()); 409 } 410 411 void VisitBlockDecl(const BlockDecl *D) { 412 // Counter tracks entry to the block body. 413 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 414 CountMap[D->getBody()] = BodyCount; 415 Visit(D->getBody()); 416 } 417 418 void VisitReturnStmt(const ReturnStmt *S) { 419 RecordStmtCount(S); 420 if (S->getRetValue()) 421 Visit(S->getRetValue()); 422 CurrentCount = 0; 423 RecordNextStmtCount = true; 424 } 425 426 void VisitCXXThrowExpr(const CXXThrowExpr *E) { 427 RecordStmtCount(E); 428 if (E->getSubExpr()) 429 Visit(E->getSubExpr()); 430 CurrentCount = 0; 431 RecordNextStmtCount = true; 432 } 433 434 void VisitGotoStmt(const GotoStmt *S) { 435 RecordStmtCount(S); 436 CurrentCount = 0; 437 RecordNextStmtCount = true; 438 } 439 440 void VisitLabelStmt(const LabelStmt *S) { 441 RecordNextStmtCount = false; 442 // Counter tracks the block following the label. 443 uint64_t BlockCount = setCount(PGO.getRegionCount(S)); 444 CountMap[S] = BlockCount; 445 Visit(S->getSubStmt()); 446 } 447 448 void VisitBreakStmt(const BreakStmt *S) { 449 RecordStmtCount(S); 450 assert(!BreakContinueStack.empty() && "break not in a loop or switch!"); 451 BreakContinueStack.back().BreakCount += CurrentCount; 452 CurrentCount = 0; 453 RecordNextStmtCount = true; 454 } 455 456 void VisitContinueStmt(const ContinueStmt *S) { 457 RecordStmtCount(S); 458 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!"); 459 BreakContinueStack.back().ContinueCount += CurrentCount; 460 CurrentCount = 0; 461 RecordNextStmtCount = true; 462 } 463 464 void VisitWhileStmt(const WhileStmt *S) { 465 RecordStmtCount(S); 466 uint64_t ParentCount = CurrentCount; 467 468 BreakContinueStack.push_back(BreakContinue()); 469 // Visit the body region first so the break/continue adjustments can be 470 // included when visiting the condition. 471 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 472 CountMap[S->getBody()] = CurrentCount; 473 Visit(S->getBody()); 474 uint64_t BackedgeCount = CurrentCount; 475 476 // ...then go back and propagate counts through the condition. The count 477 // at the start of the condition is the sum of the incoming edges, 478 // the backedge from the end of the loop body, and the edges from 479 // continue statements. 480 BreakContinue BC = BreakContinueStack.pop_back_val(); 481 uint64_t CondCount = 482 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 483 CountMap[S->getCond()] = CondCount; 484 Visit(S->getCond()); 485 setCount(BC.BreakCount + CondCount - BodyCount); 486 RecordNextStmtCount = true; 487 } 488 489 void VisitDoStmt(const DoStmt *S) { 490 RecordStmtCount(S); 491 uint64_t LoopCount = PGO.getRegionCount(S); 492 493 BreakContinueStack.push_back(BreakContinue()); 494 // The count doesn't include the fallthrough from the parent scope. Add it. 495 uint64_t BodyCount = setCount(LoopCount + CurrentCount); 496 CountMap[S->getBody()] = BodyCount; 497 Visit(S->getBody()); 498 uint64_t BackedgeCount = CurrentCount; 499 500 BreakContinue BC = BreakContinueStack.pop_back_val(); 501 // The count at the start of the condition is equal to the count at the 502 // end of the body, plus any continues. 503 uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount); 504 CountMap[S->getCond()] = CondCount; 505 Visit(S->getCond()); 506 setCount(BC.BreakCount + CondCount - LoopCount); 507 RecordNextStmtCount = true; 508 } 509 510 void VisitForStmt(const ForStmt *S) { 511 RecordStmtCount(S); 512 if (S->getInit()) 513 Visit(S->getInit()); 514 515 uint64_t ParentCount = CurrentCount; 516 517 BreakContinueStack.push_back(BreakContinue()); 518 // Visit the body region first. (This is basically the same as a while 519 // loop; see further comments in VisitWhileStmt.) 520 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 521 CountMap[S->getBody()] = BodyCount; 522 Visit(S->getBody()); 523 uint64_t BackedgeCount = CurrentCount; 524 BreakContinue BC = BreakContinueStack.pop_back_val(); 525 526 // The increment is essentially part of the body but it needs to include 527 // the count for all the continue statements. 528 if (S->getInc()) { 529 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); 530 CountMap[S->getInc()] = IncCount; 531 Visit(S->getInc()); 532 } 533 534 // ...then go back and propagate counts through the condition. 535 uint64_t CondCount = 536 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 537 if (S->getCond()) { 538 CountMap[S->getCond()] = CondCount; 539 Visit(S->getCond()); 540 } 541 setCount(BC.BreakCount + CondCount - BodyCount); 542 RecordNextStmtCount = true; 543 } 544 545 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) { 546 RecordStmtCount(S); 547 if (S->getInit()) 548 Visit(S->getInit()); 549 Visit(S->getLoopVarStmt()); 550 Visit(S->getRangeStmt()); 551 Visit(S->getBeginStmt()); 552 Visit(S->getEndStmt()); 553 554 uint64_t ParentCount = CurrentCount; 555 BreakContinueStack.push_back(BreakContinue()); 556 // Visit the body region first. (This is basically the same as a while 557 // loop; see further comments in VisitWhileStmt.) 558 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 559 CountMap[S->getBody()] = BodyCount; 560 Visit(S->getBody()); 561 uint64_t BackedgeCount = CurrentCount; 562 BreakContinue BC = BreakContinueStack.pop_back_val(); 563 564 // The increment is essentially part of the body but it needs to include 565 // the count for all the continue statements. 566 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); 567 CountMap[S->getInc()] = IncCount; 568 Visit(S->getInc()); 569 570 // ...then go back and propagate counts through the condition. 571 uint64_t CondCount = 572 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 573 CountMap[S->getCond()] = CondCount; 574 Visit(S->getCond()); 575 setCount(BC.BreakCount + CondCount - BodyCount); 576 RecordNextStmtCount = true; 577 } 578 579 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) { 580 RecordStmtCount(S); 581 Visit(S->getElement()); 582 uint64_t ParentCount = CurrentCount; 583 BreakContinueStack.push_back(BreakContinue()); 584 // Counter tracks the body of the loop. 585 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 586 CountMap[S->getBody()] = BodyCount; 587 Visit(S->getBody()); 588 uint64_t BackedgeCount = CurrentCount; 589 BreakContinue BC = BreakContinueStack.pop_back_val(); 590 591 setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount - 592 BodyCount); 593 RecordNextStmtCount = true; 594 } 595 596 void VisitSwitchStmt(const SwitchStmt *S) { 597 RecordStmtCount(S); 598 if (S->getInit()) 599 Visit(S->getInit()); 600 Visit(S->getCond()); 601 CurrentCount = 0; 602 BreakContinueStack.push_back(BreakContinue()); 603 Visit(S->getBody()); 604 // If the switch is inside a loop, add the continue counts. 605 BreakContinue BC = BreakContinueStack.pop_back_val(); 606 if (!BreakContinueStack.empty()) 607 BreakContinueStack.back().ContinueCount += BC.ContinueCount; 608 // Counter tracks the exit block of the switch. 609 setCount(PGO.getRegionCount(S)); 610 RecordNextStmtCount = true; 611 } 612 613 void VisitSwitchCase(const SwitchCase *S) { 614 RecordNextStmtCount = false; 615 // Counter for this particular case. This counts only jumps from the 616 // switch header and does not include fallthrough from the case before 617 // this one. 618 uint64_t CaseCount = PGO.getRegionCount(S); 619 setCount(CurrentCount + CaseCount); 620 // We need the count without fallthrough in the mapping, so it's more useful 621 // for branch probabilities. 622 CountMap[S] = CaseCount; 623 RecordNextStmtCount = true; 624 Visit(S->getSubStmt()); 625 } 626 627 void VisitIfStmt(const IfStmt *S) { 628 RecordStmtCount(S); 629 uint64_t ParentCount = CurrentCount; 630 if (S->getInit()) 631 Visit(S->getInit()); 632 Visit(S->getCond()); 633 634 // Counter tracks the "then" part of an if statement. The count for 635 // the "else" part, if it exists, will be calculated from this counter. 636 uint64_t ThenCount = setCount(PGO.getRegionCount(S)); 637 CountMap[S->getThen()] = ThenCount; 638 Visit(S->getThen()); 639 uint64_t OutCount = CurrentCount; 640 641 uint64_t ElseCount = ParentCount - ThenCount; 642 if (S->getElse()) { 643 setCount(ElseCount); 644 CountMap[S->getElse()] = ElseCount; 645 Visit(S->getElse()); 646 OutCount += CurrentCount; 647 } else 648 OutCount += ElseCount; 649 setCount(OutCount); 650 RecordNextStmtCount = true; 651 } 652 653 void VisitCXXTryStmt(const CXXTryStmt *S) { 654 RecordStmtCount(S); 655 Visit(S->getTryBlock()); 656 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I) 657 Visit(S->getHandler(I)); 658 // Counter tracks the continuation block of the try statement. 659 setCount(PGO.getRegionCount(S)); 660 RecordNextStmtCount = true; 661 } 662 663 void VisitCXXCatchStmt(const CXXCatchStmt *S) { 664 RecordNextStmtCount = false; 665 // Counter tracks the catch statement's handler block. 666 uint64_t CatchCount = setCount(PGO.getRegionCount(S)); 667 CountMap[S] = CatchCount; 668 Visit(S->getHandlerBlock()); 669 } 670 671 void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) { 672 RecordStmtCount(E); 673 uint64_t ParentCount = CurrentCount; 674 Visit(E->getCond()); 675 676 // Counter tracks the "true" part of a conditional operator. The 677 // count in the "false" part will be calculated from this counter. 678 uint64_t TrueCount = setCount(PGO.getRegionCount(E)); 679 CountMap[E->getTrueExpr()] = TrueCount; 680 Visit(E->getTrueExpr()); 681 uint64_t OutCount = CurrentCount; 682 683 uint64_t FalseCount = setCount(ParentCount - TrueCount); 684 CountMap[E->getFalseExpr()] = FalseCount; 685 Visit(E->getFalseExpr()); 686 OutCount += CurrentCount; 687 688 setCount(OutCount); 689 RecordNextStmtCount = true; 690 } 691 692 void VisitBinLAnd(const BinaryOperator *E) { 693 RecordStmtCount(E); 694 uint64_t ParentCount = CurrentCount; 695 Visit(E->getLHS()); 696 // Counter tracks the right hand side of a logical and operator. 697 uint64_t RHSCount = setCount(PGO.getRegionCount(E)); 698 CountMap[E->getRHS()] = RHSCount; 699 Visit(E->getRHS()); 700 setCount(ParentCount + RHSCount - CurrentCount); 701 RecordNextStmtCount = true; 702 } 703 704 void VisitBinLOr(const BinaryOperator *E) { 705 RecordStmtCount(E); 706 uint64_t ParentCount = CurrentCount; 707 Visit(E->getLHS()); 708 // Counter tracks the right hand side of a logical or operator. 709 uint64_t RHSCount = setCount(PGO.getRegionCount(E)); 710 CountMap[E->getRHS()] = RHSCount; 711 Visit(E->getRHS()); 712 setCount(ParentCount + RHSCount - CurrentCount); 713 RecordNextStmtCount = true; 714 } 715 }; 716 } // end anonymous namespace 717 718 void PGOHash::combine(HashType Type) { 719 // Check that we never combine 0 and only have six bits. 720 assert(Type && "Hash is invalid: unexpected type 0"); 721 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types"); 722 723 // Pass through MD5 if enough work has built up. 724 if (Count && Count % NumTypesPerWord == 0) { 725 using namespace llvm::support; 726 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working); 727 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped))); 728 Working = 0; 729 } 730 731 // Accumulate the current type. 732 ++Count; 733 Working = Working << NumBitsPerType | Type; 734 } 735 736 uint64_t PGOHash::finalize() { 737 // Use Working as the hash directly if we never used MD5. 738 if (Count <= NumTypesPerWord) 739 // No need to byte swap here, since none of the math was endian-dependent. 740 // This number will be byte-swapped as required on endianness transitions, 741 // so we will see the same value on the other side. 742 return Working; 743 744 // Check for remaining work in Working. 745 if (Working) 746 MD5.update(Working); 747 748 // Finalize the MD5 and return the hash. 749 llvm::MD5::MD5Result Result; 750 MD5.final(Result); 751 using namespace llvm::support; 752 return Result.low(); 753 } 754 755 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) { 756 const Decl *D = GD.getDecl(); 757 if (!D->hasBody()) 758 return; 759 760 bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr(); 761 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 762 if (!InstrumentRegions && !PGOReader) 763 return; 764 if (D->isImplicit()) 765 return; 766 // Constructors and destructors may be represented by several functions in IR. 767 // If so, instrument only base variant, others are implemented by delegation 768 // to the base one, it would be counted twice otherwise. 769 if (CGM.getTarget().getCXXABI().hasConstructorVariants()) { 770 if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base) 771 return; 772 773 if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D)) 774 if (GD.getCtorType() != Ctor_Base && 775 CodeGenFunction::IsConstructorDelegationValid(CCD)) 776 return; 777 } 778 CGM.ClearUnusedCoverageMapping(D); 779 setFuncName(Fn); 780 781 mapRegionCounters(D); 782 if (CGM.getCodeGenOpts().CoverageMapping) 783 emitCounterRegionMapping(D); 784 if (PGOReader) { 785 SourceManager &SM = CGM.getContext().getSourceManager(); 786 loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation())); 787 computeRegionCounts(D); 788 applyFunctionAttributes(PGOReader, Fn); 789 } 790 } 791 792 void CodeGenPGO::mapRegionCounters(const Decl *D) { 793 // Use the latest hash version when inserting instrumentation, but use the 794 // version in the indexed profile if we're reading PGO data. 795 PGOHashVersion HashVersion = PGO_HASH_LATEST; 796 if (auto *PGOReader = CGM.getPGOReader()) 797 HashVersion = getPGOHashVersion(PGOReader, CGM); 798 799 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>); 800 MapRegionCounters Walker(HashVersion, *RegionCounterMap); 801 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 802 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD)); 803 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 804 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD)); 805 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 806 Walker.TraverseDecl(const_cast<BlockDecl *>(BD)); 807 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 808 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD)); 809 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl"); 810 NumRegionCounters = Walker.NextCounter; 811 FunctionHash = Walker.Hash.finalize(); 812 } 813 814 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) { 815 if (!D->getBody()) 816 return true; 817 818 // Don't map the functions in system headers. 819 const auto &SM = CGM.getContext().getSourceManager(); 820 auto Loc = D->getBody()->getBeginLoc(); 821 return SM.isInSystemHeader(Loc); 822 } 823 824 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) { 825 if (skipRegionMappingForDecl(D)) 826 return; 827 828 std::string CoverageMapping; 829 llvm::raw_string_ostream OS(CoverageMapping); 830 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(), 831 CGM.getContext().getSourceManager(), 832 CGM.getLangOpts(), RegionCounterMap.get()); 833 MappingGen.emitCounterMapping(D, OS); 834 OS.flush(); 835 836 if (CoverageMapping.empty()) 837 return; 838 839 CGM.getCoverageMapping()->addFunctionMappingRecord( 840 FuncNameVar, FuncName, FunctionHash, CoverageMapping); 841 } 842 843 void 844 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name, 845 llvm::GlobalValue::LinkageTypes Linkage) { 846 if (skipRegionMappingForDecl(D)) 847 return; 848 849 std::string CoverageMapping; 850 llvm::raw_string_ostream OS(CoverageMapping); 851 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(), 852 CGM.getContext().getSourceManager(), 853 CGM.getLangOpts()); 854 MappingGen.emitEmptyMapping(D, OS); 855 OS.flush(); 856 857 if (CoverageMapping.empty()) 858 return; 859 860 setFuncName(Name, Linkage); 861 CGM.getCoverageMapping()->addFunctionMappingRecord( 862 FuncNameVar, FuncName, FunctionHash, CoverageMapping, false); 863 } 864 865 void CodeGenPGO::computeRegionCounts(const Decl *D) { 866 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>); 867 ComputeRegionCounts Walker(*StmtCountMap, *this); 868 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 869 Walker.VisitFunctionDecl(FD); 870 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 871 Walker.VisitObjCMethodDecl(MD); 872 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 873 Walker.VisitBlockDecl(BD); 874 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 875 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD)); 876 } 877 878 void 879 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader, 880 llvm::Function *Fn) { 881 if (!haveRegionCounts()) 882 return; 883 884 uint64_t FunctionCount = getRegionCount(nullptr); 885 Fn->setEntryCount(FunctionCount); 886 } 887 888 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S, 889 llvm::Value *StepV) { 890 if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap) 891 return; 892 if (!Builder.GetInsertBlock()) 893 return; 894 895 unsigned Counter = (*RegionCounterMap)[S]; 896 auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext()); 897 898 llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy), 899 Builder.getInt64(FunctionHash), 900 Builder.getInt32(NumRegionCounters), 901 Builder.getInt32(Counter), StepV}; 902 if (!StepV) 903 Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment), 904 makeArrayRef(Args, 4)); 905 else 906 Builder.CreateCall( 907 CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step), 908 makeArrayRef(Args)); 909 } 910 911 // This method either inserts a call to the profile run-time during 912 // instrumentation or puts profile data into metadata for PGO use. 913 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind, 914 llvm::Instruction *ValueSite, llvm::Value *ValuePtr) { 915 916 if (!EnableValueProfiling) 917 return; 918 919 if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock()) 920 return; 921 922 if (isa<llvm::Constant>(ValuePtr)) 923 return; 924 925 bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr(); 926 if (InstrumentValueSites && RegionCounterMap) { 927 auto BuilderInsertPoint = Builder.saveIP(); 928 Builder.SetInsertPoint(ValueSite); 929 llvm::Value *Args[5] = { 930 llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()), 931 Builder.getInt64(FunctionHash), 932 Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()), 933 Builder.getInt32(ValueKind), 934 Builder.getInt32(NumValueSites[ValueKind]++) 935 }; 936 Builder.CreateCall( 937 CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args); 938 Builder.restoreIP(BuilderInsertPoint); 939 return; 940 } 941 942 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 943 if (PGOReader && haveRegionCounts()) { 944 // We record the top most called three functions at each call site. 945 // Profile metadata contains "VP" string identifying this metadata 946 // as value profiling data, then a uint32_t value for the value profiling 947 // kind, a uint64_t value for the total number of times the call is 948 // executed, followed by the function hash and execution count (uint64_t) 949 // pairs for each function. 950 if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind)) 951 return; 952 953 llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord, 954 (llvm::InstrProfValueKind)ValueKind, 955 NumValueSites[ValueKind]); 956 957 NumValueSites[ValueKind]++; 958 } 959 } 960 961 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader, 962 bool IsInMainFile) { 963 CGM.getPGOStats().addVisited(IsInMainFile); 964 RegionCounts.clear(); 965 llvm::Expected<llvm::InstrProfRecord> RecordExpected = 966 PGOReader->getInstrProfRecord(FuncName, FunctionHash); 967 if (auto E = RecordExpected.takeError()) { 968 auto IPE = llvm::InstrProfError::take(std::move(E)); 969 if (IPE == llvm::instrprof_error::unknown_function) 970 CGM.getPGOStats().addMissing(IsInMainFile); 971 else if (IPE == llvm::instrprof_error::hash_mismatch) 972 CGM.getPGOStats().addMismatched(IsInMainFile); 973 else if (IPE == llvm::instrprof_error::malformed) 974 // TODO: Consider a more specific warning for this case. 975 CGM.getPGOStats().addMismatched(IsInMainFile); 976 return; 977 } 978 ProfRecord = 979 llvm::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get())); 980 RegionCounts = ProfRecord->Counts; 981 } 982 983 /// Calculate what to divide by to scale weights. 984 /// 985 /// Given the maximum weight, calculate a divisor that will scale all the 986 /// weights to strictly less than UINT32_MAX. 987 static uint64_t calculateWeightScale(uint64_t MaxWeight) { 988 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1; 989 } 990 991 /// Scale an individual branch weight (and add 1). 992 /// 993 /// Scale a 64-bit weight down to 32-bits using \c Scale. 994 /// 995 /// According to Laplace's Rule of Succession, it is better to compute the 996 /// weight based on the count plus 1, so universally add 1 to the value. 997 /// 998 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no 999 /// greater than \c Weight. 1000 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) { 1001 assert(Scale && "scale by 0?"); 1002 uint64_t Scaled = Weight / Scale + 1; 1003 assert(Scaled <= UINT32_MAX && "overflow 32-bits"); 1004 return Scaled; 1005 } 1006 1007 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount, 1008 uint64_t FalseCount) { 1009 // Check for empty weights. 1010 if (!TrueCount && !FalseCount) 1011 return nullptr; 1012 1013 // Calculate how to scale down to 32-bits. 1014 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount)); 1015 1016 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 1017 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale), 1018 scaleBranchWeight(FalseCount, Scale)); 1019 } 1020 1021 llvm::MDNode * 1022 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) { 1023 // We need at least two elements to create meaningful weights. 1024 if (Weights.size() < 2) 1025 return nullptr; 1026 1027 // Check for empty weights. 1028 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end()); 1029 if (MaxWeight == 0) 1030 return nullptr; 1031 1032 // Calculate how to scale down to 32-bits. 1033 uint64_t Scale = calculateWeightScale(MaxWeight); 1034 1035 SmallVector<uint32_t, 16> ScaledWeights; 1036 ScaledWeights.reserve(Weights.size()); 1037 for (uint64_t W : Weights) 1038 ScaledWeights.push_back(scaleBranchWeight(W, Scale)); 1039 1040 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 1041 return MDHelper.createBranchWeights(ScaledWeights); 1042 } 1043 1044 llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond, 1045 uint64_t LoopCount) { 1046 if (!PGO.haveRegionCounts()) 1047 return nullptr; 1048 Optional<uint64_t> CondCount = PGO.getStmtCount(Cond); 1049 assert(CondCount.hasValue() && "missing expected loop condition count"); 1050 if (*CondCount == 0) 1051 return nullptr; 1052 return createProfileWeights(LoopCount, 1053 std::max(*CondCount, LoopCount) - LoopCount); 1054 } 1055