1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // Instrumentation-based profile-guided optimization 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "CodeGenPGO.h" 15 #include "CodeGenFunction.h" 16 #include "clang/AST/RecursiveASTVisitor.h" 17 #include "clang/AST/StmtVisitor.h" 18 #include "llvm/Config/config.h" // for strtoull()/strtoll() define 19 #include "llvm/IR/MDBuilder.h" 20 #include "llvm/Support/FileSystem.h" 21 22 using namespace clang; 23 using namespace CodeGen; 24 25 static void ReportBadPGOData(CodeGenModule &CGM, const char *Message) { 26 DiagnosticsEngine &Diags = CGM.getDiags(); 27 unsigned diagID = Diags.getCustomDiagID(DiagnosticsEngine::Error, "%0"); 28 Diags.Report(diagID) << Message; 29 } 30 31 PGOProfileData::PGOProfileData(CodeGenModule &CGM, std::string Path) 32 : CGM(CGM) { 33 if (llvm::MemoryBuffer::getFile(Path, DataBuffer)) { 34 ReportBadPGOData(CGM, "failed to open pgo data file"); 35 return; 36 } 37 38 if (DataBuffer->getBufferSize() > std::numeric_limits<unsigned>::max()) { 39 ReportBadPGOData(CGM, "pgo data file too big"); 40 return; 41 } 42 43 // Scan through the data file and map each function to the corresponding 44 // file offset where its counts are stored. 45 const char *BufferStart = DataBuffer->getBufferStart(); 46 const char *BufferEnd = DataBuffer->getBufferEnd(); 47 const char *CurPtr = BufferStart; 48 uint64_t MaxCount = 0; 49 while (CurPtr < BufferEnd) { 50 // Read the function name. 51 const char *FuncStart = CurPtr; 52 // For Objective-C methods, the name may include whitespace, so search 53 // backward from the end of the line to find the space that separates the 54 // name from the number of counters. (This is a temporary hack since we are 55 // going to completely replace this file format in the near future.) 56 CurPtr = strchr(CurPtr, '\n'); 57 if (!CurPtr) { 58 ReportBadPGOData(CGM, "pgo data file has malformed function entry"); 59 return; 60 } 61 StringRef FuncName(FuncStart, CurPtr - FuncStart); 62 63 // Skip over the function hash. 64 CurPtr = strchr(++CurPtr, '\n'); 65 if (!CurPtr) { 66 ReportBadPGOData(CGM, "pgo data file is missing the function hash"); 67 return; 68 } 69 70 // Read the number of counters. 71 char *EndPtr; 72 unsigned NumCounters = strtol(++CurPtr, &EndPtr, 10); 73 if (EndPtr == CurPtr || *EndPtr != '\n' || NumCounters <= 0) { 74 ReportBadPGOData(CGM, "pgo data file has unexpected number of counters"); 75 return; 76 } 77 CurPtr = EndPtr; 78 79 // Read function count. 80 uint64_t Count = strtoll(CurPtr, &EndPtr, 10); 81 if (EndPtr == CurPtr || *EndPtr != '\n') { 82 ReportBadPGOData(CGM, "pgo-data file has bad count value"); 83 return; 84 } 85 CurPtr = EndPtr; // Point to '\n'. 86 FunctionCounts[FuncName] = Count; 87 MaxCount = Count > MaxCount ? Count : MaxCount; 88 89 // There is one line for each counter; skip over those lines. 90 // Since function count is already read, we start the loop from 1. 91 for (unsigned N = 1; N < NumCounters; ++N) { 92 CurPtr = strchr(++CurPtr, '\n'); 93 if (!CurPtr) { 94 ReportBadPGOData(CGM, "pgo data file is missing some counter info"); 95 return; 96 } 97 } 98 99 // Skip over the blank line separating functions. 100 CurPtr += 2; 101 102 DataOffsets[FuncName] = FuncStart - BufferStart; 103 } 104 MaxFunctionCount = MaxCount; 105 } 106 107 bool PGOProfileData::getFunctionCounts(StringRef FuncName, uint64_t &FuncHash, 108 std::vector<uint64_t> &Counts) { 109 // Find the relevant section of the pgo-data file. 110 llvm::StringMap<unsigned>::const_iterator OffsetIter = 111 DataOffsets.find(FuncName); 112 if (OffsetIter == DataOffsets.end()) 113 return true; 114 const char *CurPtr = DataBuffer->getBufferStart() + OffsetIter->getValue(); 115 116 // Skip over the function name. 117 CurPtr = strchr(CurPtr, '\n'); 118 assert(CurPtr && "pgo-data has corrupted function entry"); 119 120 char *EndPtr; 121 // Read the function hash. 122 FuncHash = strtoll(++CurPtr, &EndPtr, 10); 123 assert(EndPtr != CurPtr && *EndPtr == '\n' && 124 "pgo-data file has corrupted function hash"); 125 CurPtr = EndPtr; 126 127 // Read the number of counters. 128 unsigned NumCounters = strtol(++CurPtr, &EndPtr, 10); 129 assert(EndPtr != CurPtr && *EndPtr == '\n' && NumCounters > 0 && 130 "pgo-data file has corrupted number of counters"); 131 CurPtr = EndPtr; 132 133 Counts.reserve(NumCounters); 134 135 for (unsigned N = 0; N < NumCounters; ++N) { 136 // Read the count value. 137 uint64_t Count = strtoll(CurPtr, &EndPtr, 10); 138 if (EndPtr == CurPtr || *EndPtr != '\n') { 139 ReportBadPGOData(CGM, "pgo-data file has bad count value"); 140 return true; 141 } 142 Counts.push_back(Count); 143 CurPtr = EndPtr + 1; 144 } 145 146 // Make sure the number of counters matches up. 147 if (Counts.size() != NumCounters) { 148 ReportBadPGOData(CGM, "pgo-data file has inconsistent counters"); 149 return true; 150 } 151 152 return false; 153 } 154 155 void CodeGenPGO::setFuncName(llvm::Function *Fn) { 156 RawFuncName = Fn->getName(); 157 158 // Function names may be prefixed with a binary '1' to indicate 159 // that the backend should not modify the symbols due to any platform 160 // naming convention. Do not include that '1' in the PGO profile name. 161 if (RawFuncName[0] == '\1') 162 RawFuncName = RawFuncName.substr(1); 163 164 if (!Fn->hasLocalLinkage()) { 165 PrefixedFuncName = new std::string(RawFuncName); 166 return; 167 } 168 169 // For local symbols, prepend the main file name to distinguish them. 170 // Do not include the full path in the file name since there's no guarantee 171 // that it will stay the same, e.g., if the files are checked out from 172 // version control in different locations. 173 PrefixedFuncName = new std::string(CGM.getCodeGenOpts().MainFileName); 174 if (PrefixedFuncName->empty()) 175 PrefixedFuncName->assign("<unknown>"); 176 PrefixedFuncName->append(":"); 177 PrefixedFuncName->append(RawFuncName); 178 } 179 180 static llvm::Function *getRegisterFunc(CodeGenModule &CGM) { 181 return CGM.getModule().getFunction("__llvm_pgo_register_functions"); 182 } 183 184 static llvm::BasicBlock *getOrInsertRegisterBB(CodeGenModule &CGM) { 185 // Only need to insert this once per module. 186 if (llvm::Function *RegisterF = getRegisterFunc(CGM)) 187 return &RegisterF->getEntryBlock(); 188 189 // Construct the function. 190 auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext()); 191 auto *RegisterFTy = llvm::FunctionType::get(VoidTy, false); 192 auto *RegisterF = llvm::Function::Create(RegisterFTy, 193 llvm::GlobalValue::InternalLinkage, 194 "__llvm_pgo_register_functions", 195 &CGM.getModule()); 196 RegisterF->setUnnamedAddr(true); 197 RegisterF->addFnAttr(llvm::Attribute::NoInline); 198 if (CGM.getCodeGenOpts().DisableRedZone) 199 RegisterF->addFnAttr(llvm::Attribute::NoRedZone); 200 201 // Construct and return the entry block. 202 auto *BB = llvm::BasicBlock::Create(CGM.getLLVMContext(), "", RegisterF); 203 CGBuilderTy Builder(BB); 204 Builder.CreateRetVoid(); 205 return BB; 206 } 207 208 static llvm::Constant *getOrInsertRuntimeRegister(CodeGenModule &CGM) { 209 auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext()); 210 auto *VoidPtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext()); 211 auto *RuntimeRegisterTy = llvm::FunctionType::get(VoidTy, VoidPtrTy, false); 212 return CGM.getModule().getOrInsertFunction("__llvm_pgo_register_function", 213 RuntimeRegisterTy); 214 } 215 216 static llvm::Constant *getOrInsertRuntimeWriteAtExit(CodeGenModule &CGM) { 217 // TODO: make this depend on a command-line option. 218 auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext()); 219 auto *WriteAtExitTy = llvm::FunctionType::get(VoidTy, false); 220 return CGM.getModule().getOrInsertFunction("__llvm_pgo_register_write_atexit", 221 WriteAtExitTy); 222 } 223 224 static StringRef getCountersSection(const CodeGenModule &CGM) { 225 if (CGM.getTarget().getTriple().isOSBinFormatMachO()) 226 return "__DATA,__llvm_pgo_cnts"; 227 else 228 return "__llvm_pgo_cnts"; 229 } 230 231 static StringRef getNameSection(const CodeGenModule &CGM) { 232 if (CGM.getTarget().getTriple().isOSBinFormatMachO()) 233 return "__DATA,__llvm_pgo_names"; 234 else 235 return "__llvm_pgo_names"; 236 } 237 238 static StringRef getDataSection(const CodeGenModule &CGM) { 239 if (CGM.getTarget().getTriple().isOSBinFormatMachO()) 240 return "__DATA,__llvm_pgo_data"; 241 else 242 return "__llvm_pgo_data"; 243 } 244 245 llvm::GlobalVariable *CodeGenPGO::buildDataVar() { 246 // Create name variable. 247 llvm::LLVMContext &Ctx = CGM.getLLVMContext(); 248 auto *VarName = llvm::ConstantDataArray::getString(Ctx, getFuncName(), 249 false); 250 auto *Name = new llvm::GlobalVariable(CGM.getModule(), VarName->getType(), 251 true, FuncLinkage, VarName, 252 getFuncVarName("name")); 253 Name->setSection(getNameSection(CGM)); 254 Name->setAlignment(1); 255 256 // Create data variable. 257 auto *Int32Ty = llvm::Type::getInt32Ty(Ctx); 258 auto *Int64Ty = llvm::Type::getInt64Ty(Ctx); 259 auto *Int8PtrTy = llvm::Type::getInt8PtrTy(Ctx); 260 auto *Int64PtrTy = llvm::Type::getInt64PtrTy(Ctx); 261 llvm::Type *DataTypes[] = { 262 Int32Ty, Int32Ty, Int64Ty, Int8PtrTy, Int64PtrTy 263 }; 264 auto *DataTy = llvm::StructType::get(Ctx, makeArrayRef(DataTypes)); 265 llvm::Constant *DataVals[] = { 266 llvm::ConstantInt::get(Int32Ty, getFuncName().size()), 267 llvm::ConstantInt::get(Int32Ty, NumRegionCounters), 268 llvm::ConstantInt::get(Int64Ty, FunctionHash), 269 llvm::ConstantExpr::getBitCast(Name, Int8PtrTy), 270 llvm::ConstantExpr::getBitCast(RegionCounters, Int64PtrTy) 271 }; 272 auto *Data = 273 new llvm::GlobalVariable(CGM.getModule(), DataTy, true, FuncLinkage, 274 llvm::ConstantStruct::get(DataTy, DataVals), 275 getFuncVarName("data")); 276 277 // All the data should be packed into an array in its own section. 278 Data->setSection(getDataSection(CGM)); 279 Data->setAlignment(8); 280 281 // Make sure the data doesn't get deleted. 282 CGM.addUsedGlobal(Data); 283 return Data; 284 } 285 286 void CodeGenPGO::emitInstrumentationData() { 287 if (!CGM.getCodeGenOpts().ProfileInstrGenerate) 288 return; 289 290 // Build the data. 291 auto *Data = buildDataVar(); 292 293 // Register the data. 294 // 295 // TODO: only register when static initialization is required. 296 CGBuilderTy Builder(getOrInsertRegisterBB(CGM)->getTerminator()); 297 auto *VoidPtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext()); 298 Builder.CreateCall(getOrInsertRuntimeRegister(CGM), 299 Builder.CreateBitCast(Data, VoidPtrTy)); 300 } 301 302 llvm::Function *CodeGenPGO::emitInitialization(CodeGenModule &CGM) { 303 if (!CGM.getCodeGenOpts().ProfileInstrGenerate) 304 return 0; 305 306 // Only need to create this once per module. 307 if (CGM.getModule().getFunction("__llvm_pgo_init")) 308 return 0; 309 310 // Get the functions to call at initialization. 311 llvm::Constant *RegisterF = getRegisterFunc(CGM); 312 llvm::Constant *WriteAtExitF = getOrInsertRuntimeWriteAtExit(CGM); 313 if (!RegisterF && !WriteAtExitF) 314 return 0; 315 316 // Create the initialization function. 317 auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext()); 318 auto *F = llvm::Function::Create(llvm::FunctionType::get(VoidTy, false), 319 llvm::GlobalValue::InternalLinkage, 320 "__llvm_pgo_init", &CGM.getModule()); 321 F->setUnnamedAddr(true); 322 F->addFnAttr(llvm::Attribute::NoInline); 323 if (CGM.getCodeGenOpts().DisableRedZone) 324 F->addFnAttr(llvm::Attribute::NoRedZone); 325 326 // Add the basic block and the necessary calls. 327 CGBuilderTy Builder(llvm::BasicBlock::Create(CGM.getLLVMContext(), "", F)); 328 if (RegisterF) 329 Builder.CreateCall(RegisterF); 330 if (WriteAtExitF) 331 Builder.CreateCall(WriteAtExitF); 332 Builder.CreateRetVoid(); 333 334 return F; 335 } 336 337 namespace { 338 /// A StmtVisitor that fills a map of statements to PGO counters. 339 struct MapRegionCounters : public ConstStmtVisitor<MapRegionCounters> { 340 /// The next counter value to assign. 341 unsigned NextCounter; 342 /// The map of statements to counters. 343 llvm::DenseMap<const Stmt*, unsigned> *CounterMap; 344 345 MapRegionCounters(llvm::DenseMap<const Stmt*, unsigned> *CounterMap) : 346 NextCounter(0), CounterMap(CounterMap) { 347 } 348 349 void VisitChildren(const Stmt *S) { 350 for (Stmt::const_child_range I = S->children(); I; ++I) 351 if (*I) 352 this->Visit(*I); 353 } 354 void VisitStmt(const Stmt *S) { VisitChildren(S); } 355 356 /// Assign a counter to track entry to the function body. 357 void VisitFunctionDecl(const FunctionDecl *S) { 358 (*CounterMap)[S->getBody()] = NextCounter++; 359 Visit(S->getBody()); 360 } 361 void VisitObjCMethodDecl(const ObjCMethodDecl *S) { 362 (*CounterMap)[S->getBody()] = NextCounter++; 363 Visit(S->getBody()); 364 } 365 void VisitBlockDecl(const BlockDecl *S) { 366 (*CounterMap)[S->getBody()] = NextCounter++; 367 Visit(S->getBody()); 368 } 369 /// Assign a counter to track the block following a label. 370 void VisitLabelStmt(const LabelStmt *S) { 371 (*CounterMap)[S] = NextCounter++; 372 Visit(S->getSubStmt()); 373 } 374 /// Assign a counter for the body of a while loop. 375 void VisitWhileStmt(const WhileStmt *S) { 376 (*CounterMap)[S] = NextCounter++; 377 Visit(S->getCond()); 378 Visit(S->getBody()); 379 } 380 /// Assign a counter for the body of a do-while loop. 381 void VisitDoStmt(const DoStmt *S) { 382 (*CounterMap)[S] = NextCounter++; 383 Visit(S->getBody()); 384 Visit(S->getCond()); 385 } 386 /// Assign a counter for the body of a for loop. 387 void VisitForStmt(const ForStmt *S) { 388 (*CounterMap)[S] = NextCounter++; 389 if (S->getInit()) 390 Visit(S->getInit()); 391 const Expr *E; 392 if ((E = S->getCond())) 393 Visit(E); 394 if ((E = S->getInc())) 395 Visit(E); 396 Visit(S->getBody()); 397 } 398 /// Assign a counter for the body of a for-range loop. 399 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) { 400 (*CounterMap)[S] = NextCounter++; 401 Visit(S->getRangeStmt()); 402 Visit(S->getBeginEndStmt()); 403 Visit(S->getCond()); 404 Visit(S->getLoopVarStmt()); 405 Visit(S->getBody()); 406 Visit(S->getInc()); 407 } 408 /// Assign a counter for the body of a for-collection loop. 409 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) { 410 (*CounterMap)[S] = NextCounter++; 411 Visit(S->getElement()); 412 Visit(S->getBody()); 413 } 414 /// Assign a counter for the exit block of the switch statement. 415 void VisitSwitchStmt(const SwitchStmt *S) { 416 (*CounterMap)[S] = NextCounter++; 417 Visit(S->getCond()); 418 Visit(S->getBody()); 419 } 420 /// Assign a counter for a particular case in a switch. This counts jumps 421 /// from the switch header as well as fallthrough from the case before this 422 /// one. 423 void VisitCaseStmt(const CaseStmt *S) { 424 (*CounterMap)[S] = NextCounter++; 425 Visit(S->getSubStmt()); 426 } 427 /// Assign a counter for the default case of a switch statement. The count 428 /// is the number of branches from the loop header to the default, and does 429 /// not include fallthrough from previous cases. If we have multiple 430 /// conditional branch blocks from the switch instruction to the default 431 /// block, as with large GNU case ranges, this is the counter for the last 432 /// edge in that series, rather than the first. 433 void VisitDefaultStmt(const DefaultStmt *S) { 434 (*CounterMap)[S] = NextCounter++; 435 Visit(S->getSubStmt()); 436 } 437 /// Assign a counter for the "then" part of an if statement. The count for 438 /// the "else" part, if it exists, will be calculated from this counter. 439 void VisitIfStmt(const IfStmt *S) { 440 (*CounterMap)[S] = NextCounter++; 441 Visit(S->getCond()); 442 Visit(S->getThen()); 443 if (S->getElse()) 444 Visit(S->getElse()); 445 } 446 /// Assign a counter for the continuation block of a C++ try statement. 447 void VisitCXXTryStmt(const CXXTryStmt *S) { 448 (*CounterMap)[S] = NextCounter++; 449 Visit(S->getTryBlock()); 450 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I) 451 Visit(S->getHandler(I)); 452 } 453 /// Assign a counter for a catch statement's handler block. 454 void VisitCXXCatchStmt(const CXXCatchStmt *S) { 455 (*CounterMap)[S] = NextCounter++; 456 Visit(S->getHandlerBlock()); 457 } 458 /// Assign a counter for the "true" part of a conditional operator. The 459 /// count in the "false" part will be calculated from this counter. 460 void VisitConditionalOperator(const ConditionalOperator *E) { 461 (*CounterMap)[E] = NextCounter++; 462 Visit(E->getCond()); 463 Visit(E->getTrueExpr()); 464 Visit(E->getFalseExpr()); 465 } 466 /// Assign a counter for the right hand side of a logical and operator. 467 void VisitBinLAnd(const BinaryOperator *E) { 468 (*CounterMap)[E] = NextCounter++; 469 Visit(E->getLHS()); 470 Visit(E->getRHS()); 471 } 472 /// Assign a counter for the right hand side of a logical or operator. 473 void VisitBinLOr(const BinaryOperator *E) { 474 (*CounterMap)[E] = NextCounter++; 475 Visit(E->getLHS()); 476 Visit(E->getRHS()); 477 } 478 }; 479 480 /// A StmtVisitor that propagates the raw counts through the AST and 481 /// records the count at statements where the value may change. 482 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { 483 /// PGO state. 484 CodeGenPGO &PGO; 485 486 /// A flag that is set when the current count should be recorded on the 487 /// next statement, such as at the exit of a loop. 488 bool RecordNextStmtCount; 489 490 /// The map of statements to count values. 491 llvm::DenseMap<const Stmt*, uint64_t> *CountMap; 492 493 /// BreakContinueStack - Keep counts of breaks and continues inside loops. 494 struct BreakContinue { 495 uint64_t BreakCount; 496 uint64_t ContinueCount; 497 BreakContinue() : BreakCount(0), ContinueCount(0) {} 498 }; 499 SmallVector<BreakContinue, 8> BreakContinueStack; 500 501 ComputeRegionCounts(llvm::DenseMap<const Stmt*, uint64_t> *CountMap, 502 CodeGenPGO &PGO) : 503 PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) { 504 } 505 506 void RecordStmtCount(const Stmt *S) { 507 if (RecordNextStmtCount) { 508 (*CountMap)[S] = PGO.getCurrentRegionCount(); 509 RecordNextStmtCount = false; 510 } 511 } 512 513 void VisitStmt(const Stmt *S) { 514 RecordStmtCount(S); 515 for (Stmt::const_child_range I = S->children(); I; ++I) { 516 if (*I) 517 this->Visit(*I); 518 } 519 } 520 521 void VisitFunctionDecl(const FunctionDecl *S) { 522 RegionCounter Cnt(PGO, S->getBody()); 523 Cnt.beginRegion(); 524 (*CountMap)[S->getBody()] = PGO.getCurrentRegionCount(); 525 Visit(S->getBody()); 526 } 527 528 void VisitObjCMethodDecl(const ObjCMethodDecl *S) { 529 RegionCounter Cnt(PGO, S->getBody()); 530 Cnt.beginRegion(); 531 (*CountMap)[S->getBody()] = PGO.getCurrentRegionCount(); 532 Visit(S->getBody()); 533 } 534 535 void VisitBlockDecl(const BlockDecl *S) { 536 RegionCounter Cnt(PGO, S->getBody()); 537 Cnt.beginRegion(); 538 (*CountMap)[S->getBody()] = PGO.getCurrentRegionCount(); 539 Visit(S->getBody()); 540 } 541 542 void VisitReturnStmt(const ReturnStmt *S) { 543 RecordStmtCount(S); 544 if (S->getRetValue()) 545 Visit(S->getRetValue()); 546 PGO.setCurrentRegionUnreachable(); 547 RecordNextStmtCount = true; 548 } 549 550 void VisitGotoStmt(const GotoStmt *S) { 551 RecordStmtCount(S); 552 PGO.setCurrentRegionUnreachable(); 553 RecordNextStmtCount = true; 554 } 555 556 void VisitLabelStmt(const LabelStmt *S) { 557 RecordNextStmtCount = false; 558 RegionCounter Cnt(PGO, S); 559 Cnt.beginRegion(); 560 (*CountMap)[S] = PGO.getCurrentRegionCount(); 561 Visit(S->getSubStmt()); 562 } 563 564 void VisitBreakStmt(const BreakStmt *S) { 565 RecordStmtCount(S); 566 assert(!BreakContinueStack.empty() && "break not in a loop or switch!"); 567 BreakContinueStack.back().BreakCount += PGO.getCurrentRegionCount(); 568 PGO.setCurrentRegionUnreachable(); 569 RecordNextStmtCount = true; 570 } 571 572 void VisitContinueStmt(const ContinueStmt *S) { 573 RecordStmtCount(S); 574 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!"); 575 BreakContinueStack.back().ContinueCount += PGO.getCurrentRegionCount(); 576 PGO.setCurrentRegionUnreachable(); 577 RecordNextStmtCount = true; 578 } 579 580 void VisitWhileStmt(const WhileStmt *S) { 581 RecordStmtCount(S); 582 RegionCounter Cnt(PGO, S); 583 BreakContinueStack.push_back(BreakContinue()); 584 // Visit the body region first so the break/continue adjustments can be 585 // included when visiting the condition. 586 Cnt.beginRegion(); 587 (*CountMap)[S->getBody()] = PGO.getCurrentRegionCount(); 588 Visit(S->getBody()); 589 Cnt.adjustForControlFlow(); 590 591 // ...then go back and propagate counts through the condition. The count 592 // at the start of the condition is the sum of the incoming edges, 593 // the backedge from the end of the loop body, and the edges from 594 // continue statements. 595 BreakContinue BC = BreakContinueStack.pop_back_val(); 596 Cnt.setCurrentRegionCount(Cnt.getParentCount() + 597 Cnt.getAdjustedCount() + BC.ContinueCount); 598 (*CountMap)[S->getCond()] = PGO.getCurrentRegionCount(); 599 Visit(S->getCond()); 600 Cnt.adjustForControlFlow(); 601 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); 602 RecordNextStmtCount = true; 603 } 604 605 void VisitDoStmt(const DoStmt *S) { 606 RecordStmtCount(S); 607 RegionCounter Cnt(PGO, S); 608 BreakContinueStack.push_back(BreakContinue()); 609 Cnt.beginRegion(/*AddIncomingFallThrough=*/true); 610 (*CountMap)[S->getBody()] = PGO.getCurrentRegionCount(); 611 Visit(S->getBody()); 612 Cnt.adjustForControlFlow(); 613 614 BreakContinue BC = BreakContinueStack.pop_back_val(); 615 // The count at the start of the condition is equal to the count at the 616 // end of the body. The adjusted count does not include either the 617 // fall-through count coming into the loop or the continue count, so add 618 // both of those separately. This is coincidentally the same equation as 619 // with while loops but for different reasons. 620 Cnt.setCurrentRegionCount(Cnt.getParentCount() + 621 Cnt.getAdjustedCount() + BC.ContinueCount); 622 (*CountMap)[S->getCond()] = PGO.getCurrentRegionCount(); 623 Visit(S->getCond()); 624 Cnt.adjustForControlFlow(); 625 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); 626 RecordNextStmtCount = true; 627 } 628 629 void VisitForStmt(const ForStmt *S) { 630 RecordStmtCount(S); 631 if (S->getInit()) 632 Visit(S->getInit()); 633 RegionCounter Cnt(PGO, S); 634 BreakContinueStack.push_back(BreakContinue()); 635 // Visit the body region first. (This is basically the same as a while 636 // loop; see further comments in VisitWhileStmt.) 637 Cnt.beginRegion(); 638 (*CountMap)[S->getBody()] = PGO.getCurrentRegionCount(); 639 Visit(S->getBody()); 640 Cnt.adjustForControlFlow(); 641 642 // The increment is essentially part of the body but it needs to include 643 // the count for all the continue statements. 644 if (S->getInc()) { 645 Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() + 646 BreakContinueStack.back().ContinueCount); 647 (*CountMap)[S->getInc()] = PGO.getCurrentRegionCount(); 648 Visit(S->getInc()); 649 Cnt.adjustForControlFlow(); 650 } 651 652 BreakContinue BC = BreakContinueStack.pop_back_val(); 653 654 // ...then go back and propagate counts through the condition. 655 if (S->getCond()) { 656 Cnt.setCurrentRegionCount(Cnt.getParentCount() + 657 Cnt.getAdjustedCount() + 658 BC.ContinueCount); 659 (*CountMap)[S->getCond()] = PGO.getCurrentRegionCount(); 660 Visit(S->getCond()); 661 Cnt.adjustForControlFlow(); 662 } 663 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); 664 RecordNextStmtCount = true; 665 } 666 667 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) { 668 RecordStmtCount(S); 669 Visit(S->getRangeStmt()); 670 Visit(S->getBeginEndStmt()); 671 RegionCounter Cnt(PGO, S); 672 BreakContinueStack.push_back(BreakContinue()); 673 // Visit the body region first. (This is basically the same as a while 674 // loop; see further comments in VisitWhileStmt.) 675 Cnt.beginRegion(); 676 (*CountMap)[S->getLoopVarStmt()] = PGO.getCurrentRegionCount(); 677 Visit(S->getLoopVarStmt()); 678 Visit(S->getBody()); 679 Cnt.adjustForControlFlow(); 680 681 // The increment is essentially part of the body but it needs to include 682 // the count for all the continue statements. 683 Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() + 684 BreakContinueStack.back().ContinueCount); 685 (*CountMap)[S->getInc()] = PGO.getCurrentRegionCount(); 686 Visit(S->getInc()); 687 Cnt.adjustForControlFlow(); 688 689 BreakContinue BC = BreakContinueStack.pop_back_val(); 690 691 // ...then go back and propagate counts through the condition. 692 Cnt.setCurrentRegionCount(Cnt.getParentCount() + 693 Cnt.getAdjustedCount() + 694 BC.ContinueCount); 695 (*CountMap)[S->getCond()] = PGO.getCurrentRegionCount(); 696 Visit(S->getCond()); 697 Cnt.adjustForControlFlow(); 698 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); 699 RecordNextStmtCount = true; 700 } 701 702 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) { 703 RecordStmtCount(S); 704 Visit(S->getElement()); 705 RegionCounter Cnt(PGO, S); 706 BreakContinueStack.push_back(BreakContinue()); 707 Cnt.beginRegion(); 708 (*CountMap)[S->getBody()] = PGO.getCurrentRegionCount(); 709 Visit(S->getBody()); 710 BreakContinue BC = BreakContinueStack.pop_back_val(); 711 Cnt.adjustForControlFlow(); 712 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); 713 RecordNextStmtCount = true; 714 } 715 716 void VisitSwitchStmt(const SwitchStmt *S) { 717 RecordStmtCount(S); 718 Visit(S->getCond()); 719 PGO.setCurrentRegionUnreachable(); 720 BreakContinueStack.push_back(BreakContinue()); 721 Visit(S->getBody()); 722 // If the switch is inside a loop, add the continue counts. 723 BreakContinue BC = BreakContinueStack.pop_back_val(); 724 if (!BreakContinueStack.empty()) 725 BreakContinueStack.back().ContinueCount += BC.ContinueCount; 726 RegionCounter ExitCnt(PGO, S); 727 ExitCnt.beginRegion(); 728 RecordNextStmtCount = true; 729 } 730 731 void VisitCaseStmt(const CaseStmt *S) { 732 RecordNextStmtCount = false; 733 RegionCounter Cnt(PGO, S); 734 Cnt.beginRegion(/*AddIncomingFallThrough=*/true); 735 (*CountMap)[S] = Cnt.getCount(); 736 RecordNextStmtCount = true; 737 Visit(S->getSubStmt()); 738 } 739 740 void VisitDefaultStmt(const DefaultStmt *S) { 741 RecordNextStmtCount = false; 742 RegionCounter Cnt(PGO, S); 743 Cnt.beginRegion(/*AddIncomingFallThrough=*/true); 744 (*CountMap)[S] = Cnt.getCount(); 745 RecordNextStmtCount = true; 746 Visit(S->getSubStmt()); 747 } 748 749 void VisitIfStmt(const IfStmt *S) { 750 RecordStmtCount(S); 751 RegionCounter Cnt(PGO, S); 752 Visit(S->getCond()); 753 754 Cnt.beginRegion(); 755 (*CountMap)[S->getThen()] = PGO.getCurrentRegionCount(); 756 Visit(S->getThen()); 757 Cnt.adjustForControlFlow(); 758 759 if (S->getElse()) { 760 Cnt.beginElseRegion(); 761 (*CountMap)[S->getElse()] = PGO.getCurrentRegionCount(); 762 Visit(S->getElse()); 763 Cnt.adjustForControlFlow(); 764 } 765 Cnt.applyAdjustmentsToRegion(0); 766 RecordNextStmtCount = true; 767 } 768 769 void VisitCXXTryStmt(const CXXTryStmt *S) { 770 RecordStmtCount(S); 771 Visit(S->getTryBlock()); 772 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I) 773 Visit(S->getHandler(I)); 774 RegionCounter Cnt(PGO, S); 775 Cnt.beginRegion(); 776 RecordNextStmtCount = true; 777 } 778 779 void VisitCXXCatchStmt(const CXXCatchStmt *S) { 780 RecordNextStmtCount = false; 781 RegionCounter Cnt(PGO, S); 782 Cnt.beginRegion(); 783 (*CountMap)[S] = PGO.getCurrentRegionCount(); 784 Visit(S->getHandlerBlock()); 785 } 786 787 void VisitConditionalOperator(const ConditionalOperator *E) { 788 RecordStmtCount(E); 789 RegionCounter Cnt(PGO, E); 790 Visit(E->getCond()); 791 792 Cnt.beginRegion(); 793 (*CountMap)[E->getTrueExpr()] = PGO.getCurrentRegionCount(); 794 Visit(E->getTrueExpr()); 795 Cnt.adjustForControlFlow(); 796 797 Cnt.beginElseRegion(); 798 (*CountMap)[E->getFalseExpr()] = PGO.getCurrentRegionCount(); 799 Visit(E->getFalseExpr()); 800 Cnt.adjustForControlFlow(); 801 802 Cnt.applyAdjustmentsToRegion(0); 803 RecordNextStmtCount = true; 804 } 805 806 void VisitBinLAnd(const BinaryOperator *E) { 807 RecordStmtCount(E); 808 RegionCounter Cnt(PGO, E); 809 Visit(E->getLHS()); 810 Cnt.beginRegion(); 811 (*CountMap)[E->getRHS()] = PGO.getCurrentRegionCount(); 812 Visit(E->getRHS()); 813 Cnt.adjustForControlFlow(); 814 Cnt.applyAdjustmentsToRegion(0); 815 RecordNextStmtCount = true; 816 } 817 818 void VisitBinLOr(const BinaryOperator *E) { 819 RecordStmtCount(E); 820 RegionCounter Cnt(PGO, E); 821 Visit(E->getLHS()); 822 Cnt.beginRegion(); 823 (*CountMap)[E->getRHS()] = PGO.getCurrentRegionCount(); 824 Visit(E->getRHS()); 825 Cnt.adjustForControlFlow(); 826 Cnt.applyAdjustmentsToRegion(0); 827 RecordNextStmtCount = true; 828 } 829 }; 830 } 831 832 void CodeGenPGO::assignRegionCounters(const Decl *D, llvm::Function *Fn) { 833 bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate; 834 PGOProfileData *PGOData = CGM.getPGOData(); 835 if (!InstrumentRegions && !PGOData) 836 return; 837 if (!D) 838 return; 839 setFuncName(Fn); 840 FuncLinkage = Fn->getLinkage(); 841 mapRegionCounters(D); 842 if (InstrumentRegions) 843 emitCounterVariables(); 844 if (PGOData) { 845 loadRegionCounts(PGOData); 846 computeRegionCounts(D); 847 applyFunctionAttributes(PGOData, Fn); 848 } 849 } 850 851 void CodeGenPGO::mapRegionCounters(const Decl *D) { 852 RegionCounterMap = new llvm::DenseMap<const Stmt*, unsigned>(); 853 MapRegionCounters Walker(RegionCounterMap); 854 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 855 Walker.VisitFunctionDecl(FD); 856 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 857 Walker.VisitObjCMethodDecl(MD); 858 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 859 Walker.VisitBlockDecl(BD); 860 NumRegionCounters = Walker.NextCounter; 861 // FIXME: The number of counters isn't sufficient for the hash 862 FunctionHash = NumRegionCounters; 863 } 864 865 void CodeGenPGO::computeRegionCounts(const Decl *D) { 866 StmtCountMap = new llvm::DenseMap<const Stmt*, uint64_t>(); 867 ComputeRegionCounts Walker(StmtCountMap, *this); 868 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 869 Walker.VisitFunctionDecl(FD); 870 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 871 Walker.VisitObjCMethodDecl(MD); 872 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 873 Walker.VisitBlockDecl(BD); 874 } 875 876 void CodeGenPGO::applyFunctionAttributes(PGOProfileData *PGOData, 877 llvm::Function *Fn) { 878 if (!haveRegionCounts()) 879 return; 880 881 uint64_t MaxFunctionCount = PGOData->getMaximumFunctionCount(); 882 uint64_t FunctionCount = getRegionCount(0); 883 if (FunctionCount >= (uint64_t)(0.3 * (double)MaxFunctionCount)) 884 // Turn on InlineHint attribute for hot functions. 885 // FIXME: 30% is from preliminary tuning on SPEC, it may not be optimal. 886 Fn->addFnAttr(llvm::Attribute::InlineHint); 887 else if (FunctionCount <= (uint64_t)(0.01 * (double)MaxFunctionCount)) 888 // Turn on Cold attribute for cold functions. 889 // FIXME: 1% is from preliminary tuning on SPEC, it may not be optimal. 890 Fn->addFnAttr(llvm::Attribute::Cold); 891 } 892 893 void CodeGenPGO::emitCounterVariables() { 894 llvm::LLVMContext &Ctx = CGM.getLLVMContext(); 895 llvm::ArrayType *CounterTy = llvm::ArrayType::get(llvm::Type::getInt64Ty(Ctx), 896 NumRegionCounters); 897 RegionCounters = 898 new llvm::GlobalVariable(CGM.getModule(), CounterTy, false, FuncLinkage, 899 llvm::Constant::getNullValue(CounterTy), 900 getFuncVarName("counters")); 901 RegionCounters->setAlignment(8); 902 RegionCounters->setSection(getCountersSection(CGM)); 903 } 904 905 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, unsigned Counter) { 906 if (!RegionCounters) 907 return; 908 llvm::Value *Addr = 909 Builder.CreateConstInBoundsGEP2_64(RegionCounters, 0, Counter); 910 llvm::Value *Count = Builder.CreateLoad(Addr, "pgocount"); 911 Count = Builder.CreateAdd(Count, Builder.getInt64(1)); 912 Builder.CreateStore(Count, Addr); 913 } 914 915 void CodeGenPGO::loadRegionCounts(PGOProfileData *PGOData) { 916 // For now, ignore the counts from the PGO data file only if the number of 917 // counters does not match. This could be tightened down in the future to 918 // ignore counts when the input changes in various ways, e.g., by comparing a 919 // hash value based on some characteristics of the input. 920 RegionCounts = new std::vector<uint64_t>(); 921 uint64_t Hash; 922 if (PGOData->getFunctionCounts(getFuncName(), Hash, *RegionCounts) || 923 Hash != FunctionHash || RegionCounts->size() != NumRegionCounters) { 924 delete RegionCounts; 925 RegionCounts = 0; 926 } 927 } 928 929 void CodeGenPGO::destroyRegionCounters() { 930 if (RegionCounterMap != 0) 931 delete RegionCounterMap; 932 if (StmtCountMap != 0) 933 delete StmtCountMap; 934 if (RegionCounts != 0) 935 delete RegionCounts; 936 } 937 938 /// \brief Calculate what to divide by to scale weights. 939 /// 940 /// Given the maximum weight, calculate a divisor that will scale all the 941 /// weights to strictly less than UINT32_MAX. 942 static uint64_t calculateWeightScale(uint64_t MaxWeight) { 943 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1; 944 } 945 946 /// \brief Scale an individual branch weight (and add 1). 947 /// 948 /// Scale a 64-bit weight down to 32-bits using \c Scale. 949 /// 950 /// According to Laplace's Rule of Succession, it is better to compute the 951 /// weight based on the count plus 1, so universally add 1 to the value. 952 /// 953 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no 954 /// greater than \c Weight. 955 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) { 956 assert(Scale && "scale by 0?"); 957 uint64_t Scaled = Weight / Scale + 1; 958 assert(Scaled <= UINT32_MAX && "overflow 32-bits"); 959 return Scaled; 960 } 961 962 llvm::MDNode *CodeGenPGO::createBranchWeights(uint64_t TrueCount, 963 uint64_t FalseCount) { 964 // Check for empty weights. 965 if (!TrueCount && !FalseCount) 966 return 0; 967 968 // Calculate how to scale down to 32-bits. 969 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount)); 970 971 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 972 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale), 973 scaleBranchWeight(FalseCount, Scale)); 974 } 975 976 llvm::MDNode *CodeGenPGO::createBranchWeights(ArrayRef<uint64_t> Weights) { 977 // We need at least two elements to create meaningful weights. 978 if (Weights.size() < 2) 979 return 0; 980 981 // Calculate how to scale down to 32-bits. 982 uint64_t Scale = calculateWeightScale(*std::max_element(Weights.begin(), 983 Weights.end())); 984 985 SmallVector<uint32_t, 16> ScaledWeights; 986 ScaledWeights.reserve(Weights.size()); 987 for (uint64_t W : Weights) 988 ScaledWeights.push_back(scaleBranchWeight(W, Scale)); 989 990 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 991 return MDHelper.createBranchWeights(ScaledWeights); 992 } 993 994 llvm::MDNode *CodeGenPGO::createLoopWeights(const Stmt *Cond, 995 RegionCounter &Cnt) { 996 if (!haveRegionCounts()) 997 return 0; 998 uint64_t LoopCount = Cnt.getCount(); 999 uint64_t CondCount = 0; 1000 bool Found = getStmtCount(Cond, CondCount); 1001 assert(Found && "missing expected loop condition count"); 1002 (void)Found; 1003 if (CondCount == 0) 1004 return 0; 1005 return createBranchWeights(LoopCount, 1006 std::max(CondCount, LoopCount) - LoopCount); 1007 } 1008