1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // Instrumentation-based profile-guided optimization 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "CodeGenPGO.h" 15 #include "CodeGenFunction.h" 16 #include "clang/AST/RecursiveASTVisitor.h" 17 #include "clang/AST/StmtVisitor.h" 18 #include "llvm/IR/MDBuilder.h" 19 #include "llvm/ProfileData/InstrProfReader.h" 20 #include "llvm/Support/Endian.h" 21 #include "llvm/Support/FileSystem.h" 22 #include "llvm/Support/MD5.h" 23 24 using namespace clang; 25 using namespace CodeGen; 26 27 void CodeGenPGO::setFuncName(llvm::Function *Fn) { 28 RawFuncName = Fn->getName(); 29 30 // Function names may be prefixed with a binary '1' to indicate 31 // that the backend should not modify the symbols due to any platform 32 // naming convention. Do not include that '1' in the PGO profile name. 33 if (RawFuncName[0] == '\1') 34 RawFuncName = RawFuncName.substr(1); 35 36 if (!Fn->hasLocalLinkage()) { 37 PrefixedFuncName.reset(new std::string(RawFuncName)); 38 return; 39 } 40 41 // For local symbols, prepend the main file name to distinguish them. 42 // Do not include the full path in the file name since there's no guarantee 43 // that it will stay the same, e.g., if the files are checked out from 44 // version control in different locations. 45 PrefixedFuncName.reset(new std::string(CGM.getCodeGenOpts().MainFileName)); 46 if (PrefixedFuncName->empty()) 47 PrefixedFuncName->assign("<unknown>"); 48 PrefixedFuncName->append(":"); 49 PrefixedFuncName->append(RawFuncName); 50 } 51 52 static llvm::Function *getRegisterFunc(CodeGenModule &CGM) { 53 return CGM.getModule().getFunction("__llvm_profile_register_functions"); 54 } 55 56 static llvm::BasicBlock *getOrInsertRegisterBB(CodeGenModule &CGM) { 57 // Don't do this for Darwin. compiler-rt uses linker magic. 58 if (CGM.getTarget().getTriple().isOSDarwin()) 59 return nullptr; 60 61 // Only need to insert this once per module. 62 if (llvm::Function *RegisterF = getRegisterFunc(CGM)) 63 return &RegisterF->getEntryBlock(); 64 65 // Construct the function. 66 auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext()); 67 auto *RegisterFTy = llvm::FunctionType::get(VoidTy, false); 68 auto *RegisterF = llvm::Function::Create(RegisterFTy, 69 llvm::GlobalValue::InternalLinkage, 70 "__llvm_profile_register_functions", 71 &CGM.getModule()); 72 RegisterF->setUnnamedAddr(true); 73 if (CGM.getCodeGenOpts().DisableRedZone) 74 RegisterF->addFnAttr(llvm::Attribute::NoRedZone); 75 76 // Construct and return the entry block. 77 auto *BB = llvm::BasicBlock::Create(CGM.getLLVMContext(), "", RegisterF); 78 CGBuilderTy Builder(BB); 79 Builder.CreateRetVoid(); 80 return BB; 81 } 82 83 static llvm::Constant *getOrInsertRuntimeRegister(CodeGenModule &CGM) { 84 auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext()); 85 auto *VoidPtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext()); 86 auto *RuntimeRegisterTy = llvm::FunctionType::get(VoidTy, VoidPtrTy, false); 87 return CGM.getModule().getOrInsertFunction("__llvm_profile_register_function", 88 RuntimeRegisterTy); 89 } 90 91 static bool isMachO(const CodeGenModule &CGM) { 92 return CGM.getTarget().getTriple().isOSBinFormatMachO(); 93 } 94 95 static StringRef getCountersSection(const CodeGenModule &CGM) { 96 return isMachO(CGM) ? "__DATA,__llvm_prf_cnts" : "__llvm_prf_cnts"; 97 } 98 99 static StringRef getNameSection(const CodeGenModule &CGM) { 100 return isMachO(CGM) ? "__DATA,__llvm_prf_names" : "__llvm_prf_names"; 101 } 102 103 static StringRef getDataSection(const CodeGenModule &CGM) { 104 return isMachO(CGM) ? "__DATA,__llvm_prf_data" : "__llvm_prf_data"; 105 } 106 107 llvm::GlobalVariable *CodeGenPGO::buildDataVar() { 108 // Create name variable. 109 llvm::LLVMContext &Ctx = CGM.getLLVMContext(); 110 auto *VarName = llvm::ConstantDataArray::getString(Ctx, getFuncName(), 111 false); 112 auto *Name = new llvm::GlobalVariable(CGM.getModule(), VarName->getType(), 113 true, VarLinkage, VarName, 114 getFuncVarName("name")); 115 Name->setSection(getNameSection(CGM)); 116 Name->setAlignment(1); 117 118 // Create data variable. 119 auto *Int32Ty = llvm::Type::getInt32Ty(Ctx); 120 auto *Int64Ty = llvm::Type::getInt64Ty(Ctx); 121 auto *Int8PtrTy = llvm::Type::getInt8PtrTy(Ctx); 122 auto *Int64PtrTy = llvm::Type::getInt64PtrTy(Ctx); 123 llvm::Type *DataTypes[] = { 124 Int32Ty, Int32Ty, Int64Ty, Int8PtrTy, Int64PtrTy 125 }; 126 auto *DataTy = llvm::StructType::get(Ctx, makeArrayRef(DataTypes)); 127 llvm::Constant *DataVals[] = { 128 llvm::ConstantInt::get(Int32Ty, getFuncName().size()), 129 llvm::ConstantInt::get(Int32Ty, NumRegionCounters), 130 llvm::ConstantInt::get(Int64Ty, FunctionHash), 131 llvm::ConstantExpr::getBitCast(Name, Int8PtrTy), 132 llvm::ConstantExpr::getBitCast(RegionCounters, Int64PtrTy) 133 }; 134 auto *Data = 135 new llvm::GlobalVariable(CGM.getModule(), DataTy, true, VarLinkage, 136 llvm::ConstantStruct::get(DataTy, DataVals), 137 getFuncVarName("data")); 138 139 // All the data should be packed into an array in its own section. 140 Data->setSection(getDataSection(CGM)); 141 Data->setAlignment(8); 142 143 // Make sure the data doesn't get deleted. 144 CGM.addUsedGlobal(Data); 145 return Data; 146 } 147 148 void CodeGenPGO::emitInstrumentationData() { 149 if (!RegionCounters) 150 return; 151 152 // Build the data. 153 auto *Data = buildDataVar(); 154 155 // Register the data. 156 auto *RegisterBB = getOrInsertRegisterBB(CGM); 157 if (!RegisterBB) 158 return; 159 CGBuilderTy Builder(RegisterBB->getTerminator()); 160 auto *VoidPtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext()); 161 Builder.CreateCall(getOrInsertRuntimeRegister(CGM), 162 Builder.CreateBitCast(Data, VoidPtrTy)); 163 } 164 165 llvm::Function *CodeGenPGO::emitInitialization(CodeGenModule &CGM) { 166 if (!CGM.getCodeGenOpts().ProfileInstrGenerate) 167 return nullptr; 168 169 assert(CGM.getModule().getFunction("__llvm_profile_init") == nullptr && 170 "profile initialization already emitted"); 171 172 // Get the function to call at initialization. 173 llvm::Constant *RegisterF = getRegisterFunc(CGM); 174 if (!RegisterF) 175 return nullptr; 176 177 // Create the initialization function. 178 auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext()); 179 auto *F = llvm::Function::Create(llvm::FunctionType::get(VoidTy, false), 180 llvm::GlobalValue::InternalLinkage, 181 "__llvm_profile_init", &CGM.getModule()); 182 F->setUnnamedAddr(true); 183 F->addFnAttr(llvm::Attribute::NoInline); 184 if (CGM.getCodeGenOpts().DisableRedZone) 185 F->addFnAttr(llvm::Attribute::NoRedZone); 186 187 // Add the basic block and the necessary calls. 188 CGBuilderTy Builder(llvm::BasicBlock::Create(CGM.getLLVMContext(), "", F)); 189 Builder.CreateCall(RegisterF); 190 Builder.CreateRetVoid(); 191 192 return F; 193 } 194 195 namespace { 196 /// \brief Stable hasher for PGO region counters. 197 /// 198 /// PGOHash produces a stable hash of a given function's control flow. 199 /// 200 /// Changing the output of this hash will invalidate all previously generated 201 /// profiles -- i.e., don't do it. 202 /// 203 /// \note When this hash does eventually change (years?), we still need to 204 /// support old hashes. We'll need to pull in the version number from the 205 /// profile data format and use the matching hash function. 206 class PGOHash { 207 uint64_t Working; 208 unsigned Count; 209 llvm::MD5 MD5; 210 211 static const int NumBitsPerType = 6; 212 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType; 213 static const unsigned TooBig = 1u << NumBitsPerType; 214 215 public: 216 /// \brief Hash values for AST nodes. 217 /// 218 /// Distinct values for AST nodes that have region counters attached. 219 /// 220 /// These values must be stable. All new members must be added at the end, 221 /// and no members should be removed. Changing the enumeration value for an 222 /// AST node will affect the hash of every function that contains that node. 223 enum HashType : unsigned char { 224 None = 0, 225 LabelStmt = 1, 226 WhileStmt, 227 DoStmt, 228 ForStmt, 229 CXXForRangeStmt, 230 ObjCForCollectionStmt, 231 SwitchStmt, 232 CaseStmt, 233 DefaultStmt, 234 IfStmt, 235 CXXTryStmt, 236 CXXCatchStmt, 237 ConditionalOperator, 238 BinaryOperatorLAnd, 239 BinaryOperatorLOr, 240 BinaryConditionalOperator, 241 242 // Keep this last. It's for the static assert that follows. 243 LastHashType 244 }; 245 static_assert(LastHashType <= TooBig, "Too many types in HashType"); 246 247 // TODO: When this format changes, take in a version number here, and use the 248 // old hash calculation for file formats that used the old hash. 249 PGOHash() : Working(0), Count(0) {} 250 void combine(HashType Type); 251 uint64_t finalize(); 252 }; 253 const int PGOHash::NumBitsPerType; 254 const unsigned PGOHash::NumTypesPerWord; 255 const unsigned PGOHash::TooBig; 256 257 /// A RecursiveASTVisitor that fills a map of statements to PGO counters. 258 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> { 259 /// The next counter value to assign. 260 unsigned NextCounter; 261 /// The function hash. 262 PGOHash Hash; 263 /// The map of statements to counters. 264 llvm::DenseMap<const Stmt *, unsigned> &CounterMap; 265 266 MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap) 267 : NextCounter(0), CounterMap(CounterMap) {} 268 269 // Blocks and lambdas are handled as separate functions, so we need not 270 // traverse them in the parent context. 271 bool TraverseBlockExpr(BlockExpr *BE) { return true; } 272 bool TraverseLambdaBody(LambdaExpr *LE) { return true; } 273 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; } 274 275 bool VisitDecl(const Decl *D) { 276 switch (D->getKind()) { 277 default: 278 break; 279 case Decl::Function: 280 case Decl::CXXMethod: 281 case Decl::CXXConstructor: 282 case Decl::CXXDestructor: 283 case Decl::CXXConversion: 284 case Decl::ObjCMethod: 285 case Decl::Block: 286 case Decl::Captured: 287 CounterMap[D->getBody()] = NextCounter++; 288 break; 289 } 290 return true; 291 } 292 293 bool VisitStmt(const Stmt *S) { 294 auto Type = getHashType(S); 295 if (Type == PGOHash::None) 296 return true; 297 298 CounterMap[S] = NextCounter++; 299 Hash.combine(Type); 300 return true; 301 } 302 PGOHash::HashType getHashType(const Stmt *S) { 303 switch (S->getStmtClass()) { 304 default: 305 break; 306 case Stmt::LabelStmtClass: 307 return PGOHash::LabelStmt; 308 case Stmt::WhileStmtClass: 309 return PGOHash::WhileStmt; 310 case Stmt::DoStmtClass: 311 return PGOHash::DoStmt; 312 case Stmt::ForStmtClass: 313 return PGOHash::ForStmt; 314 case Stmt::CXXForRangeStmtClass: 315 return PGOHash::CXXForRangeStmt; 316 case Stmt::ObjCForCollectionStmtClass: 317 return PGOHash::ObjCForCollectionStmt; 318 case Stmt::SwitchStmtClass: 319 return PGOHash::SwitchStmt; 320 case Stmt::CaseStmtClass: 321 return PGOHash::CaseStmt; 322 case Stmt::DefaultStmtClass: 323 return PGOHash::DefaultStmt; 324 case Stmt::IfStmtClass: 325 return PGOHash::IfStmt; 326 case Stmt::CXXTryStmtClass: 327 return PGOHash::CXXTryStmt; 328 case Stmt::CXXCatchStmtClass: 329 return PGOHash::CXXCatchStmt; 330 case Stmt::ConditionalOperatorClass: 331 return PGOHash::ConditionalOperator; 332 case Stmt::BinaryConditionalOperatorClass: 333 return PGOHash::BinaryConditionalOperator; 334 case Stmt::BinaryOperatorClass: { 335 const BinaryOperator *BO = cast<BinaryOperator>(S); 336 if (BO->getOpcode() == BO_LAnd) 337 return PGOHash::BinaryOperatorLAnd; 338 if (BO->getOpcode() == BO_LOr) 339 return PGOHash::BinaryOperatorLOr; 340 break; 341 } 342 } 343 return PGOHash::None; 344 } 345 }; 346 347 /// A StmtVisitor that propagates the raw counts through the AST and 348 /// records the count at statements where the value may change. 349 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { 350 /// PGO state. 351 CodeGenPGO &PGO; 352 353 /// A flag that is set when the current count should be recorded on the 354 /// next statement, such as at the exit of a loop. 355 bool RecordNextStmtCount; 356 357 /// The map of statements to count values. 358 llvm::DenseMap<const Stmt *, uint64_t> &CountMap; 359 360 /// BreakContinueStack - Keep counts of breaks and continues inside loops. 361 struct BreakContinue { 362 uint64_t BreakCount; 363 uint64_t ContinueCount; 364 BreakContinue() : BreakCount(0), ContinueCount(0) {} 365 }; 366 SmallVector<BreakContinue, 8> BreakContinueStack; 367 368 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap, 369 CodeGenPGO &PGO) 370 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {} 371 372 void RecordStmtCount(const Stmt *S) { 373 if (RecordNextStmtCount) { 374 CountMap[S] = PGO.getCurrentRegionCount(); 375 RecordNextStmtCount = false; 376 } 377 } 378 379 void VisitStmt(const Stmt *S) { 380 RecordStmtCount(S); 381 for (Stmt::const_child_range I = S->children(); I; ++I) { 382 if (*I) 383 this->Visit(*I); 384 } 385 } 386 387 void VisitFunctionDecl(const FunctionDecl *D) { 388 // Counter tracks entry to the function body. 389 RegionCounter Cnt(PGO, D->getBody()); 390 Cnt.beginRegion(); 391 CountMap[D->getBody()] = PGO.getCurrentRegionCount(); 392 Visit(D->getBody()); 393 } 394 395 // Skip lambda expressions. We visit these as FunctionDecls when we're 396 // generating them and aren't interested in the body when generating a 397 // parent context. 398 void VisitLambdaExpr(const LambdaExpr *LE) {} 399 400 void VisitCapturedDecl(const CapturedDecl *D) { 401 // Counter tracks entry to the capture body. 402 RegionCounter Cnt(PGO, D->getBody()); 403 Cnt.beginRegion(); 404 CountMap[D->getBody()] = PGO.getCurrentRegionCount(); 405 Visit(D->getBody()); 406 } 407 408 void VisitObjCMethodDecl(const ObjCMethodDecl *D) { 409 // Counter tracks entry to the method body. 410 RegionCounter Cnt(PGO, D->getBody()); 411 Cnt.beginRegion(); 412 CountMap[D->getBody()] = PGO.getCurrentRegionCount(); 413 Visit(D->getBody()); 414 } 415 416 void VisitBlockDecl(const BlockDecl *D) { 417 // Counter tracks entry to the block body. 418 RegionCounter Cnt(PGO, D->getBody()); 419 Cnt.beginRegion(); 420 CountMap[D->getBody()] = PGO.getCurrentRegionCount(); 421 Visit(D->getBody()); 422 } 423 424 void VisitReturnStmt(const ReturnStmt *S) { 425 RecordStmtCount(S); 426 if (S->getRetValue()) 427 Visit(S->getRetValue()); 428 PGO.setCurrentRegionUnreachable(); 429 RecordNextStmtCount = true; 430 } 431 432 void VisitGotoStmt(const GotoStmt *S) { 433 RecordStmtCount(S); 434 PGO.setCurrentRegionUnreachable(); 435 RecordNextStmtCount = true; 436 } 437 438 void VisitLabelStmt(const LabelStmt *S) { 439 RecordNextStmtCount = false; 440 // Counter tracks the block following the label. 441 RegionCounter Cnt(PGO, S); 442 Cnt.beginRegion(); 443 CountMap[S] = PGO.getCurrentRegionCount(); 444 Visit(S->getSubStmt()); 445 } 446 447 void VisitBreakStmt(const BreakStmt *S) { 448 RecordStmtCount(S); 449 assert(!BreakContinueStack.empty() && "break not in a loop or switch!"); 450 BreakContinueStack.back().BreakCount += PGO.getCurrentRegionCount(); 451 PGO.setCurrentRegionUnreachable(); 452 RecordNextStmtCount = true; 453 } 454 455 void VisitContinueStmt(const ContinueStmt *S) { 456 RecordStmtCount(S); 457 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!"); 458 BreakContinueStack.back().ContinueCount += PGO.getCurrentRegionCount(); 459 PGO.setCurrentRegionUnreachable(); 460 RecordNextStmtCount = true; 461 } 462 463 void VisitWhileStmt(const WhileStmt *S) { 464 RecordStmtCount(S); 465 // Counter tracks the body of the loop. 466 RegionCounter Cnt(PGO, S); 467 BreakContinueStack.push_back(BreakContinue()); 468 // Visit the body region first so the break/continue adjustments can be 469 // included when visiting the condition. 470 Cnt.beginRegion(); 471 CountMap[S->getBody()] = PGO.getCurrentRegionCount(); 472 Visit(S->getBody()); 473 Cnt.adjustForControlFlow(); 474 475 // ...then go back and propagate counts through the condition. The count 476 // at the start of the condition is the sum of the incoming edges, 477 // the backedge from the end of the loop body, and the edges from 478 // continue statements. 479 BreakContinue BC = BreakContinueStack.pop_back_val(); 480 Cnt.setCurrentRegionCount(Cnt.getParentCount() + 481 Cnt.getAdjustedCount() + BC.ContinueCount); 482 CountMap[S->getCond()] = PGO.getCurrentRegionCount(); 483 Visit(S->getCond()); 484 Cnt.adjustForControlFlow(); 485 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); 486 RecordNextStmtCount = true; 487 } 488 489 void VisitDoStmt(const DoStmt *S) { 490 RecordStmtCount(S); 491 // Counter tracks the body of the loop. 492 RegionCounter Cnt(PGO, S); 493 BreakContinueStack.push_back(BreakContinue()); 494 Cnt.beginRegion(/*AddIncomingFallThrough=*/true); 495 CountMap[S->getBody()] = PGO.getCurrentRegionCount(); 496 Visit(S->getBody()); 497 Cnt.adjustForControlFlow(); 498 499 BreakContinue BC = BreakContinueStack.pop_back_val(); 500 // The count at the start of the condition is equal to the count at the 501 // end of the body. The adjusted count does not include either the 502 // fall-through count coming into the loop or the continue count, so add 503 // both of those separately. This is coincidentally the same equation as 504 // with while loops but for different reasons. 505 Cnt.setCurrentRegionCount(Cnt.getParentCount() + 506 Cnt.getAdjustedCount() + BC.ContinueCount); 507 CountMap[S->getCond()] = PGO.getCurrentRegionCount(); 508 Visit(S->getCond()); 509 Cnt.adjustForControlFlow(); 510 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); 511 RecordNextStmtCount = true; 512 } 513 514 void VisitForStmt(const ForStmt *S) { 515 RecordStmtCount(S); 516 if (S->getInit()) 517 Visit(S->getInit()); 518 // Counter tracks the body of the loop. 519 RegionCounter Cnt(PGO, S); 520 BreakContinueStack.push_back(BreakContinue()); 521 // Visit the body region first. (This is basically the same as a while 522 // loop; see further comments in VisitWhileStmt.) 523 Cnt.beginRegion(); 524 CountMap[S->getBody()] = PGO.getCurrentRegionCount(); 525 Visit(S->getBody()); 526 Cnt.adjustForControlFlow(); 527 528 // The increment is essentially part of the body but it needs to include 529 // the count for all the continue statements. 530 if (S->getInc()) { 531 Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() + 532 BreakContinueStack.back().ContinueCount); 533 CountMap[S->getInc()] = PGO.getCurrentRegionCount(); 534 Visit(S->getInc()); 535 Cnt.adjustForControlFlow(); 536 } 537 538 BreakContinue BC = BreakContinueStack.pop_back_val(); 539 540 // ...then go back and propagate counts through the condition. 541 if (S->getCond()) { 542 Cnt.setCurrentRegionCount(Cnt.getParentCount() + 543 Cnt.getAdjustedCount() + 544 BC.ContinueCount); 545 CountMap[S->getCond()] = PGO.getCurrentRegionCount(); 546 Visit(S->getCond()); 547 Cnt.adjustForControlFlow(); 548 } 549 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); 550 RecordNextStmtCount = true; 551 } 552 553 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) { 554 RecordStmtCount(S); 555 Visit(S->getRangeStmt()); 556 Visit(S->getBeginEndStmt()); 557 // Counter tracks the body of the loop. 558 RegionCounter Cnt(PGO, S); 559 BreakContinueStack.push_back(BreakContinue()); 560 // Visit the body region first. (This is basically the same as a while 561 // loop; see further comments in VisitWhileStmt.) 562 Cnt.beginRegion(); 563 CountMap[S->getLoopVarStmt()] = PGO.getCurrentRegionCount(); 564 Visit(S->getLoopVarStmt()); 565 Visit(S->getBody()); 566 Cnt.adjustForControlFlow(); 567 568 // The increment is essentially part of the body but it needs to include 569 // the count for all the continue statements. 570 Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() + 571 BreakContinueStack.back().ContinueCount); 572 CountMap[S->getInc()] = PGO.getCurrentRegionCount(); 573 Visit(S->getInc()); 574 Cnt.adjustForControlFlow(); 575 576 BreakContinue BC = BreakContinueStack.pop_back_val(); 577 578 // ...then go back and propagate counts through the condition. 579 Cnt.setCurrentRegionCount(Cnt.getParentCount() + 580 Cnt.getAdjustedCount() + 581 BC.ContinueCount); 582 CountMap[S->getCond()] = PGO.getCurrentRegionCount(); 583 Visit(S->getCond()); 584 Cnt.adjustForControlFlow(); 585 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); 586 RecordNextStmtCount = true; 587 } 588 589 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) { 590 RecordStmtCount(S); 591 Visit(S->getElement()); 592 // Counter tracks the body of the loop. 593 RegionCounter Cnt(PGO, S); 594 BreakContinueStack.push_back(BreakContinue()); 595 Cnt.beginRegion(); 596 CountMap[S->getBody()] = PGO.getCurrentRegionCount(); 597 Visit(S->getBody()); 598 BreakContinue BC = BreakContinueStack.pop_back_val(); 599 Cnt.adjustForControlFlow(); 600 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); 601 RecordNextStmtCount = true; 602 } 603 604 void VisitSwitchStmt(const SwitchStmt *S) { 605 RecordStmtCount(S); 606 Visit(S->getCond()); 607 PGO.setCurrentRegionUnreachable(); 608 BreakContinueStack.push_back(BreakContinue()); 609 Visit(S->getBody()); 610 // If the switch is inside a loop, add the continue counts. 611 BreakContinue BC = BreakContinueStack.pop_back_val(); 612 if (!BreakContinueStack.empty()) 613 BreakContinueStack.back().ContinueCount += BC.ContinueCount; 614 // Counter tracks the exit block of the switch. 615 RegionCounter ExitCnt(PGO, S); 616 ExitCnt.beginRegion(); 617 RecordNextStmtCount = true; 618 } 619 620 void VisitCaseStmt(const CaseStmt *S) { 621 RecordNextStmtCount = false; 622 // Counter for this particular case. This counts only jumps from the 623 // switch header and does not include fallthrough from the case before 624 // this one. 625 RegionCounter Cnt(PGO, S); 626 Cnt.beginRegion(/*AddIncomingFallThrough=*/true); 627 CountMap[S] = Cnt.getCount(); 628 RecordNextStmtCount = true; 629 Visit(S->getSubStmt()); 630 } 631 632 void VisitDefaultStmt(const DefaultStmt *S) { 633 RecordNextStmtCount = false; 634 // Counter for this default case. This does not include fallthrough from 635 // the previous case. 636 RegionCounter Cnt(PGO, S); 637 Cnt.beginRegion(/*AddIncomingFallThrough=*/true); 638 CountMap[S] = Cnt.getCount(); 639 RecordNextStmtCount = true; 640 Visit(S->getSubStmt()); 641 } 642 643 void VisitIfStmt(const IfStmt *S) { 644 RecordStmtCount(S); 645 // Counter tracks the "then" part of an if statement. The count for 646 // the "else" part, if it exists, will be calculated from this counter. 647 RegionCounter Cnt(PGO, S); 648 Visit(S->getCond()); 649 650 Cnt.beginRegion(); 651 CountMap[S->getThen()] = PGO.getCurrentRegionCount(); 652 Visit(S->getThen()); 653 Cnt.adjustForControlFlow(); 654 655 if (S->getElse()) { 656 Cnt.beginElseRegion(); 657 CountMap[S->getElse()] = PGO.getCurrentRegionCount(); 658 Visit(S->getElse()); 659 Cnt.adjustForControlFlow(); 660 } 661 Cnt.applyAdjustmentsToRegion(0); 662 RecordNextStmtCount = true; 663 } 664 665 void VisitCXXTryStmt(const CXXTryStmt *S) { 666 RecordStmtCount(S); 667 Visit(S->getTryBlock()); 668 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I) 669 Visit(S->getHandler(I)); 670 // Counter tracks the continuation block of the try statement. 671 RegionCounter Cnt(PGO, S); 672 Cnt.beginRegion(); 673 RecordNextStmtCount = true; 674 } 675 676 void VisitCXXCatchStmt(const CXXCatchStmt *S) { 677 RecordNextStmtCount = false; 678 // Counter tracks the catch statement's handler block. 679 RegionCounter Cnt(PGO, S); 680 Cnt.beginRegion(); 681 CountMap[S] = PGO.getCurrentRegionCount(); 682 Visit(S->getHandlerBlock()); 683 } 684 685 void VisitAbstractConditionalOperator( 686 const AbstractConditionalOperator *E) { 687 RecordStmtCount(E); 688 // Counter tracks the "true" part of a conditional operator. The 689 // count in the "false" part will be calculated from this counter. 690 RegionCounter Cnt(PGO, E); 691 Visit(E->getCond()); 692 693 Cnt.beginRegion(); 694 CountMap[E->getTrueExpr()] = PGO.getCurrentRegionCount(); 695 Visit(E->getTrueExpr()); 696 Cnt.adjustForControlFlow(); 697 698 Cnt.beginElseRegion(); 699 CountMap[E->getFalseExpr()] = PGO.getCurrentRegionCount(); 700 Visit(E->getFalseExpr()); 701 Cnt.adjustForControlFlow(); 702 703 Cnt.applyAdjustmentsToRegion(0); 704 RecordNextStmtCount = true; 705 } 706 707 void VisitBinLAnd(const BinaryOperator *E) { 708 RecordStmtCount(E); 709 // Counter tracks the right hand side of a logical and operator. 710 RegionCounter Cnt(PGO, E); 711 Visit(E->getLHS()); 712 Cnt.beginRegion(); 713 CountMap[E->getRHS()] = PGO.getCurrentRegionCount(); 714 Visit(E->getRHS()); 715 Cnt.adjustForControlFlow(); 716 Cnt.applyAdjustmentsToRegion(0); 717 RecordNextStmtCount = true; 718 } 719 720 void VisitBinLOr(const BinaryOperator *E) { 721 RecordStmtCount(E); 722 // Counter tracks the right hand side of a logical or operator. 723 RegionCounter Cnt(PGO, E); 724 Visit(E->getLHS()); 725 Cnt.beginRegion(); 726 CountMap[E->getRHS()] = PGO.getCurrentRegionCount(); 727 Visit(E->getRHS()); 728 Cnt.adjustForControlFlow(); 729 Cnt.applyAdjustmentsToRegion(0); 730 RecordNextStmtCount = true; 731 } 732 }; 733 } 734 735 void PGOHash::combine(HashType Type) { 736 // Check that we never combine 0 and only have six bits. 737 assert(Type && "Hash is invalid: unexpected type 0"); 738 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types"); 739 740 // Pass through MD5 if enough work has built up. 741 if (Count && Count % NumTypesPerWord == 0) { 742 using namespace llvm::support; 743 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working); 744 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped))); 745 Working = 0; 746 } 747 748 // Accumulate the current type. 749 ++Count; 750 Working = Working << NumBitsPerType | Type; 751 } 752 753 uint64_t PGOHash::finalize() { 754 // Use Working as the hash directly if we never used MD5. 755 if (Count <= NumTypesPerWord) 756 // No need to byte swap here, since none of the math was endian-dependent. 757 // This number will be byte-swapped as required on endianness transitions, 758 // so we will see the same value on the other side. 759 return Working; 760 761 // Check for remaining work in Working. 762 if (Working) 763 MD5.update(Working); 764 765 // Finalize the MD5 and return the hash. 766 llvm::MD5::MD5Result Result; 767 MD5.final(Result); 768 using namespace llvm::support; 769 return endian::read<uint64_t, little, unaligned>(Result); 770 } 771 772 static void emitRuntimeHook(CodeGenModule &CGM) { 773 const char *const RuntimeVarName = "__llvm_profile_runtime"; 774 const char *const RuntimeUserName = "__llvm_profile_runtime_user"; 775 if (CGM.getModule().getGlobalVariable(RuntimeVarName)) 776 return; 777 778 // Declare the runtime hook. 779 llvm::LLVMContext &Ctx = CGM.getLLVMContext(); 780 auto *Int32Ty = llvm::Type::getInt32Ty(Ctx); 781 auto *Var = new llvm::GlobalVariable(CGM.getModule(), Int32Ty, false, 782 llvm::GlobalValue::ExternalLinkage, 783 nullptr, RuntimeVarName); 784 785 // Make a function that uses it. 786 auto *User = llvm::Function::Create(llvm::FunctionType::get(Int32Ty, false), 787 llvm::GlobalValue::LinkOnceODRLinkage, 788 RuntimeUserName, &CGM.getModule()); 789 User->addFnAttr(llvm::Attribute::NoInline); 790 if (CGM.getCodeGenOpts().DisableRedZone) 791 User->addFnAttr(llvm::Attribute::NoRedZone); 792 CGBuilderTy Builder(llvm::BasicBlock::Create(CGM.getLLVMContext(), "", User)); 793 auto *Load = Builder.CreateLoad(Var); 794 Builder.CreateRet(Load); 795 796 // Create a use of the function. Now the definition of the runtime variable 797 // should get pulled in, along with any static initializears. 798 CGM.addUsedGlobal(User); 799 } 800 801 void CodeGenPGO::assignRegionCounters(const Decl *D, llvm::Function *Fn) { 802 bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate; 803 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 804 if (!InstrumentRegions && !PGOReader) 805 return; 806 if (D->isImplicit()) 807 return; 808 setFuncName(Fn); 809 810 // Set the linkage for variables based on the function linkage. Usually, we 811 // want to match it, but available_externally and extern_weak both have the 812 // wrong semantics. 813 VarLinkage = Fn->getLinkage(); 814 switch (VarLinkage) { 815 case llvm::GlobalValue::ExternalWeakLinkage: 816 VarLinkage = llvm::GlobalValue::LinkOnceAnyLinkage; 817 break; 818 case llvm::GlobalValue::AvailableExternallyLinkage: 819 VarLinkage = llvm::GlobalValue::LinkOnceODRLinkage; 820 break; 821 default: 822 break; 823 } 824 825 mapRegionCounters(D); 826 if (InstrumentRegions) { 827 emitRuntimeHook(CGM); 828 emitCounterVariables(); 829 } 830 if (PGOReader) { 831 loadRegionCounts(PGOReader); 832 computeRegionCounts(D); 833 applyFunctionAttributes(PGOReader, Fn); 834 } 835 } 836 837 void CodeGenPGO::mapRegionCounters(const Decl *D) { 838 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>); 839 MapRegionCounters Walker(*RegionCounterMap); 840 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 841 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD)); 842 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 843 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD)); 844 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 845 Walker.TraverseDecl(const_cast<BlockDecl *>(BD)); 846 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 847 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD)); 848 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl"); 849 NumRegionCounters = Walker.NextCounter; 850 FunctionHash = Walker.Hash.finalize(); 851 } 852 853 void CodeGenPGO::computeRegionCounts(const Decl *D) { 854 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>); 855 ComputeRegionCounts Walker(*StmtCountMap, *this); 856 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 857 Walker.VisitFunctionDecl(FD); 858 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 859 Walker.VisitObjCMethodDecl(MD); 860 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 861 Walker.VisitBlockDecl(BD); 862 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 863 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD)); 864 } 865 866 void 867 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader, 868 llvm::Function *Fn) { 869 if (!haveRegionCounts()) 870 return; 871 872 uint64_t MaxFunctionCount = PGOReader->getMaximumFunctionCount(); 873 uint64_t FunctionCount = getRegionCount(0); 874 if (FunctionCount >= (uint64_t)(0.3 * (double)MaxFunctionCount)) 875 // Turn on InlineHint attribute for hot functions. 876 // FIXME: 30% is from preliminary tuning on SPEC, it may not be optimal. 877 Fn->addFnAttr(llvm::Attribute::InlineHint); 878 else if (FunctionCount <= (uint64_t)(0.01 * (double)MaxFunctionCount)) 879 // Turn on Cold attribute for cold functions. 880 // FIXME: 1% is from preliminary tuning on SPEC, it may not be optimal. 881 Fn->addFnAttr(llvm::Attribute::Cold); 882 } 883 884 void CodeGenPGO::emitCounterVariables() { 885 llvm::LLVMContext &Ctx = CGM.getLLVMContext(); 886 llvm::ArrayType *CounterTy = llvm::ArrayType::get(llvm::Type::getInt64Ty(Ctx), 887 NumRegionCounters); 888 RegionCounters = 889 new llvm::GlobalVariable(CGM.getModule(), CounterTy, false, VarLinkage, 890 llvm::Constant::getNullValue(CounterTy), 891 getFuncVarName("counters")); 892 RegionCounters->setAlignment(8); 893 RegionCounters->setSection(getCountersSection(CGM)); 894 } 895 896 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, unsigned Counter) { 897 if (!RegionCounters) 898 return; 899 llvm::Value *Addr = 900 Builder.CreateConstInBoundsGEP2_64(RegionCounters, 0, Counter); 901 llvm::Value *Count = Builder.CreateLoad(Addr, "pgocount"); 902 Count = Builder.CreateAdd(Count, Builder.getInt64(1)); 903 Builder.CreateStore(Count, Addr); 904 } 905 906 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader) { 907 CGM.getPGOStats().Visited++; 908 RegionCounts.reset(new std::vector<uint64_t>); 909 uint64_t Hash; 910 if (PGOReader->getFunctionCounts(getFuncName(), Hash, *RegionCounts)) { 911 CGM.getPGOStats().Missing++; 912 RegionCounts.reset(); 913 } else if (Hash != FunctionHash || 914 RegionCounts->size() != NumRegionCounters) { 915 CGM.getPGOStats().Mismatched++; 916 RegionCounts.reset(); 917 } 918 } 919 920 void CodeGenPGO::destroyRegionCounters() { 921 RegionCounterMap.reset(); 922 StmtCountMap.reset(); 923 RegionCounts.reset(); 924 RegionCounters = nullptr; 925 } 926 927 /// \brief Calculate what to divide by to scale weights. 928 /// 929 /// Given the maximum weight, calculate a divisor that will scale all the 930 /// weights to strictly less than UINT32_MAX. 931 static uint64_t calculateWeightScale(uint64_t MaxWeight) { 932 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1; 933 } 934 935 /// \brief Scale an individual branch weight (and add 1). 936 /// 937 /// Scale a 64-bit weight down to 32-bits using \c Scale. 938 /// 939 /// According to Laplace's Rule of Succession, it is better to compute the 940 /// weight based on the count plus 1, so universally add 1 to the value. 941 /// 942 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no 943 /// greater than \c Weight. 944 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) { 945 assert(Scale && "scale by 0?"); 946 uint64_t Scaled = Weight / Scale + 1; 947 assert(Scaled <= UINT32_MAX && "overflow 32-bits"); 948 return Scaled; 949 } 950 951 llvm::MDNode *CodeGenPGO::createBranchWeights(uint64_t TrueCount, 952 uint64_t FalseCount) { 953 // Check for empty weights. 954 if (!TrueCount && !FalseCount) 955 return nullptr; 956 957 // Calculate how to scale down to 32-bits. 958 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount)); 959 960 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 961 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale), 962 scaleBranchWeight(FalseCount, Scale)); 963 } 964 965 llvm::MDNode *CodeGenPGO::createBranchWeights(ArrayRef<uint64_t> Weights) { 966 // We need at least two elements to create meaningful weights. 967 if (Weights.size() < 2) 968 return nullptr; 969 970 // Check for empty weights. 971 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end()); 972 if (MaxWeight == 0) 973 return nullptr; 974 975 // Calculate how to scale down to 32-bits. 976 uint64_t Scale = calculateWeightScale(MaxWeight); 977 978 SmallVector<uint32_t, 16> ScaledWeights; 979 ScaledWeights.reserve(Weights.size()); 980 for (uint64_t W : Weights) 981 ScaledWeights.push_back(scaleBranchWeight(W, Scale)); 982 983 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 984 return MDHelper.createBranchWeights(ScaledWeights); 985 } 986 987 llvm::MDNode *CodeGenPGO::createLoopWeights(const Stmt *Cond, 988 RegionCounter &Cnt) { 989 if (!haveRegionCounts()) 990 return nullptr; 991 uint64_t LoopCount = Cnt.getCount(); 992 uint64_t CondCount = 0; 993 bool Found = getStmtCount(Cond, CondCount); 994 assert(Found && "missing expected loop condition count"); 995 (void)Found; 996 if (CondCount == 0) 997 return nullptr; 998 return createBranchWeights(LoopCount, 999 std::max(CondCount, LoopCount) - LoopCount); 1000 } 1001