1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // Instrumentation-based profile-guided optimization 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "CodeGenPGO.h" 15 #include "CodeGenFunction.h" 16 #include "CoverageMappingGen.h" 17 #include "clang/AST/RecursiveASTVisitor.h" 18 #include "clang/AST/StmtVisitor.h" 19 #include "llvm/IR/Intrinsics.h" 20 #include "llvm/IR/MDBuilder.h" 21 #include "llvm/ProfileData/InstrProfReader.h" 22 #include "llvm/Support/Endian.h" 23 #include "llvm/Support/FileSystem.h" 24 #include "llvm/Support/MD5.h" 25 26 using namespace clang; 27 using namespace CodeGen; 28 29 void CodeGenPGO::setFuncName(StringRef Name, 30 llvm::GlobalValue::LinkageTypes Linkage) { 31 StringRef RawFuncName = Name; 32 33 // Function names may be prefixed with a binary '1' to indicate 34 // that the backend should not modify the symbols due to any platform 35 // naming convention. Do not include that '1' in the PGO profile name. 36 if (RawFuncName[0] == '\1') 37 RawFuncName = RawFuncName.substr(1); 38 39 FuncName = RawFuncName; 40 if (llvm::GlobalValue::isLocalLinkage(Linkage)) { 41 // For local symbols, prepend the main file name to distinguish them. 42 // Do not include the full path in the file name since there's no guarantee 43 // that it will stay the same, e.g., if the files are checked out from 44 // version control in different locations. 45 if (CGM.getCodeGenOpts().MainFileName.empty()) 46 FuncName = FuncName.insert(0, "<unknown>:"); 47 else 48 FuncName = FuncName.insert(0, CGM.getCodeGenOpts().MainFileName + ":"); 49 } 50 51 // If we're generating a profile, create a variable for the name. 52 if (CGM.getCodeGenOpts().ProfileInstrGenerate) 53 createFuncNameVar(Linkage); 54 } 55 56 void CodeGenPGO::setFuncName(llvm::Function *Fn) { 57 setFuncName(Fn->getName(), Fn->getLinkage()); 58 } 59 60 void CodeGenPGO::createFuncNameVar(llvm::GlobalValue::LinkageTypes Linkage) { 61 // We generally want to match the function's linkage, but available_externally 62 // and extern_weak both have the wrong semantics, and anything that doesn't 63 // need to link across compilation units doesn't need to be visible at all. 64 if (Linkage == llvm::GlobalValue::ExternalWeakLinkage) 65 Linkage = llvm::GlobalValue::LinkOnceAnyLinkage; 66 else if (Linkage == llvm::GlobalValue::AvailableExternallyLinkage) 67 Linkage = llvm::GlobalValue::LinkOnceODRLinkage; 68 else if (Linkage == llvm::GlobalValue::InternalLinkage || 69 Linkage == llvm::GlobalValue::ExternalLinkage) 70 Linkage = llvm::GlobalValue::PrivateLinkage; 71 72 auto *Value = 73 llvm::ConstantDataArray::getString(CGM.getLLVMContext(), FuncName, false); 74 FuncNameVar = 75 new llvm::GlobalVariable(CGM.getModule(), Value->getType(), true, Linkage, 76 Value, "__llvm_profile_name_" + FuncName); 77 78 // Hide the symbol so that we correctly get a copy for each executable. 79 if (!llvm::GlobalValue::isLocalLinkage(FuncNameVar->getLinkage())) 80 FuncNameVar->setVisibility(llvm::GlobalValue::HiddenVisibility); 81 } 82 83 namespace { 84 /// \brief Stable hasher for PGO region counters. 85 /// 86 /// PGOHash produces a stable hash of a given function's control flow. 87 /// 88 /// Changing the output of this hash will invalidate all previously generated 89 /// profiles -- i.e., don't do it. 90 /// 91 /// \note When this hash does eventually change (years?), we still need to 92 /// support old hashes. We'll need to pull in the version number from the 93 /// profile data format and use the matching hash function. 94 class PGOHash { 95 uint64_t Working; 96 unsigned Count; 97 llvm::MD5 MD5; 98 99 static const int NumBitsPerType = 6; 100 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType; 101 static const unsigned TooBig = 1u << NumBitsPerType; 102 103 public: 104 /// \brief Hash values for AST nodes. 105 /// 106 /// Distinct values for AST nodes that have region counters attached. 107 /// 108 /// These values must be stable. All new members must be added at the end, 109 /// and no members should be removed. Changing the enumeration value for an 110 /// AST node will affect the hash of every function that contains that node. 111 enum HashType : unsigned char { 112 None = 0, 113 LabelStmt = 1, 114 WhileStmt, 115 DoStmt, 116 ForStmt, 117 CXXForRangeStmt, 118 ObjCForCollectionStmt, 119 SwitchStmt, 120 CaseStmt, 121 DefaultStmt, 122 IfStmt, 123 CXXTryStmt, 124 CXXCatchStmt, 125 ConditionalOperator, 126 BinaryOperatorLAnd, 127 BinaryOperatorLOr, 128 BinaryConditionalOperator, 129 130 // Keep this last. It's for the static assert that follows. 131 LastHashType 132 }; 133 static_assert(LastHashType <= TooBig, "Too many types in HashType"); 134 135 // TODO: When this format changes, take in a version number here, and use the 136 // old hash calculation for file formats that used the old hash. 137 PGOHash() : Working(0), Count(0) {} 138 void combine(HashType Type); 139 uint64_t finalize(); 140 }; 141 const int PGOHash::NumBitsPerType; 142 const unsigned PGOHash::NumTypesPerWord; 143 const unsigned PGOHash::TooBig; 144 145 /// A RecursiveASTVisitor that fills a map of statements to PGO counters. 146 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> { 147 /// The next counter value to assign. 148 unsigned NextCounter; 149 /// The function hash. 150 PGOHash Hash; 151 /// The map of statements to counters. 152 llvm::DenseMap<const Stmt *, unsigned> &CounterMap; 153 154 MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap) 155 : NextCounter(0), CounterMap(CounterMap) {} 156 157 // Blocks and lambdas are handled as separate functions, so we need not 158 // traverse them in the parent context. 159 bool TraverseBlockExpr(BlockExpr *BE) { return true; } 160 bool TraverseLambdaBody(LambdaExpr *LE) { return true; } 161 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; } 162 163 bool VisitDecl(const Decl *D) { 164 switch (D->getKind()) { 165 default: 166 break; 167 case Decl::Function: 168 case Decl::CXXMethod: 169 case Decl::CXXConstructor: 170 case Decl::CXXDestructor: 171 case Decl::CXXConversion: 172 case Decl::ObjCMethod: 173 case Decl::Block: 174 case Decl::Captured: 175 CounterMap[D->getBody()] = NextCounter++; 176 break; 177 } 178 return true; 179 } 180 181 bool VisitStmt(const Stmt *S) { 182 auto Type = getHashType(S); 183 if (Type == PGOHash::None) 184 return true; 185 186 CounterMap[S] = NextCounter++; 187 Hash.combine(Type); 188 return true; 189 } 190 PGOHash::HashType getHashType(const Stmt *S) { 191 switch (S->getStmtClass()) { 192 default: 193 break; 194 case Stmt::LabelStmtClass: 195 return PGOHash::LabelStmt; 196 case Stmt::WhileStmtClass: 197 return PGOHash::WhileStmt; 198 case Stmt::DoStmtClass: 199 return PGOHash::DoStmt; 200 case Stmt::ForStmtClass: 201 return PGOHash::ForStmt; 202 case Stmt::CXXForRangeStmtClass: 203 return PGOHash::CXXForRangeStmt; 204 case Stmt::ObjCForCollectionStmtClass: 205 return PGOHash::ObjCForCollectionStmt; 206 case Stmt::SwitchStmtClass: 207 return PGOHash::SwitchStmt; 208 case Stmt::CaseStmtClass: 209 return PGOHash::CaseStmt; 210 case Stmt::DefaultStmtClass: 211 return PGOHash::DefaultStmt; 212 case Stmt::IfStmtClass: 213 return PGOHash::IfStmt; 214 case Stmt::CXXTryStmtClass: 215 return PGOHash::CXXTryStmt; 216 case Stmt::CXXCatchStmtClass: 217 return PGOHash::CXXCatchStmt; 218 case Stmt::ConditionalOperatorClass: 219 return PGOHash::ConditionalOperator; 220 case Stmt::BinaryConditionalOperatorClass: 221 return PGOHash::BinaryConditionalOperator; 222 case Stmt::BinaryOperatorClass: { 223 const BinaryOperator *BO = cast<BinaryOperator>(S); 224 if (BO->getOpcode() == BO_LAnd) 225 return PGOHash::BinaryOperatorLAnd; 226 if (BO->getOpcode() == BO_LOr) 227 return PGOHash::BinaryOperatorLOr; 228 break; 229 } 230 } 231 return PGOHash::None; 232 } 233 }; 234 235 /// A StmtVisitor that propagates the raw counts through the AST and 236 /// records the count at statements where the value may change. 237 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { 238 /// PGO state. 239 CodeGenPGO &PGO; 240 241 /// A flag that is set when the current count should be recorded on the 242 /// next statement, such as at the exit of a loop. 243 bool RecordNextStmtCount; 244 245 /// The count at the current location in the traversal. 246 uint64_t CurrentCount; 247 248 /// The map of statements to count values. 249 llvm::DenseMap<const Stmt *, uint64_t> &CountMap; 250 251 /// BreakContinueStack - Keep counts of breaks and continues inside loops. 252 struct BreakContinue { 253 uint64_t BreakCount; 254 uint64_t ContinueCount; 255 BreakContinue() : BreakCount(0), ContinueCount(0) {} 256 }; 257 SmallVector<BreakContinue, 8> BreakContinueStack; 258 259 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap, 260 CodeGenPGO &PGO) 261 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {} 262 263 void RecordStmtCount(const Stmt *S) { 264 if (RecordNextStmtCount) { 265 CountMap[S] = CurrentCount; 266 RecordNextStmtCount = false; 267 } 268 } 269 270 /// Set and return the current count. 271 uint64_t setCount(uint64_t Count) { 272 CurrentCount = Count; 273 return Count; 274 } 275 276 void VisitStmt(const Stmt *S) { 277 RecordStmtCount(S); 278 for (Stmt::const_child_range I = S->children(); I; ++I) { 279 if (*I) 280 this->Visit(*I); 281 } 282 } 283 284 void VisitFunctionDecl(const FunctionDecl *D) { 285 // Counter tracks entry to the function body. 286 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 287 CountMap[D->getBody()] = BodyCount; 288 Visit(D->getBody()); 289 } 290 291 // Skip lambda expressions. We visit these as FunctionDecls when we're 292 // generating them and aren't interested in the body when generating a 293 // parent context. 294 void VisitLambdaExpr(const LambdaExpr *LE) {} 295 296 void VisitCapturedDecl(const CapturedDecl *D) { 297 // Counter tracks entry to the capture body. 298 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 299 CountMap[D->getBody()] = BodyCount; 300 Visit(D->getBody()); 301 } 302 303 void VisitObjCMethodDecl(const ObjCMethodDecl *D) { 304 // Counter tracks entry to the method body. 305 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 306 CountMap[D->getBody()] = BodyCount; 307 Visit(D->getBody()); 308 } 309 310 void VisitBlockDecl(const BlockDecl *D) { 311 // Counter tracks entry to the block body. 312 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 313 CountMap[D->getBody()] = BodyCount; 314 Visit(D->getBody()); 315 } 316 317 void VisitReturnStmt(const ReturnStmt *S) { 318 RecordStmtCount(S); 319 if (S->getRetValue()) 320 Visit(S->getRetValue()); 321 CurrentCount = 0; 322 RecordNextStmtCount = true; 323 } 324 325 void VisitCXXThrowExpr(const CXXThrowExpr *E) { 326 RecordStmtCount(E); 327 if (E->getSubExpr()) 328 Visit(E->getSubExpr()); 329 CurrentCount = 0; 330 RecordNextStmtCount = true; 331 } 332 333 void VisitGotoStmt(const GotoStmt *S) { 334 RecordStmtCount(S); 335 CurrentCount = 0; 336 RecordNextStmtCount = true; 337 } 338 339 void VisitLabelStmt(const LabelStmt *S) { 340 RecordNextStmtCount = false; 341 // Counter tracks the block following the label. 342 uint64_t BlockCount = setCount(PGO.getRegionCount(S)); 343 CountMap[S] = BlockCount; 344 Visit(S->getSubStmt()); 345 } 346 347 void VisitBreakStmt(const BreakStmt *S) { 348 RecordStmtCount(S); 349 assert(!BreakContinueStack.empty() && "break not in a loop or switch!"); 350 BreakContinueStack.back().BreakCount += CurrentCount; 351 CurrentCount = 0; 352 RecordNextStmtCount = true; 353 } 354 355 void VisitContinueStmt(const ContinueStmt *S) { 356 RecordStmtCount(S); 357 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!"); 358 BreakContinueStack.back().ContinueCount += CurrentCount; 359 CurrentCount = 0; 360 RecordNextStmtCount = true; 361 } 362 363 void VisitWhileStmt(const WhileStmt *S) { 364 RecordStmtCount(S); 365 uint64_t ParentCount = CurrentCount; 366 367 BreakContinueStack.push_back(BreakContinue()); 368 // Visit the body region first so the break/continue adjustments can be 369 // included when visiting the condition. 370 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 371 CountMap[S->getBody()] = CurrentCount; 372 Visit(S->getBody()); 373 uint64_t BackedgeCount = CurrentCount; 374 375 // ...then go back and propagate counts through the condition. The count 376 // at the start of the condition is the sum of the incoming edges, 377 // the backedge from the end of the loop body, and the edges from 378 // continue statements. 379 BreakContinue BC = BreakContinueStack.pop_back_val(); 380 uint64_t CondCount = 381 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 382 CountMap[S->getCond()] = CondCount; 383 Visit(S->getCond()); 384 setCount(BC.BreakCount + CondCount - BodyCount); 385 RecordNextStmtCount = true; 386 } 387 388 void VisitDoStmt(const DoStmt *S) { 389 RecordStmtCount(S); 390 uint64_t LoopCount = PGO.getRegionCount(S); 391 392 BreakContinueStack.push_back(BreakContinue()); 393 // The count doesn't include the fallthrough from the parent scope. Add it. 394 uint64_t BodyCount = setCount(LoopCount + CurrentCount); 395 CountMap[S->getBody()] = BodyCount; 396 Visit(S->getBody()); 397 uint64_t BackedgeCount = CurrentCount; 398 399 BreakContinue BC = BreakContinueStack.pop_back_val(); 400 // The count at the start of the condition is equal to the count at the 401 // end of the body, plus any continues. 402 uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount); 403 CountMap[S->getCond()] = CondCount; 404 Visit(S->getCond()); 405 setCount(BC.BreakCount + CondCount - LoopCount); 406 RecordNextStmtCount = true; 407 } 408 409 void VisitForStmt(const ForStmt *S) { 410 RecordStmtCount(S); 411 if (S->getInit()) 412 Visit(S->getInit()); 413 414 uint64_t ParentCount = CurrentCount; 415 416 BreakContinueStack.push_back(BreakContinue()); 417 // Visit the body region first. (This is basically the same as a while 418 // loop; see further comments in VisitWhileStmt.) 419 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 420 CountMap[S->getBody()] = BodyCount; 421 Visit(S->getBody()); 422 uint64_t BackedgeCount = CurrentCount; 423 BreakContinue BC = BreakContinueStack.pop_back_val(); 424 425 // The increment is essentially part of the body but it needs to include 426 // the count for all the continue statements. 427 if (S->getInc()) { 428 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); 429 CountMap[S->getInc()] = IncCount; 430 Visit(S->getInc()); 431 } 432 433 // ...then go back and propagate counts through the condition. 434 uint64_t CondCount = 435 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 436 if (S->getCond()) { 437 CountMap[S->getCond()] = CondCount; 438 Visit(S->getCond()); 439 } 440 setCount(BC.BreakCount + CondCount - BodyCount); 441 RecordNextStmtCount = true; 442 } 443 444 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) { 445 RecordStmtCount(S); 446 Visit(S->getLoopVarStmt()); 447 Visit(S->getRangeStmt()); 448 Visit(S->getBeginEndStmt()); 449 450 uint64_t ParentCount = CurrentCount; 451 BreakContinueStack.push_back(BreakContinue()); 452 // Visit the body region first. (This is basically the same as a while 453 // loop; see further comments in VisitWhileStmt.) 454 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 455 CountMap[S->getBody()] = BodyCount; 456 Visit(S->getBody()); 457 uint64_t BackedgeCount = CurrentCount; 458 BreakContinue BC = BreakContinueStack.pop_back_val(); 459 460 // The increment is essentially part of the body but it needs to include 461 // the count for all the continue statements. 462 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); 463 CountMap[S->getInc()] = IncCount; 464 Visit(S->getInc()); 465 466 // ...then go back and propagate counts through the condition. 467 uint64_t CondCount = 468 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 469 CountMap[S->getCond()] = CondCount; 470 Visit(S->getCond()); 471 setCount(BC.BreakCount + CondCount - BodyCount); 472 RecordNextStmtCount = true; 473 } 474 475 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) { 476 RecordStmtCount(S); 477 Visit(S->getElement()); 478 uint64_t ParentCount = CurrentCount; 479 BreakContinueStack.push_back(BreakContinue()); 480 // Counter tracks the body of the loop. 481 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 482 CountMap[S->getBody()] = BodyCount; 483 Visit(S->getBody()); 484 uint64_t BackedgeCount = CurrentCount; 485 BreakContinue BC = BreakContinueStack.pop_back_val(); 486 487 setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount - 488 BodyCount); 489 RecordNextStmtCount = true; 490 } 491 492 void VisitSwitchStmt(const SwitchStmt *S) { 493 RecordStmtCount(S); 494 Visit(S->getCond()); 495 CurrentCount = 0; 496 BreakContinueStack.push_back(BreakContinue()); 497 Visit(S->getBody()); 498 // If the switch is inside a loop, add the continue counts. 499 BreakContinue BC = BreakContinueStack.pop_back_val(); 500 if (!BreakContinueStack.empty()) 501 BreakContinueStack.back().ContinueCount += BC.ContinueCount; 502 // Counter tracks the exit block of the switch. 503 setCount(PGO.getRegionCount(S)); 504 RecordNextStmtCount = true; 505 } 506 507 void VisitSwitchCase(const SwitchCase *S) { 508 RecordNextStmtCount = false; 509 // Counter for this particular case. This counts only jumps from the 510 // switch header and does not include fallthrough from the case before 511 // this one. 512 uint64_t CaseCount = PGO.getRegionCount(S); 513 setCount(CurrentCount + CaseCount); 514 // We need the count without fallthrough in the mapping, so it's more useful 515 // for branch probabilities. 516 CountMap[S] = CaseCount; 517 RecordNextStmtCount = true; 518 Visit(S->getSubStmt()); 519 } 520 521 void VisitIfStmt(const IfStmt *S) { 522 RecordStmtCount(S); 523 uint64_t ParentCount = CurrentCount; 524 Visit(S->getCond()); 525 526 // Counter tracks the "then" part of an if statement. The count for 527 // the "else" part, if it exists, will be calculated from this counter. 528 uint64_t ThenCount = setCount(PGO.getRegionCount(S)); 529 CountMap[S->getThen()] = ThenCount; 530 Visit(S->getThen()); 531 uint64_t OutCount = CurrentCount; 532 533 uint64_t ElseCount = ParentCount - ThenCount; 534 if (S->getElse()) { 535 setCount(ElseCount); 536 CountMap[S->getElse()] = ElseCount; 537 Visit(S->getElse()); 538 OutCount += CurrentCount; 539 } else 540 OutCount += ElseCount; 541 setCount(OutCount); 542 RecordNextStmtCount = true; 543 } 544 545 void VisitCXXTryStmt(const CXXTryStmt *S) { 546 RecordStmtCount(S); 547 Visit(S->getTryBlock()); 548 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I) 549 Visit(S->getHandler(I)); 550 // Counter tracks the continuation block of the try statement. 551 setCount(PGO.getRegionCount(S)); 552 RecordNextStmtCount = true; 553 } 554 555 void VisitCXXCatchStmt(const CXXCatchStmt *S) { 556 RecordNextStmtCount = false; 557 // Counter tracks the catch statement's handler block. 558 uint64_t CatchCount = setCount(PGO.getRegionCount(S)); 559 CountMap[S] = CatchCount; 560 Visit(S->getHandlerBlock()); 561 } 562 563 void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) { 564 RecordStmtCount(E); 565 uint64_t ParentCount = CurrentCount; 566 Visit(E->getCond()); 567 568 // Counter tracks the "true" part of a conditional operator. The 569 // count in the "false" part will be calculated from this counter. 570 uint64_t TrueCount = setCount(PGO.getRegionCount(E)); 571 CountMap[E->getTrueExpr()] = TrueCount; 572 Visit(E->getTrueExpr()); 573 uint64_t OutCount = CurrentCount; 574 575 uint64_t FalseCount = setCount(ParentCount - TrueCount); 576 CountMap[E->getFalseExpr()] = FalseCount; 577 Visit(E->getFalseExpr()); 578 OutCount += CurrentCount; 579 580 setCount(OutCount); 581 RecordNextStmtCount = true; 582 } 583 584 void VisitBinLAnd(const BinaryOperator *E) { 585 RecordStmtCount(E); 586 uint64_t ParentCount = CurrentCount; 587 Visit(E->getLHS()); 588 // Counter tracks the right hand side of a logical and operator. 589 uint64_t RHSCount = setCount(PGO.getRegionCount(E)); 590 CountMap[E->getRHS()] = RHSCount; 591 Visit(E->getRHS()); 592 setCount(ParentCount + RHSCount - CurrentCount); 593 RecordNextStmtCount = true; 594 } 595 596 void VisitBinLOr(const BinaryOperator *E) { 597 RecordStmtCount(E); 598 uint64_t ParentCount = CurrentCount; 599 Visit(E->getLHS()); 600 // Counter tracks the right hand side of a logical or operator. 601 uint64_t RHSCount = setCount(PGO.getRegionCount(E)); 602 CountMap[E->getRHS()] = RHSCount; 603 Visit(E->getRHS()); 604 setCount(ParentCount + RHSCount - CurrentCount); 605 RecordNextStmtCount = true; 606 } 607 }; 608 } 609 610 void PGOHash::combine(HashType Type) { 611 // Check that we never combine 0 and only have six bits. 612 assert(Type && "Hash is invalid: unexpected type 0"); 613 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types"); 614 615 // Pass through MD5 if enough work has built up. 616 if (Count && Count % NumTypesPerWord == 0) { 617 using namespace llvm::support; 618 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working); 619 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped))); 620 Working = 0; 621 } 622 623 // Accumulate the current type. 624 ++Count; 625 Working = Working << NumBitsPerType | Type; 626 } 627 628 uint64_t PGOHash::finalize() { 629 // Use Working as the hash directly if we never used MD5. 630 if (Count <= NumTypesPerWord) 631 // No need to byte swap here, since none of the math was endian-dependent. 632 // This number will be byte-swapped as required on endianness transitions, 633 // so we will see the same value on the other side. 634 return Working; 635 636 // Check for remaining work in Working. 637 if (Working) 638 MD5.update(Working); 639 640 // Finalize the MD5 and return the hash. 641 llvm::MD5::MD5Result Result; 642 MD5.final(Result); 643 using namespace llvm::support; 644 return endian::read<uint64_t, little, unaligned>(Result); 645 } 646 647 void CodeGenPGO::checkGlobalDecl(GlobalDecl GD) { 648 // Make sure we only emit coverage mapping for one constructor/destructor. 649 // Clang emits several functions for the constructor and the destructor of 650 // a class. Every function is instrumented, but we only want to provide 651 // coverage for one of them. Because of that we only emit the coverage mapping 652 // for the base constructor/destructor. 653 if ((isa<CXXConstructorDecl>(GD.getDecl()) && 654 GD.getCtorType() != Ctor_Base) || 655 (isa<CXXDestructorDecl>(GD.getDecl()) && 656 GD.getDtorType() != Dtor_Base)) { 657 SkipCoverageMapping = true; 658 } 659 } 660 661 void CodeGenPGO::assignRegionCounters(const Decl *D, llvm::Function *Fn) { 662 bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate; 663 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 664 if (!InstrumentRegions && !PGOReader) 665 return; 666 if (D->isImplicit()) 667 return; 668 CGM.ClearUnusedCoverageMapping(D); 669 setFuncName(Fn); 670 671 mapRegionCounters(D); 672 if (CGM.getCodeGenOpts().CoverageMapping) 673 emitCounterRegionMapping(D); 674 if (PGOReader) { 675 SourceManager &SM = CGM.getContext().getSourceManager(); 676 loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation())); 677 computeRegionCounts(D); 678 applyFunctionAttributes(PGOReader, Fn); 679 } 680 } 681 682 void CodeGenPGO::mapRegionCounters(const Decl *D) { 683 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>); 684 MapRegionCounters Walker(*RegionCounterMap); 685 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 686 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD)); 687 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 688 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD)); 689 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 690 Walker.TraverseDecl(const_cast<BlockDecl *>(BD)); 691 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 692 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD)); 693 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl"); 694 NumRegionCounters = Walker.NextCounter; 695 FunctionHash = Walker.Hash.finalize(); 696 } 697 698 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) { 699 if (SkipCoverageMapping) 700 return; 701 // Don't map the functions inside the system headers 702 auto Loc = D->getBody()->getLocStart(); 703 if (CGM.getContext().getSourceManager().isInSystemHeader(Loc)) 704 return; 705 706 std::string CoverageMapping; 707 llvm::raw_string_ostream OS(CoverageMapping); 708 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(), 709 CGM.getContext().getSourceManager(), 710 CGM.getLangOpts(), RegionCounterMap.get()); 711 MappingGen.emitCounterMapping(D, OS); 712 OS.flush(); 713 714 if (CoverageMapping.empty()) 715 return; 716 717 CGM.getCoverageMapping()->addFunctionMappingRecord( 718 FuncNameVar, FuncName, FunctionHash, CoverageMapping); 719 } 720 721 void 722 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name, 723 llvm::GlobalValue::LinkageTypes Linkage) { 724 if (SkipCoverageMapping) 725 return; 726 // Don't map the functions inside the system headers 727 auto Loc = D->getBody()->getLocStart(); 728 if (CGM.getContext().getSourceManager().isInSystemHeader(Loc)) 729 return; 730 731 std::string CoverageMapping; 732 llvm::raw_string_ostream OS(CoverageMapping); 733 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(), 734 CGM.getContext().getSourceManager(), 735 CGM.getLangOpts()); 736 MappingGen.emitEmptyMapping(D, OS); 737 OS.flush(); 738 739 if (CoverageMapping.empty()) 740 return; 741 742 setFuncName(Name, Linkage); 743 CGM.getCoverageMapping()->addFunctionMappingRecord( 744 FuncNameVar, FuncName, FunctionHash, CoverageMapping); 745 } 746 747 void CodeGenPGO::computeRegionCounts(const Decl *D) { 748 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>); 749 ComputeRegionCounts Walker(*StmtCountMap, *this); 750 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 751 Walker.VisitFunctionDecl(FD); 752 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 753 Walker.VisitObjCMethodDecl(MD); 754 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 755 Walker.VisitBlockDecl(BD); 756 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 757 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD)); 758 } 759 760 void 761 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader, 762 llvm::Function *Fn) { 763 if (!haveRegionCounts()) 764 return; 765 766 uint64_t MaxFunctionCount = PGOReader->getMaximumFunctionCount(); 767 uint64_t FunctionCount = getRegionCount(0); 768 if (FunctionCount >= (uint64_t)(0.3 * (double)MaxFunctionCount)) 769 // Turn on InlineHint attribute for hot functions. 770 // FIXME: 30% is from preliminary tuning on SPEC, it may not be optimal. 771 Fn->addFnAttr(llvm::Attribute::InlineHint); 772 else if (FunctionCount <= (uint64_t)(0.01 * (double)MaxFunctionCount)) 773 // Turn on Cold attribute for cold functions. 774 // FIXME: 1% is from preliminary tuning on SPEC, it may not be optimal. 775 Fn->addFnAttr(llvm::Attribute::Cold); 776 } 777 778 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S) { 779 if (!CGM.getCodeGenOpts().ProfileInstrGenerate || !RegionCounterMap) 780 return; 781 if (!Builder.GetInsertPoint()) 782 return; 783 784 unsigned Counter = (*RegionCounterMap)[S]; 785 auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext()); 786 Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment), 787 {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy), 788 Builder.getInt64(FunctionHash), 789 Builder.getInt32(NumRegionCounters), 790 Builder.getInt32(Counter)}); 791 } 792 793 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader, 794 bool IsInMainFile) { 795 CGM.getPGOStats().addVisited(IsInMainFile); 796 RegionCounts.clear(); 797 if (std::error_code EC = 798 PGOReader->getFunctionCounts(FuncName, FunctionHash, RegionCounts)) { 799 if (EC == llvm::instrprof_error::unknown_function) 800 CGM.getPGOStats().addMissing(IsInMainFile); 801 else if (EC == llvm::instrprof_error::hash_mismatch) 802 CGM.getPGOStats().addMismatched(IsInMainFile); 803 else if (EC == llvm::instrprof_error::malformed) 804 // TODO: Consider a more specific warning for this case. 805 CGM.getPGOStats().addMismatched(IsInMainFile); 806 RegionCounts.clear(); 807 } 808 } 809 810 /// \brief Calculate what to divide by to scale weights. 811 /// 812 /// Given the maximum weight, calculate a divisor that will scale all the 813 /// weights to strictly less than UINT32_MAX. 814 static uint64_t calculateWeightScale(uint64_t MaxWeight) { 815 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1; 816 } 817 818 /// \brief Scale an individual branch weight (and add 1). 819 /// 820 /// Scale a 64-bit weight down to 32-bits using \c Scale. 821 /// 822 /// According to Laplace's Rule of Succession, it is better to compute the 823 /// weight based on the count plus 1, so universally add 1 to the value. 824 /// 825 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no 826 /// greater than \c Weight. 827 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) { 828 assert(Scale && "scale by 0?"); 829 uint64_t Scaled = Weight / Scale + 1; 830 assert(Scaled <= UINT32_MAX && "overflow 32-bits"); 831 return Scaled; 832 } 833 834 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount, 835 uint64_t FalseCount) { 836 // Check for empty weights. 837 if (!TrueCount && !FalseCount) 838 return nullptr; 839 840 // Calculate how to scale down to 32-bits. 841 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount)); 842 843 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 844 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale), 845 scaleBranchWeight(FalseCount, Scale)); 846 } 847 848 llvm::MDNode * 849 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) { 850 // We need at least two elements to create meaningful weights. 851 if (Weights.size() < 2) 852 return nullptr; 853 854 // Check for empty weights. 855 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end()); 856 if (MaxWeight == 0) 857 return nullptr; 858 859 // Calculate how to scale down to 32-bits. 860 uint64_t Scale = calculateWeightScale(MaxWeight); 861 862 SmallVector<uint32_t, 16> ScaledWeights; 863 ScaledWeights.reserve(Weights.size()); 864 for (uint64_t W : Weights) 865 ScaledWeights.push_back(scaleBranchWeight(W, Scale)); 866 867 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 868 return MDHelper.createBranchWeights(ScaledWeights); 869 } 870 871 llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond, 872 uint64_t LoopCount) { 873 if (!PGO.haveRegionCounts()) 874 return nullptr; 875 Optional<uint64_t> CondCount = PGO.getStmtCount(Cond); 876 assert(CondCount.hasValue() && "missing expected loop condition count"); 877 if (*CondCount == 0) 878 return nullptr; 879 return createProfileWeights(LoopCount, 880 std::max(*CondCount, LoopCount) - LoopCount); 881 } 882