1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // Instrumentation-based profile-guided optimization 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "CodeGenPGO.h" 15 #include "CodeGenFunction.h" 16 #include "CoverageMappingGen.h" 17 #include "clang/AST/RecursiveASTVisitor.h" 18 #include "clang/AST/StmtVisitor.h" 19 #include "llvm/IR/Intrinsics.h" 20 #include "llvm/IR/MDBuilder.h" 21 #include "llvm/Support/Endian.h" 22 #include "llvm/Support/FileSystem.h" 23 #include "llvm/Support/MD5.h" 24 25 static llvm::cl::opt<bool> EnableValueProfiling( 26 "enable-value-profiling", llvm::cl::ZeroOrMore, 27 llvm::cl::desc("Enable value profiling"), llvm::cl::init(false)); 28 29 using namespace clang; 30 using namespace CodeGen; 31 32 void CodeGenPGO::setFuncName(StringRef Name, 33 llvm::GlobalValue::LinkageTypes Linkage) { 34 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 35 FuncName = llvm::getPGOFuncName( 36 Name, Linkage, CGM.getCodeGenOpts().MainFileName, 37 PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version); 38 39 // If we're generating a profile, create a variable for the name. 40 if (CGM.getCodeGenOpts().hasProfileClangInstr()) 41 FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName); 42 } 43 44 void CodeGenPGO::setFuncName(llvm::Function *Fn) { 45 setFuncName(Fn->getName(), Fn->getLinkage()); 46 } 47 48 namespace { 49 /// \brief Stable hasher for PGO region counters. 50 /// 51 /// PGOHash produces a stable hash of a given function's control flow. 52 /// 53 /// Changing the output of this hash will invalidate all previously generated 54 /// profiles -- i.e., don't do it. 55 /// 56 /// \note When this hash does eventually change (years?), we still need to 57 /// support old hashes. We'll need to pull in the version number from the 58 /// profile data format and use the matching hash function. 59 class PGOHash { 60 uint64_t Working; 61 unsigned Count; 62 llvm::MD5 MD5; 63 64 static const int NumBitsPerType = 6; 65 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType; 66 static const unsigned TooBig = 1u << NumBitsPerType; 67 68 public: 69 /// \brief Hash values for AST nodes. 70 /// 71 /// Distinct values for AST nodes that have region counters attached. 72 /// 73 /// These values must be stable. All new members must be added at the end, 74 /// and no members should be removed. Changing the enumeration value for an 75 /// AST node will affect the hash of every function that contains that node. 76 enum HashType : unsigned char { 77 None = 0, 78 LabelStmt = 1, 79 WhileStmt, 80 DoStmt, 81 ForStmt, 82 CXXForRangeStmt, 83 ObjCForCollectionStmt, 84 SwitchStmt, 85 CaseStmt, 86 DefaultStmt, 87 IfStmt, 88 CXXTryStmt, 89 CXXCatchStmt, 90 ConditionalOperator, 91 BinaryOperatorLAnd, 92 BinaryOperatorLOr, 93 BinaryConditionalOperator, 94 95 // Keep this last. It's for the static assert that follows. 96 LastHashType 97 }; 98 static_assert(LastHashType <= TooBig, "Too many types in HashType"); 99 100 // TODO: When this format changes, take in a version number here, and use the 101 // old hash calculation for file formats that used the old hash. 102 PGOHash() : Working(0), Count(0) {} 103 void combine(HashType Type); 104 uint64_t finalize(); 105 }; 106 const int PGOHash::NumBitsPerType; 107 const unsigned PGOHash::NumTypesPerWord; 108 const unsigned PGOHash::TooBig; 109 110 /// A RecursiveASTVisitor that fills a map of statements to PGO counters. 111 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> { 112 /// The next counter value to assign. 113 unsigned NextCounter; 114 /// The function hash. 115 PGOHash Hash; 116 /// The map of statements to counters. 117 llvm::DenseMap<const Stmt *, unsigned> &CounterMap; 118 119 MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap) 120 : NextCounter(0), CounterMap(CounterMap) {} 121 122 // Blocks and lambdas are handled as separate functions, so we need not 123 // traverse them in the parent context. 124 bool TraverseBlockExpr(BlockExpr *BE) { return true; } 125 bool TraverseLambdaBody(LambdaExpr *LE) { return true; } 126 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; } 127 128 bool VisitDecl(const Decl *D) { 129 switch (D->getKind()) { 130 default: 131 break; 132 case Decl::Function: 133 case Decl::CXXMethod: 134 case Decl::CXXConstructor: 135 case Decl::CXXDestructor: 136 case Decl::CXXConversion: 137 case Decl::ObjCMethod: 138 case Decl::Block: 139 case Decl::Captured: 140 CounterMap[D->getBody()] = NextCounter++; 141 break; 142 } 143 return true; 144 } 145 146 bool VisitStmt(const Stmt *S) { 147 auto Type = getHashType(S); 148 if (Type == PGOHash::None) 149 return true; 150 151 CounterMap[S] = NextCounter++; 152 Hash.combine(Type); 153 return true; 154 } 155 PGOHash::HashType getHashType(const Stmt *S) { 156 switch (S->getStmtClass()) { 157 default: 158 break; 159 case Stmt::LabelStmtClass: 160 return PGOHash::LabelStmt; 161 case Stmt::WhileStmtClass: 162 return PGOHash::WhileStmt; 163 case Stmt::DoStmtClass: 164 return PGOHash::DoStmt; 165 case Stmt::ForStmtClass: 166 return PGOHash::ForStmt; 167 case Stmt::CXXForRangeStmtClass: 168 return PGOHash::CXXForRangeStmt; 169 case Stmt::ObjCForCollectionStmtClass: 170 return PGOHash::ObjCForCollectionStmt; 171 case Stmt::SwitchStmtClass: 172 return PGOHash::SwitchStmt; 173 case Stmt::CaseStmtClass: 174 return PGOHash::CaseStmt; 175 case Stmt::DefaultStmtClass: 176 return PGOHash::DefaultStmt; 177 case Stmt::IfStmtClass: 178 return PGOHash::IfStmt; 179 case Stmt::CXXTryStmtClass: 180 return PGOHash::CXXTryStmt; 181 case Stmt::CXXCatchStmtClass: 182 return PGOHash::CXXCatchStmt; 183 case Stmt::ConditionalOperatorClass: 184 return PGOHash::ConditionalOperator; 185 case Stmt::BinaryConditionalOperatorClass: 186 return PGOHash::BinaryConditionalOperator; 187 case Stmt::BinaryOperatorClass: { 188 const BinaryOperator *BO = cast<BinaryOperator>(S); 189 if (BO->getOpcode() == BO_LAnd) 190 return PGOHash::BinaryOperatorLAnd; 191 if (BO->getOpcode() == BO_LOr) 192 return PGOHash::BinaryOperatorLOr; 193 break; 194 } 195 } 196 return PGOHash::None; 197 } 198 }; 199 200 /// A StmtVisitor that propagates the raw counts through the AST and 201 /// records the count at statements where the value may change. 202 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { 203 /// PGO state. 204 CodeGenPGO &PGO; 205 206 /// A flag that is set when the current count should be recorded on the 207 /// next statement, such as at the exit of a loop. 208 bool RecordNextStmtCount; 209 210 /// The count at the current location in the traversal. 211 uint64_t CurrentCount; 212 213 /// The map of statements to count values. 214 llvm::DenseMap<const Stmt *, uint64_t> &CountMap; 215 216 /// BreakContinueStack - Keep counts of breaks and continues inside loops. 217 struct BreakContinue { 218 uint64_t BreakCount; 219 uint64_t ContinueCount; 220 BreakContinue() : BreakCount(0), ContinueCount(0) {} 221 }; 222 SmallVector<BreakContinue, 8> BreakContinueStack; 223 224 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap, 225 CodeGenPGO &PGO) 226 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {} 227 228 void RecordStmtCount(const Stmt *S) { 229 if (RecordNextStmtCount) { 230 CountMap[S] = CurrentCount; 231 RecordNextStmtCount = false; 232 } 233 } 234 235 /// Set and return the current count. 236 uint64_t setCount(uint64_t Count) { 237 CurrentCount = Count; 238 return Count; 239 } 240 241 void VisitStmt(const Stmt *S) { 242 RecordStmtCount(S); 243 for (const Stmt *Child : S->children()) 244 if (Child) 245 this->Visit(Child); 246 } 247 248 void VisitFunctionDecl(const FunctionDecl *D) { 249 // Counter tracks entry to the function body. 250 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 251 CountMap[D->getBody()] = BodyCount; 252 Visit(D->getBody()); 253 } 254 255 // Skip lambda expressions. We visit these as FunctionDecls when we're 256 // generating them and aren't interested in the body when generating a 257 // parent context. 258 void VisitLambdaExpr(const LambdaExpr *LE) {} 259 260 void VisitCapturedDecl(const CapturedDecl *D) { 261 // Counter tracks entry to the capture body. 262 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 263 CountMap[D->getBody()] = BodyCount; 264 Visit(D->getBody()); 265 } 266 267 void VisitObjCMethodDecl(const ObjCMethodDecl *D) { 268 // Counter tracks entry to the method body. 269 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 270 CountMap[D->getBody()] = BodyCount; 271 Visit(D->getBody()); 272 } 273 274 void VisitBlockDecl(const BlockDecl *D) { 275 // Counter tracks entry to the block body. 276 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 277 CountMap[D->getBody()] = BodyCount; 278 Visit(D->getBody()); 279 } 280 281 void VisitReturnStmt(const ReturnStmt *S) { 282 RecordStmtCount(S); 283 if (S->getRetValue()) 284 Visit(S->getRetValue()); 285 CurrentCount = 0; 286 RecordNextStmtCount = true; 287 } 288 289 void VisitCXXThrowExpr(const CXXThrowExpr *E) { 290 RecordStmtCount(E); 291 if (E->getSubExpr()) 292 Visit(E->getSubExpr()); 293 CurrentCount = 0; 294 RecordNextStmtCount = true; 295 } 296 297 void VisitGotoStmt(const GotoStmt *S) { 298 RecordStmtCount(S); 299 CurrentCount = 0; 300 RecordNextStmtCount = true; 301 } 302 303 void VisitLabelStmt(const LabelStmt *S) { 304 RecordNextStmtCount = false; 305 // Counter tracks the block following the label. 306 uint64_t BlockCount = setCount(PGO.getRegionCount(S)); 307 CountMap[S] = BlockCount; 308 Visit(S->getSubStmt()); 309 } 310 311 void VisitBreakStmt(const BreakStmt *S) { 312 RecordStmtCount(S); 313 assert(!BreakContinueStack.empty() && "break not in a loop or switch!"); 314 BreakContinueStack.back().BreakCount += CurrentCount; 315 CurrentCount = 0; 316 RecordNextStmtCount = true; 317 } 318 319 void VisitContinueStmt(const ContinueStmt *S) { 320 RecordStmtCount(S); 321 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!"); 322 BreakContinueStack.back().ContinueCount += CurrentCount; 323 CurrentCount = 0; 324 RecordNextStmtCount = true; 325 } 326 327 void VisitWhileStmt(const WhileStmt *S) { 328 RecordStmtCount(S); 329 uint64_t ParentCount = CurrentCount; 330 331 BreakContinueStack.push_back(BreakContinue()); 332 // Visit the body region first so the break/continue adjustments can be 333 // included when visiting the condition. 334 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 335 CountMap[S->getBody()] = CurrentCount; 336 Visit(S->getBody()); 337 uint64_t BackedgeCount = CurrentCount; 338 339 // ...then go back and propagate counts through the condition. The count 340 // at the start of the condition is the sum of the incoming edges, 341 // the backedge from the end of the loop body, and the edges from 342 // continue statements. 343 BreakContinue BC = BreakContinueStack.pop_back_val(); 344 uint64_t CondCount = 345 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 346 CountMap[S->getCond()] = CondCount; 347 Visit(S->getCond()); 348 setCount(BC.BreakCount + CondCount - BodyCount); 349 RecordNextStmtCount = true; 350 } 351 352 void VisitDoStmt(const DoStmt *S) { 353 RecordStmtCount(S); 354 uint64_t LoopCount = PGO.getRegionCount(S); 355 356 BreakContinueStack.push_back(BreakContinue()); 357 // The count doesn't include the fallthrough from the parent scope. Add it. 358 uint64_t BodyCount = setCount(LoopCount + CurrentCount); 359 CountMap[S->getBody()] = BodyCount; 360 Visit(S->getBody()); 361 uint64_t BackedgeCount = CurrentCount; 362 363 BreakContinue BC = BreakContinueStack.pop_back_val(); 364 // The count at the start of the condition is equal to the count at the 365 // end of the body, plus any continues. 366 uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount); 367 CountMap[S->getCond()] = CondCount; 368 Visit(S->getCond()); 369 setCount(BC.BreakCount + CondCount - LoopCount); 370 RecordNextStmtCount = true; 371 } 372 373 void VisitForStmt(const ForStmt *S) { 374 RecordStmtCount(S); 375 if (S->getInit()) 376 Visit(S->getInit()); 377 378 uint64_t ParentCount = CurrentCount; 379 380 BreakContinueStack.push_back(BreakContinue()); 381 // Visit the body region first. (This is basically the same as a while 382 // loop; see further comments in VisitWhileStmt.) 383 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 384 CountMap[S->getBody()] = BodyCount; 385 Visit(S->getBody()); 386 uint64_t BackedgeCount = CurrentCount; 387 BreakContinue BC = BreakContinueStack.pop_back_val(); 388 389 // The increment is essentially part of the body but it needs to include 390 // the count for all the continue statements. 391 if (S->getInc()) { 392 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); 393 CountMap[S->getInc()] = IncCount; 394 Visit(S->getInc()); 395 } 396 397 // ...then go back and propagate counts through the condition. 398 uint64_t CondCount = 399 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 400 if (S->getCond()) { 401 CountMap[S->getCond()] = CondCount; 402 Visit(S->getCond()); 403 } 404 setCount(BC.BreakCount + CondCount - BodyCount); 405 RecordNextStmtCount = true; 406 } 407 408 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) { 409 RecordStmtCount(S); 410 Visit(S->getLoopVarStmt()); 411 Visit(S->getRangeStmt()); 412 Visit(S->getBeginStmt()); 413 Visit(S->getEndStmt()); 414 415 uint64_t ParentCount = CurrentCount; 416 BreakContinueStack.push_back(BreakContinue()); 417 // Visit the body region first. (This is basically the same as a while 418 // loop; see further comments in VisitWhileStmt.) 419 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 420 CountMap[S->getBody()] = BodyCount; 421 Visit(S->getBody()); 422 uint64_t BackedgeCount = CurrentCount; 423 BreakContinue BC = BreakContinueStack.pop_back_val(); 424 425 // The increment is essentially part of the body but it needs to include 426 // the count for all the continue statements. 427 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); 428 CountMap[S->getInc()] = IncCount; 429 Visit(S->getInc()); 430 431 // ...then go back and propagate counts through the condition. 432 uint64_t CondCount = 433 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 434 CountMap[S->getCond()] = CondCount; 435 Visit(S->getCond()); 436 setCount(BC.BreakCount + CondCount - BodyCount); 437 RecordNextStmtCount = true; 438 } 439 440 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) { 441 RecordStmtCount(S); 442 Visit(S->getElement()); 443 uint64_t ParentCount = CurrentCount; 444 BreakContinueStack.push_back(BreakContinue()); 445 // Counter tracks the body of the loop. 446 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 447 CountMap[S->getBody()] = BodyCount; 448 Visit(S->getBody()); 449 uint64_t BackedgeCount = CurrentCount; 450 BreakContinue BC = BreakContinueStack.pop_back_val(); 451 452 setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount - 453 BodyCount); 454 RecordNextStmtCount = true; 455 } 456 457 void VisitSwitchStmt(const SwitchStmt *S) { 458 RecordStmtCount(S); 459 Visit(S->getCond()); 460 CurrentCount = 0; 461 BreakContinueStack.push_back(BreakContinue()); 462 Visit(S->getBody()); 463 // If the switch is inside a loop, add the continue counts. 464 BreakContinue BC = BreakContinueStack.pop_back_val(); 465 if (!BreakContinueStack.empty()) 466 BreakContinueStack.back().ContinueCount += BC.ContinueCount; 467 // Counter tracks the exit block of the switch. 468 setCount(PGO.getRegionCount(S)); 469 RecordNextStmtCount = true; 470 } 471 472 void VisitSwitchCase(const SwitchCase *S) { 473 RecordNextStmtCount = false; 474 // Counter for this particular case. This counts only jumps from the 475 // switch header and does not include fallthrough from the case before 476 // this one. 477 uint64_t CaseCount = PGO.getRegionCount(S); 478 setCount(CurrentCount + CaseCount); 479 // We need the count without fallthrough in the mapping, so it's more useful 480 // for branch probabilities. 481 CountMap[S] = CaseCount; 482 RecordNextStmtCount = true; 483 Visit(S->getSubStmt()); 484 } 485 486 void VisitIfStmt(const IfStmt *S) { 487 RecordStmtCount(S); 488 uint64_t ParentCount = CurrentCount; 489 Visit(S->getCond()); 490 491 // Counter tracks the "then" part of an if statement. The count for 492 // the "else" part, if it exists, will be calculated from this counter. 493 uint64_t ThenCount = setCount(PGO.getRegionCount(S)); 494 CountMap[S->getThen()] = ThenCount; 495 Visit(S->getThen()); 496 uint64_t OutCount = CurrentCount; 497 498 uint64_t ElseCount = ParentCount - ThenCount; 499 if (S->getElse()) { 500 setCount(ElseCount); 501 CountMap[S->getElse()] = ElseCount; 502 Visit(S->getElse()); 503 OutCount += CurrentCount; 504 } else 505 OutCount += ElseCount; 506 setCount(OutCount); 507 RecordNextStmtCount = true; 508 } 509 510 void VisitCXXTryStmt(const CXXTryStmt *S) { 511 RecordStmtCount(S); 512 Visit(S->getTryBlock()); 513 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I) 514 Visit(S->getHandler(I)); 515 // Counter tracks the continuation block of the try statement. 516 setCount(PGO.getRegionCount(S)); 517 RecordNextStmtCount = true; 518 } 519 520 void VisitCXXCatchStmt(const CXXCatchStmt *S) { 521 RecordNextStmtCount = false; 522 // Counter tracks the catch statement's handler block. 523 uint64_t CatchCount = setCount(PGO.getRegionCount(S)); 524 CountMap[S] = CatchCount; 525 Visit(S->getHandlerBlock()); 526 } 527 528 void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) { 529 RecordStmtCount(E); 530 uint64_t ParentCount = CurrentCount; 531 Visit(E->getCond()); 532 533 // Counter tracks the "true" part of a conditional operator. The 534 // count in the "false" part will be calculated from this counter. 535 uint64_t TrueCount = setCount(PGO.getRegionCount(E)); 536 CountMap[E->getTrueExpr()] = TrueCount; 537 Visit(E->getTrueExpr()); 538 uint64_t OutCount = CurrentCount; 539 540 uint64_t FalseCount = setCount(ParentCount - TrueCount); 541 CountMap[E->getFalseExpr()] = FalseCount; 542 Visit(E->getFalseExpr()); 543 OutCount += CurrentCount; 544 545 setCount(OutCount); 546 RecordNextStmtCount = true; 547 } 548 549 void VisitBinLAnd(const BinaryOperator *E) { 550 RecordStmtCount(E); 551 uint64_t ParentCount = CurrentCount; 552 Visit(E->getLHS()); 553 // Counter tracks the right hand side of a logical and operator. 554 uint64_t RHSCount = setCount(PGO.getRegionCount(E)); 555 CountMap[E->getRHS()] = RHSCount; 556 Visit(E->getRHS()); 557 setCount(ParentCount + RHSCount - CurrentCount); 558 RecordNextStmtCount = true; 559 } 560 561 void VisitBinLOr(const BinaryOperator *E) { 562 RecordStmtCount(E); 563 uint64_t ParentCount = CurrentCount; 564 Visit(E->getLHS()); 565 // Counter tracks the right hand side of a logical or operator. 566 uint64_t RHSCount = setCount(PGO.getRegionCount(E)); 567 CountMap[E->getRHS()] = RHSCount; 568 Visit(E->getRHS()); 569 setCount(ParentCount + RHSCount - CurrentCount); 570 RecordNextStmtCount = true; 571 } 572 }; 573 } // end anonymous namespace 574 575 void PGOHash::combine(HashType Type) { 576 // Check that we never combine 0 and only have six bits. 577 assert(Type && "Hash is invalid: unexpected type 0"); 578 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types"); 579 580 // Pass through MD5 if enough work has built up. 581 if (Count && Count % NumTypesPerWord == 0) { 582 using namespace llvm::support; 583 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working); 584 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped))); 585 Working = 0; 586 } 587 588 // Accumulate the current type. 589 ++Count; 590 Working = Working << NumBitsPerType | Type; 591 } 592 593 uint64_t PGOHash::finalize() { 594 // Use Working as the hash directly if we never used MD5. 595 if (Count <= NumTypesPerWord) 596 // No need to byte swap here, since none of the math was endian-dependent. 597 // This number will be byte-swapped as required on endianness transitions, 598 // so we will see the same value on the other side. 599 return Working; 600 601 // Check for remaining work in Working. 602 if (Working) 603 MD5.update(Working); 604 605 // Finalize the MD5 and return the hash. 606 llvm::MD5::MD5Result Result; 607 MD5.final(Result); 608 using namespace llvm::support; 609 return endian::read<uint64_t, little, unaligned>(Result); 610 } 611 612 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) { 613 const Decl *D = GD.getDecl(); 614 bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr(); 615 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 616 if (!InstrumentRegions && !PGOReader) 617 return; 618 if (D->isImplicit()) 619 return; 620 // Constructors and destructors may be represented by several functions in IR. 621 // If so, instrument only base variant, others are implemented by delegation 622 // to the base one, it would be counted twice otherwise. 623 if (CGM.getTarget().getCXXABI().hasConstructorVariants() && 624 ((isa<CXXConstructorDecl>(GD.getDecl()) && 625 GD.getCtorType() != Ctor_Base) || 626 (isa<CXXDestructorDecl>(GD.getDecl()) && 627 GD.getDtorType() != Dtor_Base))) { 628 return; 629 } 630 CGM.ClearUnusedCoverageMapping(D); 631 setFuncName(Fn); 632 633 mapRegionCounters(D); 634 if (CGM.getCodeGenOpts().CoverageMapping) 635 emitCounterRegionMapping(D); 636 if (PGOReader) { 637 SourceManager &SM = CGM.getContext().getSourceManager(); 638 loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation())); 639 computeRegionCounts(D); 640 applyFunctionAttributes(PGOReader, Fn); 641 } 642 } 643 644 void CodeGenPGO::mapRegionCounters(const Decl *D) { 645 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>); 646 MapRegionCounters Walker(*RegionCounterMap); 647 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 648 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD)); 649 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 650 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD)); 651 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 652 Walker.TraverseDecl(const_cast<BlockDecl *>(BD)); 653 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 654 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD)); 655 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl"); 656 NumRegionCounters = Walker.NextCounter; 657 FunctionHash = Walker.Hash.finalize(); 658 } 659 660 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) { 661 if (SkipCoverageMapping) 662 return; 663 // Don't map the functions inside the system headers 664 auto Loc = D->getBody()->getLocStart(); 665 if (CGM.getContext().getSourceManager().isInSystemHeader(Loc)) 666 return; 667 668 std::string CoverageMapping; 669 llvm::raw_string_ostream OS(CoverageMapping); 670 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(), 671 CGM.getContext().getSourceManager(), 672 CGM.getLangOpts(), RegionCounterMap.get()); 673 MappingGen.emitCounterMapping(D, OS); 674 OS.flush(); 675 676 if (CoverageMapping.empty()) 677 return; 678 679 CGM.getCoverageMapping()->addFunctionMappingRecord( 680 FuncNameVar, FuncName, FunctionHash, CoverageMapping); 681 } 682 683 void 684 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name, 685 llvm::GlobalValue::LinkageTypes Linkage) { 686 if (SkipCoverageMapping) 687 return; 688 // Don't map the functions inside the system headers 689 auto Loc = D->getBody()->getLocStart(); 690 if (CGM.getContext().getSourceManager().isInSystemHeader(Loc)) 691 return; 692 693 std::string CoverageMapping; 694 llvm::raw_string_ostream OS(CoverageMapping); 695 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(), 696 CGM.getContext().getSourceManager(), 697 CGM.getLangOpts()); 698 MappingGen.emitEmptyMapping(D, OS); 699 OS.flush(); 700 701 if (CoverageMapping.empty()) 702 return; 703 704 setFuncName(Name, Linkage); 705 CGM.getCoverageMapping()->addFunctionMappingRecord( 706 FuncNameVar, FuncName, FunctionHash, CoverageMapping, false); 707 } 708 709 void CodeGenPGO::computeRegionCounts(const Decl *D) { 710 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>); 711 ComputeRegionCounts Walker(*StmtCountMap, *this); 712 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 713 Walker.VisitFunctionDecl(FD); 714 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 715 Walker.VisitObjCMethodDecl(MD); 716 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 717 Walker.VisitBlockDecl(BD); 718 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 719 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD)); 720 } 721 722 void 723 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader, 724 llvm::Function *Fn) { 725 if (!haveRegionCounts()) 726 return; 727 728 uint64_t FunctionCount = getRegionCount(nullptr); 729 Fn->setEntryCount(FunctionCount); 730 } 731 732 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S) { 733 if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap) 734 return; 735 if (!Builder.GetInsertBlock()) 736 return; 737 738 unsigned Counter = (*RegionCounterMap)[S]; 739 auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext()); 740 Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment), 741 {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy), 742 Builder.getInt64(FunctionHash), 743 Builder.getInt32(NumRegionCounters), 744 Builder.getInt32(Counter)}); 745 } 746 747 // This method either inserts a call to the profile run-time during 748 // instrumentation or puts profile data into metadata for PGO use. 749 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind, 750 llvm::Instruction *ValueSite, llvm::Value *ValuePtr) { 751 752 if (!EnableValueProfiling) 753 return; 754 755 if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock()) 756 return; 757 758 if (isa<llvm::Constant>(ValuePtr)) 759 return; 760 761 bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr(); 762 if (InstrumentValueSites && RegionCounterMap) { 763 auto BuilderInsertPoint = Builder.saveIP(); 764 Builder.SetInsertPoint(ValueSite); 765 llvm::Value *Args[5] = { 766 llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()), 767 Builder.getInt64(FunctionHash), 768 Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()), 769 Builder.getInt32(ValueKind), 770 Builder.getInt32(NumValueSites[ValueKind]++) 771 }; 772 Builder.CreateCall( 773 CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args); 774 Builder.restoreIP(BuilderInsertPoint); 775 return; 776 } 777 778 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 779 if (PGOReader && haveRegionCounts()) { 780 // We record the top most called three functions at each call site. 781 // Profile metadata contains "VP" string identifying this metadata 782 // as value profiling data, then a uint32_t value for the value profiling 783 // kind, a uint64_t value for the total number of times the call is 784 // executed, followed by the function hash and execution count (uint64_t) 785 // pairs for each function. 786 if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind)) 787 return; 788 789 llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord, 790 (llvm::InstrProfValueKind)ValueKind, 791 NumValueSites[ValueKind]); 792 793 NumValueSites[ValueKind]++; 794 } 795 } 796 797 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader, 798 bool IsInMainFile) { 799 CGM.getPGOStats().addVisited(IsInMainFile); 800 RegionCounts.clear(); 801 llvm::ErrorOr<llvm::InstrProfRecord> RecordErrorOr = 802 PGOReader->getInstrProfRecord(FuncName, FunctionHash); 803 if (std::error_code EC = RecordErrorOr.getError()) { 804 if (EC == llvm::instrprof_error::unknown_function) 805 CGM.getPGOStats().addMissing(IsInMainFile); 806 else if (EC == llvm::instrprof_error::hash_mismatch) 807 CGM.getPGOStats().addMismatched(IsInMainFile); 808 else if (EC == llvm::instrprof_error::malformed) 809 // TODO: Consider a more specific warning for this case. 810 CGM.getPGOStats().addMismatched(IsInMainFile); 811 return; 812 } 813 ProfRecord = 814 llvm::make_unique<llvm::InstrProfRecord>(std::move(RecordErrorOr.get())); 815 RegionCounts = ProfRecord->Counts; 816 } 817 818 /// \brief Calculate what to divide by to scale weights. 819 /// 820 /// Given the maximum weight, calculate a divisor that will scale all the 821 /// weights to strictly less than UINT32_MAX. 822 static uint64_t calculateWeightScale(uint64_t MaxWeight) { 823 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1; 824 } 825 826 /// \brief Scale an individual branch weight (and add 1). 827 /// 828 /// Scale a 64-bit weight down to 32-bits using \c Scale. 829 /// 830 /// According to Laplace's Rule of Succession, it is better to compute the 831 /// weight based on the count plus 1, so universally add 1 to the value. 832 /// 833 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no 834 /// greater than \c Weight. 835 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) { 836 assert(Scale && "scale by 0?"); 837 uint64_t Scaled = Weight / Scale + 1; 838 assert(Scaled <= UINT32_MAX && "overflow 32-bits"); 839 return Scaled; 840 } 841 842 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount, 843 uint64_t FalseCount) { 844 // Check for empty weights. 845 if (!TrueCount && !FalseCount) 846 return nullptr; 847 848 // Calculate how to scale down to 32-bits. 849 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount)); 850 851 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 852 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale), 853 scaleBranchWeight(FalseCount, Scale)); 854 } 855 856 llvm::MDNode * 857 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) { 858 // We need at least two elements to create meaningful weights. 859 if (Weights.size() < 2) 860 return nullptr; 861 862 // Check for empty weights. 863 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end()); 864 if (MaxWeight == 0) 865 return nullptr; 866 867 // Calculate how to scale down to 32-bits. 868 uint64_t Scale = calculateWeightScale(MaxWeight); 869 870 SmallVector<uint32_t, 16> ScaledWeights; 871 ScaledWeights.reserve(Weights.size()); 872 for (uint64_t W : Weights) 873 ScaledWeights.push_back(scaleBranchWeight(W, Scale)); 874 875 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 876 return MDHelper.createBranchWeights(ScaledWeights); 877 } 878 879 llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond, 880 uint64_t LoopCount) { 881 if (!PGO.haveRegionCounts()) 882 return nullptr; 883 Optional<uint64_t> CondCount = PGO.getStmtCount(Cond); 884 assert(CondCount.hasValue() && "missing expected loop condition count"); 885 if (*CondCount == 0) 886 return nullptr; 887 return createProfileWeights(LoopCount, 888 std::max(*CondCount, LoopCount) - LoopCount); 889 } 890