1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // Instrumentation-based profile-guided optimization 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "CodeGenPGO.h" 15 #include "CodeGenFunction.h" 16 #include "clang/AST/RecursiveASTVisitor.h" 17 #include "clang/AST/StmtVisitor.h" 18 #include "llvm/Config/config.h" // for strtoull()/strtoll() define 19 #include "llvm/IR/MDBuilder.h" 20 #include "llvm/Support/FileSystem.h" 21 22 using namespace clang; 23 using namespace CodeGen; 24 25 static void ReportBadPGOData(CodeGenModule &CGM, const char *Message) { 26 DiagnosticsEngine &Diags = CGM.getDiags(); 27 unsigned diagID = Diags.getCustomDiagID(DiagnosticsEngine::Error, "%0"); 28 Diags.Report(diagID) << Message; 29 } 30 31 PGOProfileData::PGOProfileData(CodeGenModule &CGM, std::string Path) 32 : CGM(CGM) { 33 if (llvm::MemoryBuffer::getFile(Path, DataBuffer)) { 34 ReportBadPGOData(CGM, "failed to open pgo data file"); 35 return; 36 } 37 38 if (DataBuffer->getBufferSize() > std::numeric_limits<unsigned>::max()) { 39 ReportBadPGOData(CGM, "pgo data file too big"); 40 return; 41 } 42 43 // Scan through the data file and map each function to the corresponding 44 // file offset where its counts are stored. 45 const char *BufferStart = DataBuffer->getBufferStart(); 46 const char *BufferEnd = DataBuffer->getBufferEnd(); 47 const char *CurPtr = BufferStart; 48 uint64_t MaxCount = 0; 49 while (CurPtr < BufferEnd) { 50 // Read the mangled function name. 51 const char *FuncName = CurPtr; 52 // FIXME: Something will need to be added to distinguish static functions. 53 CurPtr = strchr(CurPtr, ' '); 54 if (!CurPtr) { 55 ReportBadPGOData(CGM, "pgo data file has malformed function entry"); 56 return; 57 } 58 StringRef MangledName(FuncName, CurPtr - FuncName); 59 60 // Read the number of counters. 61 char *EndPtr; 62 unsigned NumCounters = strtol(++CurPtr, &EndPtr, 10); 63 if (EndPtr == CurPtr || *EndPtr != '\n' || NumCounters <= 0) { 64 ReportBadPGOData(CGM, "pgo data file has unexpected number of counters"); 65 return; 66 } 67 CurPtr = EndPtr; 68 69 // Read function count. 70 uint64_t Count = strtoll(CurPtr, &EndPtr, 10); 71 if (EndPtr == CurPtr || *EndPtr != '\n') { 72 ReportBadPGOData(CGM, "pgo-data file has bad count value"); 73 return; 74 } 75 CurPtr = EndPtr; // Point to '\n'. 76 FunctionCounts[MangledName] = Count; 77 MaxCount = Count > MaxCount ? Count : MaxCount; 78 79 // There is one line for each counter; skip over those lines. 80 // Since function count is already read, we start the loop from 1. 81 for (unsigned N = 1; N < NumCounters; ++N) { 82 CurPtr = strchr(++CurPtr, '\n'); 83 if (!CurPtr) { 84 ReportBadPGOData(CGM, "pgo data file is missing some counter info"); 85 return; 86 } 87 } 88 89 // Skip over the blank line separating functions. 90 CurPtr += 2; 91 92 DataOffsets[MangledName] = FuncName - BufferStart; 93 } 94 MaxFunctionCount = MaxCount; 95 } 96 97 /// Return true if a function is hot. If we know nothing about the function, 98 /// return false. 99 bool PGOProfileData::isHotFunction(StringRef MangledName) { 100 llvm::StringMap<uint64_t>::const_iterator CountIter = 101 FunctionCounts.find(MangledName); 102 // If we know nothing about the function, return false. 103 if (CountIter == FunctionCounts.end()) 104 return false; 105 // FIXME: functions with >= 30% of the maximal function count are 106 // treated as hot. This number is from preliminary tuning on SPEC. 107 return CountIter->getValue() >= (uint64_t)(0.3 * (double)MaxFunctionCount); 108 } 109 110 /// Return true if a function is cold. If we know nothing about the function, 111 /// return false. 112 bool PGOProfileData::isColdFunction(StringRef MangledName) { 113 llvm::StringMap<uint64_t>::const_iterator CountIter = 114 FunctionCounts.find(MangledName); 115 // If we know nothing about the function, return false. 116 if (CountIter == FunctionCounts.end()) 117 return false; 118 // FIXME: functions with <= 1% of the maximal function count are treated as 119 // cold. This number is from preliminary tuning on SPEC. 120 return CountIter->getValue() <= (uint64_t)(0.01 * (double)MaxFunctionCount); 121 } 122 123 bool PGOProfileData::getFunctionCounts(StringRef MangledName, 124 std::vector<uint64_t> &Counts) { 125 // Find the relevant section of the pgo-data file. 126 llvm::StringMap<unsigned>::const_iterator OffsetIter = 127 DataOffsets.find(MangledName); 128 if (OffsetIter == DataOffsets.end()) 129 return true; 130 const char *CurPtr = DataBuffer->getBufferStart() + OffsetIter->getValue(); 131 132 // Skip over the function name. 133 CurPtr = strchr(CurPtr, ' '); 134 assert(CurPtr && "pgo-data has corrupted function entry"); 135 136 // Read the number of counters. 137 char *EndPtr; 138 unsigned NumCounters = strtol(++CurPtr, &EndPtr, 10); 139 assert(EndPtr != CurPtr && *EndPtr == '\n' && NumCounters > 0 && 140 "pgo-data file has corrupted number of counters"); 141 CurPtr = EndPtr; 142 143 Counts.reserve(NumCounters); 144 145 for (unsigned N = 0; N < NumCounters; ++N) { 146 // Read the count value. 147 uint64_t Count = strtoll(CurPtr, &EndPtr, 10); 148 if (EndPtr == CurPtr || *EndPtr != '\n') { 149 ReportBadPGOData(CGM, "pgo-data file has bad count value"); 150 return true; 151 } 152 Counts.push_back(Count); 153 CurPtr = EndPtr + 1; 154 } 155 156 // Make sure the number of counters matches up. 157 if (Counts.size() != NumCounters) { 158 ReportBadPGOData(CGM, "pgo-data file has inconsistent counters"); 159 return true; 160 } 161 162 return false; 163 } 164 165 void CodeGenPGO::emitWriteoutFunction(GlobalDecl &GD) { 166 if (!CGM.getCodeGenOpts().ProfileInstrGenerate) 167 return; 168 169 llvm::LLVMContext &Ctx = CGM.getLLVMContext(); 170 171 llvm::Type *Int32Ty = llvm::Type::getInt32Ty(Ctx); 172 llvm::Type *Int8PtrTy = llvm::Type::getInt8PtrTy(Ctx); 173 174 llvm::Function *WriteoutF = 175 CGM.getModule().getFunction("__llvm_pgo_writeout"); 176 if (!WriteoutF) { 177 llvm::FunctionType *WriteoutFTy = 178 llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), false); 179 WriteoutF = llvm::Function::Create(WriteoutFTy, 180 llvm::GlobalValue::InternalLinkage, 181 "__llvm_pgo_writeout", &CGM.getModule()); 182 } 183 WriteoutF->setUnnamedAddr(true); 184 WriteoutF->addFnAttr(llvm::Attribute::NoInline); 185 if (CGM.getCodeGenOpts().DisableRedZone) 186 WriteoutF->addFnAttr(llvm::Attribute::NoRedZone); 187 188 llvm::BasicBlock *BB = WriteoutF->empty() ? 189 llvm::BasicBlock::Create(Ctx, "", WriteoutF) : &WriteoutF->getEntryBlock(); 190 191 CGBuilderTy PGOBuilder(BB); 192 193 llvm::Instruction *I = BB->getTerminator(); 194 if (!I) 195 I = PGOBuilder.CreateRetVoid(); 196 PGOBuilder.SetInsertPoint(I); 197 198 llvm::Type *Int64PtrTy = llvm::Type::getInt64PtrTy(Ctx); 199 llvm::Type *Args[] = { 200 Int8PtrTy, // const char *MangledName 201 Int32Ty, // uint32_t NumCounters 202 Int64PtrTy // uint64_t *Counters 203 }; 204 llvm::FunctionType *FTy = 205 llvm::FunctionType::get(PGOBuilder.getVoidTy(), Args, false); 206 llvm::Constant *EmitFunc = 207 CGM.getModule().getOrInsertFunction("llvm_pgo_emit", FTy); 208 209 llvm::Constant *MangledName = 210 CGM.GetAddrOfConstantCString(CGM.getMangledName(GD), "__llvm_pgo_name"); 211 MangledName = llvm::ConstantExpr::getBitCast(MangledName, Int8PtrTy); 212 PGOBuilder.CreateCall3(EmitFunc, MangledName, 213 PGOBuilder.getInt32(NumRegionCounters), 214 PGOBuilder.CreateBitCast(RegionCounters, Int64PtrTy)); 215 } 216 217 llvm::Function *CodeGenPGO::emitInitialization(CodeGenModule &CGM) { 218 llvm::Function *WriteoutF = 219 CGM.getModule().getFunction("__llvm_pgo_writeout"); 220 if (!WriteoutF) 221 return NULL; 222 223 // Create a small bit of code that registers the "__llvm_pgo_writeout" to 224 // be executed at exit. 225 llvm::Function *F = CGM.getModule().getFunction("__llvm_pgo_init"); 226 if (F) 227 return NULL; 228 229 llvm::LLVMContext &Ctx = CGM.getLLVMContext(); 230 llvm::FunctionType *FTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), 231 false); 232 F = llvm::Function::Create(FTy, llvm::GlobalValue::InternalLinkage, 233 "__llvm_pgo_init", &CGM.getModule()); 234 F->setUnnamedAddr(true); 235 F->setLinkage(llvm::GlobalValue::InternalLinkage); 236 F->addFnAttr(llvm::Attribute::NoInline); 237 if (CGM.getCodeGenOpts().DisableRedZone) 238 F->addFnAttr(llvm::Attribute::NoRedZone); 239 240 llvm::BasicBlock *BB = llvm::BasicBlock::Create(CGM.getLLVMContext(), "", F); 241 CGBuilderTy PGOBuilder(BB); 242 243 FTy = llvm::FunctionType::get(PGOBuilder.getVoidTy(), false); 244 llvm::Type *Params[] = { 245 llvm::PointerType::get(FTy, 0) 246 }; 247 FTy = llvm::FunctionType::get(PGOBuilder.getVoidTy(), Params, false); 248 249 // Inialize the environment and register the local writeout function. 250 llvm::Constant *PGOInit = 251 CGM.getModule().getOrInsertFunction("llvm_pgo_init", FTy); 252 PGOBuilder.CreateCall(PGOInit, WriteoutF); 253 PGOBuilder.CreateRetVoid(); 254 255 return F; 256 } 257 258 namespace { 259 /// A StmtVisitor that fills a map of statements to PGO counters. 260 struct MapRegionCounters : public ConstStmtVisitor<MapRegionCounters> { 261 /// The next counter value to assign. 262 unsigned NextCounter; 263 /// The map of statements to counters. 264 llvm::DenseMap<const Stmt*, unsigned> *CounterMap; 265 266 MapRegionCounters(llvm::DenseMap<const Stmt*, unsigned> *CounterMap) : 267 NextCounter(0), CounterMap(CounterMap) { 268 } 269 270 void VisitChildren(const Stmt *S) { 271 for (Stmt::const_child_range I = S->children(); I; ++I) 272 if (*I) 273 this->Visit(*I); 274 } 275 void VisitStmt(const Stmt *S) { VisitChildren(S); } 276 277 /// Assign a counter to track entry to the function body. 278 void VisitFunctionDecl(const FunctionDecl *S) { 279 (*CounterMap)[S->getBody()] = NextCounter++; 280 Visit(S->getBody()); 281 } 282 /// Assign a counter to track the block following a label. 283 void VisitLabelStmt(const LabelStmt *S) { 284 (*CounterMap)[S] = NextCounter++; 285 Visit(S->getSubStmt()); 286 } 287 /// Assign a counter for the body of a while loop. 288 void VisitWhileStmt(const WhileStmt *S) { 289 (*CounterMap)[S] = NextCounter++; 290 Visit(S->getCond()); 291 Visit(S->getBody()); 292 } 293 /// Assign a counter for the body of a do-while loop. 294 void VisitDoStmt(const DoStmt *S) { 295 (*CounterMap)[S] = NextCounter++; 296 Visit(S->getBody()); 297 Visit(S->getCond()); 298 } 299 /// Assign a counter for the body of a for loop. 300 void VisitForStmt(const ForStmt *S) { 301 (*CounterMap)[S] = NextCounter++; 302 if (S->getInit()) 303 Visit(S->getInit()); 304 const Expr *E; 305 if ((E = S->getCond())) 306 Visit(E); 307 if ((E = S->getInc())) 308 Visit(E); 309 Visit(S->getBody()); 310 } 311 /// Assign a counter for the body of a for-range loop. 312 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) { 313 (*CounterMap)[S] = NextCounter++; 314 Visit(S->getRangeStmt()); 315 Visit(S->getBeginEndStmt()); 316 Visit(S->getCond()); 317 Visit(S->getLoopVarStmt()); 318 Visit(S->getBody()); 319 Visit(S->getInc()); 320 } 321 /// Assign a counter for the body of a for-collection loop. 322 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) { 323 (*CounterMap)[S] = NextCounter++; 324 Visit(S->getElement()); 325 Visit(S->getBody()); 326 } 327 /// Assign a counter for the exit block of the switch statement. 328 void VisitSwitchStmt(const SwitchStmt *S) { 329 (*CounterMap)[S] = NextCounter++; 330 Visit(S->getCond()); 331 Visit(S->getBody()); 332 } 333 /// Assign a counter for a particular case in a switch. This counts jumps 334 /// from the switch header as well as fallthrough from the case before this 335 /// one. 336 void VisitCaseStmt(const CaseStmt *S) { 337 (*CounterMap)[S] = NextCounter++; 338 Visit(S->getSubStmt()); 339 } 340 /// Assign a counter for the default case of a switch statement. The count 341 /// is the number of branches from the loop header to the default, and does 342 /// not include fallthrough from previous cases. If we have multiple 343 /// conditional branch blocks from the switch instruction to the default 344 /// block, as with large GNU case ranges, this is the counter for the last 345 /// edge in that series, rather than the first. 346 void VisitDefaultStmt(const DefaultStmt *S) { 347 (*CounterMap)[S] = NextCounter++; 348 Visit(S->getSubStmt()); 349 } 350 /// Assign a counter for the "then" part of an if statement. The count for 351 /// the "else" part, if it exists, will be calculated from this counter. 352 void VisitIfStmt(const IfStmt *S) { 353 (*CounterMap)[S] = NextCounter++; 354 Visit(S->getCond()); 355 Visit(S->getThen()); 356 if (S->getElse()) 357 Visit(S->getElse()); 358 } 359 /// Assign a counter for the continuation block of a C++ try statement. 360 void VisitCXXTryStmt(const CXXTryStmt *S) { 361 (*CounterMap)[S] = NextCounter++; 362 Visit(S->getTryBlock()); 363 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I) 364 Visit(S->getHandler(I)); 365 } 366 /// Assign a counter for a catch statement's handler block. 367 void VisitCXXCatchStmt(const CXXCatchStmt *S) { 368 (*CounterMap)[S] = NextCounter++; 369 Visit(S->getHandlerBlock()); 370 } 371 /// Assign a counter for the "true" part of a conditional operator. The 372 /// count in the "false" part will be calculated from this counter. 373 void VisitConditionalOperator(const ConditionalOperator *E) { 374 (*CounterMap)[E] = NextCounter++; 375 Visit(E->getCond()); 376 Visit(E->getTrueExpr()); 377 Visit(E->getFalseExpr()); 378 } 379 /// Assign a counter for the right hand side of a logical and operator. 380 void VisitBinLAnd(const BinaryOperator *E) { 381 (*CounterMap)[E] = NextCounter++; 382 Visit(E->getLHS()); 383 Visit(E->getRHS()); 384 } 385 /// Assign a counter for the right hand side of a logical or operator. 386 void VisitBinLOr(const BinaryOperator *E) { 387 (*CounterMap)[E] = NextCounter++; 388 Visit(E->getLHS()); 389 Visit(E->getRHS()); 390 } 391 }; 392 393 /// A StmtVisitor that propagates the raw counts through the AST and 394 /// records the count at statements where the value may change. 395 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { 396 /// PGO state. 397 CodeGenPGO &PGO; 398 399 /// A flag that is set when the current count should be recorded on the 400 /// next statement, such as at the exit of a loop. 401 bool RecordNextStmtCount; 402 403 /// The map of statements to count values. 404 llvm::DenseMap<const Stmt*, uint64_t> *CountMap; 405 406 /// BreakContinueStack - Keep counts of breaks and continues inside loops. 407 struct BreakContinue { 408 uint64_t BreakCount; 409 uint64_t ContinueCount; 410 BreakContinue() : BreakCount(0), ContinueCount(0) {} 411 }; 412 SmallVector<BreakContinue, 8> BreakContinueStack; 413 414 ComputeRegionCounts(llvm::DenseMap<const Stmt*, uint64_t> *CountMap, 415 CodeGenPGO &PGO) : 416 PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) { 417 } 418 419 void RecordStmtCount(const Stmt *S) { 420 if (RecordNextStmtCount) { 421 (*CountMap)[S] = PGO.getCurrentRegionCount(); 422 RecordNextStmtCount = false; 423 } 424 } 425 426 void VisitStmt(const Stmt *S) { 427 RecordStmtCount(S); 428 for (Stmt::const_child_range I = S->children(); I; ++I) { 429 if (*I) 430 this->Visit(*I); 431 } 432 } 433 434 void VisitFunctionDecl(const FunctionDecl *S) { 435 RegionCounter Cnt(PGO, S->getBody()); 436 Cnt.beginRegion(); 437 (*CountMap)[S->getBody()] = PGO.getCurrentRegionCount(); 438 Visit(S->getBody()); 439 } 440 441 void VisitReturnStmt(const ReturnStmt *S) { 442 RecordStmtCount(S); 443 if (S->getRetValue()) 444 Visit(S->getRetValue()); 445 PGO.setCurrentRegionUnreachable(); 446 RecordNextStmtCount = true; 447 } 448 449 void VisitGotoStmt(const GotoStmt *S) { 450 RecordStmtCount(S); 451 PGO.setCurrentRegionUnreachable(); 452 RecordNextStmtCount = true; 453 } 454 455 void VisitLabelStmt(const LabelStmt *S) { 456 RecordNextStmtCount = false; 457 RegionCounter Cnt(PGO, S); 458 Cnt.beginRegion(); 459 (*CountMap)[S] = PGO.getCurrentRegionCount(); 460 Visit(S->getSubStmt()); 461 } 462 463 void VisitBreakStmt(const BreakStmt *S) { 464 RecordStmtCount(S); 465 assert(!BreakContinueStack.empty() && "break not in a loop or switch!"); 466 BreakContinueStack.back().BreakCount += PGO.getCurrentRegionCount(); 467 PGO.setCurrentRegionUnreachable(); 468 RecordNextStmtCount = true; 469 } 470 471 void VisitContinueStmt(const ContinueStmt *S) { 472 RecordStmtCount(S); 473 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!"); 474 BreakContinueStack.back().ContinueCount += PGO.getCurrentRegionCount(); 475 PGO.setCurrentRegionUnreachable(); 476 RecordNextStmtCount = true; 477 } 478 479 void VisitWhileStmt(const WhileStmt *S) { 480 RecordStmtCount(S); 481 RegionCounter Cnt(PGO, S); 482 BreakContinueStack.push_back(BreakContinue()); 483 // Visit the body region first so the break/continue adjustments can be 484 // included when visiting the condition. 485 Cnt.beginRegion(); 486 (*CountMap)[S->getBody()] = PGO.getCurrentRegionCount(); 487 Visit(S->getBody()); 488 Cnt.adjustForControlFlow(); 489 490 // ...then go back and propagate counts through the condition. The count 491 // at the start of the condition is the sum of the incoming edges, 492 // the backedge from the end of the loop body, and the edges from 493 // continue statements. 494 BreakContinue BC = BreakContinueStack.pop_back_val(); 495 Cnt.setCurrentRegionCount(Cnt.getParentCount() + 496 Cnt.getAdjustedCount() + BC.ContinueCount); 497 (*CountMap)[S->getCond()] = PGO.getCurrentRegionCount(); 498 Visit(S->getCond()); 499 Cnt.adjustForControlFlow(); 500 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); 501 RecordNextStmtCount = true; 502 } 503 504 void VisitDoStmt(const DoStmt *S) { 505 RecordStmtCount(S); 506 RegionCounter Cnt(PGO, S); 507 BreakContinueStack.push_back(BreakContinue()); 508 Cnt.beginRegion(/*AddIncomingFallThrough=*/true); 509 (*CountMap)[S->getBody()] = PGO.getCurrentRegionCount(); 510 Visit(S->getBody()); 511 Cnt.adjustForControlFlow(); 512 513 BreakContinue BC = BreakContinueStack.pop_back_val(); 514 // The count at the start of the condition is equal to the count at the 515 // end of the body. The adjusted count does not include either the 516 // fall-through count coming into the loop or the continue count, so add 517 // both of those separately. This is coincidentally the same equation as 518 // with while loops but for different reasons. 519 Cnt.setCurrentRegionCount(Cnt.getParentCount() + 520 Cnt.getAdjustedCount() + BC.ContinueCount); 521 (*CountMap)[S->getCond()] = PGO.getCurrentRegionCount(); 522 Visit(S->getCond()); 523 Cnt.adjustForControlFlow(); 524 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); 525 RecordNextStmtCount = true; 526 } 527 528 void VisitForStmt(const ForStmt *S) { 529 RecordStmtCount(S); 530 if (S->getInit()) 531 Visit(S->getInit()); 532 RegionCounter Cnt(PGO, S); 533 BreakContinueStack.push_back(BreakContinue()); 534 // Visit the body region first. (This is basically the same as a while 535 // loop; see further comments in VisitWhileStmt.) 536 Cnt.beginRegion(); 537 (*CountMap)[S->getBody()] = PGO.getCurrentRegionCount(); 538 Visit(S->getBody()); 539 Cnt.adjustForControlFlow(); 540 541 // The increment is essentially part of the body but it needs to include 542 // the count for all the continue statements. 543 if (S->getInc()) { 544 Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() + 545 BreakContinueStack.back().ContinueCount); 546 (*CountMap)[S->getInc()] = PGO.getCurrentRegionCount(); 547 Visit(S->getInc()); 548 Cnt.adjustForControlFlow(); 549 } 550 551 BreakContinue BC = BreakContinueStack.pop_back_val(); 552 553 // ...then go back and propagate counts through the condition. 554 if (S->getCond()) { 555 Cnt.setCurrentRegionCount(Cnt.getParentCount() + 556 Cnt.getAdjustedCount() + 557 BC.ContinueCount); 558 (*CountMap)[S->getCond()] = PGO.getCurrentRegionCount(); 559 Visit(S->getCond()); 560 Cnt.adjustForControlFlow(); 561 } 562 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); 563 RecordNextStmtCount = true; 564 } 565 566 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) { 567 RecordStmtCount(S); 568 Visit(S->getRangeStmt()); 569 Visit(S->getBeginEndStmt()); 570 RegionCounter Cnt(PGO, S); 571 BreakContinueStack.push_back(BreakContinue()); 572 // Visit the body region first. (This is basically the same as a while 573 // loop; see further comments in VisitWhileStmt.) 574 Cnt.beginRegion(); 575 (*CountMap)[S->getLoopVarStmt()] = PGO.getCurrentRegionCount(); 576 Visit(S->getLoopVarStmt()); 577 Visit(S->getBody()); 578 Cnt.adjustForControlFlow(); 579 580 // The increment is essentially part of the body but it needs to include 581 // the count for all the continue statements. 582 Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() + 583 BreakContinueStack.back().ContinueCount); 584 (*CountMap)[S->getInc()] = PGO.getCurrentRegionCount(); 585 Visit(S->getInc()); 586 Cnt.adjustForControlFlow(); 587 588 BreakContinue BC = BreakContinueStack.pop_back_val(); 589 590 // ...then go back and propagate counts through the condition. 591 Cnt.setCurrentRegionCount(Cnt.getParentCount() + 592 Cnt.getAdjustedCount() + 593 BC.ContinueCount); 594 (*CountMap)[S->getCond()] = PGO.getCurrentRegionCount(); 595 Visit(S->getCond()); 596 Cnt.adjustForControlFlow(); 597 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); 598 RecordNextStmtCount = true; 599 } 600 601 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) { 602 RecordStmtCount(S); 603 Visit(S->getElement()); 604 RegionCounter Cnt(PGO, S); 605 BreakContinueStack.push_back(BreakContinue()); 606 Cnt.beginRegion(); 607 (*CountMap)[S->getBody()] = PGO.getCurrentRegionCount(); 608 Visit(S->getBody()); 609 BreakContinue BC = BreakContinueStack.pop_back_val(); 610 Cnt.adjustForControlFlow(); 611 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); 612 RecordNextStmtCount = true; 613 } 614 615 void VisitSwitchStmt(const SwitchStmt *S) { 616 RecordStmtCount(S); 617 Visit(S->getCond()); 618 PGO.setCurrentRegionUnreachable(); 619 BreakContinueStack.push_back(BreakContinue()); 620 Visit(S->getBody()); 621 // If the switch is inside a loop, add the continue counts. 622 BreakContinue BC = BreakContinueStack.pop_back_val(); 623 if (!BreakContinueStack.empty()) 624 BreakContinueStack.back().ContinueCount += BC.ContinueCount; 625 RegionCounter ExitCnt(PGO, S); 626 ExitCnt.beginRegion(); 627 RecordNextStmtCount = true; 628 } 629 630 void VisitCaseStmt(const CaseStmt *S) { 631 RecordNextStmtCount = false; 632 RegionCounter Cnt(PGO, S); 633 Cnt.beginRegion(/*AddIncomingFallThrough=*/true); 634 (*CountMap)[S] = Cnt.getCount(); 635 RecordNextStmtCount = true; 636 Visit(S->getSubStmt()); 637 } 638 639 void VisitDefaultStmt(const DefaultStmt *S) { 640 RecordNextStmtCount = false; 641 RegionCounter Cnt(PGO, S); 642 Cnt.beginRegion(/*AddIncomingFallThrough=*/true); 643 (*CountMap)[S] = Cnt.getCount(); 644 RecordNextStmtCount = true; 645 Visit(S->getSubStmt()); 646 } 647 648 void VisitIfStmt(const IfStmt *S) { 649 RecordStmtCount(S); 650 RegionCounter Cnt(PGO, S); 651 Visit(S->getCond()); 652 653 Cnt.beginRegion(); 654 (*CountMap)[S->getThen()] = PGO.getCurrentRegionCount(); 655 Visit(S->getThen()); 656 Cnt.adjustForControlFlow(); 657 658 if (S->getElse()) { 659 Cnt.beginElseRegion(); 660 (*CountMap)[S->getElse()] = PGO.getCurrentRegionCount(); 661 Visit(S->getElse()); 662 Cnt.adjustForControlFlow(); 663 } 664 Cnt.applyAdjustmentsToRegion(0); 665 RecordNextStmtCount = true; 666 } 667 668 void VisitCXXTryStmt(const CXXTryStmt *S) { 669 RecordStmtCount(S); 670 Visit(S->getTryBlock()); 671 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I) 672 Visit(S->getHandler(I)); 673 RegionCounter Cnt(PGO, S); 674 Cnt.beginRegion(); 675 RecordNextStmtCount = true; 676 } 677 678 void VisitCXXCatchStmt(const CXXCatchStmt *S) { 679 RecordNextStmtCount = false; 680 RegionCounter Cnt(PGO, S); 681 Cnt.beginRegion(); 682 (*CountMap)[S] = PGO.getCurrentRegionCount(); 683 Visit(S->getHandlerBlock()); 684 } 685 686 void VisitConditionalOperator(const ConditionalOperator *E) { 687 RecordStmtCount(E); 688 RegionCounter Cnt(PGO, E); 689 Visit(E->getCond()); 690 691 Cnt.beginRegion(); 692 (*CountMap)[E->getTrueExpr()] = PGO.getCurrentRegionCount(); 693 Visit(E->getTrueExpr()); 694 Cnt.adjustForControlFlow(); 695 696 Cnt.beginElseRegion(); 697 (*CountMap)[E->getFalseExpr()] = PGO.getCurrentRegionCount(); 698 Visit(E->getFalseExpr()); 699 Cnt.adjustForControlFlow(); 700 701 Cnt.applyAdjustmentsToRegion(0); 702 RecordNextStmtCount = true; 703 } 704 705 void VisitBinLAnd(const BinaryOperator *E) { 706 RecordStmtCount(E); 707 RegionCounter Cnt(PGO, E); 708 Visit(E->getLHS()); 709 Cnt.beginRegion(); 710 (*CountMap)[E->getRHS()] = PGO.getCurrentRegionCount(); 711 Visit(E->getRHS()); 712 Cnt.adjustForControlFlow(); 713 Cnt.applyAdjustmentsToRegion(0); 714 RecordNextStmtCount = true; 715 } 716 717 void VisitBinLOr(const BinaryOperator *E) { 718 RecordStmtCount(E); 719 RegionCounter Cnt(PGO, E); 720 Visit(E->getLHS()); 721 Cnt.beginRegion(); 722 (*CountMap)[E->getRHS()] = PGO.getCurrentRegionCount(); 723 Visit(E->getRHS()); 724 Cnt.adjustForControlFlow(); 725 Cnt.applyAdjustmentsToRegion(0); 726 RecordNextStmtCount = true; 727 } 728 }; 729 } 730 731 void CodeGenPGO::assignRegionCounters(GlobalDecl &GD) { 732 bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate; 733 PGOProfileData *PGOData = CGM.getPGOData(); 734 if (!InstrumentRegions && !PGOData) 735 return; 736 const Decl *D = GD.getDecl(); 737 if (!D) 738 return; 739 mapRegionCounters(D); 740 if (InstrumentRegions) 741 emitCounterVariables(); 742 if (PGOData) { 743 loadRegionCounts(GD, PGOData); 744 computeRegionCounts(D); 745 } 746 } 747 748 void CodeGenPGO::mapRegionCounters(const Decl *D) { 749 RegionCounterMap = new llvm::DenseMap<const Stmt*, unsigned>(); 750 MapRegionCounters Walker(RegionCounterMap); 751 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 752 Walker.VisitFunctionDecl(FD); 753 NumRegionCounters = Walker.NextCounter; 754 } 755 756 void CodeGenPGO::computeRegionCounts(const Decl *D) { 757 StmtCountMap = new llvm::DenseMap<const Stmt*, uint64_t>(); 758 ComputeRegionCounts Walker(StmtCountMap, *this); 759 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 760 Walker.VisitFunctionDecl(FD); 761 } 762 763 void CodeGenPGO::emitCounterVariables() { 764 llvm::LLVMContext &Ctx = CGM.getLLVMContext(); 765 llvm::ArrayType *CounterTy = llvm::ArrayType::get(llvm::Type::getInt64Ty(Ctx), 766 NumRegionCounters); 767 RegionCounters = 768 new llvm::GlobalVariable(CGM.getModule(), CounterTy, false, 769 llvm::GlobalVariable::PrivateLinkage, 770 llvm::Constant::getNullValue(CounterTy), 771 "__llvm_pgo_ctr"); 772 } 773 774 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, unsigned Counter) { 775 if (!CGM.getCodeGenOpts().ProfileInstrGenerate) 776 return; 777 llvm::Value *Addr = 778 Builder.CreateConstInBoundsGEP2_64(RegionCounters, 0, Counter); 779 llvm::Value *Count = Builder.CreateLoad(Addr, "pgocount"); 780 Count = Builder.CreateAdd(Count, Builder.getInt64(1)); 781 Builder.CreateStore(Count, Addr); 782 } 783 784 void CodeGenPGO::loadRegionCounts(GlobalDecl &GD, PGOProfileData *PGOData) { 785 // For now, ignore the counts from the PGO data file only if the number of 786 // counters does not match. This could be tightened down in the future to 787 // ignore counts when the input changes in various ways, e.g., by comparing a 788 // hash value based on some characteristics of the input. 789 RegionCounts = new std::vector<uint64_t>(); 790 if (PGOData->getFunctionCounts(CGM.getMangledName(GD), *RegionCounts) || 791 RegionCounts->size() != NumRegionCounters) { 792 delete RegionCounts; 793 RegionCounts = 0; 794 } 795 } 796 797 void CodeGenPGO::destroyRegionCounters() { 798 if (RegionCounterMap != 0) 799 delete RegionCounterMap; 800 if (StmtCountMap != 0) 801 delete StmtCountMap; 802 if (RegionCounts != 0) 803 delete RegionCounts; 804 } 805 806 llvm::MDNode *CodeGenPGO::createBranchWeights(uint64_t TrueCount, 807 uint64_t FalseCount) { 808 if (!TrueCount && !FalseCount) 809 return 0; 810 811 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 812 // TODO: need to scale down to 32-bits 813 // According to Laplace's Rule of Succession, it is better to compute the 814 // weight based on the count plus 1. 815 return MDHelper.createBranchWeights(TrueCount + 1, FalseCount + 1); 816 } 817 818 llvm::MDNode *CodeGenPGO::createBranchWeights(ArrayRef<uint64_t> Weights) { 819 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 820 // TODO: need to scale down to 32-bits, instead of just truncating. 821 // According to Laplace's Rule of Succession, it is better to compute the 822 // weight based on the count plus 1. 823 SmallVector<uint32_t, 16> ScaledWeights; 824 ScaledWeights.reserve(Weights.size()); 825 for (ArrayRef<uint64_t>::iterator WI = Weights.begin(), WE = Weights.end(); 826 WI != WE; ++WI) { 827 ScaledWeights.push_back(*WI + 1); 828 } 829 return MDHelper.createBranchWeights(ScaledWeights); 830 } 831 832 llvm::MDNode *CodeGenPGO::createLoopWeights(const Stmt *Cond, 833 RegionCounter &Cnt) { 834 if (!haveRegionCounts()) 835 return 0; 836 uint64_t LoopCount = Cnt.getCount(); 837 uint64_t CondCount = 0; 838 bool Found = getStmtCount(Cond, CondCount); 839 assert(Found && "missing expected loop condition count"); 840 (void)Found; 841 if (CondCount == 0) 842 return 0; 843 return createBranchWeights(LoopCount, 844 std::max(CondCount, LoopCount) - LoopCount); 845 } 846