1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // Instrumentation-based profile-guided optimization 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "CodeGenPGO.h" 15 #include "CodeGenFunction.h" 16 #include "clang/AST/RecursiveASTVisitor.h" 17 #include "clang/AST/StmtVisitor.h" 18 #include "llvm/Config/config.h" // for strtoull()/strtoll() define 19 #include "llvm/IR/MDBuilder.h" 20 #include "llvm/Support/FileSystem.h" 21 22 using namespace clang; 23 using namespace CodeGen; 24 25 static void ReportBadPGOData(CodeGenModule &CGM, const char *Message) { 26 DiagnosticsEngine &Diags = CGM.getDiags(); 27 unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error, Message); 28 Diags.Report(DiagID); 29 } 30 31 PGOProfileData::PGOProfileData(CodeGenModule &CGM, std::string Path) 32 : CGM(CGM) { 33 if (llvm::MemoryBuffer::getFile(Path, DataBuffer)) { 34 ReportBadPGOData(CGM, "failed to open pgo data file"); 35 return; 36 } 37 38 if (DataBuffer->getBufferSize() > std::numeric_limits<unsigned>::max()) { 39 ReportBadPGOData(CGM, "pgo data file too big"); 40 return; 41 } 42 43 // Scan through the data file and map each function to the corresponding 44 // file offset where its counts are stored. 45 const char *BufferStart = DataBuffer->getBufferStart(); 46 const char *BufferEnd = DataBuffer->getBufferEnd(); 47 const char *CurPtr = BufferStart; 48 while (CurPtr < BufferEnd) { 49 // Read the mangled function name. 50 const char *FuncName = CurPtr; 51 // FIXME: Something will need to be added to distinguish static functions. 52 CurPtr = strchr(CurPtr, ' '); 53 if (!CurPtr) { 54 ReportBadPGOData(CGM, "pgo data file has malformed function entry"); 55 return; 56 } 57 StringRef MangledName(FuncName, CurPtr - FuncName); 58 59 // Read the number of counters. 60 char *EndPtr; 61 unsigned NumCounters = strtol(++CurPtr, &EndPtr, 10); 62 if (EndPtr == CurPtr || *EndPtr != '\n' || NumCounters <= 0) { 63 ReportBadPGOData(CGM, "pgo data file has unexpected number of counters"); 64 return; 65 } 66 CurPtr = EndPtr; 67 68 // There is one line for each counter; skip over those lines. 69 for (unsigned N = 0; N < NumCounters; ++N) { 70 CurPtr = strchr(++CurPtr, '\n'); 71 if (!CurPtr) { 72 ReportBadPGOData(CGM, "pgo data file is missing some counter info"); 73 return; 74 } 75 } 76 77 // Skip over the blank line separating functions. 78 CurPtr += 2; 79 80 DataOffsets[MangledName] = FuncName - BufferStart; 81 } 82 } 83 84 bool PGOProfileData::getFunctionCounts(StringRef MangledName, 85 std::vector<uint64_t> &Counts) { 86 // Find the relevant section of the pgo-data file. 87 llvm::StringMap<unsigned>::const_iterator OffsetIter = 88 DataOffsets.find(MangledName); 89 if (OffsetIter == DataOffsets.end()) 90 return true; 91 const char *CurPtr = DataBuffer->getBufferStart() + OffsetIter->getValue(); 92 93 // Skip over the function name. 94 CurPtr = strchr(CurPtr, ' '); 95 assert(CurPtr && "pgo-data has corrupted function entry"); 96 97 // Read the number of counters. 98 char *EndPtr; 99 unsigned NumCounters = strtol(++CurPtr, &EndPtr, 10); 100 assert(EndPtr != CurPtr && *EndPtr == '\n' && NumCounters > 0 && 101 "pgo-data file has corrupted number of counters"); 102 CurPtr = EndPtr; 103 104 Counts.reserve(NumCounters); 105 106 for (unsigned N = 0; N < NumCounters; ++N) { 107 // Read the count value. 108 uint64_t Count = strtoll(CurPtr, &EndPtr, 10); 109 if (EndPtr == CurPtr || *EndPtr != '\n') { 110 ReportBadPGOData(CGM, "pgo-data file has bad count value"); 111 return true; 112 } 113 Counts.push_back(Count); 114 CurPtr = EndPtr + 1; 115 } 116 117 // Make sure the number of counters matches up. 118 if (Counts.size() != NumCounters) { 119 ReportBadPGOData(CGM, "pgo-data file has inconsistent counters"); 120 return true; 121 } 122 123 return false; 124 } 125 126 void CodeGenPGO::emitWriteoutFunction(GlobalDecl &GD) { 127 if (!CGM.getCodeGenOpts().ProfileInstrGenerate) 128 return; 129 130 llvm::LLVMContext &Ctx = CGM.getLLVMContext(); 131 132 llvm::Type *Int32Ty = llvm::Type::getInt32Ty(Ctx); 133 llvm::Type *Int8PtrTy = llvm::Type::getInt8PtrTy(Ctx); 134 135 llvm::Function *WriteoutF = 136 CGM.getModule().getFunction("__llvm_pgo_writeout"); 137 if (!WriteoutF) { 138 llvm::FunctionType *WriteoutFTy = 139 llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), false); 140 WriteoutF = llvm::Function::Create(WriteoutFTy, 141 llvm::GlobalValue::InternalLinkage, 142 "__llvm_pgo_writeout", &CGM.getModule()); 143 } 144 WriteoutF->setUnnamedAddr(true); 145 WriteoutF->addFnAttr(llvm::Attribute::NoInline); 146 if (CGM.getCodeGenOpts().DisableRedZone) 147 WriteoutF->addFnAttr(llvm::Attribute::NoRedZone); 148 149 llvm::BasicBlock *BB = WriteoutF->empty() ? 150 llvm::BasicBlock::Create(Ctx, "", WriteoutF) : &WriteoutF->getEntryBlock(); 151 152 CGBuilderTy PGOBuilder(BB); 153 154 llvm::Instruction *I = BB->getTerminator(); 155 if (!I) 156 I = PGOBuilder.CreateRetVoid(); 157 PGOBuilder.SetInsertPoint(I); 158 159 llvm::Type *Int64PtrTy = llvm::Type::getInt64PtrTy(Ctx); 160 llvm::Type *Args[] = { 161 Int8PtrTy, // const char *MangledName 162 Int32Ty, // uint32_t NumCounters 163 Int64PtrTy // uint64_t *Counters 164 }; 165 llvm::FunctionType *FTy = 166 llvm::FunctionType::get(PGOBuilder.getVoidTy(), Args, false); 167 llvm::Constant *EmitFunc = 168 CGM.getModule().getOrInsertFunction("llvm_pgo_emit", FTy); 169 170 llvm::Constant *MangledName = 171 CGM.GetAddrOfConstantCString(CGM.getMangledName(GD), "__llvm_pgo_name"); 172 MangledName = llvm::ConstantExpr::getBitCast(MangledName, Int8PtrTy); 173 PGOBuilder.CreateCall3(EmitFunc, MangledName, 174 PGOBuilder.getInt32(NumRegionCounters), 175 PGOBuilder.CreateBitCast(RegionCounters, Int64PtrTy)); 176 } 177 178 llvm::Function *CodeGenPGO::emitInitialization(CodeGenModule &CGM) { 179 llvm::Function *WriteoutF = 180 CGM.getModule().getFunction("__llvm_pgo_writeout"); 181 if (!WriteoutF) 182 return NULL; 183 184 // Create a small bit of code that registers the "__llvm_pgo_writeout" to 185 // be executed at exit. 186 llvm::Function *F = CGM.getModule().getFunction("__llvm_pgo_init"); 187 if (F) 188 return NULL; 189 190 llvm::LLVMContext &Ctx = CGM.getLLVMContext(); 191 llvm::FunctionType *FTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), 192 false); 193 F = llvm::Function::Create(FTy, llvm::GlobalValue::InternalLinkage, 194 "__llvm_pgo_init", &CGM.getModule()); 195 F->setUnnamedAddr(true); 196 F->setLinkage(llvm::GlobalValue::InternalLinkage); 197 F->addFnAttr(llvm::Attribute::NoInline); 198 if (CGM.getCodeGenOpts().DisableRedZone) 199 F->addFnAttr(llvm::Attribute::NoRedZone); 200 201 llvm::BasicBlock *BB = llvm::BasicBlock::Create(CGM.getLLVMContext(), "", F); 202 CGBuilderTy PGOBuilder(BB); 203 204 FTy = llvm::FunctionType::get(PGOBuilder.getVoidTy(), false); 205 llvm::Type *Params[] = { 206 llvm::PointerType::get(FTy, 0) 207 }; 208 FTy = llvm::FunctionType::get(PGOBuilder.getVoidTy(), Params, false); 209 210 // Inialize the environment and register the local writeout function. 211 llvm::Constant *PGOInit = 212 CGM.getModule().getOrInsertFunction("llvm_pgo_init", FTy); 213 PGOBuilder.CreateCall(PGOInit, WriteoutF); 214 PGOBuilder.CreateRetVoid(); 215 216 return F; 217 } 218 219 namespace { 220 /// A StmtVisitor that fills a map of statements to PGO counters. 221 struct MapRegionCounters : public ConstStmtVisitor<MapRegionCounters> { 222 /// The next counter value to assign. 223 unsigned NextCounter; 224 /// The map of statements to counters. 225 llvm::DenseMap<const Stmt*, unsigned> *CounterMap; 226 227 MapRegionCounters(llvm::DenseMap<const Stmt*, unsigned> *CounterMap) : 228 NextCounter(0), CounterMap(CounterMap) { 229 } 230 231 void VisitChildren(const Stmt *S) { 232 for (Stmt::const_child_range I = S->children(); I; ++I) 233 if (*I) 234 this->Visit(*I); 235 } 236 void VisitStmt(const Stmt *S) { VisitChildren(S); } 237 238 /// Assign a counter to track entry to the function body. 239 void VisitFunctionDecl(const FunctionDecl *S) { 240 (*CounterMap)[S->getBody()] = NextCounter++; 241 Visit(S->getBody()); 242 } 243 /// Assign a counter to track the block following a label. 244 void VisitLabelStmt(const LabelStmt *S) { 245 (*CounterMap)[S] = NextCounter++; 246 Visit(S->getSubStmt()); 247 } 248 /// Assign three counters - one for the body of the loop, one for breaks 249 /// from the loop, and one for continues. 250 /// 251 /// The break and continue counters cover all such statements in this loop, 252 /// and are used in calculations to find the number of times the condition 253 /// and exit of the loop occur. They are needed so we can differentiate 254 /// these statements from non-local exits like return and goto. 255 void VisitWhileStmt(const WhileStmt *S) { 256 (*CounterMap)[S] = NextCounter; 257 NextCounter += 3; 258 Visit(S->getCond()); 259 Visit(S->getBody()); 260 } 261 /// Assign counters for the body of the loop, and for breaks and 262 /// continues. See VisitWhileStmt. 263 void VisitDoStmt(const DoStmt *S) { 264 (*CounterMap)[S] = NextCounter; 265 NextCounter += 3; 266 Visit(S->getBody()); 267 Visit(S->getCond()); 268 } 269 /// Assign counters for the body of the loop, and for breaks and 270 /// continues. See VisitWhileStmt. 271 void VisitForStmt(const ForStmt *S) { 272 (*CounterMap)[S] = NextCounter; 273 NextCounter += 3; 274 const Expr *E; 275 if ((E = S->getCond())) 276 Visit(E); 277 Visit(S->getBody()); 278 if ((E = S->getInc())) 279 Visit(E); 280 } 281 /// Assign counters for the body of the loop, and for breaks and 282 /// continues. See VisitWhileStmt. 283 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) { 284 (*CounterMap)[S] = NextCounter; 285 NextCounter += 3; 286 const Expr *E; 287 if ((E = S->getCond())) 288 Visit(E); 289 Visit(S->getBody()); 290 if ((E = S->getInc())) 291 Visit(E); 292 } 293 /// Assign counters for the body of the loop, and for breaks and 294 /// continues. See VisitWhileStmt. 295 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) { 296 (*CounterMap)[S] = NextCounter; 297 NextCounter += 3; 298 Visit(S->getElement()); 299 Visit(S->getBody()); 300 } 301 /// Assign a counter for the exit block of the switch statement. 302 void VisitSwitchStmt(const SwitchStmt *S) { 303 (*CounterMap)[S] = NextCounter++; 304 Visit(S->getCond()); 305 Visit(S->getBody()); 306 } 307 /// Assign a counter for a particular case in a switch. This counts jumps 308 /// from the switch header as well as fallthrough from the case before this 309 /// one. 310 void VisitCaseStmt(const CaseStmt *S) { 311 (*CounterMap)[S] = NextCounter++; 312 Visit(S->getSubStmt()); 313 } 314 /// Assign a counter for the default case of a switch statement. The count 315 /// is the number of branches from the loop header to the default, and does 316 /// not include fallthrough from previous cases. If we have multiple 317 /// conditional branch blocks from the switch instruction to the default 318 /// block, as with large GNU case ranges, this is the counter for the last 319 /// edge in that series, rather than the first. 320 void VisitDefaultStmt(const DefaultStmt *S) { 321 (*CounterMap)[S] = NextCounter++; 322 Visit(S->getSubStmt()); 323 } 324 /// Assign a counter for the "then" part of an if statement. The count for 325 /// the "else" part, if it exists, will be calculated from this counter. 326 void VisitIfStmt(const IfStmt *S) { 327 (*CounterMap)[S] = NextCounter++; 328 Visit(S->getCond()); 329 Visit(S->getThen()); 330 if (S->getElse()) 331 Visit(S->getElse()); 332 } 333 /// Assign a counter for the continuation block of a C++ try statement. 334 void VisitCXXTryStmt(const CXXTryStmt *S) { 335 (*CounterMap)[S] = NextCounter++; 336 Visit(S->getTryBlock()); 337 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I) 338 Visit(S->getHandler(I)); 339 } 340 /// Assign a counter for a catch statement's handler block. 341 void VisitCXXCatchStmt(const CXXCatchStmt *S) { 342 (*CounterMap)[S] = NextCounter++; 343 Visit(S->getHandlerBlock()); 344 } 345 /// Assign a counter for the "true" part of a conditional operator. The 346 /// count in the "false" part will be calculated from this counter. 347 void VisitConditionalOperator(const ConditionalOperator *E) { 348 (*CounterMap)[E] = NextCounter++; 349 Visit(E->getCond()); 350 Visit(E->getTrueExpr()); 351 Visit(E->getFalseExpr()); 352 } 353 /// Assign a counter for the right hand side of a logical and operator. 354 void VisitBinLAnd(const BinaryOperator *E) { 355 (*CounterMap)[E] = NextCounter++; 356 Visit(E->getLHS()); 357 Visit(E->getRHS()); 358 } 359 /// Assign a counter for the right hand side of a logical or operator. 360 void VisitBinLOr(const BinaryOperator *E) { 361 (*CounterMap)[E] = NextCounter++; 362 Visit(E->getLHS()); 363 Visit(E->getRHS()); 364 } 365 }; 366 } 367 368 void CodeGenPGO::assignRegionCounters(GlobalDecl &GD) { 369 bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate; 370 PGOProfileData *PGOData = CGM.getPGOData(); 371 if (!InstrumentRegions && !PGOData) 372 return; 373 const Decl *D = GD.getDecl(); 374 if (!D) 375 return; 376 mapRegionCounters(D); 377 if (InstrumentRegions) 378 emitCounterVariables(); 379 if (PGOData) 380 loadRegionCounts(GD, PGOData); 381 } 382 383 void CodeGenPGO::mapRegionCounters(const Decl *D) { 384 RegionCounterMap = new llvm::DenseMap<const Stmt*, unsigned>(); 385 MapRegionCounters Walker(RegionCounterMap); 386 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 387 Walker.VisitFunctionDecl(FD); 388 NumRegionCounters = Walker.NextCounter; 389 } 390 391 void CodeGenPGO::emitCounterVariables() { 392 llvm::LLVMContext &Ctx = CGM.getLLVMContext(); 393 llvm::ArrayType *CounterTy = llvm::ArrayType::get(llvm::Type::getInt64Ty(Ctx), 394 NumRegionCounters); 395 RegionCounters = 396 new llvm::GlobalVariable(CGM.getModule(), CounterTy, false, 397 llvm::GlobalVariable::PrivateLinkage, 398 llvm::Constant::getNullValue(CounterTy), 399 "__llvm_pgo_ctr"); 400 } 401 402 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, unsigned Counter) { 403 if (!CGM.getCodeGenOpts().ProfileInstrGenerate) 404 return; 405 llvm::Value *Addr = 406 Builder.CreateConstInBoundsGEP2_64(RegionCounters, 0, Counter); 407 llvm::Value *Count = Builder.CreateLoad(Addr, "pgocount"); 408 Count = Builder.CreateAdd(Count, Builder.getInt64(1)); 409 Builder.CreateStore(Count, Addr); 410 } 411 412 void CodeGenPGO::loadRegionCounts(GlobalDecl &GD, PGOProfileData *PGOData) { 413 // For now, ignore the counts from the PGO data file only if the number of 414 // counters does not match. This could be tightened down in the future to 415 // ignore counts when the input changes in various ways, e.g., by comparing a 416 // hash value based on some characteristics of the input. 417 RegionCounts = new std::vector<uint64_t>(); 418 if (PGOData->getFunctionCounts(CGM.getMangledName(GD), *RegionCounts) || 419 RegionCounts->size() != NumRegionCounters) { 420 delete RegionCounts; 421 RegionCounts = 0; 422 } 423 } 424 425 void CodeGenPGO::destroyRegionCounters() { 426 if (RegionCounterMap != 0) 427 delete RegionCounterMap; 428 if (RegionCounts != 0) 429 delete RegionCounts; 430 } 431 432 llvm::MDNode *CodeGenPGO::createBranchWeights(uint64_t TrueCount, 433 uint64_t FalseCount) { 434 if (!TrueCount && !FalseCount) 435 return 0; 436 437 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 438 // TODO: need to scale down to 32-bits 439 // According to Laplace's Rule of Succession, it is better to compute the 440 // weight based on the count plus 1. 441 return MDHelper.createBranchWeights(TrueCount + 1, FalseCount + 1); 442 } 443 444 llvm::MDNode * 445 CodeGenPGO::createBranchWeights(ArrayRef<uint64_t> Weights) { 446 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 447 // TODO: need to scale down to 32-bits, instead of just truncating. 448 // According to Laplace's Rule of Succession, it is better to compute the 449 // weight based on the count plus 1. 450 SmallVector<uint32_t, 16> ScaledWeights; 451 ScaledWeights.reserve(Weights.size()); 452 for (ArrayRef<uint64_t>::iterator WI = Weights.begin(), WE = Weights.end(); 453 WI != WE; ++WI) { 454 ScaledWeights.push_back(*WI + 1); 455 } 456 return MDHelper.createBranchWeights(ScaledWeights); 457 } 458