1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // Instrumentation-based profile-guided optimization 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "CodeGenPGO.h" 15 #include "CodeGenFunction.h" 16 #include "clang/AST/RecursiveASTVisitor.h" 17 #include "clang/AST/StmtVisitor.h" 18 #include "llvm/Config/config.h" // for strtoull()/strtoul() define 19 #include "llvm/IR/MDBuilder.h" 20 #include "llvm/ProfileData/InstrProfReader.h" 21 #include "llvm/Support/Endian.h" 22 #include "llvm/Support/FileSystem.h" 23 #include "llvm/Support/MD5.h" 24 25 using namespace clang; 26 using namespace CodeGen; 27 28 void CodeGenPGO::setFuncName(llvm::Function *Fn) { 29 RawFuncName = Fn->getName(); 30 31 // Function names may be prefixed with a binary '1' to indicate 32 // that the backend should not modify the symbols due to any platform 33 // naming convention. Do not include that '1' in the PGO profile name. 34 if (RawFuncName[0] == '\1') 35 RawFuncName = RawFuncName.substr(1); 36 37 if (!Fn->hasLocalLinkage()) { 38 PrefixedFuncName.reset(new std::string(RawFuncName)); 39 return; 40 } 41 42 // For local symbols, prepend the main file name to distinguish them. 43 // Do not include the full path in the file name since there's no guarantee 44 // that it will stay the same, e.g., if the files are checked out from 45 // version control in different locations. 46 PrefixedFuncName.reset(new std::string(CGM.getCodeGenOpts().MainFileName)); 47 if (PrefixedFuncName->empty()) 48 PrefixedFuncName->assign("<unknown>"); 49 PrefixedFuncName->append(":"); 50 PrefixedFuncName->append(RawFuncName); 51 } 52 53 static llvm::Function *getRegisterFunc(CodeGenModule &CGM) { 54 return CGM.getModule().getFunction("__llvm_profile_register_functions"); 55 } 56 57 static llvm::BasicBlock *getOrInsertRegisterBB(CodeGenModule &CGM) { 58 // Don't do this for Darwin. compiler-rt uses linker magic. 59 if (CGM.getTarget().getTriple().isOSDarwin()) 60 return nullptr; 61 62 // Only need to insert this once per module. 63 if (llvm::Function *RegisterF = getRegisterFunc(CGM)) 64 return &RegisterF->getEntryBlock(); 65 66 // Construct the function. 67 auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext()); 68 auto *RegisterFTy = llvm::FunctionType::get(VoidTy, false); 69 auto *RegisterF = llvm::Function::Create(RegisterFTy, 70 llvm::GlobalValue::InternalLinkage, 71 "__llvm_profile_register_functions", 72 &CGM.getModule()); 73 RegisterF->setUnnamedAddr(true); 74 if (CGM.getCodeGenOpts().DisableRedZone) 75 RegisterF->addFnAttr(llvm::Attribute::NoRedZone); 76 77 // Construct and return the entry block. 78 auto *BB = llvm::BasicBlock::Create(CGM.getLLVMContext(), "", RegisterF); 79 CGBuilderTy Builder(BB); 80 Builder.CreateRetVoid(); 81 return BB; 82 } 83 84 static llvm::Constant *getOrInsertRuntimeRegister(CodeGenModule &CGM) { 85 auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext()); 86 auto *VoidPtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext()); 87 auto *RuntimeRegisterTy = llvm::FunctionType::get(VoidTy, VoidPtrTy, false); 88 return CGM.getModule().getOrInsertFunction("__llvm_profile_register_function", 89 RuntimeRegisterTy); 90 } 91 92 static bool isMachO(const CodeGenModule &CGM) { 93 return CGM.getTarget().getTriple().isOSBinFormatMachO(); 94 } 95 96 static StringRef getCountersSection(const CodeGenModule &CGM) { 97 return isMachO(CGM) ? "__DATA,__llvm_prf_cnts" : "__llvm_prf_cnts"; 98 } 99 100 static StringRef getNameSection(const CodeGenModule &CGM) { 101 return isMachO(CGM) ? "__DATA,__llvm_prf_names" : "__llvm_prf_names"; 102 } 103 104 static StringRef getDataSection(const CodeGenModule &CGM) { 105 return isMachO(CGM) ? "__DATA,__llvm_prf_data" : "__llvm_prf_data"; 106 } 107 108 llvm::GlobalVariable *CodeGenPGO::buildDataVar() { 109 // Create name variable. 110 llvm::LLVMContext &Ctx = CGM.getLLVMContext(); 111 auto *VarName = llvm::ConstantDataArray::getString(Ctx, getFuncName(), 112 false); 113 auto *Name = new llvm::GlobalVariable(CGM.getModule(), VarName->getType(), 114 true, VarLinkage, VarName, 115 getFuncVarName("name")); 116 Name->setSection(getNameSection(CGM)); 117 Name->setAlignment(1); 118 119 // Create data variable. 120 auto *Int32Ty = llvm::Type::getInt32Ty(Ctx); 121 auto *Int64Ty = llvm::Type::getInt64Ty(Ctx); 122 auto *Int8PtrTy = llvm::Type::getInt8PtrTy(Ctx); 123 auto *Int64PtrTy = llvm::Type::getInt64PtrTy(Ctx); 124 llvm::Type *DataTypes[] = { 125 Int32Ty, Int32Ty, Int64Ty, Int8PtrTy, Int64PtrTy 126 }; 127 auto *DataTy = llvm::StructType::get(Ctx, makeArrayRef(DataTypes)); 128 llvm::Constant *DataVals[] = { 129 llvm::ConstantInt::get(Int32Ty, getFuncName().size()), 130 llvm::ConstantInt::get(Int32Ty, NumRegionCounters), 131 llvm::ConstantInt::get(Int64Ty, FunctionHash), 132 llvm::ConstantExpr::getBitCast(Name, Int8PtrTy), 133 llvm::ConstantExpr::getBitCast(RegionCounters, Int64PtrTy) 134 }; 135 auto *Data = 136 new llvm::GlobalVariable(CGM.getModule(), DataTy, true, VarLinkage, 137 llvm::ConstantStruct::get(DataTy, DataVals), 138 getFuncVarName("data")); 139 140 // All the data should be packed into an array in its own section. 141 Data->setSection(getDataSection(CGM)); 142 Data->setAlignment(8); 143 144 // Make sure the data doesn't get deleted. 145 CGM.addUsedGlobal(Data); 146 return Data; 147 } 148 149 void CodeGenPGO::emitInstrumentationData() { 150 if (!CGM.getCodeGenOpts().ProfileInstrGenerate) 151 return; 152 153 // Build the data. 154 auto *Data = buildDataVar(); 155 156 // Register the data. 157 auto *RegisterBB = getOrInsertRegisterBB(CGM); 158 if (!RegisterBB) 159 return; 160 CGBuilderTy Builder(RegisterBB->getTerminator()); 161 auto *VoidPtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext()); 162 Builder.CreateCall(getOrInsertRuntimeRegister(CGM), 163 Builder.CreateBitCast(Data, VoidPtrTy)); 164 } 165 166 llvm::Function *CodeGenPGO::emitInitialization(CodeGenModule &CGM) { 167 if (!CGM.getCodeGenOpts().ProfileInstrGenerate) 168 return nullptr; 169 170 assert(CGM.getModule().getFunction("__llvm_profile_init") == nullptr && 171 "profile initialization already emitted"); 172 173 // Get the function to call at initialization. 174 llvm::Constant *RegisterF = getRegisterFunc(CGM); 175 if (!RegisterF) 176 return nullptr; 177 178 // Create the initialization function. 179 auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext()); 180 auto *F = llvm::Function::Create(llvm::FunctionType::get(VoidTy, false), 181 llvm::GlobalValue::InternalLinkage, 182 "__llvm_profile_init", &CGM.getModule()); 183 F->setUnnamedAddr(true); 184 F->addFnAttr(llvm::Attribute::NoInline); 185 if (CGM.getCodeGenOpts().DisableRedZone) 186 F->addFnAttr(llvm::Attribute::NoRedZone); 187 188 // Add the basic block and the necessary calls. 189 CGBuilderTy Builder(llvm::BasicBlock::Create(CGM.getLLVMContext(), "", F)); 190 Builder.CreateCall(RegisterF); 191 Builder.CreateRetVoid(); 192 193 return F; 194 } 195 196 namespace { 197 /// \brief Stable hasher for PGO region counters. 198 /// 199 /// PGOHash produces a stable hash of a given function's control flow. 200 /// 201 /// Changing the output of this hash will invalidate all previously generated 202 /// profiles -- i.e., don't do it. 203 /// 204 /// \note When this hash does eventually change (years?), we still need to 205 /// support old hashes. We'll need to pull in the version number from the 206 /// profile data format and use the matching hash function. 207 class PGOHash { 208 uint64_t Working; 209 unsigned Count; 210 llvm::MD5 MD5; 211 212 static const int NumBitsPerType = 6; 213 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType; 214 static const unsigned TooBig = 1u << NumBitsPerType; 215 216 public: 217 /// \brief Hash values for AST nodes. 218 /// 219 /// Distinct values for AST nodes that have region counters attached. 220 /// 221 /// These values must be stable. All new members must be added at the end, 222 /// and no members should be removed. Changing the enumeration value for an 223 /// AST node will affect the hash of every function that contains that node. 224 enum HashType : unsigned char { 225 None = 0, 226 LabelStmt = 1, 227 WhileStmt, 228 DoStmt, 229 ForStmt, 230 CXXForRangeStmt, 231 ObjCForCollectionStmt, 232 SwitchStmt, 233 CaseStmt, 234 DefaultStmt, 235 IfStmt, 236 CXXTryStmt, 237 CXXCatchStmt, 238 ConditionalOperator, 239 BinaryOperatorLAnd, 240 BinaryOperatorLOr, 241 BinaryConditionalOperator, 242 243 // Keep this last. It's for the static assert that follows. 244 LastHashType 245 }; 246 static_assert(LastHashType <= TooBig, "Too many types in HashType"); 247 248 // TODO: When this format changes, take in a version number here, and use the 249 // old hash calculation for file formats that used the old hash. 250 PGOHash() : Working(0), Count(0) {} 251 void combine(HashType Type); 252 uint64_t finalize(); 253 }; 254 const int PGOHash::NumBitsPerType; 255 const unsigned PGOHash::NumTypesPerWord; 256 const unsigned PGOHash::TooBig; 257 258 /// A RecursiveASTVisitor that fills a map of statements to PGO counters. 259 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> { 260 /// The next counter value to assign. 261 unsigned NextCounter; 262 /// The function hash. 263 PGOHash Hash; 264 /// The map of statements to counters. 265 llvm::DenseMap<const Stmt *, unsigned> &CounterMap; 266 267 MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap) 268 : NextCounter(0), CounterMap(CounterMap) {} 269 270 // Blocks and lambdas are handled as separate functions, so we need not 271 // traverse them in the parent context. 272 bool TraverseBlockExpr(BlockExpr *BE) { return true; } 273 bool TraverseLambdaBody(LambdaExpr *LE) { return true; } 274 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; } 275 276 bool VisitDecl(const Decl *D) { 277 switch (D->getKind()) { 278 default: 279 break; 280 case Decl::Function: 281 case Decl::CXXMethod: 282 case Decl::CXXConstructor: 283 case Decl::CXXDestructor: 284 case Decl::CXXConversion: 285 case Decl::ObjCMethod: 286 case Decl::Block: 287 case Decl::Captured: 288 CounterMap[D->getBody()] = NextCounter++; 289 break; 290 } 291 return true; 292 } 293 294 bool VisitStmt(const Stmt *S) { 295 auto Type = getHashType(S); 296 if (Type == PGOHash::None) 297 return true; 298 299 CounterMap[S] = NextCounter++; 300 Hash.combine(Type); 301 return true; 302 } 303 PGOHash::HashType getHashType(const Stmt *S) { 304 switch (S->getStmtClass()) { 305 default: 306 break; 307 case Stmt::LabelStmtClass: 308 return PGOHash::LabelStmt; 309 case Stmt::WhileStmtClass: 310 return PGOHash::WhileStmt; 311 case Stmt::DoStmtClass: 312 return PGOHash::DoStmt; 313 case Stmt::ForStmtClass: 314 return PGOHash::ForStmt; 315 case Stmt::CXXForRangeStmtClass: 316 return PGOHash::CXXForRangeStmt; 317 case Stmt::ObjCForCollectionStmtClass: 318 return PGOHash::ObjCForCollectionStmt; 319 case Stmt::SwitchStmtClass: 320 return PGOHash::SwitchStmt; 321 case Stmt::CaseStmtClass: 322 return PGOHash::CaseStmt; 323 case Stmt::DefaultStmtClass: 324 return PGOHash::DefaultStmt; 325 case Stmt::IfStmtClass: 326 return PGOHash::IfStmt; 327 case Stmt::CXXTryStmtClass: 328 return PGOHash::CXXTryStmt; 329 case Stmt::CXXCatchStmtClass: 330 return PGOHash::CXXCatchStmt; 331 case Stmt::ConditionalOperatorClass: 332 return PGOHash::ConditionalOperator; 333 case Stmt::BinaryConditionalOperatorClass: 334 return PGOHash::BinaryConditionalOperator; 335 case Stmt::BinaryOperatorClass: { 336 const BinaryOperator *BO = cast<BinaryOperator>(S); 337 if (BO->getOpcode() == BO_LAnd) 338 return PGOHash::BinaryOperatorLAnd; 339 if (BO->getOpcode() == BO_LOr) 340 return PGOHash::BinaryOperatorLOr; 341 break; 342 } 343 } 344 return PGOHash::None; 345 } 346 }; 347 348 /// A StmtVisitor that propagates the raw counts through the AST and 349 /// records the count at statements where the value may change. 350 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { 351 /// PGO state. 352 CodeGenPGO &PGO; 353 354 /// A flag that is set when the current count should be recorded on the 355 /// next statement, such as at the exit of a loop. 356 bool RecordNextStmtCount; 357 358 /// The map of statements to count values. 359 llvm::DenseMap<const Stmt *, uint64_t> &CountMap; 360 361 /// BreakContinueStack - Keep counts of breaks and continues inside loops. 362 struct BreakContinue { 363 uint64_t BreakCount; 364 uint64_t ContinueCount; 365 BreakContinue() : BreakCount(0), ContinueCount(0) {} 366 }; 367 SmallVector<BreakContinue, 8> BreakContinueStack; 368 369 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap, 370 CodeGenPGO &PGO) 371 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {} 372 373 void RecordStmtCount(const Stmt *S) { 374 if (RecordNextStmtCount) { 375 CountMap[S] = PGO.getCurrentRegionCount(); 376 RecordNextStmtCount = false; 377 } 378 } 379 380 void VisitStmt(const Stmt *S) { 381 RecordStmtCount(S); 382 for (Stmt::const_child_range I = S->children(); I; ++I) { 383 if (*I) 384 this->Visit(*I); 385 } 386 } 387 388 void VisitFunctionDecl(const FunctionDecl *D) { 389 // Counter tracks entry to the function body. 390 RegionCounter Cnt(PGO, D->getBody()); 391 Cnt.beginRegion(); 392 CountMap[D->getBody()] = PGO.getCurrentRegionCount(); 393 Visit(D->getBody()); 394 } 395 396 // Skip lambda expressions. We visit these as FunctionDecls when we're 397 // generating them and aren't interested in the body when generating a 398 // parent context. 399 void VisitLambdaExpr(const LambdaExpr *LE) {} 400 401 void VisitCapturedDecl(const CapturedDecl *D) { 402 // Counter tracks entry to the capture body. 403 RegionCounter Cnt(PGO, D->getBody()); 404 Cnt.beginRegion(); 405 CountMap[D->getBody()] = PGO.getCurrentRegionCount(); 406 Visit(D->getBody()); 407 } 408 409 void VisitObjCMethodDecl(const ObjCMethodDecl *D) { 410 // Counter tracks entry to the method body. 411 RegionCounter Cnt(PGO, D->getBody()); 412 Cnt.beginRegion(); 413 CountMap[D->getBody()] = PGO.getCurrentRegionCount(); 414 Visit(D->getBody()); 415 } 416 417 void VisitBlockDecl(const BlockDecl *D) { 418 // Counter tracks entry to the block body. 419 RegionCounter Cnt(PGO, D->getBody()); 420 Cnt.beginRegion(); 421 CountMap[D->getBody()] = PGO.getCurrentRegionCount(); 422 Visit(D->getBody()); 423 } 424 425 void VisitReturnStmt(const ReturnStmt *S) { 426 RecordStmtCount(S); 427 if (S->getRetValue()) 428 Visit(S->getRetValue()); 429 PGO.setCurrentRegionUnreachable(); 430 RecordNextStmtCount = true; 431 } 432 433 void VisitGotoStmt(const GotoStmt *S) { 434 RecordStmtCount(S); 435 PGO.setCurrentRegionUnreachable(); 436 RecordNextStmtCount = true; 437 } 438 439 void VisitLabelStmt(const LabelStmt *S) { 440 RecordNextStmtCount = false; 441 // Counter tracks the block following the label. 442 RegionCounter Cnt(PGO, S); 443 Cnt.beginRegion(); 444 CountMap[S] = PGO.getCurrentRegionCount(); 445 Visit(S->getSubStmt()); 446 } 447 448 void VisitBreakStmt(const BreakStmt *S) { 449 RecordStmtCount(S); 450 assert(!BreakContinueStack.empty() && "break not in a loop or switch!"); 451 BreakContinueStack.back().BreakCount += PGO.getCurrentRegionCount(); 452 PGO.setCurrentRegionUnreachable(); 453 RecordNextStmtCount = true; 454 } 455 456 void VisitContinueStmt(const ContinueStmt *S) { 457 RecordStmtCount(S); 458 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!"); 459 BreakContinueStack.back().ContinueCount += PGO.getCurrentRegionCount(); 460 PGO.setCurrentRegionUnreachable(); 461 RecordNextStmtCount = true; 462 } 463 464 void VisitWhileStmt(const WhileStmt *S) { 465 RecordStmtCount(S); 466 // Counter tracks the body of the loop. 467 RegionCounter Cnt(PGO, S); 468 BreakContinueStack.push_back(BreakContinue()); 469 // Visit the body region first so the break/continue adjustments can be 470 // included when visiting the condition. 471 Cnt.beginRegion(); 472 CountMap[S->getBody()] = PGO.getCurrentRegionCount(); 473 Visit(S->getBody()); 474 Cnt.adjustForControlFlow(); 475 476 // ...then go back and propagate counts through the condition. The count 477 // at the start of the condition is the sum of the incoming edges, 478 // the backedge from the end of the loop body, and the edges from 479 // continue statements. 480 BreakContinue BC = BreakContinueStack.pop_back_val(); 481 Cnt.setCurrentRegionCount(Cnt.getParentCount() + 482 Cnt.getAdjustedCount() + BC.ContinueCount); 483 CountMap[S->getCond()] = PGO.getCurrentRegionCount(); 484 Visit(S->getCond()); 485 Cnt.adjustForControlFlow(); 486 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); 487 RecordNextStmtCount = true; 488 } 489 490 void VisitDoStmt(const DoStmt *S) { 491 RecordStmtCount(S); 492 // Counter tracks the body of the loop. 493 RegionCounter Cnt(PGO, S); 494 BreakContinueStack.push_back(BreakContinue()); 495 Cnt.beginRegion(/*AddIncomingFallThrough=*/true); 496 CountMap[S->getBody()] = PGO.getCurrentRegionCount(); 497 Visit(S->getBody()); 498 Cnt.adjustForControlFlow(); 499 500 BreakContinue BC = BreakContinueStack.pop_back_val(); 501 // The count at the start of the condition is equal to the count at the 502 // end of the body. The adjusted count does not include either the 503 // fall-through count coming into the loop or the continue count, so add 504 // both of those separately. This is coincidentally the same equation as 505 // with while loops but for different reasons. 506 Cnt.setCurrentRegionCount(Cnt.getParentCount() + 507 Cnt.getAdjustedCount() + BC.ContinueCount); 508 CountMap[S->getCond()] = PGO.getCurrentRegionCount(); 509 Visit(S->getCond()); 510 Cnt.adjustForControlFlow(); 511 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); 512 RecordNextStmtCount = true; 513 } 514 515 void VisitForStmt(const ForStmt *S) { 516 RecordStmtCount(S); 517 if (S->getInit()) 518 Visit(S->getInit()); 519 // Counter tracks the body of the loop. 520 RegionCounter Cnt(PGO, S); 521 BreakContinueStack.push_back(BreakContinue()); 522 // Visit the body region first. (This is basically the same as a while 523 // loop; see further comments in VisitWhileStmt.) 524 Cnt.beginRegion(); 525 CountMap[S->getBody()] = PGO.getCurrentRegionCount(); 526 Visit(S->getBody()); 527 Cnt.adjustForControlFlow(); 528 529 // The increment is essentially part of the body but it needs to include 530 // the count for all the continue statements. 531 if (S->getInc()) { 532 Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() + 533 BreakContinueStack.back().ContinueCount); 534 CountMap[S->getInc()] = PGO.getCurrentRegionCount(); 535 Visit(S->getInc()); 536 Cnt.adjustForControlFlow(); 537 } 538 539 BreakContinue BC = BreakContinueStack.pop_back_val(); 540 541 // ...then go back and propagate counts through the condition. 542 if (S->getCond()) { 543 Cnt.setCurrentRegionCount(Cnt.getParentCount() + 544 Cnt.getAdjustedCount() + 545 BC.ContinueCount); 546 CountMap[S->getCond()] = PGO.getCurrentRegionCount(); 547 Visit(S->getCond()); 548 Cnt.adjustForControlFlow(); 549 } 550 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); 551 RecordNextStmtCount = true; 552 } 553 554 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) { 555 RecordStmtCount(S); 556 Visit(S->getRangeStmt()); 557 Visit(S->getBeginEndStmt()); 558 // Counter tracks the body of the loop. 559 RegionCounter Cnt(PGO, S); 560 BreakContinueStack.push_back(BreakContinue()); 561 // Visit the body region first. (This is basically the same as a while 562 // loop; see further comments in VisitWhileStmt.) 563 Cnt.beginRegion(); 564 CountMap[S->getLoopVarStmt()] = PGO.getCurrentRegionCount(); 565 Visit(S->getLoopVarStmt()); 566 Visit(S->getBody()); 567 Cnt.adjustForControlFlow(); 568 569 // The increment is essentially part of the body but it needs to include 570 // the count for all the continue statements. 571 Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() + 572 BreakContinueStack.back().ContinueCount); 573 CountMap[S->getInc()] = PGO.getCurrentRegionCount(); 574 Visit(S->getInc()); 575 Cnt.adjustForControlFlow(); 576 577 BreakContinue BC = BreakContinueStack.pop_back_val(); 578 579 // ...then go back and propagate counts through the condition. 580 Cnt.setCurrentRegionCount(Cnt.getParentCount() + 581 Cnt.getAdjustedCount() + 582 BC.ContinueCount); 583 CountMap[S->getCond()] = PGO.getCurrentRegionCount(); 584 Visit(S->getCond()); 585 Cnt.adjustForControlFlow(); 586 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); 587 RecordNextStmtCount = true; 588 } 589 590 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) { 591 RecordStmtCount(S); 592 Visit(S->getElement()); 593 // Counter tracks the body of the loop. 594 RegionCounter Cnt(PGO, S); 595 BreakContinueStack.push_back(BreakContinue()); 596 Cnt.beginRegion(); 597 CountMap[S->getBody()] = PGO.getCurrentRegionCount(); 598 Visit(S->getBody()); 599 BreakContinue BC = BreakContinueStack.pop_back_val(); 600 Cnt.adjustForControlFlow(); 601 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); 602 RecordNextStmtCount = true; 603 } 604 605 void VisitSwitchStmt(const SwitchStmt *S) { 606 RecordStmtCount(S); 607 Visit(S->getCond()); 608 PGO.setCurrentRegionUnreachable(); 609 BreakContinueStack.push_back(BreakContinue()); 610 Visit(S->getBody()); 611 // If the switch is inside a loop, add the continue counts. 612 BreakContinue BC = BreakContinueStack.pop_back_val(); 613 if (!BreakContinueStack.empty()) 614 BreakContinueStack.back().ContinueCount += BC.ContinueCount; 615 // Counter tracks the exit block of the switch. 616 RegionCounter ExitCnt(PGO, S); 617 ExitCnt.beginRegion(); 618 RecordNextStmtCount = true; 619 } 620 621 void VisitCaseStmt(const CaseStmt *S) { 622 RecordNextStmtCount = false; 623 // Counter for this particular case. This counts only jumps from the 624 // switch header and does not include fallthrough from the case before 625 // this one. 626 RegionCounter Cnt(PGO, S); 627 Cnt.beginRegion(/*AddIncomingFallThrough=*/true); 628 CountMap[S] = Cnt.getCount(); 629 RecordNextStmtCount = true; 630 Visit(S->getSubStmt()); 631 } 632 633 void VisitDefaultStmt(const DefaultStmt *S) { 634 RecordNextStmtCount = false; 635 // Counter for this default case. This does not include fallthrough from 636 // the previous case. 637 RegionCounter Cnt(PGO, S); 638 Cnt.beginRegion(/*AddIncomingFallThrough=*/true); 639 CountMap[S] = Cnt.getCount(); 640 RecordNextStmtCount = true; 641 Visit(S->getSubStmt()); 642 } 643 644 void VisitIfStmt(const IfStmt *S) { 645 RecordStmtCount(S); 646 // Counter tracks the "then" part of an if statement. The count for 647 // the "else" part, if it exists, will be calculated from this counter. 648 RegionCounter Cnt(PGO, S); 649 Visit(S->getCond()); 650 651 Cnt.beginRegion(); 652 CountMap[S->getThen()] = PGO.getCurrentRegionCount(); 653 Visit(S->getThen()); 654 Cnt.adjustForControlFlow(); 655 656 if (S->getElse()) { 657 Cnt.beginElseRegion(); 658 CountMap[S->getElse()] = PGO.getCurrentRegionCount(); 659 Visit(S->getElse()); 660 Cnt.adjustForControlFlow(); 661 } 662 Cnt.applyAdjustmentsToRegion(0); 663 RecordNextStmtCount = true; 664 } 665 666 void VisitCXXTryStmt(const CXXTryStmt *S) { 667 RecordStmtCount(S); 668 Visit(S->getTryBlock()); 669 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I) 670 Visit(S->getHandler(I)); 671 // Counter tracks the continuation block of the try statement. 672 RegionCounter Cnt(PGO, S); 673 Cnt.beginRegion(); 674 RecordNextStmtCount = true; 675 } 676 677 void VisitCXXCatchStmt(const CXXCatchStmt *S) { 678 RecordNextStmtCount = false; 679 // Counter tracks the catch statement's handler block. 680 RegionCounter Cnt(PGO, S); 681 Cnt.beginRegion(); 682 CountMap[S] = PGO.getCurrentRegionCount(); 683 Visit(S->getHandlerBlock()); 684 } 685 686 void VisitAbstractConditionalOperator( 687 const AbstractConditionalOperator *E) { 688 RecordStmtCount(E); 689 // Counter tracks the "true" part of a conditional operator. The 690 // count in the "false" part will be calculated from this counter. 691 RegionCounter Cnt(PGO, E); 692 Visit(E->getCond()); 693 694 Cnt.beginRegion(); 695 CountMap[E->getTrueExpr()] = PGO.getCurrentRegionCount(); 696 Visit(E->getTrueExpr()); 697 Cnt.adjustForControlFlow(); 698 699 Cnt.beginElseRegion(); 700 CountMap[E->getFalseExpr()] = PGO.getCurrentRegionCount(); 701 Visit(E->getFalseExpr()); 702 Cnt.adjustForControlFlow(); 703 704 Cnt.applyAdjustmentsToRegion(0); 705 RecordNextStmtCount = true; 706 } 707 708 void VisitBinLAnd(const BinaryOperator *E) { 709 RecordStmtCount(E); 710 // Counter tracks the right hand side of a logical and operator. 711 RegionCounter Cnt(PGO, E); 712 Visit(E->getLHS()); 713 Cnt.beginRegion(); 714 CountMap[E->getRHS()] = PGO.getCurrentRegionCount(); 715 Visit(E->getRHS()); 716 Cnt.adjustForControlFlow(); 717 Cnt.applyAdjustmentsToRegion(0); 718 RecordNextStmtCount = true; 719 } 720 721 void VisitBinLOr(const BinaryOperator *E) { 722 RecordStmtCount(E); 723 // Counter tracks the right hand side of a logical or operator. 724 RegionCounter Cnt(PGO, E); 725 Visit(E->getLHS()); 726 Cnt.beginRegion(); 727 CountMap[E->getRHS()] = PGO.getCurrentRegionCount(); 728 Visit(E->getRHS()); 729 Cnt.adjustForControlFlow(); 730 Cnt.applyAdjustmentsToRegion(0); 731 RecordNextStmtCount = true; 732 } 733 }; 734 } 735 736 void PGOHash::combine(HashType Type) { 737 // Check that we never combine 0 and only have six bits. 738 assert(Type && "Hash is invalid: unexpected type 0"); 739 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types"); 740 741 // Pass through MD5 if enough work has built up. 742 if (Count && Count % NumTypesPerWord == 0) { 743 using namespace llvm::support; 744 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working); 745 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped))); 746 Working = 0; 747 } 748 749 // Accumulate the current type. 750 ++Count; 751 Working = Working << NumBitsPerType | Type; 752 } 753 754 uint64_t PGOHash::finalize() { 755 // Use Working as the hash directly if we never used MD5. 756 if (Count <= NumTypesPerWord) 757 // No need to byte swap here, since none of the math was endian-dependent. 758 // This number will be byte-swapped as required on endianness transitions, 759 // so we will see the same value on the other side. 760 return Working; 761 762 // Check for remaining work in Working. 763 if (Working) 764 MD5.update(Working); 765 766 // Finalize the MD5 and return the hash. 767 llvm::MD5::MD5Result Result; 768 MD5.final(Result); 769 using namespace llvm::support; 770 return endian::read<uint64_t, little, unaligned>(Result); 771 } 772 773 static void emitRuntimeHook(CodeGenModule &CGM) { 774 const char *const RuntimeVarName = "__llvm_profile_runtime"; 775 const char *const RuntimeUserName = "__llvm_profile_runtime_user"; 776 if (CGM.getModule().getGlobalVariable(RuntimeVarName)) 777 return; 778 779 // Declare the runtime hook. 780 llvm::LLVMContext &Ctx = CGM.getLLVMContext(); 781 auto *Int32Ty = llvm::Type::getInt32Ty(Ctx); 782 auto *Var = new llvm::GlobalVariable(CGM.getModule(), Int32Ty, false, 783 llvm::GlobalValue::ExternalLinkage, 784 nullptr, RuntimeVarName); 785 786 // Make a function that uses it. 787 auto *User = llvm::Function::Create(llvm::FunctionType::get(Int32Ty, false), 788 llvm::GlobalValue::LinkOnceODRLinkage, 789 RuntimeUserName, &CGM.getModule()); 790 User->addFnAttr(llvm::Attribute::NoInline); 791 if (CGM.getCodeGenOpts().DisableRedZone) 792 User->addFnAttr(llvm::Attribute::NoRedZone); 793 CGBuilderTy Builder(llvm::BasicBlock::Create(CGM.getLLVMContext(), "", User)); 794 auto *Load = Builder.CreateLoad(Var); 795 Builder.CreateRet(Load); 796 797 // Create a use of the function. Now the definition of the runtime variable 798 // should get pulled in, along with any static initializears. 799 CGM.addUsedGlobal(User); 800 } 801 802 void CodeGenPGO::assignRegionCounters(const Decl *D, llvm::Function *Fn) { 803 bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate; 804 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 805 if (!InstrumentRegions && !PGOReader) 806 return; 807 if (!D) 808 return; 809 setFuncName(Fn); 810 811 // Set the linkage for variables based on the function linkage. Usually, we 812 // want to match it, but available_externally and extern_weak both have the 813 // wrong semantics. 814 VarLinkage = Fn->getLinkage(); 815 switch (VarLinkage) { 816 case llvm::GlobalValue::ExternalWeakLinkage: 817 VarLinkage = llvm::GlobalValue::LinkOnceAnyLinkage; 818 break; 819 case llvm::GlobalValue::AvailableExternallyLinkage: 820 VarLinkage = llvm::GlobalValue::LinkOnceODRLinkage; 821 break; 822 default: 823 break; 824 } 825 826 mapRegionCounters(D); 827 if (InstrumentRegions) { 828 emitRuntimeHook(CGM); 829 emitCounterVariables(); 830 } 831 if (PGOReader) { 832 loadRegionCounts(PGOReader); 833 computeRegionCounts(D); 834 applyFunctionAttributes(PGOReader, Fn); 835 } 836 } 837 838 void CodeGenPGO::mapRegionCounters(const Decl *D) { 839 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>); 840 MapRegionCounters Walker(*RegionCounterMap); 841 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 842 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD)); 843 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 844 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD)); 845 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 846 Walker.TraverseDecl(const_cast<BlockDecl *>(BD)); 847 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 848 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD)); 849 NumRegionCounters = Walker.NextCounter; 850 FunctionHash = Walker.Hash.finalize(); 851 } 852 853 void CodeGenPGO::computeRegionCounts(const Decl *D) { 854 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>); 855 ComputeRegionCounts Walker(*StmtCountMap, *this); 856 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 857 Walker.VisitFunctionDecl(FD); 858 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 859 Walker.VisitObjCMethodDecl(MD); 860 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 861 Walker.VisitBlockDecl(BD); 862 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 863 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD)); 864 } 865 866 void 867 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader, 868 llvm::Function *Fn) { 869 if (!haveRegionCounts()) 870 return; 871 872 uint64_t MaxFunctionCount = PGOReader->getMaximumFunctionCount(); 873 uint64_t FunctionCount = getRegionCount(0); 874 if (FunctionCount >= (uint64_t)(0.3 * (double)MaxFunctionCount)) 875 // Turn on InlineHint attribute for hot functions. 876 // FIXME: 30% is from preliminary tuning on SPEC, it may not be optimal. 877 Fn->addFnAttr(llvm::Attribute::InlineHint); 878 else if (FunctionCount <= (uint64_t)(0.01 * (double)MaxFunctionCount)) 879 // Turn on Cold attribute for cold functions. 880 // FIXME: 1% is from preliminary tuning on SPEC, it may not be optimal. 881 Fn->addFnAttr(llvm::Attribute::Cold); 882 } 883 884 void CodeGenPGO::emitCounterVariables() { 885 llvm::LLVMContext &Ctx = CGM.getLLVMContext(); 886 llvm::ArrayType *CounterTy = llvm::ArrayType::get(llvm::Type::getInt64Ty(Ctx), 887 NumRegionCounters); 888 RegionCounters = 889 new llvm::GlobalVariable(CGM.getModule(), CounterTy, false, VarLinkage, 890 llvm::Constant::getNullValue(CounterTy), 891 getFuncVarName("counters")); 892 RegionCounters->setAlignment(8); 893 RegionCounters->setSection(getCountersSection(CGM)); 894 } 895 896 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, unsigned Counter) { 897 if (!RegionCounters) 898 return; 899 llvm::Value *Addr = 900 Builder.CreateConstInBoundsGEP2_64(RegionCounters, 0, Counter); 901 llvm::Value *Count = Builder.CreateLoad(Addr, "pgocount"); 902 Count = Builder.CreateAdd(Count, Builder.getInt64(1)); 903 Builder.CreateStore(Count, Addr); 904 } 905 906 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader) { 907 CGM.getPGOStats().Visited++; 908 RegionCounts.reset(new std::vector<uint64_t>); 909 uint64_t Hash; 910 if (PGOReader->getFunctionCounts(getFuncName(), Hash, *RegionCounts)) { 911 CGM.getPGOStats().Missing++; 912 RegionCounts.reset(); 913 } else if (Hash != FunctionHash || 914 RegionCounts->size() != NumRegionCounters) { 915 CGM.getPGOStats().Mismatched++; 916 RegionCounts.reset(); 917 } 918 } 919 920 void CodeGenPGO::destroyRegionCounters() { 921 RegionCounterMap.reset(); 922 StmtCountMap.reset(); 923 RegionCounts.reset(); 924 } 925 926 /// \brief Calculate what to divide by to scale weights. 927 /// 928 /// Given the maximum weight, calculate a divisor that will scale all the 929 /// weights to strictly less than UINT32_MAX. 930 static uint64_t calculateWeightScale(uint64_t MaxWeight) { 931 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1; 932 } 933 934 /// \brief Scale an individual branch weight (and add 1). 935 /// 936 /// Scale a 64-bit weight down to 32-bits using \c Scale. 937 /// 938 /// According to Laplace's Rule of Succession, it is better to compute the 939 /// weight based on the count plus 1, so universally add 1 to the value. 940 /// 941 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no 942 /// greater than \c Weight. 943 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) { 944 assert(Scale && "scale by 0?"); 945 uint64_t Scaled = Weight / Scale + 1; 946 assert(Scaled <= UINT32_MAX && "overflow 32-bits"); 947 return Scaled; 948 } 949 950 llvm::MDNode *CodeGenPGO::createBranchWeights(uint64_t TrueCount, 951 uint64_t FalseCount) { 952 // Check for empty weights. 953 if (!TrueCount && !FalseCount) 954 return nullptr; 955 956 // Calculate how to scale down to 32-bits. 957 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount)); 958 959 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 960 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale), 961 scaleBranchWeight(FalseCount, Scale)); 962 } 963 964 llvm::MDNode *CodeGenPGO::createBranchWeights(ArrayRef<uint64_t> Weights) { 965 // We need at least two elements to create meaningful weights. 966 if (Weights.size() < 2) 967 return nullptr; 968 969 // Check for empty weights. 970 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end()); 971 if (MaxWeight == 0) 972 return nullptr; 973 974 // Calculate how to scale down to 32-bits. 975 uint64_t Scale = calculateWeightScale(MaxWeight); 976 977 SmallVector<uint32_t, 16> ScaledWeights; 978 ScaledWeights.reserve(Weights.size()); 979 for (uint64_t W : Weights) 980 ScaledWeights.push_back(scaleBranchWeight(W, Scale)); 981 982 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 983 return MDHelper.createBranchWeights(ScaledWeights); 984 } 985 986 llvm::MDNode *CodeGenPGO::createLoopWeights(const Stmt *Cond, 987 RegionCounter &Cnt) { 988 if (!haveRegionCounts()) 989 return nullptr; 990 uint64_t LoopCount = Cnt.getCount(); 991 uint64_t CondCount = 0; 992 bool Found = getStmtCount(Cond, CondCount); 993 assert(Found && "missing expected loop condition count"); 994 (void)Found; 995 if (CondCount == 0) 996 return nullptr; 997 return createBranchWeights(LoopCount, 998 std::max(CondCount, LoopCount) - LoopCount); 999 } 1000