1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // Instrumentation-based profile-guided optimization 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "CodeGenPGO.h" 15 #include "CodeGenFunction.h" 16 #include "CoverageMappingGen.h" 17 #include "clang/AST/RecursiveASTVisitor.h" 18 #include "clang/AST/StmtVisitor.h" 19 #include "llvm/IR/Intrinsics.h" 20 #include "llvm/IR/MDBuilder.h" 21 #include "llvm/ProfileData/InstrProfReader.h" 22 #include "llvm/Support/Endian.h" 23 #include "llvm/Support/FileSystem.h" 24 #include "llvm/Support/MD5.h" 25 26 using namespace clang; 27 using namespace CodeGen; 28 29 void CodeGenPGO::setFuncName(StringRef Name, 30 llvm::GlobalValue::LinkageTypes Linkage) { 31 FuncName = llvm::getPGOFuncName(Name, Linkage, CGM.getCodeGenOpts().MainFileName); 32 33 // If we're generating a profile, create a variable for the name. 34 if (CGM.getCodeGenOpts().ProfileInstrGenerate) 35 FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName); 36 } 37 38 void CodeGenPGO::setFuncName(llvm::Function *Fn) { 39 setFuncName(Fn->getName(), Fn->getLinkage()); 40 } 41 42 namespace { 43 /// \brief Stable hasher for PGO region counters. 44 /// 45 /// PGOHash produces a stable hash of a given function's control flow. 46 /// 47 /// Changing the output of this hash will invalidate all previously generated 48 /// profiles -- i.e., don't do it. 49 /// 50 /// \note When this hash does eventually change (years?), we still need to 51 /// support old hashes. We'll need to pull in the version number from the 52 /// profile data format and use the matching hash function. 53 class PGOHash { 54 uint64_t Working; 55 unsigned Count; 56 llvm::MD5 MD5; 57 58 static const int NumBitsPerType = 6; 59 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType; 60 static const unsigned TooBig = 1u << NumBitsPerType; 61 62 public: 63 /// \brief Hash values for AST nodes. 64 /// 65 /// Distinct values for AST nodes that have region counters attached. 66 /// 67 /// These values must be stable. All new members must be added at the end, 68 /// and no members should be removed. Changing the enumeration value for an 69 /// AST node will affect the hash of every function that contains that node. 70 enum HashType : unsigned char { 71 None = 0, 72 LabelStmt = 1, 73 WhileStmt, 74 DoStmt, 75 ForStmt, 76 CXXForRangeStmt, 77 ObjCForCollectionStmt, 78 SwitchStmt, 79 CaseStmt, 80 DefaultStmt, 81 IfStmt, 82 CXXTryStmt, 83 CXXCatchStmt, 84 ConditionalOperator, 85 BinaryOperatorLAnd, 86 BinaryOperatorLOr, 87 BinaryConditionalOperator, 88 89 // Keep this last. It's for the static assert that follows. 90 LastHashType 91 }; 92 static_assert(LastHashType <= TooBig, "Too many types in HashType"); 93 94 // TODO: When this format changes, take in a version number here, and use the 95 // old hash calculation for file formats that used the old hash. 96 PGOHash() : Working(0), Count(0) {} 97 void combine(HashType Type); 98 uint64_t finalize(); 99 }; 100 const int PGOHash::NumBitsPerType; 101 const unsigned PGOHash::NumTypesPerWord; 102 const unsigned PGOHash::TooBig; 103 104 /// A RecursiveASTVisitor that fills a map of statements to PGO counters. 105 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> { 106 /// The next counter value to assign. 107 unsigned NextCounter; 108 /// The function hash. 109 PGOHash Hash; 110 /// The map of statements to counters. 111 llvm::DenseMap<const Stmt *, unsigned> &CounterMap; 112 113 MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap) 114 : NextCounter(0), CounterMap(CounterMap) {} 115 116 // Blocks and lambdas are handled as separate functions, so we need not 117 // traverse them in the parent context. 118 bool TraverseBlockExpr(BlockExpr *BE) { return true; } 119 bool TraverseLambdaBody(LambdaExpr *LE) { return true; } 120 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; } 121 122 bool VisitDecl(const Decl *D) { 123 switch (D->getKind()) { 124 default: 125 break; 126 case Decl::Function: 127 case Decl::CXXMethod: 128 case Decl::CXXConstructor: 129 case Decl::CXXDestructor: 130 case Decl::CXXConversion: 131 case Decl::ObjCMethod: 132 case Decl::Block: 133 case Decl::Captured: 134 CounterMap[D->getBody()] = NextCounter++; 135 break; 136 } 137 return true; 138 } 139 140 bool VisitStmt(const Stmt *S) { 141 auto Type = getHashType(S); 142 if (Type == PGOHash::None) 143 return true; 144 145 CounterMap[S] = NextCounter++; 146 Hash.combine(Type); 147 return true; 148 } 149 PGOHash::HashType getHashType(const Stmt *S) { 150 switch (S->getStmtClass()) { 151 default: 152 break; 153 case Stmt::LabelStmtClass: 154 return PGOHash::LabelStmt; 155 case Stmt::WhileStmtClass: 156 return PGOHash::WhileStmt; 157 case Stmt::DoStmtClass: 158 return PGOHash::DoStmt; 159 case Stmt::ForStmtClass: 160 return PGOHash::ForStmt; 161 case Stmt::CXXForRangeStmtClass: 162 return PGOHash::CXXForRangeStmt; 163 case Stmt::ObjCForCollectionStmtClass: 164 return PGOHash::ObjCForCollectionStmt; 165 case Stmt::SwitchStmtClass: 166 return PGOHash::SwitchStmt; 167 case Stmt::CaseStmtClass: 168 return PGOHash::CaseStmt; 169 case Stmt::DefaultStmtClass: 170 return PGOHash::DefaultStmt; 171 case Stmt::IfStmtClass: 172 return PGOHash::IfStmt; 173 case Stmt::CXXTryStmtClass: 174 return PGOHash::CXXTryStmt; 175 case Stmt::CXXCatchStmtClass: 176 return PGOHash::CXXCatchStmt; 177 case Stmt::ConditionalOperatorClass: 178 return PGOHash::ConditionalOperator; 179 case Stmt::BinaryConditionalOperatorClass: 180 return PGOHash::BinaryConditionalOperator; 181 case Stmt::BinaryOperatorClass: { 182 const BinaryOperator *BO = cast<BinaryOperator>(S); 183 if (BO->getOpcode() == BO_LAnd) 184 return PGOHash::BinaryOperatorLAnd; 185 if (BO->getOpcode() == BO_LOr) 186 return PGOHash::BinaryOperatorLOr; 187 break; 188 } 189 } 190 return PGOHash::None; 191 } 192 }; 193 194 /// A StmtVisitor that propagates the raw counts through the AST and 195 /// records the count at statements where the value may change. 196 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { 197 /// PGO state. 198 CodeGenPGO &PGO; 199 200 /// A flag that is set when the current count should be recorded on the 201 /// next statement, such as at the exit of a loop. 202 bool RecordNextStmtCount; 203 204 /// The count at the current location in the traversal. 205 uint64_t CurrentCount; 206 207 /// The map of statements to count values. 208 llvm::DenseMap<const Stmt *, uint64_t> &CountMap; 209 210 /// BreakContinueStack - Keep counts of breaks and continues inside loops. 211 struct BreakContinue { 212 uint64_t BreakCount; 213 uint64_t ContinueCount; 214 BreakContinue() : BreakCount(0), ContinueCount(0) {} 215 }; 216 SmallVector<BreakContinue, 8> BreakContinueStack; 217 218 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap, 219 CodeGenPGO &PGO) 220 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {} 221 222 void RecordStmtCount(const Stmt *S) { 223 if (RecordNextStmtCount) { 224 CountMap[S] = CurrentCount; 225 RecordNextStmtCount = false; 226 } 227 } 228 229 /// Set and return the current count. 230 uint64_t setCount(uint64_t Count) { 231 CurrentCount = Count; 232 return Count; 233 } 234 235 void VisitStmt(const Stmt *S) { 236 RecordStmtCount(S); 237 for (const Stmt *Child : S->children()) 238 if (Child) 239 this->Visit(Child); 240 } 241 242 void VisitFunctionDecl(const FunctionDecl *D) { 243 // Counter tracks entry to the function body. 244 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 245 CountMap[D->getBody()] = BodyCount; 246 Visit(D->getBody()); 247 } 248 249 // Skip lambda expressions. We visit these as FunctionDecls when we're 250 // generating them and aren't interested in the body when generating a 251 // parent context. 252 void VisitLambdaExpr(const LambdaExpr *LE) {} 253 254 void VisitCapturedDecl(const CapturedDecl *D) { 255 // Counter tracks entry to the capture body. 256 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 257 CountMap[D->getBody()] = BodyCount; 258 Visit(D->getBody()); 259 } 260 261 void VisitObjCMethodDecl(const ObjCMethodDecl *D) { 262 // Counter tracks entry to the method body. 263 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 264 CountMap[D->getBody()] = BodyCount; 265 Visit(D->getBody()); 266 } 267 268 void VisitBlockDecl(const BlockDecl *D) { 269 // Counter tracks entry to the block body. 270 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 271 CountMap[D->getBody()] = BodyCount; 272 Visit(D->getBody()); 273 } 274 275 void VisitReturnStmt(const ReturnStmt *S) { 276 RecordStmtCount(S); 277 if (S->getRetValue()) 278 Visit(S->getRetValue()); 279 CurrentCount = 0; 280 RecordNextStmtCount = true; 281 } 282 283 void VisitCXXThrowExpr(const CXXThrowExpr *E) { 284 RecordStmtCount(E); 285 if (E->getSubExpr()) 286 Visit(E->getSubExpr()); 287 CurrentCount = 0; 288 RecordNextStmtCount = true; 289 } 290 291 void VisitGotoStmt(const GotoStmt *S) { 292 RecordStmtCount(S); 293 CurrentCount = 0; 294 RecordNextStmtCount = true; 295 } 296 297 void VisitLabelStmt(const LabelStmt *S) { 298 RecordNextStmtCount = false; 299 // Counter tracks the block following the label. 300 uint64_t BlockCount = setCount(PGO.getRegionCount(S)); 301 CountMap[S] = BlockCount; 302 Visit(S->getSubStmt()); 303 } 304 305 void VisitBreakStmt(const BreakStmt *S) { 306 RecordStmtCount(S); 307 assert(!BreakContinueStack.empty() && "break not in a loop or switch!"); 308 BreakContinueStack.back().BreakCount += CurrentCount; 309 CurrentCount = 0; 310 RecordNextStmtCount = true; 311 } 312 313 void VisitContinueStmt(const ContinueStmt *S) { 314 RecordStmtCount(S); 315 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!"); 316 BreakContinueStack.back().ContinueCount += CurrentCount; 317 CurrentCount = 0; 318 RecordNextStmtCount = true; 319 } 320 321 void VisitWhileStmt(const WhileStmt *S) { 322 RecordStmtCount(S); 323 uint64_t ParentCount = CurrentCount; 324 325 BreakContinueStack.push_back(BreakContinue()); 326 // Visit the body region first so the break/continue adjustments can be 327 // included when visiting the condition. 328 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 329 CountMap[S->getBody()] = CurrentCount; 330 Visit(S->getBody()); 331 uint64_t BackedgeCount = CurrentCount; 332 333 // ...then go back and propagate counts through the condition. The count 334 // at the start of the condition is the sum of the incoming edges, 335 // the backedge from the end of the loop body, and the edges from 336 // continue statements. 337 BreakContinue BC = BreakContinueStack.pop_back_val(); 338 uint64_t CondCount = 339 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 340 CountMap[S->getCond()] = CondCount; 341 Visit(S->getCond()); 342 setCount(BC.BreakCount + CondCount - BodyCount); 343 RecordNextStmtCount = true; 344 } 345 346 void VisitDoStmt(const DoStmt *S) { 347 RecordStmtCount(S); 348 uint64_t LoopCount = PGO.getRegionCount(S); 349 350 BreakContinueStack.push_back(BreakContinue()); 351 // The count doesn't include the fallthrough from the parent scope. Add it. 352 uint64_t BodyCount = setCount(LoopCount + CurrentCount); 353 CountMap[S->getBody()] = BodyCount; 354 Visit(S->getBody()); 355 uint64_t BackedgeCount = CurrentCount; 356 357 BreakContinue BC = BreakContinueStack.pop_back_val(); 358 // The count at the start of the condition is equal to the count at the 359 // end of the body, plus any continues. 360 uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount); 361 CountMap[S->getCond()] = CondCount; 362 Visit(S->getCond()); 363 setCount(BC.BreakCount + CondCount - LoopCount); 364 RecordNextStmtCount = true; 365 } 366 367 void VisitForStmt(const ForStmt *S) { 368 RecordStmtCount(S); 369 if (S->getInit()) 370 Visit(S->getInit()); 371 372 uint64_t ParentCount = CurrentCount; 373 374 BreakContinueStack.push_back(BreakContinue()); 375 // Visit the body region first. (This is basically the same as a while 376 // loop; see further comments in VisitWhileStmt.) 377 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 378 CountMap[S->getBody()] = BodyCount; 379 Visit(S->getBody()); 380 uint64_t BackedgeCount = CurrentCount; 381 BreakContinue BC = BreakContinueStack.pop_back_val(); 382 383 // The increment is essentially part of the body but it needs to include 384 // the count for all the continue statements. 385 if (S->getInc()) { 386 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); 387 CountMap[S->getInc()] = IncCount; 388 Visit(S->getInc()); 389 } 390 391 // ...then go back and propagate counts through the condition. 392 uint64_t CondCount = 393 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 394 if (S->getCond()) { 395 CountMap[S->getCond()] = CondCount; 396 Visit(S->getCond()); 397 } 398 setCount(BC.BreakCount + CondCount - BodyCount); 399 RecordNextStmtCount = true; 400 } 401 402 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) { 403 RecordStmtCount(S); 404 Visit(S->getLoopVarStmt()); 405 Visit(S->getRangeStmt()); 406 Visit(S->getBeginEndStmt()); 407 408 uint64_t ParentCount = CurrentCount; 409 BreakContinueStack.push_back(BreakContinue()); 410 // Visit the body region first. (This is basically the same as a while 411 // loop; see further comments in VisitWhileStmt.) 412 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 413 CountMap[S->getBody()] = BodyCount; 414 Visit(S->getBody()); 415 uint64_t BackedgeCount = CurrentCount; 416 BreakContinue BC = BreakContinueStack.pop_back_val(); 417 418 // The increment is essentially part of the body but it needs to include 419 // the count for all the continue statements. 420 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); 421 CountMap[S->getInc()] = IncCount; 422 Visit(S->getInc()); 423 424 // ...then go back and propagate counts through the condition. 425 uint64_t CondCount = 426 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 427 CountMap[S->getCond()] = CondCount; 428 Visit(S->getCond()); 429 setCount(BC.BreakCount + CondCount - BodyCount); 430 RecordNextStmtCount = true; 431 } 432 433 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) { 434 RecordStmtCount(S); 435 Visit(S->getElement()); 436 uint64_t ParentCount = CurrentCount; 437 BreakContinueStack.push_back(BreakContinue()); 438 // Counter tracks the body of the loop. 439 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 440 CountMap[S->getBody()] = BodyCount; 441 Visit(S->getBody()); 442 uint64_t BackedgeCount = CurrentCount; 443 BreakContinue BC = BreakContinueStack.pop_back_val(); 444 445 setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount - 446 BodyCount); 447 RecordNextStmtCount = true; 448 } 449 450 void VisitSwitchStmt(const SwitchStmt *S) { 451 RecordStmtCount(S); 452 Visit(S->getCond()); 453 CurrentCount = 0; 454 BreakContinueStack.push_back(BreakContinue()); 455 Visit(S->getBody()); 456 // If the switch is inside a loop, add the continue counts. 457 BreakContinue BC = BreakContinueStack.pop_back_val(); 458 if (!BreakContinueStack.empty()) 459 BreakContinueStack.back().ContinueCount += BC.ContinueCount; 460 // Counter tracks the exit block of the switch. 461 setCount(PGO.getRegionCount(S)); 462 RecordNextStmtCount = true; 463 } 464 465 void VisitSwitchCase(const SwitchCase *S) { 466 RecordNextStmtCount = false; 467 // Counter for this particular case. This counts only jumps from the 468 // switch header and does not include fallthrough from the case before 469 // this one. 470 uint64_t CaseCount = PGO.getRegionCount(S); 471 setCount(CurrentCount + CaseCount); 472 // We need the count without fallthrough in the mapping, so it's more useful 473 // for branch probabilities. 474 CountMap[S] = CaseCount; 475 RecordNextStmtCount = true; 476 Visit(S->getSubStmt()); 477 } 478 479 void VisitIfStmt(const IfStmt *S) { 480 RecordStmtCount(S); 481 uint64_t ParentCount = CurrentCount; 482 Visit(S->getCond()); 483 484 // Counter tracks the "then" part of an if statement. The count for 485 // the "else" part, if it exists, will be calculated from this counter. 486 uint64_t ThenCount = setCount(PGO.getRegionCount(S)); 487 CountMap[S->getThen()] = ThenCount; 488 Visit(S->getThen()); 489 uint64_t OutCount = CurrentCount; 490 491 uint64_t ElseCount = ParentCount - ThenCount; 492 if (S->getElse()) { 493 setCount(ElseCount); 494 CountMap[S->getElse()] = ElseCount; 495 Visit(S->getElse()); 496 OutCount += CurrentCount; 497 } else 498 OutCount += ElseCount; 499 setCount(OutCount); 500 RecordNextStmtCount = true; 501 } 502 503 void VisitCXXTryStmt(const CXXTryStmt *S) { 504 RecordStmtCount(S); 505 Visit(S->getTryBlock()); 506 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I) 507 Visit(S->getHandler(I)); 508 // Counter tracks the continuation block of the try statement. 509 setCount(PGO.getRegionCount(S)); 510 RecordNextStmtCount = true; 511 } 512 513 void VisitCXXCatchStmt(const CXXCatchStmt *S) { 514 RecordNextStmtCount = false; 515 // Counter tracks the catch statement's handler block. 516 uint64_t CatchCount = setCount(PGO.getRegionCount(S)); 517 CountMap[S] = CatchCount; 518 Visit(S->getHandlerBlock()); 519 } 520 521 void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) { 522 RecordStmtCount(E); 523 uint64_t ParentCount = CurrentCount; 524 Visit(E->getCond()); 525 526 // Counter tracks the "true" part of a conditional operator. The 527 // count in the "false" part will be calculated from this counter. 528 uint64_t TrueCount = setCount(PGO.getRegionCount(E)); 529 CountMap[E->getTrueExpr()] = TrueCount; 530 Visit(E->getTrueExpr()); 531 uint64_t OutCount = CurrentCount; 532 533 uint64_t FalseCount = setCount(ParentCount - TrueCount); 534 CountMap[E->getFalseExpr()] = FalseCount; 535 Visit(E->getFalseExpr()); 536 OutCount += CurrentCount; 537 538 setCount(OutCount); 539 RecordNextStmtCount = true; 540 } 541 542 void VisitBinLAnd(const BinaryOperator *E) { 543 RecordStmtCount(E); 544 uint64_t ParentCount = CurrentCount; 545 Visit(E->getLHS()); 546 // Counter tracks the right hand side of a logical and operator. 547 uint64_t RHSCount = setCount(PGO.getRegionCount(E)); 548 CountMap[E->getRHS()] = RHSCount; 549 Visit(E->getRHS()); 550 setCount(ParentCount + RHSCount - CurrentCount); 551 RecordNextStmtCount = true; 552 } 553 554 void VisitBinLOr(const BinaryOperator *E) { 555 RecordStmtCount(E); 556 uint64_t ParentCount = CurrentCount; 557 Visit(E->getLHS()); 558 // Counter tracks the right hand side of a logical or operator. 559 uint64_t RHSCount = setCount(PGO.getRegionCount(E)); 560 CountMap[E->getRHS()] = RHSCount; 561 Visit(E->getRHS()); 562 setCount(ParentCount + RHSCount - CurrentCount); 563 RecordNextStmtCount = true; 564 } 565 }; 566 } // end anonymous namespace 567 568 void PGOHash::combine(HashType Type) { 569 // Check that we never combine 0 and only have six bits. 570 assert(Type && "Hash is invalid: unexpected type 0"); 571 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types"); 572 573 // Pass through MD5 if enough work has built up. 574 if (Count && Count % NumTypesPerWord == 0) { 575 using namespace llvm::support; 576 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working); 577 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped))); 578 Working = 0; 579 } 580 581 // Accumulate the current type. 582 ++Count; 583 Working = Working << NumBitsPerType | Type; 584 } 585 586 uint64_t PGOHash::finalize() { 587 // Use Working as the hash directly if we never used MD5. 588 if (Count <= NumTypesPerWord) 589 // No need to byte swap here, since none of the math was endian-dependent. 590 // This number will be byte-swapped as required on endianness transitions, 591 // so we will see the same value on the other side. 592 return Working; 593 594 // Check for remaining work in Working. 595 if (Working) 596 MD5.update(Working); 597 598 // Finalize the MD5 and return the hash. 599 llvm::MD5::MD5Result Result; 600 MD5.final(Result); 601 using namespace llvm::support; 602 return endian::read<uint64_t, little, unaligned>(Result); 603 } 604 605 void CodeGenPGO::checkGlobalDecl(GlobalDecl GD) { 606 // Make sure we only emit coverage mapping for one constructor/destructor. 607 // Clang emits several functions for the constructor and the destructor of 608 // a class. Every function is instrumented, but we only want to provide 609 // coverage for one of them. Because of that we only emit the coverage mapping 610 // for the base constructor/destructor. 611 if ((isa<CXXConstructorDecl>(GD.getDecl()) && 612 GD.getCtorType() != Ctor_Base) || 613 (isa<CXXDestructorDecl>(GD.getDecl()) && 614 GD.getDtorType() != Dtor_Base)) { 615 SkipCoverageMapping = true; 616 } 617 } 618 619 void CodeGenPGO::assignRegionCounters(const Decl *D, llvm::Function *Fn) { 620 bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate; 621 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 622 if (!InstrumentRegions && !PGOReader) 623 return; 624 if (D->isImplicit()) 625 return; 626 CGM.ClearUnusedCoverageMapping(D); 627 setFuncName(Fn); 628 629 mapRegionCounters(D); 630 if (CGM.getCodeGenOpts().CoverageMapping) 631 emitCounterRegionMapping(D); 632 if (PGOReader) { 633 SourceManager &SM = CGM.getContext().getSourceManager(); 634 loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation())); 635 computeRegionCounts(D); 636 applyFunctionAttributes(PGOReader, Fn); 637 } 638 } 639 640 void CodeGenPGO::mapRegionCounters(const Decl *D) { 641 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>); 642 MapRegionCounters Walker(*RegionCounterMap); 643 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 644 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD)); 645 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 646 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD)); 647 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 648 Walker.TraverseDecl(const_cast<BlockDecl *>(BD)); 649 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 650 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD)); 651 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl"); 652 NumRegionCounters = Walker.NextCounter; 653 FunctionHash = Walker.Hash.finalize(); 654 } 655 656 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) { 657 if (SkipCoverageMapping) 658 return; 659 // Don't map the functions inside the system headers 660 auto Loc = D->getBody()->getLocStart(); 661 if (CGM.getContext().getSourceManager().isInSystemHeader(Loc)) 662 return; 663 664 std::string CoverageMapping; 665 llvm::raw_string_ostream OS(CoverageMapping); 666 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(), 667 CGM.getContext().getSourceManager(), 668 CGM.getLangOpts(), RegionCounterMap.get()); 669 MappingGen.emitCounterMapping(D, OS); 670 OS.flush(); 671 672 if (CoverageMapping.empty()) 673 return; 674 675 CGM.getCoverageMapping()->addFunctionMappingRecord( 676 FuncNameVar, FuncName, FunctionHash, CoverageMapping); 677 } 678 679 void 680 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name, 681 llvm::GlobalValue::LinkageTypes Linkage) { 682 if (SkipCoverageMapping) 683 return; 684 // Don't map the functions inside the system headers 685 auto Loc = D->getBody()->getLocStart(); 686 if (CGM.getContext().getSourceManager().isInSystemHeader(Loc)) 687 return; 688 689 std::string CoverageMapping; 690 llvm::raw_string_ostream OS(CoverageMapping); 691 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(), 692 CGM.getContext().getSourceManager(), 693 CGM.getLangOpts()); 694 MappingGen.emitEmptyMapping(D, OS); 695 OS.flush(); 696 697 if (CoverageMapping.empty()) 698 return; 699 700 setFuncName(Name, Linkage); 701 CGM.getCoverageMapping()->addFunctionMappingRecord( 702 FuncNameVar, FuncName, FunctionHash, CoverageMapping); 703 } 704 705 void CodeGenPGO::computeRegionCounts(const Decl *D) { 706 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>); 707 ComputeRegionCounts Walker(*StmtCountMap, *this); 708 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 709 Walker.VisitFunctionDecl(FD); 710 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 711 Walker.VisitObjCMethodDecl(MD); 712 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 713 Walker.VisitBlockDecl(BD); 714 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 715 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD)); 716 } 717 718 void 719 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader, 720 llvm::Function *Fn) { 721 if (!haveRegionCounts()) 722 return; 723 724 uint64_t MaxFunctionCount = PGOReader->getMaximumFunctionCount(); 725 uint64_t FunctionCount = getRegionCount(nullptr); 726 if (FunctionCount >= (uint64_t)(0.3 * (double)MaxFunctionCount)) 727 // Turn on InlineHint attribute for hot functions. 728 // FIXME: 30% is from preliminary tuning on SPEC, it may not be optimal. 729 Fn->addFnAttr(llvm::Attribute::InlineHint); 730 else if (FunctionCount <= (uint64_t)(0.01 * (double)MaxFunctionCount)) 731 // Turn on Cold attribute for cold functions. 732 // FIXME: 1% is from preliminary tuning on SPEC, it may not be optimal. 733 Fn->addFnAttr(llvm::Attribute::Cold); 734 735 Fn->setEntryCount(FunctionCount); 736 } 737 738 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S) { 739 if (!CGM.getCodeGenOpts().ProfileInstrGenerate || !RegionCounterMap) 740 return; 741 if (!Builder.GetInsertBlock()) 742 return; 743 744 unsigned Counter = (*RegionCounterMap)[S]; 745 auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext()); 746 Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment), 747 {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy), 748 Builder.getInt64(FunctionHash), 749 Builder.getInt32(NumRegionCounters), 750 Builder.getInt32(Counter)}); 751 } 752 753 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader, 754 bool IsInMainFile) { 755 CGM.getPGOStats().addVisited(IsInMainFile); 756 RegionCounts.clear(); 757 if (std::error_code EC = 758 PGOReader->getFunctionCounts(FuncName, FunctionHash, RegionCounts)) { 759 if (EC == llvm::instrprof_error::unknown_function) 760 CGM.getPGOStats().addMissing(IsInMainFile); 761 else if (EC == llvm::instrprof_error::hash_mismatch) 762 CGM.getPGOStats().addMismatched(IsInMainFile); 763 else if (EC == llvm::instrprof_error::malformed) 764 // TODO: Consider a more specific warning for this case. 765 CGM.getPGOStats().addMismatched(IsInMainFile); 766 RegionCounts.clear(); 767 } 768 } 769 770 /// \brief Calculate what to divide by to scale weights. 771 /// 772 /// Given the maximum weight, calculate a divisor that will scale all the 773 /// weights to strictly less than UINT32_MAX. 774 static uint64_t calculateWeightScale(uint64_t MaxWeight) { 775 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1; 776 } 777 778 /// \brief Scale an individual branch weight (and add 1). 779 /// 780 /// Scale a 64-bit weight down to 32-bits using \c Scale. 781 /// 782 /// According to Laplace's Rule of Succession, it is better to compute the 783 /// weight based on the count plus 1, so universally add 1 to the value. 784 /// 785 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no 786 /// greater than \c Weight. 787 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) { 788 assert(Scale && "scale by 0?"); 789 uint64_t Scaled = Weight / Scale + 1; 790 assert(Scaled <= UINT32_MAX && "overflow 32-bits"); 791 return Scaled; 792 } 793 794 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount, 795 uint64_t FalseCount) { 796 // Check for empty weights. 797 if (!TrueCount && !FalseCount) 798 return nullptr; 799 800 // Calculate how to scale down to 32-bits. 801 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount)); 802 803 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 804 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale), 805 scaleBranchWeight(FalseCount, Scale)); 806 } 807 808 llvm::MDNode * 809 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) { 810 // We need at least two elements to create meaningful weights. 811 if (Weights.size() < 2) 812 return nullptr; 813 814 // Check for empty weights. 815 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end()); 816 if (MaxWeight == 0) 817 return nullptr; 818 819 // Calculate how to scale down to 32-bits. 820 uint64_t Scale = calculateWeightScale(MaxWeight); 821 822 SmallVector<uint32_t, 16> ScaledWeights; 823 ScaledWeights.reserve(Weights.size()); 824 for (uint64_t W : Weights) 825 ScaledWeights.push_back(scaleBranchWeight(W, Scale)); 826 827 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 828 return MDHelper.createBranchWeights(ScaledWeights); 829 } 830 831 llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond, 832 uint64_t LoopCount) { 833 if (!PGO.haveRegionCounts()) 834 return nullptr; 835 Optional<uint64_t> CondCount = PGO.getStmtCount(Cond); 836 assert(CondCount.hasValue() && "missing expected loop condition count"); 837 if (*CondCount == 0) 838 return nullptr; 839 return createProfileWeights(LoopCount, 840 std::max(*CondCount, LoopCount) - LoopCount); 841 } 842