1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Instrumentation-based profile-guided optimization
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "CodeGenPGO.h"
15 #include "CodeGenFunction.h"
16 #include "clang/AST/RecursiveASTVisitor.h"
17 #include "clang/AST/StmtVisitor.h"
18 #include "llvm/Config/config.h" // for strtoull()/strtoll() define
19 #include "llvm/IR/MDBuilder.h"
20 #include "llvm/Support/FileSystem.h"
21 
22 using namespace clang;
23 using namespace CodeGen;
24 
25 static void ReportBadPGOData(CodeGenModule &CGM, const char *Message) {
26   DiagnosticsEngine &Diags = CGM.getDiags();
27   unsigned diagID = Diags.getCustomDiagID(DiagnosticsEngine::Error, "%0");
28   Diags.Report(diagID) << Message;
29 }
30 
31 PGOProfileData::PGOProfileData(CodeGenModule &CGM, std::string Path)
32   : CGM(CGM) {
33   if (llvm::MemoryBuffer::getFile(Path, DataBuffer)) {
34     ReportBadPGOData(CGM, "failed to open pgo data file");
35     return;
36   }
37 
38   if (DataBuffer->getBufferSize() > std::numeric_limits<unsigned>::max()) {
39     ReportBadPGOData(CGM, "pgo data file too big");
40     return;
41   }
42 
43   // Scan through the data file and map each function to the corresponding
44   // file offset where its counts are stored.
45   const char *BufferStart = DataBuffer->getBufferStart();
46   const char *BufferEnd = DataBuffer->getBufferEnd();
47   const char *CurPtr = BufferStart;
48   uint64_t MaxCount = 0;
49   while (CurPtr < BufferEnd) {
50     // Read the mangled function name.
51     const char *FuncName = CurPtr;
52     // FIXME: Something will need to be added to distinguish static functions.
53     CurPtr = strchr(CurPtr, ' ');
54     if (!CurPtr) {
55       ReportBadPGOData(CGM, "pgo data file has malformed function entry");
56       return;
57     }
58     StringRef MangledName(FuncName, CurPtr - FuncName);
59 
60     // Read the number of counters.
61     char *EndPtr;
62     unsigned NumCounters = strtol(++CurPtr, &EndPtr, 10);
63     if (EndPtr == CurPtr || *EndPtr != '\n' || NumCounters <= 0) {
64       ReportBadPGOData(CGM, "pgo data file has unexpected number of counters");
65       return;
66     }
67     CurPtr = EndPtr;
68 
69     // Read function count.
70     uint64_t Count = strtoll(CurPtr, &EndPtr, 10);
71     if (EndPtr == CurPtr || *EndPtr != '\n') {
72       ReportBadPGOData(CGM, "pgo-data file has bad count value");
73       return;
74     }
75     CurPtr = EndPtr + 1;
76     FunctionCounts[MangledName] = Count;
77     MaxCount = Count > MaxCount ? Count : MaxCount;
78 
79     // There is one line for each counter; skip over those lines.
80     // Since function count is already read, we start the loop from 1.
81     for (unsigned N = 1; N < NumCounters; ++N) {
82       CurPtr = strchr(++CurPtr, '\n');
83       if (!CurPtr) {
84         ReportBadPGOData(CGM, "pgo data file is missing some counter info");
85         return;
86       }
87     }
88 
89     // Skip over the blank line separating functions.
90     CurPtr += 2;
91 
92     DataOffsets[MangledName] = FuncName - BufferStart;
93   }
94   MaxFunctionCount = MaxCount;
95 }
96 
97 /// Return true if a function is hot. If we know nothing about the function,
98 /// return false.
99 bool PGOProfileData::isHotFunction(StringRef MangledName) {
100   llvm::StringMap<uint64_t>::const_iterator CountIter =
101     FunctionCounts.find(MangledName);
102   // If we know nothing about the function, return false.
103   if (CountIter == FunctionCounts.end())
104     return false;
105   // FIXME: functions with >= 30% of the maximal function count are
106   // treated as hot. This number is from preliminary tuning on SPEC.
107   return CountIter->getValue() >= (uint64_t)(0.3 * (double)MaxFunctionCount);
108 }
109 
110 /// Return true if a function is cold. If we know nothing about the function,
111 /// return false.
112 bool PGOProfileData::isColdFunction(StringRef MangledName) {
113   llvm::StringMap<uint64_t>::const_iterator CountIter =
114     FunctionCounts.find(MangledName);
115   // If we know nothing about the function, return false.
116   if (CountIter == FunctionCounts.end())
117     return false;
118   // FIXME: functions with <= 1% of the maximal function count are treated as
119   // cold. This number is from preliminary tuning on SPEC.
120   return CountIter->getValue() <= (uint64_t)(0.01 * (double)MaxFunctionCount);
121 }
122 
123 bool PGOProfileData::getFunctionCounts(StringRef MangledName,
124                                        std::vector<uint64_t> &Counts) {
125   // Find the relevant section of the pgo-data file.
126   llvm::StringMap<unsigned>::const_iterator OffsetIter =
127     DataOffsets.find(MangledName);
128   if (OffsetIter == DataOffsets.end())
129     return true;
130   const char *CurPtr = DataBuffer->getBufferStart() + OffsetIter->getValue();
131 
132   // Skip over the function name.
133   CurPtr = strchr(CurPtr, ' ');
134   assert(CurPtr && "pgo-data has corrupted function entry");
135 
136   // Read the number of counters.
137   char *EndPtr;
138   unsigned NumCounters = strtol(++CurPtr, &EndPtr, 10);
139   assert(EndPtr != CurPtr && *EndPtr == '\n' && NumCounters > 0 &&
140          "pgo-data file has corrupted number of counters");
141   CurPtr = EndPtr;
142 
143   Counts.reserve(NumCounters);
144 
145   for (unsigned N = 0; N < NumCounters; ++N) {
146     // Read the count value.
147     uint64_t Count = strtoll(CurPtr, &EndPtr, 10);
148     if (EndPtr == CurPtr || *EndPtr != '\n') {
149       ReportBadPGOData(CGM, "pgo-data file has bad count value");
150       return true;
151     }
152     Counts.push_back(Count);
153     CurPtr = EndPtr + 1;
154   }
155 
156   // Make sure the number of counters matches up.
157   if (Counts.size() != NumCounters) {
158     ReportBadPGOData(CGM, "pgo-data file has inconsistent counters");
159     return true;
160   }
161 
162   return false;
163 }
164 
165 void CodeGenPGO::emitWriteoutFunction(GlobalDecl &GD) {
166   if (!CGM.getCodeGenOpts().ProfileInstrGenerate)
167     return;
168 
169   llvm::LLVMContext &Ctx = CGM.getLLVMContext();
170 
171   llvm::Type *Int32Ty = llvm::Type::getInt32Ty(Ctx);
172   llvm::Type *Int8PtrTy = llvm::Type::getInt8PtrTy(Ctx);
173 
174   llvm::Function *WriteoutF =
175     CGM.getModule().getFunction("__llvm_pgo_writeout");
176   if (!WriteoutF) {
177     llvm::FunctionType *WriteoutFTy =
178       llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), false);
179     WriteoutF = llvm::Function::Create(WriteoutFTy,
180                                        llvm::GlobalValue::InternalLinkage,
181                                        "__llvm_pgo_writeout", &CGM.getModule());
182   }
183   WriteoutF->setUnnamedAddr(true);
184   WriteoutF->addFnAttr(llvm::Attribute::NoInline);
185   if (CGM.getCodeGenOpts().DisableRedZone)
186     WriteoutF->addFnAttr(llvm::Attribute::NoRedZone);
187 
188   llvm::BasicBlock *BB = WriteoutF->empty() ?
189     llvm::BasicBlock::Create(Ctx, "", WriteoutF) : &WriteoutF->getEntryBlock();
190 
191   CGBuilderTy PGOBuilder(BB);
192 
193   llvm::Instruction *I = BB->getTerminator();
194   if (!I)
195     I = PGOBuilder.CreateRetVoid();
196   PGOBuilder.SetInsertPoint(I);
197 
198   llvm::Type *Int64PtrTy = llvm::Type::getInt64PtrTy(Ctx);
199   llvm::Type *Args[] = {
200     Int8PtrTy,                       // const char *MangledName
201     Int32Ty,                         // uint32_t NumCounters
202     Int64PtrTy                       // uint64_t *Counters
203   };
204   llvm::FunctionType *FTy =
205     llvm::FunctionType::get(PGOBuilder.getVoidTy(), Args, false);
206   llvm::Constant *EmitFunc =
207     CGM.getModule().getOrInsertFunction("llvm_pgo_emit", FTy);
208 
209   llvm::Constant *MangledName =
210     CGM.GetAddrOfConstantCString(CGM.getMangledName(GD), "__llvm_pgo_name");
211   MangledName = llvm::ConstantExpr::getBitCast(MangledName, Int8PtrTy);
212   PGOBuilder.CreateCall3(EmitFunc, MangledName,
213                          PGOBuilder.getInt32(NumRegionCounters),
214                          PGOBuilder.CreateBitCast(RegionCounters, Int64PtrTy));
215 }
216 
217 llvm::Function *CodeGenPGO::emitInitialization(CodeGenModule &CGM) {
218   llvm::Function *WriteoutF =
219     CGM.getModule().getFunction("__llvm_pgo_writeout");
220   if (!WriteoutF)
221     return NULL;
222 
223   // Create a small bit of code that registers the "__llvm_pgo_writeout" to
224   // be executed at exit.
225   llvm::Function *F = CGM.getModule().getFunction("__llvm_pgo_init");
226   if (F)
227     return NULL;
228 
229   llvm::LLVMContext &Ctx = CGM.getLLVMContext();
230   llvm::FunctionType *FTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx),
231                                                     false);
232   F = llvm::Function::Create(FTy, llvm::GlobalValue::InternalLinkage,
233                              "__llvm_pgo_init", &CGM.getModule());
234   F->setUnnamedAddr(true);
235   F->setLinkage(llvm::GlobalValue::InternalLinkage);
236   F->addFnAttr(llvm::Attribute::NoInline);
237   if (CGM.getCodeGenOpts().DisableRedZone)
238     F->addFnAttr(llvm::Attribute::NoRedZone);
239 
240   llvm::BasicBlock *BB = llvm::BasicBlock::Create(CGM.getLLVMContext(), "", F);
241   CGBuilderTy PGOBuilder(BB);
242 
243   FTy = llvm::FunctionType::get(PGOBuilder.getVoidTy(), false);
244   llvm::Type *Params[] = {
245     llvm::PointerType::get(FTy, 0)
246   };
247   FTy = llvm::FunctionType::get(PGOBuilder.getVoidTy(), Params, false);
248 
249   // Inialize the environment and register the local writeout function.
250   llvm::Constant *PGOInit =
251     CGM.getModule().getOrInsertFunction("llvm_pgo_init", FTy);
252   PGOBuilder.CreateCall(PGOInit, WriteoutF);
253   PGOBuilder.CreateRetVoid();
254 
255   return F;
256 }
257 
258 namespace {
259   /// A StmtVisitor that fills a map of statements to PGO counters.
260   struct MapRegionCounters : public ConstStmtVisitor<MapRegionCounters> {
261     /// The next counter value to assign.
262     unsigned NextCounter;
263     /// The map of statements to counters.
264     llvm::DenseMap<const Stmt*, unsigned> *CounterMap;
265 
266     MapRegionCounters(llvm::DenseMap<const Stmt*, unsigned> *CounterMap) :
267       NextCounter(0), CounterMap(CounterMap) {
268     }
269 
270     void VisitChildren(const Stmt *S) {
271       for (Stmt::const_child_range I = S->children(); I; ++I)
272         if (*I)
273          this->Visit(*I);
274     }
275     void VisitStmt(const Stmt *S) { VisitChildren(S); }
276 
277     /// Assign a counter to track entry to the function body.
278     void VisitFunctionDecl(const FunctionDecl *S) {
279       (*CounterMap)[S->getBody()] = NextCounter++;
280       Visit(S->getBody());
281     }
282     /// Assign a counter to track the block following a label.
283     void VisitLabelStmt(const LabelStmt *S) {
284       (*CounterMap)[S] = NextCounter++;
285       Visit(S->getSubStmt());
286     }
287     /// Assign three counters - one for the body of the loop, one for breaks
288     /// from the loop, and one for continues.
289     ///
290     /// The break and continue counters cover all such statements in this loop,
291     /// and are used in calculations to find the number of times the condition
292     /// and exit of the loop occur. They are needed so we can differentiate
293     /// these statements from non-local exits like return and goto.
294     void VisitWhileStmt(const WhileStmt *S) {
295       (*CounterMap)[S] = NextCounter;
296       NextCounter += 3;
297       Visit(S->getCond());
298       Visit(S->getBody());
299     }
300     /// Assign counters for the body of the loop, and for breaks and
301     /// continues. See VisitWhileStmt.
302     void VisitDoStmt(const DoStmt *S) {
303       (*CounterMap)[S] = NextCounter;
304       NextCounter += 3;
305       Visit(S->getBody());
306       Visit(S->getCond());
307     }
308     /// Assign counters for the body of the loop, and for breaks and
309     /// continues. See VisitWhileStmt.
310     void VisitForStmt(const ForStmt *S) {
311       (*CounterMap)[S] = NextCounter;
312       NextCounter += 3;
313       const Expr *E;
314       if ((E = S->getCond()))
315         Visit(E);
316       Visit(S->getBody());
317       if ((E = S->getInc()))
318         Visit(E);
319     }
320     /// Assign counters for the body of the loop, and for breaks and
321     /// continues. See VisitWhileStmt.
322     void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
323       (*CounterMap)[S] = NextCounter;
324       NextCounter += 3;
325       const Expr *E;
326       if ((E = S->getCond()))
327         Visit(E);
328       Visit(S->getBody());
329       if ((E = S->getInc()))
330         Visit(E);
331     }
332     /// Assign counters for the body of the loop, and for breaks and
333     /// continues. See VisitWhileStmt.
334     void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
335       (*CounterMap)[S] = NextCounter;
336       NextCounter += 3;
337       Visit(S->getElement());
338       Visit(S->getBody());
339     }
340     /// Assign a counter for the exit block of the switch statement.
341     void VisitSwitchStmt(const SwitchStmt *S) {
342       (*CounterMap)[S] = NextCounter++;
343       Visit(S->getCond());
344       Visit(S->getBody());
345     }
346     /// Assign a counter for a particular case in a switch. This counts jumps
347     /// from the switch header as well as fallthrough from the case before this
348     /// one.
349     void VisitCaseStmt(const CaseStmt *S) {
350       (*CounterMap)[S] = NextCounter++;
351       Visit(S->getSubStmt());
352     }
353     /// Assign a counter for the default case of a switch statement. The count
354     /// is the number of branches from the loop header to the default, and does
355     /// not include fallthrough from previous cases. If we have multiple
356     /// conditional branch blocks from the switch instruction to the default
357     /// block, as with large GNU case ranges, this is the counter for the last
358     /// edge in that series, rather than the first.
359     void VisitDefaultStmt(const DefaultStmt *S) {
360       (*CounterMap)[S] = NextCounter++;
361       Visit(S->getSubStmt());
362     }
363     /// Assign a counter for the "then" part of an if statement. The count for
364     /// the "else" part, if it exists, will be calculated from this counter.
365     void VisitIfStmt(const IfStmt *S) {
366       (*CounterMap)[S] = NextCounter++;
367       Visit(S->getCond());
368       Visit(S->getThen());
369       if (S->getElse())
370         Visit(S->getElse());
371     }
372     /// Assign a counter for the continuation block of a C++ try statement.
373     void VisitCXXTryStmt(const CXXTryStmt *S) {
374       (*CounterMap)[S] = NextCounter++;
375       Visit(S->getTryBlock());
376       for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
377         Visit(S->getHandler(I));
378     }
379     /// Assign a counter for a catch statement's handler block.
380     void VisitCXXCatchStmt(const CXXCatchStmt *S) {
381       (*CounterMap)[S] = NextCounter++;
382       Visit(S->getHandlerBlock());
383     }
384     /// Assign a counter for the "true" part of a conditional operator. The
385     /// count in the "false" part will be calculated from this counter.
386     void VisitConditionalOperator(const ConditionalOperator *E) {
387       (*CounterMap)[E] = NextCounter++;
388       Visit(E->getCond());
389       Visit(E->getTrueExpr());
390       Visit(E->getFalseExpr());
391     }
392     /// Assign a counter for the right hand side of a logical and operator.
393     void VisitBinLAnd(const BinaryOperator *E) {
394       (*CounterMap)[E] = NextCounter++;
395       Visit(E->getLHS());
396       Visit(E->getRHS());
397     }
398     /// Assign a counter for the right hand side of a logical or operator.
399     void VisitBinLOr(const BinaryOperator *E) {
400       (*CounterMap)[E] = NextCounter++;
401       Visit(E->getLHS());
402       Visit(E->getRHS());
403     }
404   };
405 }
406 
407 void CodeGenPGO::assignRegionCounters(GlobalDecl &GD) {
408   bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate;
409   PGOProfileData *PGOData = CGM.getPGOData();
410   if (!InstrumentRegions && !PGOData)
411     return;
412   const Decl *D = GD.getDecl();
413   if (!D)
414     return;
415   mapRegionCounters(D);
416   if (InstrumentRegions)
417     emitCounterVariables();
418   if (PGOData)
419     loadRegionCounts(GD, PGOData);
420 }
421 
422 void CodeGenPGO::mapRegionCounters(const Decl *D) {
423   RegionCounterMap = new llvm::DenseMap<const Stmt*, unsigned>();
424   MapRegionCounters Walker(RegionCounterMap);
425   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
426     Walker.VisitFunctionDecl(FD);
427   NumRegionCounters = Walker.NextCounter;
428 }
429 
430 void CodeGenPGO::emitCounterVariables() {
431   llvm::LLVMContext &Ctx = CGM.getLLVMContext();
432   llvm::ArrayType *CounterTy = llvm::ArrayType::get(llvm::Type::getInt64Ty(Ctx),
433                                                     NumRegionCounters);
434   RegionCounters =
435     new llvm::GlobalVariable(CGM.getModule(), CounterTy, false,
436                              llvm::GlobalVariable::PrivateLinkage,
437                              llvm::Constant::getNullValue(CounterTy),
438                              "__llvm_pgo_ctr");
439 }
440 
441 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, unsigned Counter) {
442   if (!CGM.getCodeGenOpts().ProfileInstrGenerate)
443     return;
444   llvm::Value *Addr =
445     Builder.CreateConstInBoundsGEP2_64(RegionCounters, 0, Counter);
446   llvm::Value *Count = Builder.CreateLoad(Addr, "pgocount");
447   Count = Builder.CreateAdd(Count, Builder.getInt64(1));
448   Builder.CreateStore(Count, Addr);
449 }
450 
451 void CodeGenPGO::loadRegionCounts(GlobalDecl &GD, PGOProfileData *PGOData) {
452   // For now, ignore the counts from the PGO data file only if the number of
453   // counters does not match. This could be tightened down in the future to
454   // ignore counts when the input changes in various ways, e.g., by comparing a
455   // hash value based on some characteristics of the input.
456   RegionCounts = new std::vector<uint64_t>();
457   if (PGOData->getFunctionCounts(CGM.getMangledName(GD), *RegionCounts) ||
458       RegionCounts->size() != NumRegionCounters) {
459     delete RegionCounts;
460     RegionCounts = 0;
461   }
462 }
463 
464 void CodeGenPGO::destroyRegionCounters() {
465   if (RegionCounterMap != 0)
466     delete RegionCounterMap;
467   if (RegionCounts != 0)
468     delete RegionCounts;
469 }
470 
471 llvm::MDNode *CodeGenPGO::createBranchWeights(uint64_t TrueCount,
472                                               uint64_t FalseCount) {
473   if (!TrueCount && !FalseCount)
474     return 0;
475 
476   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
477   // TODO: need to scale down to 32-bits
478   // According to Laplace's Rule of Succession, it is better to compute the
479   // weight based on the count plus 1.
480   return MDHelper.createBranchWeights(TrueCount + 1, FalseCount + 1);
481 }
482 
483 llvm::MDNode *
484 CodeGenPGO::createBranchWeights(ArrayRef<uint64_t> Weights) {
485   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
486   // TODO: need to scale down to 32-bits, instead of just truncating.
487   // According to Laplace's Rule of Succession, it is better to compute the
488   // weight based on the count plus 1.
489   SmallVector<uint32_t, 16> ScaledWeights;
490   ScaledWeights.reserve(Weights.size());
491   for (ArrayRef<uint64_t>::iterator WI = Weights.begin(), WE = Weights.end();
492        WI != WE; ++WI) {
493     ScaledWeights.push_back(*WI + 1);
494   }
495   return MDHelper.createBranchWeights(ScaledWeights);
496 }
497