1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // Instrumentation-based profile-guided optimization 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "CodeGenPGO.h" 15 #include "CodeGenFunction.h" 16 #include "clang/AST/RecursiveASTVisitor.h" 17 #include "clang/AST/StmtVisitor.h" 18 #include "llvm/Config/config.h" // for strtoull()/strtoll() define 19 #include "llvm/IR/MDBuilder.h" 20 #include "llvm/Support/FileSystem.h" 21 22 using namespace clang; 23 using namespace CodeGen; 24 25 static void ReportBadPGOData(CodeGenModule &CGM, const char *Message) { 26 DiagnosticsEngine &Diags = CGM.getDiags(); 27 unsigned diagID = Diags.getCustomDiagID(DiagnosticsEngine::Error, "%0"); 28 Diags.Report(diagID) << Message; 29 } 30 31 PGOProfileData::PGOProfileData(CodeGenModule &CGM, std::string Path) 32 : CGM(CGM) { 33 if (llvm::MemoryBuffer::getFile(Path, DataBuffer)) { 34 ReportBadPGOData(CGM, "failed to open pgo data file"); 35 return; 36 } 37 38 if (DataBuffer->getBufferSize() > std::numeric_limits<unsigned>::max()) { 39 ReportBadPGOData(CGM, "pgo data file too big"); 40 return; 41 } 42 43 // Scan through the data file and map each function to the corresponding 44 // file offset where its counts are stored. 45 const char *BufferStart = DataBuffer->getBufferStart(); 46 const char *BufferEnd = DataBuffer->getBufferEnd(); 47 const char *CurPtr = BufferStart; 48 uint64_t MaxCount = 0; 49 while (CurPtr < BufferEnd) { 50 // Read the mangled function name. 51 const char *FuncName = CurPtr; 52 // FIXME: Something will need to be added to distinguish static functions. 53 CurPtr = strchr(CurPtr, ' '); 54 if (!CurPtr) { 55 ReportBadPGOData(CGM, "pgo data file has malformed function entry"); 56 return; 57 } 58 StringRef MangledName(FuncName, CurPtr - FuncName); 59 60 // Read the number of counters. 61 char *EndPtr; 62 unsigned NumCounters = strtol(++CurPtr, &EndPtr, 10); 63 if (EndPtr == CurPtr || *EndPtr != '\n' || NumCounters <= 0) { 64 ReportBadPGOData(CGM, "pgo data file has unexpected number of counters"); 65 return; 66 } 67 CurPtr = EndPtr; 68 69 // Read function count. 70 uint64_t Count = strtoll(CurPtr, &EndPtr, 10); 71 if (EndPtr == CurPtr || *EndPtr != '\n') { 72 ReportBadPGOData(CGM, "pgo-data file has bad count value"); 73 return; 74 } 75 CurPtr = EndPtr + 1; 76 FunctionCounts[MangledName] = Count; 77 MaxCount = Count > MaxCount ? Count : MaxCount; 78 79 // There is one line for each counter; skip over those lines. 80 // Since function count is already read, we start the loop from 1. 81 for (unsigned N = 1; N < NumCounters; ++N) { 82 CurPtr = strchr(++CurPtr, '\n'); 83 if (!CurPtr) { 84 ReportBadPGOData(CGM, "pgo data file is missing some counter info"); 85 return; 86 } 87 } 88 89 // Skip over the blank line separating functions. 90 CurPtr += 2; 91 92 DataOffsets[MangledName] = FuncName - BufferStart; 93 } 94 MaxFunctionCount = MaxCount; 95 } 96 97 /// Return true if a function is hot. If we know nothing about the function, 98 /// return false. 99 bool PGOProfileData::isHotFunction(StringRef MangledName) { 100 llvm::StringMap<uint64_t>::const_iterator CountIter = 101 FunctionCounts.find(MangledName); 102 // If we know nothing about the function, return false. 103 if (CountIter == FunctionCounts.end()) 104 return false; 105 // FIXME: functions with >= 30% of the maximal function count are 106 // treated as hot. This number is from preliminary tuning on SPEC. 107 return CountIter->getValue() >= (uint64_t)(0.3 * (double)MaxFunctionCount); 108 } 109 110 /// Return true if a function is cold. If we know nothing about the function, 111 /// return false. 112 bool PGOProfileData::isColdFunction(StringRef MangledName) { 113 llvm::StringMap<uint64_t>::const_iterator CountIter = 114 FunctionCounts.find(MangledName); 115 // If we know nothing about the function, return false. 116 if (CountIter == FunctionCounts.end()) 117 return false; 118 // FIXME: functions with <= 1% of the maximal function count are treated as 119 // cold. This number is from preliminary tuning on SPEC. 120 return CountIter->getValue() <= (uint64_t)(0.01 * (double)MaxFunctionCount); 121 } 122 123 bool PGOProfileData::getFunctionCounts(StringRef MangledName, 124 std::vector<uint64_t> &Counts) { 125 // Find the relevant section of the pgo-data file. 126 llvm::StringMap<unsigned>::const_iterator OffsetIter = 127 DataOffsets.find(MangledName); 128 if (OffsetIter == DataOffsets.end()) 129 return true; 130 const char *CurPtr = DataBuffer->getBufferStart() + OffsetIter->getValue(); 131 132 // Skip over the function name. 133 CurPtr = strchr(CurPtr, ' '); 134 assert(CurPtr && "pgo-data has corrupted function entry"); 135 136 // Read the number of counters. 137 char *EndPtr; 138 unsigned NumCounters = strtol(++CurPtr, &EndPtr, 10); 139 assert(EndPtr != CurPtr && *EndPtr == '\n' && NumCounters > 0 && 140 "pgo-data file has corrupted number of counters"); 141 CurPtr = EndPtr; 142 143 Counts.reserve(NumCounters); 144 145 for (unsigned N = 0; N < NumCounters; ++N) { 146 // Read the count value. 147 uint64_t Count = strtoll(CurPtr, &EndPtr, 10); 148 if (EndPtr == CurPtr || *EndPtr != '\n') { 149 ReportBadPGOData(CGM, "pgo-data file has bad count value"); 150 return true; 151 } 152 Counts.push_back(Count); 153 CurPtr = EndPtr + 1; 154 } 155 156 // Make sure the number of counters matches up. 157 if (Counts.size() != NumCounters) { 158 ReportBadPGOData(CGM, "pgo-data file has inconsistent counters"); 159 return true; 160 } 161 162 return false; 163 } 164 165 void CodeGenPGO::emitWriteoutFunction(GlobalDecl &GD) { 166 if (!CGM.getCodeGenOpts().ProfileInstrGenerate) 167 return; 168 169 llvm::LLVMContext &Ctx = CGM.getLLVMContext(); 170 171 llvm::Type *Int32Ty = llvm::Type::getInt32Ty(Ctx); 172 llvm::Type *Int8PtrTy = llvm::Type::getInt8PtrTy(Ctx); 173 174 llvm::Function *WriteoutF = 175 CGM.getModule().getFunction("__llvm_pgo_writeout"); 176 if (!WriteoutF) { 177 llvm::FunctionType *WriteoutFTy = 178 llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), false); 179 WriteoutF = llvm::Function::Create(WriteoutFTy, 180 llvm::GlobalValue::InternalLinkage, 181 "__llvm_pgo_writeout", &CGM.getModule()); 182 } 183 WriteoutF->setUnnamedAddr(true); 184 WriteoutF->addFnAttr(llvm::Attribute::NoInline); 185 if (CGM.getCodeGenOpts().DisableRedZone) 186 WriteoutF->addFnAttr(llvm::Attribute::NoRedZone); 187 188 llvm::BasicBlock *BB = WriteoutF->empty() ? 189 llvm::BasicBlock::Create(Ctx, "", WriteoutF) : &WriteoutF->getEntryBlock(); 190 191 CGBuilderTy PGOBuilder(BB); 192 193 llvm::Instruction *I = BB->getTerminator(); 194 if (!I) 195 I = PGOBuilder.CreateRetVoid(); 196 PGOBuilder.SetInsertPoint(I); 197 198 llvm::Type *Int64PtrTy = llvm::Type::getInt64PtrTy(Ctx); 199 llvm::Type *Args[] = { 200 Int8PtrTy, // const char *MangledName 201 Int32Ty, // uint32_t NumCounters 202 Int64PtrTy // uint64_t *Counters 203 }; 204 llvm::FunctionType *FTy = 205 llvm::FunctionType::get(PGOBuilder.getVoidTy(), Args, false); 206 llvm::Constant *EmitFunc = 207 CGM.getModule().getOrInsertFunction("llvm_pgo_emit", FTy); 208 209 llvm::Constant *MangledName = 210 CGM.GetAddrOfConstantCString(CGM.getMangledName(GD), "__llvm_pgo_name"); 211 MangledName = llvm::ConstantExpr::getBitCast(MangledName, Int8PtrTy); 212 PGOBuilder.CreateCall3(EmitFunc, MangledName, 213 PGOBuilder.getInt32(NumRegionCounters), 214 PGOBuilder.CreateBitCast(RegionCounters, Int64PtrTy)); 215 } 216 217 llvm::Function *CodeGenPGO::emitInitialization(CodeGenModule &CGM) { 218 llvm::Function *WriteoutF = 219 CGM.getModule().getFunction("__llvm_pgo_writeout"); 220 if (!WriteoutF) 221 return NULL; 222 223 // Create a small bit of code that registers the "__llvm_pgo_writeout" to 224 // be executed at exit. 225 llvm::Function *F = CGM.getModule().getFunction("__llvm_pgo_init"); 226 if (F) 227 return NULL; 228 229 llvm::LLVMContext &Ctx = CGM.getLLVMContext(); 230 llvm::FunctionType *FTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), 231 false); 232 F = llvm::Function::Create(FTy, llvm::GlobalValue::InternalLinkage, 233 "__llvm_pgo_init", &CGM.getModule()); 234 F->setUnnamedAddr(true); 235 F->setLinkage(llvm::GlobalValue::InternalLinkage); 236 F->addFnAttr(llvm::Attribute::NoInline); 237 if (CGM.getCodeGenOpts().DisableRedZone) 238 F->addFnAttr(llvm::Attribute::NoRedZone); 239 240 llvm::BasicBlock *BB = llvm::BasicBlock::Create(CGM.getLLVMContext(), "", F); 241 CGBuilderTy PGOBuilder(BB); 242 243 FTy = llvm::FunctionType::get(PGOBuilder.getVoidTy(), false); 244 llvm::Type *Params[] = { 245 llvm::PointerType::get(FTy, 0) 246 }; 247 FTy = llvm::FunctionType::get(PGOBuilder.getVoidTy(), Params, false); 248 249 // Inialize the environment and register the local writeout function. 250 llvm::Constant *PGOInit = 251 CGM.getModule().getOrInsertFunction("llvm_pgo_init", FTy); 252 PGOBuilder.CreateCall(PGOInit, WriteoutF); 253 PGOBuilder.CreateRetVoid(); 254 255 return F; 256 } 257 258 namespace { 259 /// A StmtVisitor that fills a map of statements to PGO counters. 260 struct MapRegionCounters : public ConstStmtVisitor<MapRegionCounters> { 261 /// The next counter value to assign. 262 unsigned NextCounter; 263 /// The map of statements to counters. 264 llvm::DenseMap<const Stmt*, unsigned> *CounterMap; 265 266 MapRegionCounters(llvm::DenseMap<const Stmt*, unsigned> *CounterMap) : 267 NextCounter(0), CounterMap(CounterMap) { 268 } 269 270 void VisitChildren(const Stmt *S) { 271 for (Stmt::const_child_range I = S->children(); I; ++I) 272 if (*I) 273 this->Visit(*I); 274 } 275 void VisitStmt(const Stmt *S) { VisitChildren(S); } 276 277 /// Assign a counter to track entry to the function body. 278 void VisitFunctionDecl(const FunctionDecl *S) { 279 (*CounterMap)[S->getBody()] = NextCounter++; 280 Visit(S->getBody()); 281 } 282 /// Assign a counter to track the block following a label. 283 void VisitLabelStmt(const LabelStmt *S) { 284 (*CounterMap)[S] = NextCounter++; 285 Visit(S->getSubStmt()); 286 } 287 /// Assign three counters - one for the body of the loop, one for breaks 288 /// from the loop, and one for continues. 289 /// 290 /// The break and continue counters cover all such statements in this loop, 291 /// and are used in calculations to find the number of times the condition 292 /// and exit of the loop occur. They are needed so we can differentiate 293 /// these statements from non-local exits like return and goto. 294 void VisitWhileStmt(const WhileStmt *S) { 295 (*CounterMap)[S] = NextCounter; 296 NextCounter += 3; 297 Visit(S->getCond()); 298 Visit(S->getBody()); 299 } 300 /// Assign counters for the body of the loop, and for breaks and 301 /// continues. See VisitWhileStmt. 302 void VisitDoStmt(const DoStmt *S) { 303 (*CounterMap)[S] = NextCounter; 304 NextCounter += 3; 305 Visit(S->getBody()); 306 Visit(S->getCond()); 307 } 308 /// Assign counters for the body of the loop, and for breaks and 309 /// continues. See VisitWhileStmt. 310 void VisitForStmt(const ForStmt *S) { 311 (*CounterMap)[S] = NextCounter; 312 NextCounter += 3; 313 const Expr *E; 314 if ((E = S->getCond())) 315 Visit(E); 316 Visit(S->getBody()); 317 if ((E = S->getInc())) 318 Visit(E); 319 } 320 /// Assign counters for the body of the loop, and for breaks and 321 /// continues. See VisitWhileStmt. 322 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) { 323 (*CounterMap)[S] = NextCounter; 324 NextCounter += 3; 325 const Expr *E; 326 if ((E = S->getCond())) 327 Visit(E); 328 Visit(S->getBody()); 329 if ((E = S->getInc())) 330 Visit(E); 331 } 332 /// Assign counters for the body of the loop, and for breaks and 333 /// continues. See VisitWhileStmt. 334 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) { 335 (*CounterMap)[S] = NextCounter; 336 NextCounter += 3; 337 Visit(S->getElement()); 338 Visit(S->getBody()); 339 } 340 /// Assign a counter for the exit block of the switch statement. 341 void VisitSwitchStmt(const SwitchStmt *S) { 342 (*CounterMap)[S] = NextCounter++; 343 Visit(S->getCond()); 344 Visit(S->getBody()); 345 } 346 /// Assign a counter for a particular case in a switch. This counts jumps 347 /// from the switch header as well as fallthrough from the case before this 348 /// one. 349 void VisitCaseStmt(const CaseStmt *S) { 350 (*CounterMap)[S] = NextCounter++; 351 Visit(S->getSubStmt()); 352 } 353 /// Assign a counter for the default case of a switch statement. The count 354 /// is the number of branches from the loop header to the default, and does 355 /// not include fallthrough from previous cases. If we have multiple 356 /// conditional branch blocks from the switch instruction to the default 357 /// block, as with large GNU case ranges, this is the counter for the last 358 /// edge in that series, rather than the first. 359 void VisitDefaultStmt(const DefaultStmt *S) { 360 (*CounterMap)[S] = NextCounter++; 361 Visit(S->getSubStmt()); 362 } 363 /// Assign a counter for the "then" part of an if statement. The count for 364 /// the "else" part, if it exists, will be calculated from this counter. 365 void VisitIfStmt(const IfStmt *S) { 366 (*CounterMap)[S] = NextCounter++; 367 Visit(S->getCond()); 368 Visit(S->getThen()); 369 if (S->getElse()) 370 Visit(S->getElse()); 371 } 372 /// Assign a counter for the continuation block of a C++ try statement. 373 void VisitCXXTryStmt(const CXXTryStmt *S) { 374 (*CounterMap)[S] = NextCounter++; 375 Visit(S->getTryBlock()); 376 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I) 377 Visit(S->getHandler(I)); 378 } 379 /// Assign a counter for a catch statement's handler block. 380 void VisitCXXCatchStmt(const CXXCatchStmt *S) { 381 (*CounterMap)[S] = NextCounter++; 382 Visit(S->getHandlerBlock()); 383 } 384 /// Assign a counter for the "true" part of a conditional operator. The 385 /// count in the "false" part will be calculated from this counter. 386 void VisitConditionalOperator(const ConditionalOperator *E) { 387 (*CounterMap)[E] = NextCounter++; 388 Visit(E->getCond()); 389 Visit(E->getTrueExpr()); 390 Visit(E->getFalseExpr()); 391 } 392 /// Assign a counter for the right hand side of a logical and operator. 393 void VisitBinLAnd(const BinaryOperator *E) { 394 (*CounterMap)[E] = NextCounter++; 395 Visit(E->getLHS()); 396 Visit(E->getRHS()); 397 } 398 /// Assign a counter for the right hand side of a logical or operator. 399 void VisitBinLOr(const BinaryOperator *E) { 400 (*CounterMap)[E] = NextCounter++; 401 Visit(E->getLHS()); 402 Visit(E->getRHS()); 403 } 404 }; 405 } 406 407 void CodeGenPGO::assignRegionCounters(GlobalDecl &GD) { 408 bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate; 409 PGOProfileData *PGOData = CGM.getPGOData(); 410 if (!InstrumentRegions && !PGOData) 411 return; 412 const Decl *D = GD.getDecl(); 413 if (!D) 414 return; 415 mapRegionCounters(D); 416 if (InstrumentRegions) 417 emitCounterVariables(); 418 if (PGOData) 419 loadRegionCounts(GD, PGOData); 420 } 421 422 void CodeGenPGO::mapRegionCounters(const Decl *D) { 423 RegionCounterMap = new llvm::DenseMap<const Stmt*, unsigned>(); 424 MapRegionCounters Walker(RegionCounterMap); 425 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 426 Walker.VisitFunctionDecl(FD); 427 NumRegionCounters = Walker.NextCounter; 428 } 429 430 void CodeGenPGO::emitCounterVariables() { 431 llvm::LLVMContext &Ctx = CGM.getLLVMContext(); 432 llvm::ArrayType *CounterTy = llvm::ArrayType::get(llvm::Type::getInt64Ty(Ctx), 433 NumRegionCounters); 434 RegionCounters = 435 new llvm::GlobalVariable(CGM.getModule(), CounterTy, false, 436 llvm::GlobalVariable::PrivateLinkage, 437 llvm::Constant::getNullValue(CounterTy), 438 "__llvm_pgo_ctr"); 439 } 440 441 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, unsigned Counter) { 442 if (!CGM.getCodeGenOpts().ProfileInstrGenerate) 443 return; 444 llvm::Value *Addr = 445 Builder.CreateConstInBoundsGEP2_64(RegionCounters, 0, Counter); 446 llvm::Value *Count = Builder.CreateLoad(Addr, "pgocount"); 447 Count = Builder.CreateAdd(Count, Builder.getInt64(1)); 448 Builder.CreateStore(Count, Addr); 449 } 450 451 void CodeGenPGO::loadRegionCounts(GlobalDecl &GD, PGOProfileData *PGOData) { 452 // For now, ignore the counts from the PGO data file only if the number of 453 // counters does not match. This could be tightened down in the future to 454 // ignore counts when the input changes in various ways, e.g., by comparing a 455 // hash value based on some characteristics of the input. 456 RegionCounts = new std::vector<uint64_t>(); 457 if (PGOData->getFunctionCounts(CGM.getMangledName(GD), *RegionCounts) || 458 RegionCounts->size() != NumRegionCounters) { 459 delete RegionCounts; 460 RegionCounts = 0; 461 } 462 } 463 464 void CodeGenPGO::destroyRegionCounters() { 465 if (RegionCounterMap != 0) 466 delete RegionCounterMap; 467 if (RegionCounts != 0) 468 delete RegionCounts; 469 } 470 471 llvm::MDNode *CodeGenPGO::createBranchWeights(uint64_t TrueCount, 472 uint64_t FalseCount) { 473 if (!TrueCount && !FalseCount) 474 return 0; 475 476 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 477 // TODO: need to scale down to 32-bits 478 // According to Laplace's Rule of Succession, it is better to compute the 479 // weight based on the count plus 1. 480 return MDHelper.createBranchWeights(TrueCount + 1, FalseCount + 1); 481 } 482 483 llvm::MDNode * 484 CodeGenPGO::createBranchWeights(ArrayRef<uint64_t> Weights) { 485 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 486 // TODO: need to scale down to 32-bits, instead of just truncating. 487 // According to Laplace's Rule of Succession, it is better to compute the 488 // weight based on the count plus 1. 489 SmallVector<uint32_t, 16> ScaledWeights; 490 ScaledWeights.reserve(Weights.size()); 491 for (ArrayRef<uint64_t>::iterator WI = Weights.begin(), WE = Weights.end(); 492 WI != WE; ++WI) { 493 ScaledWeights.push_back(*WI + 1); 494 } 495 return MDHelper.createBranchWeights(ScaledWeights); 496 } 497