1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // Instrumentation-based profile-guided optimization 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "CodeGenPGO.h" 15 #include "CodeGenFunction.h" 16 #include "clang/AST/RecursiveASTVisitor.h" 17 #include "clang/AST/StmtVisitor.h" 18 #include "llvm/IR/MDBuilder.h" 19 #include "llvm/ProfileData/InstrProfReader.h" 20 #include "llvm/Support/Endian.h" 21 #include "llvm/Support/FileSystem.h" 22 #include "llvm/Support/MD5.h" 23 24 using namespace clang; 25 using namespace CodeGen; 26 27 void CodeGenPGO::setFuncName(llvm::Function *Fn) { 28 RawFuncName = Fn->getName(); 29 30 // Function names may be prefixed with a binary '1' to indicate 31 // that the backend should not modify the symbols due to any platform 32 // naming convention. Do not include that '1' in the PGO profile name. 33 if (RawFuncName[0] == '\1') 34 RawFuncName = RawFuncName.substr(1); 35 36 if (!Fn->hasLocalLinkage()) { 37 PrefixedFuncName.reset(new std::string(RawFuncName)); 38 return; 39 } 40 41 // For local symbols, prepend the main file name to distinguish them. 42 // Do not include the full path in the file name since there's no guarantee 43 // that it will stay the same, e.g., if the files are checked out from 44 // version control in different locations. 45 PrefixedFuncName.reset(new std::string(CGM.getCodeGenOpts().MainFileName)); 46 if (PrefixedFuncName->empty()) 47 PrefixedFuncName->assign("<unknown>"); 48 PrefixedFuncName->append(":"); 49 PrefixedFuncName->append(RawFuncName); 50 } 51 52 static llvm::Function *getRegisterFunc(CodeGenModule &CGM) { 53 return CGM.getModule().getFunction("__llvm_profile_register_functions"); 54 } 55 56 static llvm::BasicBlock *getOrInsertRegisterBB(CodeGenModule &CGM) { 57 // Don't do this for Darwin. compiler-rt uses linker magic. 58 if (CGM.getTarget().getTriple().isOSDarwin()) 59 return nullptr; 60 61 // Only need to insert this once per module. 62 if (llvm::Function *RegisterF = getRegisterFunc(CGM)) 63 return &RegisterF->getEntryBlock(); 64 65 // Construct the function. 66 auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext()); 67 auto *RegisterFTy = llvm::FunctionType::get(VoidTy, false); 68 auto *RegisterF = llvm::Function::Create(RegisterFTy, 69 llvm::GlobalValue::InternalLinkage, 70 "__llvm_profile_register_functions", 71 &CGM.getModule()); 72 RegisterF->setUnnamedAddr(true); 73 if (CGM.getCodeGenOpts().DisableRedZone) 74 RegisterF->addFnAttr(llvm::Attribute::NoRedZone); 75 76 // Construct and return the entry block. 77 auto *BB = llvm::BasicBlock::Create(CGM.getLLVMContext(), "", RegisterF); 78 CGBuilderTy Builder(BB); 79 Builder.CreateRetVoid(); 80 return BB; 81 } 82 83 static llvm::Constant *getOrInsertRuntimeRegister(CodeGenModule &CGM) { 84 auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext()); 85 auto *VoidPtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext()); 86 auto *RuntimeRegisterTy = llvm::FunctionType::get(VoidTy, VoidPtrTy, false); 87 return CGM.getModule().getOrInsertFunction("__llvm_profile_register_function", 88 RuntimeRegisterTy); 89 } 90 91 static bool isMachO(const CodeGenModule &CGM) { 92 return CGM.getTarget().getTriple().isOSBinFormatMachO(); 93 } 94 95 static StringRef getCountersSection(const CodeGenModule &CGM) { 96 return isMachO(CGM) ? "__DATA,__llvm_prf_cnts" : "__llvm_prf_cnts"; 97 } 98 99 static StringRef getNameSection(const CodeGenModule &CGM) { 100 return isMachO(CGM) ? "__DATA,__llvm_prf_names" : "__llvm_prf_names"; 101 } 102 103 static StringRef getDataSection(const CodeGenModule &CGM) { 104 return isMachO(CGM) ? "__DATA,__llvm_prf_data" : "__llvm_prf_data"; 105 } 106 107 llvm::GlobalVariable *CodeGenPGO::buildDataVar() { 108 // Create name variable. 109 llvm::LLVMContext &Ctx = CGM.getLLVMContext(); 110 auto *VarName = llvm::ConstantDataArray::getString(Ctx, getFuncName(), 111 false); 112 auto *Name = new llvm::GlobalVariable(CGM.getModule(), VarName->getType(), 113 true, VarLinkage, VarName, 114 getFuncVarName("name")); 115 Name->setSection(getNameSection(CGM)); 116 Name->setAlignment(1); 117 118 // Create data variable. 119 auto *Int32Ty = llvm::Type::getInt32Ty(Ctx); 120 auto *Int64Ty = llvm::Type::getInt64Ty(Ctx); 121 auto *Int8PtrTy = llvm::Type::getInt8PtrTy(Ctx); 122 auto *Int64PtrTy = llvm::Type::getInt64PtrTy(Ctx); 123 llvm::Type *DataTypes[] = { 124 Int32Ty, Int32Ty, Int64Ty, Int8PtrTy, Int64PtrTy 125 }; 126 auto *DataTy = llvm::StructType::get(Ctx, makeArrayRef(DataTypes)); 127 llvm::Constant *DataVals[] = { 128 llvm::ConstantInt::get(Int32Ty, getFuncName().size()), 129 llvm::ConstantInt::get(Int32Ty, NumRegionCounters), 130 llvm::ConstantInt::get(Int64Ty, FunctionHash), 131 llvm::ConstantExpr::getBitCast(Name, Int8PtrTy), 132 llvm::ConstantExpr::getBitCast(RegionCounters, Int64PtrTy) 133 }; 134 auto *Data = 135 new llvm::GlobalVariable(CGM.getModule(), DataTy, true, VarLinkage, 136 llvm::ConstantStruct::get(DataTy, DataVals), 137 getFuncVarName("data")); 138 139 // All the data should be packed into an array in its own section. 140 Data->setSection(getDataSection(CGM)); 141 Data->setAlignment(8); 142 143 // Hide all these symbols so that we correctly get a copy for each 144 // executable. The profile format expects names and counters to be 145 // contiguous, so references into shared objects would be invalid. 146 if (!llvm::GlobalValue::isLocalLinkage(VarLinkage)) { 147 Name->setVisibility(llvm::GlobalValue::HiddenVisibility); 148 Data->setVisibility(llvm::GlobalValue::HiddenVisibility); 149 RegionCounters->setVisibility(llvm::GlobalValue::HiddenVisibility); 150 } 151 152 // Make sure the data doesn't get deleted. 153 CGM.addUsedGlobal(Data); 154 return Data; 155 } 156 157 void CodeGenPGO::emitInstrumentationData() { 158 if (!RegionCounters) 159 return; 160 161 // Build the data. 162 auto *Data = buildDataVar(); 163 164 // Register the data. 165 auto *RegisterBB = getOrInsertRegisterBB(CGM); 166 if (!RegisterBB) 167 return; 168 CGBuilderTy Builder(RegisterBB->getTerminator()); 169 auto *VoidPtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext()); 170 Builder.CreateCall(getOrInsertRuntimeRegister(CGM), 171 Builder.CreateBitCast(Data, VoidPtrTy)); 172 } 173 174 llvm::Function *CodeGenPGO::emitInitialization(CodeGenModule &CGM) { 175 if (!CGM.getCodeGenOpts().ProfileInstrGenerate) 176 return nullptr; 177 178 assert(CGM.getModule().getFunction("__llvm_profile_init") == nullptr && 179 "profile initialization already emitted"); 180 181 // Get the function to call at initialization. 182 llvm::Constant *RegisterF = getRegisterFunc(CGM); 183 if (!RegisterF) 184 return nullptr; 185 186 // Create the initialization function. 187 auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext()); 188 auto *F = llvm::Function::Create(llvm::FunctionType::get(VoidTy, false), 189 llvm::GlobalValue::InternalLinkage, 190 "__llvm_profile_init", &CGM.getModule()); 191 F->setUnnamedAddr(true); 192 F->addFnAttr(llvm::Attribute::NoInline); 193 if (CGM.getCodeGenOpts().DisableRedZone) 194 F->addFnAttr(llvm::Attribute::NoRedZone); 195 196 // Add the basic block and the necessary calls. 197 CGBuilderTy Builder(llvm::BasicBlock::Create(CGM.getLLVMContext(), "", F)); 198 Builder.CreateCall(RegisterF); 199 Builder.CreateRetVoid(); 200 201 return F; 202 } 203 204 namespace { 205 /// \brief Stable hasher for PGO region counters. 206 /// 207 /// PGOHash produces a stable hash of a given function's control flow. 208 /// 209 /// Changing the output of this hash will invalidate all previously generated 210 /// profiles -- i.e., don't do it. 211 /// 212 /// \note When this hash does eventually change (years?), we still need to 213 /// support old hashes. We'll need to pull in the version number from the 214 /// profile data format and use the matching hash function. 215 class PGOHash { 216 uint64_t Working; 217 unsigned Count; 218 llvm::MD5 MD5; 219 220 static const int NumBitsPerType = 6; 221 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType; 222 static const unsigned TooBig = 1u << NumBitsPerType; 223 224 public: 225 /// \brief Hash values for AST nodes. 226 /// 227 /// Distinct values for AST nodes that have region counters attached. 228 /// 229 /// These values must be stable. All new members must be added at the end, 230 /// and no members should be removed. Changing the enumeration value for an 231 /// AST node will affect the hash of every function that contains that node. 232 enum HashType : unsigned char { 233 None = 0, 234 LabelStmt = 1, 235 WhileStmt, 236 DoStmt, 237 ForStmt, 238 CXXForRangeStmt, 239 ObjCForCollectionStmt, 240 SwitchStmt, 241 CaseStmt, 242 DefaultStmt, 243 IfStmt, 244 CXXTryStmt, 245 CXXCatchStmt, 246 ConditionalOperator, 247 BinaryOperatorLAnd, 248 BinaryOperatorLOr, 249 BinaryConditionalOperator, 250 251 // Keep this last. It's for the static assert that follows. 252 LastHashType 253 }; 254 static_assert(LastHashType <= TooBig, "Too many types in HashType"); 255 256 // TODO: When this format changes, take in a version number here, and use the 257 // old hash calculation for file formats that used the old hash. 258 PGOHash() : Working(0), Count(0) {} 259 void combine(HashType Type); 260 uint64_t finalize(); 261 }; 262 const int PGOHash::NumBitsPerType; 263 const unsigned PGOHash::NumTypesPerWord; 264 const unsigned PGOHash::TooBig; 265 266 /// A RecursiveASTVisitor that fills a map of statements to PGO counters. 267 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> { 268 /// The next counter value to assign. 269 unsigned NextCounter; 270 /// The function hash. 271 PGOHash Hash; 272 /// The map of statements to counters. 273 llvm::DenseMap<const Stmt *, unsigned> &CounterMap; 274 275 MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap) 276 : NextCounter(0), CounterMap(CounterMap) {} 277 278 // Blocks and lambdas are handled as separate functions, so we need not 279 // traverse them in the parent context. 280 bool TraverseBlockExpr(BlockExpr *BE) { return true; } 281 bool TraverseLambdaBody(LambdaExpr *LE) { return true; } 282 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; } 283 284 bool VisitDecl(const Decl *D) { 285 switch (D->getKind()) { 286 default: 287 break; 288 case Decl::Function: 289 case Decl::CXXMethod: 290 case Decl::CXXConstructor: 291 case Decl::CXXDestructor: 292 case Decl::CXXConversion: 293 case Decl::ObjCMethod: 294 case Decl::Block: 295 case Decl::Captured: 296 CounterMap[D->getBody()] = NextCounter++; 297 break; 298 } 299 return true; 300 } 301 302 bool VisitStmt(const Stmt *S) { 303 auto Type = getHashType(S); 304 if (Type == PGOHash::None) 305 return true; 306 307 CounterMap[S] = NextCounter++; 308 Hash.combine(Type); 309 return true; 310 } 311 PGOHash::HashType getHashType(const Stmt *S) { 312 switch (S->getStmtClass()) { 313 default: 314 break; 315 case Stmt::LabelStmtClass: 316 return PGOHash::LabelStmt; 317 case Stmt::WhileStmtClass: 318 return PGOHash::WhileStmt; 319 case Stmt::DoStmtClass: 320 return PGOHash::DoStmt; 321 case Stmt::ForStmtClass: 322 return PGOHash::ForStmt; 323 case Stmt::CXXForRangeStmtClass: 324 return PGOHash::CXXForRangeStmt; 325 case Stmt::ObjCForCollectionStmtClass: 326 return PGOHash::ObjCForCollectionStmt; 327 case Stmt::SwitchStmtClass: 328 return PGOHash::SwitchStmt; 329 case Stmt::CaseStmtClass: 330 return PGOHash::CaseStmt; 331 case Stmt::DefaultStmtClass: 332 return PGOHash::DefaultStmt; 333 case Stmt::IfStmtClass: 334 return PGOHash::IfStmt; 335 case Stmt::CXXTryStmtClass: 336 return PGOHash::CXXTryStmt; 337 case Stmt::CXXCatchStmtClass: 338 return PGOHash::CXXCatchStmt; 339 case Stmt::ConditionalOperatorClass: 340 return PGOHash::ConditionalOperator; 341 case Stmt::BinaryConditionalOperatorClass: 342 return PGOHash::BinaryConditionalOperator; 343 case Stmt::BinaryOperatorClass: { 344 const BinaryOperator *BO = cast<BinaryOperator>(S); 345 if (BO->getOpcode() == BO_LAnd) 346 return PGOHash::BinaryOperatorLAnd; 347 if (BO->getOpcode() == BO_LOr) 348 return PGOHash::BinaryOperatorLOr; 349 break; 350 } 351 } 352 return PGOHash::None; 353 } 354 }; 355 356 /// A StmtVisitor that propagates the raw counts through the AST and 357 /// records the count at statements where the value may change. 358 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { 359 /// PGO state. 360 CodeGenPGO &PGO; 361 362 /// A flag that is set when the current count should be recorded on the 363 /// next statement, such as at the exit of a loop. 364 bool RecordNextStmtCount; 365 366 /// The map of statements to count values. 367 llvm::DenseMap<const Stmt *, uint64_t> &CountMap; 368 369 /// BreakContinueStack - Keep counts of breaks and continues inside loops. 370 struct BreakContinue { 371 uint64_t BreakCount; 372 uint64_t ContinueCount; 373 BreakContinue() : BreakCount(0), ContinueCount(0) {} 374 }; 375 SmallVector<BreakContinue, 8> BreakContinueStack; 376 377 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap, 378 CodeGenPGO &PGO) 379 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {} 380 381 void RecordStmtCount(const Stmt *S) { 382 if (RecordNextStmtCount) { 383 CountMap[S] = PGO.getCurrentRegionCount(); 384 RecordNextStmtCount = false; 385 } 386 } 387 388 void VisitStmt(const Stmt *S) { 389 RecordStmtCount(S); 390 for (Stmt::const_child_range I = S->children(); I; ++I) { 391 if (*I) 392 this->Visit(*I); 393 } 394 } 395 396 void VisitFunctionDecl(const FunctionDecl *D) { 397 // Counter tracks entry to the function body. 398 RegionCounter Cnt(PGO, D->getBody()); 399 Cnt.beginRegion(); 400 CountMap[D->getBody()] = PGO.getCurrentRegionCount(); 401 Visit(D->getBody()); 402 } 403 404 // Skip lambda expressions. We visit these as FunctionDecls when we're 405 // generating them and aren't interested in the body when generating a 406 // parent context. 407 void VisitLambdaExpr(const LambdaExpr *LE) {} 408 409 void VisitCapturedDecl(const CapturedDecl *D) { 410 // Counter tracks entry to the capture body. 411 RegionCounter Cnt(PGO, D->getBody()); 412 Cnt.beginRegion(); 413 CountMap[D->getBody()] = PGO.getCurrentRegionCount(); 414 Visit(D->getBody()); 415 } 416 417 void VisitObjCMethodDecl(const ObjCMethodDecl *D) { 418 // Counter tracks entry to the method body. 419 RegionCounter Cnt(PGO, D->getBody()); 420 Cnt.beginRegion(); 421 CountMap[D->getBody()] = PGO.getCurrentRegionCount(); 422 Visit(D->getBody()); 423 } 424 425 void VisitBlockDecl(const BlockDecl *D) { 426 // Counter tracks entry to the block body. 427 RegionCounter Cnt(PGO, D->getBody()); 428 Cnt.beginRegion(); 429 CountMap[D->getBody()] = PGO.getCurrentRegionCount(); 430 Visit(D->getBody()); 431 } 432 433 void VisitReturnStmt(const ReturnStmt *S) { 434 RecordStmtCount(S); 435 if (S->getRetValue()) 436 Visit(S->getRetValue()); 437 PGO.setCurrentRegionUnreachable(); 438 RecordNextStmtCount = true; 439 } 440 441 void VisitGotoStmt(const GotoStmt *S) { 442 RecordStmtCount(S); 443 PGO.setCurrentRegionUnreachable(); 444 RecordNextStmtCount = true; 445 } 446 447 void VisitLabelStmt(const LabelStmt *S) { 448 RecordNextStmtCount = false; 449 // Counter tracks the block following the label. 450 RegionCounter Cnt(PGO, S); 451 Cnt.beginRegion(); 452 CountMap[S] = PGO.getCurrentRegionCount(); 453 Visit(S->getSubStmt()); 454 } 455 456 void VisitBreakStmt(const BreakStmt *S) { 457 RecordStmtCount(S); 458 assert(!BreakContinueStack.empty() && "break not in a loop or switch!"); 459 BreakContinueStack.back().BreakCount += PGO.getCurrentRegionCount(); 460 PGO.setCurrentRegionUnreachable(); 461 RecordNextStmtCount = true; 462 } 463 464 void VisitContinueStmt(const ContinueStmt *S) { 465 RecordStmtCount(S); 466 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!"); 467 BreakContinueStack.back().ContinueCount += PGO.getCurrentRegionCount(); 468 PGO.setCurrentRegionUnreachable(); 469 RecordNextStmtCount = true; 470 } 471 472 void VisitWhileStmt(const WhileStmt *S) { 473 RecordStmtCount(S); 474 // Counter tracks the body of the loop. 475 RegionCounter Cnt(PGO, S); 476 BreakContinueStack.push_back(BreakContinue()); 477 // Visit the body region first so the break/continue adjustments can be 478 // included when visiting the condition. 479 Cnt.beginRegion(); 480 CountMap[S->getBody()] = PGO.getCurrentRegionCount(); 481 Visit(S->getBody()); 482 Cnt.adjustForControlFlow(); 483 484 // ...then go back and propagate counts through the condition. The count 485 // at the start of the condition is the sum of the incoming edges, 486 // the backedge from the end of the loop body, and the edges from 487 // continue statements. 488 BreakContinue BC = BreakContinueStack.pop_back_val(); 489 Cnt.setCurrentRegionCount(Cnt.getParentCount() + 490 Cnt.getAdjustedCount() + BC.ContinueCount); 491 CountMap[S->getCond()] = PGO.getCurrentRegionCount(); 492 Visit(S->getCond()); 493 Cnt.adjustForControlFlow(); 494 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); 495 RecordNextStmtCount = true; 496 } 497 498 void VisitDoStmt(const DoStmt *S) { 499 RecordStmtCount(S); 500 // Counter tracks the body of the loop. 501 RegionCounter Cnt(PGO, S); 502 BreakContinueStack.push_back(BreakContinue()); 503 Cnt.beginRegion(/*AddIncomingFallThrough=*/true); 504 CountMap[S->getBody()] = PGO.getCurrentRegionCount(); 505 Visit(S->getBody()); 506 Cnt.adjustForControlFlow(); 507 508 BreakContinue BC = BreakContinueStack.pop_back_val(); 509 // The count at the start of the condition is equal to the count at the 510 // end of the body. The adjusted count does not include either the 511 // fall-through count coming into the loop or the continue count, so add 512 // both of those separately. This is coincidentally the same equation as 513 // with while loops but for different reasons. 514 Cnt.setCurrentRegionCount(Cnt.getParentCount() + 515 Cnt.getAdjustedCount() + BC.ContinueCount); 516 CountMap[S->getCond()] = PGO.getCurrentRegionCount(); 517 Visit(S->getCond()); 518 Cnt.adjustForControlFlow(); 519 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); 520 RecordNextStmtCount = true; 521 } 522 523 void VisitForStmt(const ForStmt *S) { 524 RecordStmtCount(S); 525 if (S->getInit()) 526 Visit(S->getInit()); 527 // Counter tracks the body of the loop. 528 RegionCounter Cnt(PGO, S); 529 BreakContinueStack.push_back(BreakContinue()); 530 // Visit the body region first. (This is basically the same as a while 531 // loop; see further comments in VisitWhileStmt.) 532 Cnt.beginRegion(); 533 CountMap[S->getBody()] = PGO.getCurrentRegionCount(); 534 Visit(S->getBody()); 535 Cnt.adjustForControlFlow(); 536 537 // The increment is essentially part of the body but it needs to include 538 // the count for all the continue statements. 539 if (S->getInc()) { 540 Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() + 541 BreakContinueStack.back().ContinueCount); 542 CountMap[S->getInc()] = PGO.getCurrentRegionCount(); 543 Visit(S->getInc()); 544 Cnt.adjustForControlFlow(); 545 } 546 547 BreakContinue BC = BreakContinueStack.pop_back_val(); 548 549 // ...then go back and propagate counts through the condition. 550 if (S->getCond()) { 551 Cnt.setCurrentRegionCount(Cnt.getParentCount() + 552 Cnt.getAdjustedCount() + 553 BC.ContinueCount); 554 CountMap[S->getCond()] = PGO.getCurrentRegionCount(); 555 Visit(S->getCond()); 556 Cnt.adjustForControlFlow(); 557 } 558 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); 559 RecordNextStmtCount = true; 560 } 561 562 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) { 563 RecordStmtCount(S); 564 Visit(S->getRangeStmt()); 565 Visit(S->getBeginEndStmt()); 566 // Counter tracks the body of the loop. 567 RegionCounter Cnt(PGO, S); 568 BreakContinueStack.push_back(BreakContinue()); 569 // Visit the body region first. (This is basically the same as a while 570 // loop; see further comments in VisitWhileStmt.) 571 Cnt.beginRegion(); 572 CountMap[S->getLoopVarStmt()] = PGO.getCurrentRegionCount(); 573 Visit(S->getLoopVarStmt()); 574 Visit(S->getBody()); 575 Cnt.adjustForControlFlow(); 576 577 // The increment is essentially part of the body but it needs to include 578 // the count for all the continue statements. 579 Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() + 580 BreakContinueStack.back().ContinueCount); 581 CountMap[S->getInc()] = PGO.getCurrentRegionCount(); 582 Visit(S->getInc()); 583 Cnt.adjustForControlFlow(); 584 585 BreakContinue BC = BreakContinueStack.pop_back_val(); 586 587 // ...then go back and propagate counts through the condition. 588 Cnt.setCurrentRegionCount(Cnt.getParentCount() + 589 Cnt.getAdjustedCount() + 590 BC.ContinueCount); 591 CountMap[S->getCond()] = PGO.getCurrentRegionCount(); 592 Visit(S->getCond()); 593 Cnt.adjustForControlFlow(); 594 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); 595 RecordNextStmtCount = true; 596 } 597 598 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) { 599 RecordStmtCount(S); 600 Visit(S->getElement()); 601 // Counter tracks the body of the loop. 602 RegionCounter Cnt(PGO, S); 603 BreakContinueStack.push_back(BreakContinue()); 604 Cnt.beginRegion(); 605 CountMap[S->getBody()] = PGO.getCurrentRegionCount(); 606 Visit(S->getBody()); 607 BreakContinue BC = BreakContinueStack.pop_back_val(); 608 Cnt.adjustForControlFlow(); 609 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); 610 RecordNextStmtCount = true; 611 } 612 613 void VisitSwitchStmt(const SwitchStmt *S) { 614 RecordStmtCount(S); 615 Visit(S->getCond()); 616 PGO.setCurrentRegionUnreachable(); 617 BreakContinueStack.push_back(BreakContinue()); 618 Visit(S->getBody()); 619 // If the switch is inside a loop, add the continue counts. 620 BreakContinue BC = BreakContinueStack.pop_back_val(); 621 if (!BreakContinueStack.empty()) 622 BreakContinueStack.back().ContinueCount += BC.ContinueCount; 623 // Counter tracks the exit block of the switch. 624 RegionCounter ExitCnt(PGO, S); 625 ExitCnt.beginRegion(); 626 RecordNextStmtCount = true; 627 } 628 629 void VisitCaseStmt(const CaseStmt *S) { 630 RecordNextStmtCount = false; 631 // Counter for this particular case. This counts only jumps from the 632 // switch header and does not include fallthrough from the case before 633 // this one. 634 RegionCounter Cnt(PGO, S); 635 Cnt.beginRegion(/*AddIncomingFallThrough=*/true); 636 CountMap[S] = Cnt.getCount(); 637 RecordNextStmtCount = true; 638 Visit(S->getSubStmt()); 639 } 640 641 void VisitDefaultStmt(const DefaultStmt *S) { 642 RecordNextStmtCount = false; 643 // Counter for this default case. This does not include fallthrough from 644 // the previous case. 645 RegionCounter Cnt(PGO, S); 646 Cnt.beginRegion(/*AddIncomingFallThrough=*/true); 647 CountMap[S] = Cnt.getCount(); 648 RecordNextStmtCount = true; 649 Visit(S->getSubStmt()); 650 } 651 652 void VisitIfStmt(const IfStmt *S) { 653 RecordStmtCount(S); 654 // Counter tracks the "then" part of an if statement. The count for 655 // the "else" part, if it exists, will be calculated from this counter. 656 RegionCounter Cnt(PGO, S); 657 Visit(S->getCond()); 658 659 Cnt.beginRegion(); 660 CountMap[S->getThen()] = PGO.getCurrentRegionCount(); 661 Visit(S->getThen()); 662 Cnt.adjustForControlFlow(); 663 664 if (S->getElse()) { 665 Cnt.beginElseRegion(); 666 CountMap[S->getElse()] = PGO.getCurrentRegionCount(); 667 Visit(S->getElse()); 668 Cnt.adjustForControlFlow(); 669 } 670 Cnt.applyAdjustmentsToRegion(0); 671 RecordNextStmtCount = true; 672 } 673 674 void VisitCXXTryStmt(const CXXTryStmt *S) { 675 RecordStmtCount(S); 676 Visit(S->getTryBlock()); 677 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I) 678 Visit(S->getHandler(I)); 679 // Counter tracks the continuation block of the try statement. 680 RegionCounter Cnt(PGO, S); 681 Cnt.beginRegion(); 682 RecordNextStmtCount = true; 683 } 684 685 void VisitCXXCatchStmt(const CXXCatchStmt *S) { 686 RecordNextStmtCount = false; 687 // Counter tracks the catch statement's handler block. 688 RegionCounter Cnt(PGO, S); 689 Cnt.beginRegion(); 690 CountMap[S] = PGO.getCurrentRegionCount(); 691 Visit(S->getHandlerBlock()); 692 } 693 694 void VisitAbstractConditionalOperator( 695 const AbstractConditionalOperator *E) { 696 RecordStmtCount(E); 697 // Counter tracks the "true" part of a conditional operator. The 698 // count in the "false" part will be calculated from this counter. 699 RegionCounter Cnt(PGO, E); 700 Visit(E->getCond()); 701 702 Cnt.beginRegion(); 703 CountMap[E->getTrueExpr()] = PGO.getCurrentRegionCount(); 704 Visit(E->getTrueExpr()); 705 Cnt.adjustForControlFlow(); 706 707 Cnt.beginElseRegion(); 708 CountMap[E->getFalseExpr()] = PGO.getCurrentRegionCount(); 709 Visit(E->getFalseExpr()); 710 Cnt.adjustForControlFlow(); 711 712 Cnt.applyAdjustmentsToRegion(0); 713 RecordNextStmtCount = true; 714 } 715 716 void VisitBinLAnd(const BinaryOperator *E) { 717 RecordStmtCount(E); 718 // Counter tracks the right hand side of a logical and operator. 719 RegionCounter Cnt(PGO, E); 720 Visit(E->getLHS()); 721 Cnt.beginRegion(); 722 CountMap[E->getRHS()] = PGO.getCurrentRegionCount(); 723 Visit(E->getRHS()); 724 Cnt.adjustForControlFlow(); 725 Cnt.applyAdjustmentsToRegion(0); 726 RecordNextStmtCount = true; 727 } 728 729 void VisitBinLOr(const BinaryOperator *E) { 730 RecordStmtCount(E); 731 // Counter tracks the right hand side of a logical or operator. 732 RegionCounter Cnt(PGO, E); 733 Visit(E->getLHS()); 734 Cnt.beginRegion(); 735 CountMap[E->getRHS()] = PGO.getCurrentRegionCount(); 736 Visit(E->getRHS()); 737 Cnt.adjustForControlFlow(); 738 Cnt.applyAdjustmentsToRegion(0); 739 RecordNextStmtCount = true; 740 } 741 }; 742 } 743 744 void PGOHash::combine(HashType Type) { 745 // Check that we never combine 0 and only have six bits. 746 assert(Type && "Hash is invalid: unexpected type 0"); 747 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types"); 748 749 // Pass through MD5 if enough work has built up. 750 if (Count && Count % NumTypesPerWord == 0) { 751 using namespace llvm::support; 752 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working); 753 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped))); 754 Working = 0; 755 } 756 757 // Accumulate the current type. 758 ++Count; 759 Working = Working << NumBitsPerType | Type; 760 } 761 762 uint64_t PGOHash::finalize() { 763 // Use Working as the hash directly if we never used MD5. 764 if (Count <= NumTypesPerWord) 765 // No need to byte swap here, since none of the math was endian-dependent. 766 // This number will be byte-swapped as required on endianness transitions, 767 // so we will see the same value on the other side. 768 return Working; 769 770 // Check for remaining work in Working. 771 if (Working) 772 MD5.update(Working); 773 774 // Finalize the MD5 and return the hash. 775 llvm::MD5::MD5Result Result; 776 MD5.final(Result); 777 using namespace llvm::support; 778 return endian::read<uint64_t, little, unaligned>(Result); 779 } 780 781 static void emitRuntimeHook(CodeGenModule &CGM) { 782 const char *const RuntimeVarName = "__llvm_profile_runtime"; 783 const char *const RuntimeUserName = "__llvm_profile_runtime_user"; 784 if (CGM.getModule().getGlobalVariable(RuntimeVarName)) 785 return; 786 787 // Declare the runtime hook. 788 llvm::LLVMContext &Ctx = CGM.getLLVMContext(); 789 auto *Int32Ty = llvm::Type::getInt32Ty(Ctx); 790 auto *Var = new llvm::GlobalVariable(CGM.getModule(), Int32Ty, false, 791 llvm::GlobalValue::ExternalLinkage, 792 nullptr, RuntimeVarName); 793 794 // Make a function that uses it. 795 auto *User = llvm::Function::Create(llvm::FunctionType::get(Int32Ty, false), 796 llvm::GlobalValue::LinkOnceODRLinkage, 797 RuntimeUserName, &CGM.getModule()); 798 User->addFnAttr(llvm::Attribute::NoInline); 799 if (CGM.getCodeGenOpts().DisableRedZone) 800 User->addFnAttr(llvm::Attribute::NoRedZone); 801 CGBuilderTy Builder(llvm::BasicBlock::Create(CGM.getLLVMContext(), "", User)); 802 auto *Load = Builder.CreateLoad(Var); 803 Builder.CreateRet(Load); 804 805 // Create a use of the function. Now the definition of the runtime variable 806 // should get pulled in, along with any static initializears. 807 CGM.addUsedGlobal(User); 808 } 809 810 void CodeGenPGO::assignRegionCounters(const Decl *D, llvm::Function *Fn) { 811 bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate; 812 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 813 if (!InstrumentRegions && !PGOReader) 814 return; 815 if (D->isImplicit()) 816 return; 817 setFuncName(Fn); 818 819 // Set the linkage for variables based on the function linkage. Usually, we 820 // want to match it, but available_externally and extern_weak both have the 821 // wrong semantics. 822 VarLinkage = Fn->getLinkage(); 823 switch (VarLinkage) { 824 case llvm::GlobalValue::ExternalWeakLinkage: 825 VarLinkage = llvm::GlobalValue::LinkOnceAnyLinkage; 826 break; 827 case llvm::GlobalValue::AvailableExternallyLinkage: 828 VarLinkage = llvm::GlobalValue::LinkOnceODRLinkage; 829 break; 830 default: 831 break; 832 } 833 834 mapRegionCounters(D); 835 if (InstrumentRegions) { 836 emitRuntimeHook(CGM); 837 emitCounterVariables(); 838 } 839 if (PGOReader) { 840 loadRegionCounts(PGOReader); 841 computeRegionCounts(D); 842 applyFunctionAttributes(PGOReader, Fn); 843 } 844 } 845 846 void CodeGenPGO::mapRegionCounters(const Decl *D) { 847 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>); 848 MapRegionCounters Walker(*RegionCounterMap); 849 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 850 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD)); 851 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 852 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD)); 853 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 854 Walker.TraverseDecl(const_cast<BlockDecl *>(BD)); 855 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 856 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD)); 857 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl"); 858 NumRegionCounters = Walker.NextCounter; 859 FunctionHash = Walker.Hash.finalize(); 860 } 861 862 void CodeGenPGO::computeRegionCounts(const Decl *D) { 863 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>); 864 ComputeRegionCounts Walker(*StmtCountMap, *this); 865 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 866 Walker.VisitFunctionDecl(FD); 867 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 868 Walker.VisitObjCMethodDecl(MD); 869 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 870 Walker.VisitBlockDecl(BD); 871 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 872 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD)); 873 } 874 875 void 876 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader, 877 llvm::Function *Fn) { 878 if (!haveRegionCounts()) 879 return; 880 881 uint64_t MaxFunctionCount = PGOReader->getMaximumFunctionCount(); 882 uint64_t FunctionCount = getRegionCount(0); 883 if (FunctionCount >= (uint64_t)(0.3 * (double)MaxFunctionCount)) 884 // Turn on InlineHint attribute for hot functions. 885 // FIXME: 30% is from preliminary tuning on SPEC, it may not be optimal. 886 Fn->addFnAttr(llvm::Attribute::InlineHint); 887 else if (FunctionCount <= (uint64_t)(0.01 * (double)MaxFunctionCount)) 888 // Turn on Cold attribute for cold functions. 889 // FIXME: 1% is from preliminary tuning on SPEC, it may not be optimal. 890 Fn->addFnAttr(llvm::Attribute::Cold); 891 } 892 893 void CodeGenPGO::emitCounterVariables() { 894 llvm::LLVMContext &Ctx = CGM.getLLVMContext(); 895 llvm::ArrayType *CounterTy = llvm::ArrayType::get(llvm::Type::getInt64Ty(Ctx), 896 NumRegionCounters); 897 RegionCounters = 898 new llvm::GlobalVariable(CGM.getModule(), CounterTy, false, VarLinkage, 899 llvm::Constant::getNullValue(CounterTy), 900 getFuncVarName("counters")); 901 RegionCounters->setAlignment(8); 902 RegionCounters->setSection(getCountersSection(CGM)); 903 } 904 905 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, unsigned Counter) { 906 if (!RegionCounters) 907 return; 908 llvm::Value *Addr = 909 Builder.CreateConstInBoundsGEP2_64(RegionCounters, 0, Counter); 910 llvm::Value *Count = Builder.CreateLoad(Addr, "pgocount"); 911 Count = Builder.CreateAdd(Count, Builder.getInt64(1)); 912 Builder.CreateStore(Count, Addr); 913 } 914 915 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader) { 916 CGM.getPGOStats().Visited++; 917 RegionCounts.reset(new std::vector<uint64_t>); 918 uint64_t Hash; 919 if (PGOReader->getFunctionCounts(getFuncName(), Hash, *RegionCounts)) { 920 CGM.getPGOStats().Missing++; 921 RegionCounts.reset(); 922 } else if (Hash != FunctionHash || 923 RegionCounts->size() != NumRegionCounters) { 924 CGM.getPGOStats().Mismatched++; 925 RegionCounts.reset(); 926 } 927 } 928 929 void CodeGenPGO::destroyRegionCounters() { 930 RegionCounterMap.reset(); 931 StmtCountMap.reset(); 932 RegionCounts.reset(); 933 RegionCounters = nullptr; 934 } 935 936 /// \brief Calculate what to divide by to scale weights. 937 /// 938 /// Given the maximum weight, calculate a divisor that will scale all the 939 /// weights to strictly less than UINT32_MAX. 940 static uint64_t calculateWeightScale(uint64_t MaxWeight) { 941 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1; 942 } 943 944 /// \brief Scale an individual branch weight (and add 1). 945 /// 946 /// Scale a 64-bit weight down to 32-bits using \c Scale. 947 /// 948 /// According to Laplace's Rule of Succession, it is better to compute the 949 /// weight based on the count plus 1, so universally add 1 to the value. 950 /// 951 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no 952 /// greater than \c Weight. 953 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) { 954 assert(Scale && "scale by 0?"); 955 uint64_t Scaled = Weight / Scale + 1; 956 assert(Scaled <= UINT32_MAX && "overflow 32-bits"); 957 return Scaled; 958 } 959 960 llvm::MDNode *CodeGenPGO::createBranchWeights(uint64_t TrueCount, 961 uint64_t FalseCount) { 962 // Check for empty weights. 963 if (!TrueCount && !FalseCount) 964 return nullptr; 965 966 // Calculate how to scale down to 32-bits. 967 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount)); 968 969 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 970 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale), 971 scaleBranchWeight(FalseCount, Scale)); 972 } 973 974 llvm::MDNode *CodeGenPGO::createBranchWeights(ArrayRef<uint64_t> Weights) { 975 // We need at least two elements to create meaningful weights. 976 if (Weights.size() < 2) 977 return nullptr; 978 979 // Check for empty weights. 980 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end()); 981 if (MaxWeight == 0) 982 return nullptr; 983 984 // Calculate how to scale down to 32-bits. 985 uint64_t Scale = calculateWeightScale(MaxWeight); 986 987 SmallVector<uint32_t, 16> ScaledWeights; 988 ScaledWeights.reserve(Weights.size()); 989 for (uint64_t W : Weights) 990 ScaledWeights.push_back(scaleBranchWeight(W, Scale)); 991 992 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 993 return MDHelper.createBranchWeights(ScaledWeights); 994 } 995 996 llvm::MDNode *CodeGenPGO::createLoopWeights(const Stmt *Cond, 997 RegionCounter &Cnt) { 998 if (!haveRegionCounts()) 999 return nullptr; 1000 uint64_t LoopCount = Cnt.getCount(); 1001 uint64_t CondCount = 0; 1002 bool Found = getStmtCount(Cond, CondCount); 1003 assert(Found && "missing expected loop condition count"); 1004 (void)Found; 1005 if (CondCount == 0) 1006 return nullptr; 1007 return createBranchWeights(LoopCount, 1008 std::max(CondCount, LoopCount) - LoopCount); 1009 } 1010