1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // Instrumentation-based profile-guided optimization 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "CodeGenPGO.h" 15 #include "CodeGenFunction.h" 16 #include "CoverageMappingGen.h" 17 #include "clang/AST/RecursiveASTVisitor.h" 18 #include "clang/AST/StmtVisitor.h" 19 #include "llvm/IR/MDBuilder.h" 20 #include "llvm/ProfileData/InstrProfReader.h" 21 #include "llvm/Support/Endian.h" 22 #include "llvm/Support/FileSystem.h" 23 #include "llvm/Support/MD5.h" 24 25 using namespace clang; 26 using namespace CodeGen; 27 28 void CodeGenPGO::setFuncName(StringRef Name, 29 llvm::GlobalValue::LinkageTypes Linkage) { 30 RawFuncName = Name; 31 32 // Function names may be prefixed with a binary '1' to indicate 33 // that the backend should not modify the symbols due to any platform 34 // naming convention. Do not include that '1' in the PGO profile name. 35 if (RawFuncName[0] == '\1') 36 RawFuncName = RawFuncName.substr(1); 37 38 if (!llvm::GlobalValue::isLocalLinkage(Linkage)) { 39 PrefixedFuncName.reset(new std::string(RawFuncName)); 40 return; 41 } 42 43 // For local symbols, prepend the main file name to distinguish them. 44 // Do not include the full path in the file name since there's no guarantee 45 // that it will stay the same, e.g., if the files are checked out from 46 // version control in different locations. 47 PrefixedFuncName.reset(new std::string(CGM.getCodeGenOpts().MainFileName)); 48 if (PrefixedFuncName->empty()) 49 PrefixedFuncName->assign("<unknown>"); 50 PrefixedFuncName->append(":"); 51 PrefixedFuncName->append(RawFuncName); 52 } 53 54 void CodeGenPGO::setFuncName(llvm::Function *Fn) { 55 setFuncName(Fn->getName(), Fn->getLinkage()); 56 } 57 58 void CodeGenPGO::setVarLinkage(llvm::GlobalValue::LinkageTypes Linkage) { 59 // Set the linkage for variables based on the function linkage. Usually, we 60 // want to match it, but available_externally and extern_weak both have the 61 // wrong semantics. 62 VarLinkage = Linkage; 63 switch (VarLinkage) { 64 case llvm::GlobalValue::ExternalWeakLinkage: 65 VarLinkage = llvm::GlobalValue::LinkOnceAnyLinkage; 66 break; 67 case llvm::GlobalValue::AvailableExternallyLinkage: 68 VarLinkage = llvm::GlobalValue::LinkOnceODRLinkage; 69 break; 70 default: 71 break; 72 } 73 } 74 75 static llvm::Function *getRegisterFunc(CodeGenModule &CGM) { 76 return CGM.getModule().getFunction("__llvm_profile_register_functions"); 77 } 78 79 static llvm::BasicBlock *getOrInsertRegisterBB(CodeGenModule &CGM) { 80 // Don't do this for Darwin. compiler-rt uses linker magic. 81 if (CGM.getTarget().getTriple().isOSDarwin()) 82 return nullptr; 83 84 // Only need to insert this once per module. 85 if (llvm::Function *RegisterF = getRegisterFunc(CGM)) 86 return &RegisterF->getEntryBlock(); 87 88 // Construct the function. 89 auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext()); 90 auto *RegisterFTy = llvm::FunctionType::get(VoidTy, false); 91 auto *RegisterF = llvm::Function::Create(RegisterFTy, 92 llvm::GlobalValue::InternalLinkage, 93 "__llvm_profile_register_functions", 94 &CGM.getModule()); 95 RegisterF->setUnnamedAddr(true); 96 if (CGM.getCodeGenOpts().DisableRedZone) 97 RegisterF->addFnAttr(llvm::Attribute::NoRedZone); 98 99 // Construct and return the entry block. 100 auto *BB = llvm::BasicBlock::Create(CGM.getLLVMContext(), "", RegisterF); 101 CGBuilderTy Builder(BB); 102 Builder.CreateRetVoid(); 103 return BB; 104 } 105 106 static llvm::Constant *getOrInsertRuntimeRegister(CodeGenModule &CGM) { 107 auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext()); 108 auto *VoidPtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext()); 109 auto *RuntimeRegisterTy = llvm::FunctionType::get(VoidTy, VoidPtrTy, false); 110 return CGM.getModule().getOrInsertFunction("__llvm_profile_register_function", 111 RuntimeRegisterTy); 112 } 113 114 static bool isMachO(const CodeGenModule &CGM) { 115 return CGM.getTarget().getTriple().isOSBinFormatMachO(); 116 } 117 118 static StringRef getCountersSection(const CodeGenModule &CGM) { 119 return isMachO(CGM) ? "__DATA,__llvm_prf_cnts" : "__llvm_prf_cnts"; 120 } 121 122 static StringRef getNameSection(const CodeGenModule &CGM) { 123 return isMachO(CGM) ? "__DATA,__llvm_prf_names" : "__llvm_prf_names"; 124 } 125 126 static StringRef getDataSection(const CodeGenModule &CGM) { 127 return isMachO(CGM) ? "__DATA,__llvm_prf_data" : "__llvm_prf_data"; 128 } 129 130 llvm::GlobalVariable *CodeGenPGO::buildDataVar() { 131 // Create name variable. 132 llvm::LLVMContext &Ctx = CGM.getLLVMContext(); 133 auto *VarName = llvm::ConstantDataArray::getString(Ctx, getFuncName(), 134 false); 135 auto *Name = new llvm::GlobalVariable(CGM.getModule(), VarName->getType(), 136 true, VarLinkage, VarName, 137 getFuncVarName("name")); 138 Name->setSection(getNameSection(CGM)); 139 Name->setAlignment(1); 140 141 // Create data variable. 142 auto *Int32Ty = llvm::Type::getInt32Ty(Ctx); 143 auto *Int64Ty = llvm::Type::getInt64Ty(Ctx); 144 auto *Int8PtrTy = llvm::Type::getInt8PtrTy(Ctx); 145 auto *Int64PtrTy = llvm::Type::getInt64PtrTy(Ctx); 146 llvm::GlobalVariable *Data = nullptr; 147 if (RegionCounters) { 148 llvm::Type *DataTypes[] = { 149 Int32Ty, Int32Ty, Int64Ty, Int8PtrTy, Int64PtrTy 150 }; 151 auto *DataTy = llvm::StructType::get(Ctx, makeArrayRef(DataTypes)); 152 llvm::Constant *DataVals[] = { 153 llvm::ConstantInt::get(Int32Ty, getFuncName().size()), 154 llvm::ConstantInt::get(Int32Ty, NumRegionCounters), 155 llvm::ConstantInt::get(Int64Ty, FunctionHash), 156 llvm::ConstantExpr::getBitCast(Name, Int8PtrTy), 157 llvm::ConstantExpr::getBitCast(RegionCounters, Int64PtrTy) 158 }; 159 Data = 160 new llvm::GlobalVariable(CGM.getModule(), DataTy, true, VarLinkage, 161 llvm::ConstantStruct::get(DataTy, DataVals), 162 getFuncVarName("data")); 163 164 // All the data should be packed into an array in its own section. 165 Data->setSection(getDataSection(CGM)); 166 Data->setAlignment(8); 167 } 168 169 // Create coverage mapping data variable. 170 if (!CoverageMapping.empty()) 171 CGM.getCoverageMapping()->addFunctionMappingRecord(Name, getFuncName(), 172 FunctionHash, 173 CoverageMapping); 174 175 // Hide all these symbols so that we correctly get a copy for each 176 // executable. The profile format expects names and counters to be 177 // contiguous, so references into shared objects would be invalid. 178 if (!llvm::GlobalValue::isLocalLinkage(VarLinkage)) { 179 Name->setVisibility(llvm::GlobalValue::HiddenVisibility); 180 if (Data) { 181 Data->setVisibility(llvm::GlobalValue::HiddenVisibility); 182 RegionCounters->setVisibility(llvm::GlobalValue::HiddenVisibility); 183 } 184 } 185 186 // Make sure the data doesn't get deleted. 187 if (Data) CGM.addUsedGlobal(Data); 188 return Data; 189 } 190 191 void CodeGenPGO::emitInstrumentationData() { 192 if (!RegionCounters) 193 return; 194 195 // Build the data. 196 auto *Data = buildDataVar(); 197 198 // Register the data. 199 auto *RegisterBB = getOrInsertRegisterBB(CGM); 200 if (!RegisterBB) 201 return; 202 CGBuilderTy Builder(RegisterBB->getTerminator()); 203 auto *VoidPtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext()); 204 Builder.CreateCall(getOrInsertRuntimeRegister(CGM), 205 Builder.CreateBitCast(Data, VoidPtrTy)); 206 } 207 208 llvm::Function *CodeGenPGO::emitInitialization(CodeGenModule &CGM) { 209 if (!CGM.getCodeGenOpts().ProfileInstrGenerate) 210 return nullptr; 211 212 assert(CGM.getModule().getFunction("__llvm_profile_init") == nullptr && 213 "profile initialization already emitted"); 214 215 // Get the function to call at initialization. 216 llvm::Constant *RegisterF = getRegisterFunc(CGM); 217 if (!RegisterF) 218 return nullptr; 219 220 // Create the initialization function. 221 auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext()); 222 auto *F = llvm::Function::Create(llvm::FunctionType::get(VoidTy, false), 223 llvm::GlobalValue::InternalLinkage, 224 "__llvm_profile_init", &CGM.getModule()); 225 F->setUnnamedAddr(true); 226 F->addFnAttr(llvm::Attribute::NoInline); 227 if (CGM.getCodeGenOpts().DisableRedZone) 228 F->addFnAttr(llvm::Attribute::NoRedZone); 229 230 // Add the basic block and the necessary calls. 231 CGBuilderTy Builder(llvm::BasicBlock::Create(CGM.getLLVMContext(), "", F)); 232 Builder.CreateCall(RegisterF); 233 Builder.CreateRetVoid(); 234 235 return F; 236 } 237 238 namespace { 239 /// \brief Stable hasher for PGO region counters. 240 /// 241 /// PGOHash produces a stable hash of a given function's control flow. 242 /// 243 /// Changing the output of this hash will invalidate all previously generated 244 /// profiles -- i.e., don't do it. 245 /// 246 /// \note When this hash does eventually change (years?), we still need to 247 /// support old hashes. We'll need to pull in the version number from the 248 /// profile data format and use the matching hash function. 249 class PGOHash { 250 uint64_t Working; 251 unsigned Count; 252 llvm::MD5 MD5; 253 254 static const int NumBitsPerType = 6; 255 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType; 256 static const unsigned TooBig = 1u << NumBitsPerType; 257 258 public: 259 /// \brief Hash values for AST nodes. 260 /// 261 /// Distinct values for AST nodes that have region counters attached. 262 /// 263 /// These values must be stable. All new members must be added at the end, 264 /// and no members should be removed. Changing the enumeration value for an 265 /// AST node will affect the hash of every function that contains that node. 266 enum HashType : unsigned char { 267 None = 0, 268 LabelStmt = 1, 269 WhileStmt, 270 DoStmt, 271 ForStmt, 272 CXXForRangeStmt, 273 ObjCForCollectionStmt, 274 SwitchStmt, 275 CaseStmt, 276 DefaultStmt, 277 IfStmt, 278 CXXTryStmt, 279 CXXCatchStmt, 280 ConditionalOperator, 281 BinaryOperatorLAnd, 282 BinaryOperatorLOr, 283 BinaryConditionalOperator, 284 285 // Keep this last. It's for the static assert that follows. 286 LastHashType 287 }; 288 static_assert(LastHashType <= TooBig, "Too many types in HashType"); 289 290 // TODO: When this format changes, take in a version number here, and use the 291 // old hash calculation for file formats that used the old hash. 292 PGOHash() : Working(0), Count(0) {} 293 void combine(HashType Type); 294 uint64_t finalize(); 295 }; 296 const int PGOHash::NumBitsPerType; 297 const unsigned PGOHash::NumTypesPerWord; 298 const unsigned PGOHash::TooBig; 299 300 /// A RecursiveASTVisitor that fills a map of statements to PGO counters. 301 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> { 302 /// The next counter value to assign. 303 unsigned NextCounter; 304 /// The function hash. 305 PGOHash Hash; 306 /// The map of statements to counters. 307 llvm::DenseMap<const Stmt *, unsigned> &CounterMap; 308 309 MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap) 310 : NextCounter(0), CounterMap(CounterMap) {} 311 312 // Blocks and lambdas are handled as separate functions, so we need not 313 // traverse them in the parent context. 314 bool TraverseBlockExpr(BlockExpr *BE) { return true; } 315 bool TraverseLambdaBody(LambdaExpr *LE) { return true; } 316 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; } 317 318 bool VisitDecl(const Decl *D) { 319 switch (D->getKind()) { 320 default: 321 break; 322 case Decl::Function: 323 case Decl::CXXMethod: 324 case Decl::CXXConstructor: 325 case Decl::CXXDestructor: 326 case Decl::CXXConversion: 327 case Decl::ObjCMethod: 328 case Decl::Block: 329 case Decl::Captured: 330 CounterMap[D->getBody()] = NextCounter++; 331 break; 332 } 333 return true; 334 } 335 336 bool VisitStmt(const Stmt *S) { 337 auto Type = getHashType(S); 338 if (Type == PGOHash::None) 339 return true; 340 341 CounterMap[S] = NextCounter++; 342 Hash.combine(Type); 343 return true; 344 } 345 PGOHash::HashType getHashType(const Stmt *S) { 346 switch (S->getStmtClass()) { 347 default: 348 break; 349 case Stmt::LabelStmtClass: 350 return PGOHash::LabelStmt; 351 case Stmt::WhileStmtClass: 352 return PGOHash::WhileStmt; 353 case Stmt::DoStmtClass: 354 return PGOHash::DoStmt; 355 case Stmt::ForStmtClass: 356 return PGOHash::ForStmt; 357 case Stmt::CXXForRangeStmtClass: 358 return PGOHash::CXXForRangeStmt; 359 case Stmt::ObjCForCollectionStmtClass: 360 return PGOHash::ObjCForCollectionStmt; 361 case Stmt::SwitchStmtClass: 362 return PGOHash::SwitchStmt; 363 case Stmt::CaseStmtClass: 364 return PGOHash::CaseStmt; 365 case Stmt::DefaultStmtClass: 366 return PGOHash::DefaultStmt; 367 case Stmt::IfStmtClass: 368 return PGOHash::IfStmt; 369 case Stmt::CXXTryStmtClass: 370 return PGOHash::CXXTryStmt; 371 case Stmt::CXXCatchStmtClass: 372 return PGOHash::CXXCatchStmt; 373 case Stmt::ConditionalOperatorClass: 374 return PGOHash::ConditionalOperator; 375 case Stmt::BinaryConditionalOperatorClass: 376 return PGOHash::BinaryConditionalOperator; 377 case Stmt::BinaryOperatorClass: { 378 const BinaryOperator *BO = cast<BinaryOperator>(S); 379 if (BO->getOpcode() == BO_LAnd) 380 return PGOHash::BinaryOperatorLAnd; 381 if (BO->getOpcode() == BO_LOr) 382 return PGOHash::BinaryOperatorLOr; 383 break; 384 } 385 } 386 return PGOHash::None; 387 } 388 }; 389 390 /// A StmtVisitor that propagates the raw counts through the AST and 391 /// records the count at statements where the value may change. 392 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { 393 /// PGO state. 394 CodeGenPGO &PGO; 395 396 /// A flag that is set when the current count should be recorded on the 397 /// next statement, such as at the exit of a loop. 398 bool RecordNextStmtCount; 399 400 /// The map of statements to count values. 401 llvm::DenseMap<const Stmt *, uint64_t> &CountMap; 402 403 /// BreakContinueStack - Keep counts of breaks and continues inside loops. 404 struct BreakContinue { 405 uint64_t BreakCount; 406 uint64_t ContinueCount; 407 BreakContinue() : BreakCount(0), ContinueCount(0) {} 408 }; 409 SmallVector<BreakContinue, 8> BreakContinueStack; 410 411 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap, 412 CodeGenPGO &PGO) 413 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {} 414 415 void RecordStmtCount(const Stmt *S) { 416 if (RecordNextStmtCount) { 417 CountMap[S] = PGO.getCurrentRegionCount(); 418 RecordNextStmtCount = false; 419 } 420 } 421 422 void VisitStmt(const Stmt *S) { 423 RecordStmtCount(S); 424 for (Stmt::const_child_range I = S->children(); I; ++I) { 425 if (*I) 426 this->Visit(*I); 427 } 428 } 429 430 void VisitFunctionDecl(const FunctionDecl *D) { 431 // Counter tracks entry to the function body. 432 RegionCounter Cnt(PGO, D->getBody()); 433 Cnt.beginRegion(); 434 CountMap[D->getBody()] = PGO.getCurrentRegionCount(); 435 Visit(D->getBody()); 436 } 437 438 // Skip lambda expressions. We visit these as FunctionDecls when we're 439 // generating them and aren't interested in the body when generating a 440 // parent context. 441 void VisitLambdaExpr(const LambdaExpr *LE) {} 442 443 void VisitCapturedDecl(const CapturedDecl *D) { 444 // Counter tracks entry to the capture body. 445 RegionCounter Cnt(PGO, D->getBody()); 446 Cnt.beginRegion(); 447 CountMap[D->getBody()] = PGO.getCurrentRegionCount(); 448 Visit(D->getBody()); 449 } 450 451 void VisitObjCMethodDecl(const ObjCMethodDecl *D) { 452 // Counter tracks entry to the method body. 453 RegionCounter Cnt(PGO, D->getBody()); 454 Cnt.beginRegion(); 455 CountMap[D->getBody()] = PGO.getCurrentRegionCount(); 456 Visit(D->getBody()); 457 } 458 459 void VisitBlockDecl(const BlockDecl *D) { 460 // Counter tracks entry to the block body. 461 RegionCounter Cnt(PGO, D->getBody()); 462 Cnt.beginRegion(); 463 CountMap[D->getBody()] = PGO.getCurrentRegionCount(); 464 Visit(D->getBody()); 465 } 466 467 void VisitReturnStmt(const ReturnStmt *S) { 468 RecordStmtCount(S); 469 if (S->getRetValue()) 470 Visit(S->getRetValue()); 471 PGO.setCurrentRegionUnreachable(); 472 RecordNextStmtCount = true; 473 } 474 475 void VisitGotoStmt(const GotoStmt *S) { 476 RecordStmtCount(S); 477 PGO.setCurrentRegionUnreachable(); 478 RecordNextStmtCount = true; 479 } 480 481 void VisitLabelStmt(const LabelStmt *S) { 482 RecordNextStmtCount = false; 483 // Counter tracks the block following the label. 484 RegionCounter Cnt(PGO, S); 485 Cnt.beginRegion(); 486 CountMap[S] = PGO.getCurrentRegionCount(); 487 Visit(S->getSubStmt()); 488 } 489 490 void VisitBreakStmt(const BreakStmt *S) { 491 RecordStmtCount(S); 492 assert(!BreakContinueStack.empty() && "break not in a loop or switch!"); 493 BreakContinueStack.back().BreakCount += PGO.getCurrentRegionCount(); 494 PGO.setCurrentRegionUnreachable(); 495 RecordNextStmtCount = true; 496 } 497 498 void VisitContinueStmt(const ContinueStmt *S) { 499 RecordStmtCount(S); 500 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!"); 501 BreakContinueStack.back().ContinueCount += PGO.getCurrentRegionCount(); 502 PGO.setCurrentRegionUnreachable(); 503 RecordNextStmtCount = true; 504 } 505 506 void VisitWhileStmt(const WhileStmt *S) { 507 RecordStmtCount(S); 508 // Counter tracks the body of the loop. 509 RegionCounter Cnt(PGO, S); 510 BreakContinueStack.push_back(BreakContinue()); 511 // Visit the body region first so the break/continue adjustments can be 512 // included when visiting the condition. 513 Cnt.beginRegion(); 514 CountMap[S->getBody()] = PGO.getCurrentRegionCount(); 515 Visit(S->getBody()); 516 Cnt.adjustForControlFlow(); 517 518 // ...then go back and propagate counts through the condition. The count 519 // at the start of the condition is the sum of the incoming edges, 520 // the backedge from the end of the loop body, and the edges from 521 // continue statements. 522 BreakContinue BC = BreakContinueStack.pop_back_val(); 523 Cnt.setCurrentRegionCount(Cnt.getParentCount() + 524 Cnt.getAdjustedCount() + BC.ContinueCount); 525 CountMap[S->getCond()] = PGO.getCurrentRegionCount(); 526 Visit(S->getCond()); 527 Cnt.adjustForControlFlow(); 528 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); 529 RecordNextStmtCount = true; 530 } 531 532 void VisitDoStmt(const DoStmt *S) { 533 RecordStmtCount(S); 534 // Counter tracks the body of the loop. 535 RegionCounter Cnt(PGO, S); 536 BreakContinueStack.push_back(BreakContinue()); 537 Cnt.beginRegion(/*AddIncomingFallThrough=*/true); 538 CountMap[S->getBody()] = PGO.getCurrentRegionCount(); 539 Visit(S->getBody()); 540 Cnt.adjustForControlFlow(); 541 542 BreakContinue BC = BreakContinueStack.pop_back_val(); 543 // The count at the start of the condition is equal to the count at the 544 // end of the body. The adjusted count does not include either the 545 // fall-through count coming into the loop or the continue count, so add 546 // both of those separately. This is coincidentally the same equation as 547 // with while loops but for different reasons. 548 Cnt.setCurrentRegionCount(Cnt.getParentCount() + 549 Cnt.getAdjustedCount() + BC.ContinueCount); 550 CountMap[S->getCond()] = PGO.getCurrentRegionCount(); 551 Visit(S->getCond()); 552 Cnt.adjustForControlFlow(); 553 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); 554 RecordNextStmtCount = true; 555 } 556 557 void VisitForStmt(const ForStmt *S) { 558 RecordStmtCount(S); 559 if (S->getInit()) 560 Visit(S->getInit()); 561 // Counter tracks the body of the loop. 562 RegionCounter Cnt(PGO, S); 563 BreakContinueStack.push_back(BreakContinue()); 564 // Visit the body region first. (This is basically the same as a while 565 // loop; see further comments in VisitWhileStmt.) 566 Cnt.beginRegion(); 567 CountMap[S->getBody()] = PGO.getCurrentRegionCount(); 568 Visit(S->getBody()); 569 Cnt.adjustForControlFlow(); 570 571 // The increment is essentially part of the body but it needs to include 572 // the count for all the continue statements. 573 if (S->getInc()) { 574 Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() + 575 BreakContinueStack.back().ContinueCount); 576 CountMap[S->getInc()] = PGO.getCurrentRegionCount(); 577 Visit(S->getInc()); 578 Cnt.adjustForControlFlow(); 579 } 580 581 BreakContinue BC = BreakContinueStack.pop_back_val(); 582 583 // ...then go back and propagate counts through the condition. 584 if (S->getCond()) { 585 Cnt.setCurrentRegionCount(Cnt.getParentCount() + 586 Cnt.getAdjustedCount() + 587 BC.ContinueCount); 588 CountMap[S->getCond()] = PGO.getCurrentRegionCount(); 589 Visit(S->getCond()); 590 Cnt.adjustForControlFlow(); 591 } 592 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); 593 RecordNextStmtCount = true; 594 } 595 596 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) { 597 RecordStmtCount(S); 598 Visit(S->getRangeStmt()); 599 Visit(S->getBeginEndStmt()); 600 // Counter tracks the body of the loop. 601 RegionCounter Cnt(PGO, S); 602 BreakContinueStack.push_back(BreakContinue()); 603 // Visit the body region first. (This is basically the same as a while 604 // loop; see further comments in VisitWhileStmt.) 605 Cnt.beginRegion(); 606 CountMap[S->getLoopVarStmt()] = PGO.getCurrentRegionCount(); 607 Visit(S->getLoopVarStmt()); 608 Visit(S->getBody()); 609 Cnt.adjustForControlFlow(); 610 611 // The increment is essentially part of the body but it needs to include 612 // the count for all the continue statements. 613 Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() + 614 BreakContinueStack.back().ContinueCount); 615 CountMap[S->getInc()] = PGO.getCurrentRegionCount(); 616 Visit(S->getInc()); 617 Cnt.adjustForControlFlow(); 618 619 BreakContinue BC = BreakContinueStack.pop_back_val(); 620 621 // ...then go back and propagate counts through the condition. 622 Cnt.setCurrentRegionCount(Cnt.getParentCount() + 623 Cnt.getAdjustedCount() + 624 BC.ContinueCount); 625 CountMap[S->getCond()] = PGO.getCurrentRegionCount(); 626 Visit(S->getCond()); 627 Cnt.adjustForControlFlow(); 628 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); 629 RecordNextStmtCount = true; 630 } 631 632 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) { 633 RecordStmtCount(S); 634 Visit(S->getElement()); 635 // Counter tracks the body of the loop. 636 RegionCounter Cnt(PGO, S); 637 BreakContinueStack.push_back(BreakContinue()); 638 Cnt.beginRegion(); 639 CountMap[S->getBody()] = PGO.getCurrentRegionCount(); 640 Visit(S->getBody()); 641 BreakContinue BC = BreakContinueStack.pop_back_val(); 642 Cnt.adjustForControlFlow(); 643 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); 644 RecordNextStmtCount = true; 645 } 646 647 void VisitSwitchStmt(const SwitchStmt *S) { 648 RecordStmtCount(S); 649 Visit(S->getCond()); 650 PGO.setCurrentRegionUnreachable(); 651 BreakContinueStack.push_back(BreakContinue()); 652 Visit(S->getBody()); 653 // If the switch is inside a loop, add the continue counts. 654 BreakContinue BC = BreakContinueStack.pop_back_val(); 655 if (!BreakContinueStack.empty()) 656 BreakContinueStack.back().ContinueCount += BC.ContinueCount; 657 // Counter tracks the exit block of the switch. 658 RegionCounter ExitCnt(PGO, S); 659 ExitCnt.beginRegion(); 660 RecordNextStmtCount = true; 661 } 662 663 void VisitCaseStmt(const CaseStmt *S) { 664 RecordNextStmtCount = false; 665 // Counter for this particular case. This counts only jumps from the 666 // switch header and does not include fallthrough from the case before 667 // this one. 668 RegionCounter Cnt(PGO, S); 669 Cnt.beginRegion(/*AddIncomingFallThrough=*/true); 670 CountMap[S] = Cnt.getCount(); 671 RecordNextStmtCount = true; 672 Visit(S->getSubStmt()); 673 } 674 675 void VisitDefaultStmt(const DefaultStmt *S) { 676 RecordNextStmtCount = false; 677 // Counter for this default case. This does not include fallthrough from 678 // the previous case. 679 RegionCounter Cnt(PGO, S); 680 Cnt.beginRegion(/*AddIncomingFallThrough=*/true); 681 CountMap[S] = Cnt.getCount(); 682 RecordNextStmtCount = true; 683 Visit(S->getSubStmt()); 684 } 685 686 void VisitIfStmt(const IfStmt *S) { 687 RecordStmtCount(S); 688 // Counter tracks the "then" part of an if statement. The count for 689 // the "else" part, if it exists, will be calculated from this counter. 690 RegionCounter Cnt(PGO, S); 691 Visit(S->getCond()); 692 693 Cnt.beginRegion(); 694 CountMap[S->getThen()] = PGO.getCurrentRegionCount(); 695 Visit(S->getThen()); 696 Cnt.adjustForControlFlow(); 697 698 if (S->getElse()) { 699 Cnt.beginElseRegion(); 700 CountMap[S->getElse()] = PGO.getCurrentRegionCount(); 701 Visit(S->getElse()); 702 Cnt.adjustForControlFlow(); 703 } 704 Cnt.applyAdjustmentsToRegion(0); 705 RecordNextStmtCount = true; 706 } 707 708 void VisitCXXTryStmt(const CXXTryStmt *S) { 709 RecordStmtCount(S); 710 Visit(S->getTryBlock()); 711 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I) 712 Visit(S->getHandler(I)); 713 // Counter tracks the continuation block of the try statement. 714 RegionCounter Cnt(PGO, S); 715 Cnt.beginRegion(); 716 RecordNextStmtCount = true; 717 } 718 719 void VisitCXXCatchStmt(const CXXCatchStmt *S) { 720 RecordNextStmtCount = false; 721 // Counter tracks the catch statement's handler block. 722 RegionCounter Cnt(PGO, S); 723 Cnt.beginRegion(); 724 CountMap[S] = PGO.getCurrentRegionCount(); 725 Visit(S->getHandlerBlock()); 726 } 727 728 void VisitAbstractConditionalOperator( 729 const AbstractConditionalOperator *E) { 730 RecordStmtCount(E); 731 // Counter tracks the "true" part of a conditional operator. The 732 // count in the "false" part will be calculated from this counter. 733 RegionCounter Cnt(PGO, E); 734 Visit(E->getCond()); 735 736 Cnt.beginRegion(); 737 CountMap[E->getTrueExpr()] = PGO.getCurrentRegionCount(); 738 Visit(E->getTrueExpr()); 739 Cnt.adjustForControlFlow(); 740 741 Cnt.beginElseRegion(); 742 CountMap[E->getFalseExpr()] = PGO.getCurrentRegionCount(); 743 Visit(E->getFalseExpr()); 744 Cnt.adjustForControlFlow(); 745 746 Cnt.applyAdjustmentsToRegion(0); 747 RecordNextStmtCount = true; 748 } 749 750 void VisitBinLAnd(const BinaryOperator *E) { 751 RecordStmtCount(E); 752 // Counter tracks the right hand side of a logical and operator. 753 RegionCounter Cnt(PGO, E); 754 Visit(E->getLHS()); 755 Cnt.beginRegion(); 756 CountMap[E->getRHS()] = PGO.getCurrentRegionCount(); 757 Visit(E->getRHS()); 758 Cnt.adjustForControlFlow(); 759 Cnt.applyAdjustmentsToRegion(0); 760 RecordNextStmtCount = true; 761 } 762 763 void VisitBinLOr(const BinaryOperator *E) { 764 RecordStmtCount(E); 765 // Counter tracks the right hand side of a logical or operator. 766 RegionCounter Cnt(PGO, E); 767 Visit(E->getLHS()); 768 Cnt.beginRegion(); 769 CountMap[E->getRHS()] = PGO.getCurrentRegionCount(); 770 Visit(E->getRHS()); 771 Cnt.adjustForControlFlow(); 772 Cnt.applyAdjustmentsToRegion(0); 773 RecordNextStmtCount = true; 774 } 775 }; 776 } 777 778 void PGOHash::combine(HashType Type) { 779 // Check that we never combine 0 and only have six bits. 780 assert(Type && "Hash is invalid: unexpected type 0"); 781 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types"); 782 783 // Pass through MD5 if enough work has built up. 784 if (Count && Count % NumTypesPerWord == 0) { 785 using namespace llvm::support; 786 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working); 787 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped))); 788 Working = 0; 789 } 790 791 // Accumulate the current type. 792 ++Count; 793 Working = Working << NumBitsPerType | Type; 794 } 795 796 uint64_t PGOHash::finalize() { 797 // Use Working as the hash directly if we never used MD5. 798 if (Count <= NumTypesPerWord) 799 // No need to byte swap here, since none of the math was endian-dependent. 800 // This number will be byte-swapped as required on endianness transitions, 801 // so we will see the same value on the other side. 802 return Working; 803 804 // Check for remaining work in Working. 805 if (Working) 806 MD5.update(Working); 807 808 // Finalize the MD5 and return the hash. 809 llvm::MD5::MD5Result Result; 810 MD5.final(Result); 811 using namespace llvm::support; 812 return endian::read<uint64_t, little, unaligned>(Result); 813 } 814 815 static void emitRuntimeHook(CodeGenModule &CGM) { 816 const char *const RuntimeVarName = "__llvm_profile_runtime"; 817 const char *const RuntimeUserName = "__llvm_profile_runtime_user"; 818 if (CGM.getModule().getGlobalVariable(RuntimeVarName)) 819 return; 820 821 // Declare the runtime hook. 822 llvm::LLVMContext &Ctx = CGM.getLLVMContext(); 823 auto *Int32Ty = llvm::Type::getInt32Ty(Ctx); 824 auto *Var = new llvm::GlobalVariable(CGM.getModule(), Int32Ty, false, 825 llvm::GlobalValue::ExternalLinkage, 826 nullptr, RuntimeVarName); 827 828 // Make a function that uses it. 829 auto *User = llvm::Function::Create(llvm::FunctionType::get(Int32Ty, false), 830 llvm::GlobalValue::LinkOnceODRLinkage, 831 RuntimeUserName, &CGM.getModule()); 832 User->addFnAttr(llvm::Attribute::NoInline); 833 if (CGM.getCodeGenOpts().DisableRedZone) 834 User->addFnAttr(llvm::Attribute::NoRedZone); 835 CGBuilderTy Builder(llvm::BasicBlock::Create(CGM.getLLVMContext(), "", User)); 836 auto *Load = Builder.CreateLoad(Var); 837 Builder.CreateRet(Load); 838 839 // Create a use of the function. Now the definition of the runtime variable 840 // should get pulled in, along with any static initializears. 841 CGM.addUsedGlobal(User); 842 } 843 844 void CodeGenPGO::checkGlobalDecl(GlobalDecl GD) { 845 // Make sure we only emit coverage mapping for one constructor/destructor. 846 // Clang emits several functions for the constructor and the destructor of 847 // a class. Every function is instrumented, but we only want to provide 848 // coverage for one of them. Because of that we only emit the coverage mapping 849 // for the base constructor/destructor. 850 if ((isa<CXXConstructorDecl>(GD.getDecl()) && 851 GD.getCtorType() != Ctor_Base) || 852 (isa<CXXDestructorDecl>(GD.getDecl()) && 853 GD.getDtorType() != Dtor_Base)) { 854 SkipCoverageMapping = true; 855 } 856 } 857 858 void CodeGenPGO::assignRegionCounters(const Decl *D, llvm::Function *Fn) { 859 bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate; 860 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 861 if (!InstrumentRegions && !PGOReader) 862 return; 863 if (D->isImplicit()) 864 return; 865 CGM.ClearUnusedCoverageMapping(D); 866 setFuncName(Fn); 867 setVarLinkage(Fn->getLinkage()); 868 869 mapRegionCounters(D); 870 if (InstrumentRegions) { 871 emitRuntimeHook(CGM); 872 emitCounterVariables(); 873 if (CGM.getCodeGenOpts().CoverageMapping) 874 emitCounterRegionMapping(D); 875 } 876 if (PGOReader) { 877 SourceManager &SM = CGM.getContext().getSourceManager(); 878 loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation())); 879 computeRegionCounts(D); 880 applyFunctionAttributes(PGOReader, Fn); 881 } 882 } 883 884 void CodeGenPGO::mapRegionCounters(const Decl *D) { 885 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>); 886 MapRegionCounters Walker(*RegionCounterMap); 887 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 888 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD)); 889 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 890 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD)); 891 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 892 Walker.TraverseDecl(const_cast<BlockDecl *>(BD)); 893 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 894 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD)); 895 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl"); 896 NumRegionCounters = Walker.NextCounter; 897 FunctionHash = Walker.Hash.finalize(); 898 } 899 900 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) { 901 if (SkipCoverageMapping) 902 return; 903 // Don't map the functions inside the system headers 904 auto Loc = D->getBody()->getLocStart(); 905 if (CGM.getContext().getSourceManager().isInSystemHeader(Loc)) 906 return; 907 908 llvm::raw_string_ostream OS(CoverageMapping); 909 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(), 910 CGM.getContext().getSourceManager(), 911 CGM.getLangOpts(), RegionCounterMap.get()); 912 MappingGen.emitCounterMapping(D, OS); 913 OS.flush(); 914 } 915 916 void 917 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef FuncName, 918 llvm::GlobalValue::LinkageTypes Linkage) { 919 if (SkipCoverageMapping) 920 return; 921 setFuncName(FuncName, Linkage); 922 setVarLinkage(Linkage); 923 924 // Don't map the functions inside the system headers 925 auto Loc = D->getBody()->getLocStart(); 926 if (CGM.getContext().getSourceManager().isInSystemHeader(Loc)) 927 return; 928 929 llvm::raw_string_ostream OS(CoverageMapping); 930 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(), 931 CGM.getContext().getSourceManager(), 932 CGM.getLangOpts()); 933 MappingGen.emitEmptyMapping(D, OS); 934 OS.flush(); 935 buildDataVar(); 936 } 937 938 void CodeGenPGO::computeRegionCounts(const Decl *D) { 939 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>); 940 ComputeRegionCounts Walker(*StmtCountMap, *this); 941 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 942 Walker.VisitFunctionDecl(FD); 943 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 944 Walker.VisitObjCMethodDecl(MD); 945 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 946 Walker.VisitBlockDecl(BD); 947 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 948 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD)); 949 } 950 951 void 952 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader, 953 llvm::Function *Fn) { 954 if (!haveRegionCounts()) 955 return; 956 957 uint64_t MaxFunctionCount = PGOReader->getMaximumFunctionCount(); 958 uint64_t FunctionCount = getRegionCount(0); 959 if (FunctionCount >= (uint64_t)(0.3 * (double)MaxFunctionCount)) 960 // Turn on InlineHint attribute for hot functions. 961 // FIXME: 30% is from preliminary tuning on SPEC, it may not be optimal. 962 Fn->addFnAttr(llvm::Attribute::InlineHint); 963 else if (FunctionCount <= (uint64_t)(0.01 * (double)MaxFunctionCount)) 964 // Turn on Cold attribute for cold functions. 965 // FIXME: 1% is from preliminary tuning on SPEC, it may not be optimal. 966 Fn->addFnAttr(llvm::Attribute::Cold); 967 } 968 969 void CodeGenPGO::emitCounterVariables() { 970 llvm::LLVMContext &Ctx = CGM.getLLVMContext(); 971 llvm::ArrayType *CounterTy = llvm::ArrayType::get(llvm::Type::getInt64Ty(Ctx), 972 NumRegionCounters); 973 RegionCounters = 974 new llvm::GlobalVariable(CGM.getModule(), CounterTy, false, VarLinkage, 975 llvm::Constant::getNullValue(CounterTy), 976 getFuncVarName("counters")); 977 RegionCounters->setAlignment(8); 978 RegionCounters->setSection(getCountersSection(CGM)); 979 } 980 981 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, unsigned Counter) { 982 if (!RegionCounters) 983 return; 984 llvm::Value *Addr = 985 Builder.CreateConstInBoundsGEP2_64(RegionCounters, 0, Counter); 986 llvm::Value *Count = Builder.CreateLoad(Addr, "pgocount"); 987 Count = Builder.CreateAdd(Count, Builder.getInt64(1)); 988 Builder.CreateStore(Count, Addr); 989 } 990 991 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader, 992 bool IsInMainFile) { 993 CGM.getPGOStats().addVisited(IsInMainFile); 994 RegionCounts.reset(new std::vector<uint64_t>); 995 if (std::error_code EC = PGOReader->getFunctionCounts( 996 getFuncName(), FunctionHash, *RegionCounts)) { 997 if (EC == llvm::instrprof_error::unknown_function) 998 CGM.getPGOStats().addMissing(IsInMainFile); 999 else if (EC == llvm::instrprof_error::hash_mismatch) 1000 CGM.getPGOStats().addMismatched(IsInMainFile); 1001 else if (EC == llvm::instrprof_error::malformed) 1002 // TODO: Consider a more specific warning for this case. 1003 CGM.getPGOStats().addMismatched(IsInMainFile); 1004 RegionCounts.reset(); 1005 } 1006 } 1007 1008 void CodeGenPGO::destroyRegionCounters() { 1009 RegionCounterMap.reset(); 1010 StmtCountMap.reset(); 1011 RegionCounts.reset(); 1012 RegionCounters = nullptr; 1013 } 1014 1015 /// \brief Calculate what to divide by to scale weights. 1016 /// 1017 /// Given the maximum weight, calculate a divisor that will scale all the 1018 /// weights to strictly less than UINT32_MAX. 1019 static uint64_t calculateWeightScale(uint64_t MaxWeight) { 1020 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1; 1021 } 1022 1023 /// \brief Scale an individual branch weight (and add 1). 1024 /// 1025 /// Scale a 64-bit weight down to 32-bits using \c Scale. 1026 /// 1027 /// According to Laplace's Rule of Succession, it is better to compute the 1028 /// weight based on the count plus 1, so universally add 1 to the value. 1029 /// 1030 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no 1031 /// greater than \c Weight. 1032 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) { 1033 assert(Scale && "scale by 0?"); 1034 uint64_t Scaled = Weight / Scale + 1; 1035 assert(Scaled <= UINT32_MAX && "overflow 32-bits"); 1036 return Scaled; 1037 } 1038 1039 llvm::MDNode *CodeGenPGO::createBranchWeights(uint64_t TrueCount, 1040 uint64_t FalseCount) { 1041 // Check for empty weights. 1042 if (!TrueCount && !FalseCount) 1043 return nullptr; 1044 1045 // Calculate how to scale down to 32-bits. 1046 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount)); 1047 1048 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 1049 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale), 1050 scaleBranchWeight(FalseCount, Scale)); 1051 } 1052 1053 llvm::MDNode *CodeGenPGO::createBranchWeights(ArrayRef<uint64_t> Weights) { 1054 // We need at least two elements to create meaningful weights. 1055 if (Weights.size() < 2) 1056 return nullptr; 1057 1058 // Check for empty weights. 1059 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end()); 1060 if (MaxWeight == 0) 1061 return nullptr; 1062 1063 // Calculate how to scale down to 32-bits. 1064 uint64_t Scale = calculateWeightScale(MaxWeight); 1065 1066 SmallVector<uint32_t, 16> ScaledWeights; 1067 ScaledWeights.reserve(Weights.size()); 1068 for (uint64_t W : Weights) 1069 ScaledWeights.push_back(scaleBranchWeight(W, Scale)); 1070 1071 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 1072 return MDHelper.createBranchWeights(ScaledWeights); 1073 } 1074 1075 llvm::MDNode *CodeGenPGO::createLoopWeights(const Stmt *Cond, 1076 RegionCounter &Cnt) { 1077 if (!haveRegionCounts()) 1078 return nullptr; 1079 uint64_t LoopCount = Cnt.getCount(); 1080 uint64_t CondCount = 0; 1081 bool Found = getStmtCount(Cond, CondCount); 1082 assert(Found && "missing expected loop condition count"); 1083 (void)Found; 1084 if (CondCount == 0) 1085 return nullptr; 1086 return createBranchWeights(LoopCount, 1087 std::max(CondCount, LoopCount) - LoopCount); 1088 } 1089