1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Instrumentation-based profile-guided optimization
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "CodeGenPGO.h"
15 #include "CodeGenFunction.h"
16 #include "clang/AST/RecursiveASTVisitor.h"
17 #include "clang/AST/StmtVisitor.h"
18 #include "llvm/Config/config.h" // for strtoull()/strtoll() define
19 #include "llvm/IR/MDBuilder.h"
20 #include "llvm/Support/FileSystem.h"
21 
22 using namespace clang;
23 using namespace CodeGen;
24 
25 static void ReportBadPGOData(CodeGenModule &CGM, const char *Message) {
26   DiagnosticsEngine &Diags = CGM.getDiags();
27   unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error, Message);
28   Diags.Report(DiagID);
29 }
30 
31 PGOProfileData::PGOProfileData(CodeGenModule &CGM, std::string Path)
32   : CGM(CGM) {
33   if (llvm::MemoryBuffer::getFile(Path, DataBuffer)) {
34     ReportBadPGOData(CGM, "failed to open pgo data file");
35     return;
36   }
37 
38   if (DataBuffer->getBufferSize() > std::numeric_limits<unsigned>::max()) {
39     ReportBadPGOData(CGM, "pgo data file too big");
40     return;
41   }
42 
43   // Scan through the data file and map each function to the corresponding
44   // file offset where its counts are stored.
45   const char *BufferStart = DataBuffer->getBufferStart();
46   const char *BufferEnd = DataBuffer->getBufferEnd();
47   const char *CurPtr = BufferStart;
48   while (CurPtr < BufferEnd) {
49     // Read the mangled function name.
50     const char *FuncName = CurPtr;
51     // FIXME: Something will need to be added to distinguish static functions.
52     CurPtr = strchr(CurPtr, ' ');
53     if (!CurPtr) {
54       ReportBadPGOData(CGM, "pgo data file has malformed function entry");
55       return;
56     }
57     StringRef MangledName(FuncName, CurPtr - FuncName);
58 
59     // Read the number of counters.
60     char *EndPtr;
61     unsigned NumCounters = strtol(++CurPtr, &EndPtr, 10);
62     if (EndPtr == CurPtr || *EndPtr != '\n' || NumCounters <= 0) {
63       ReportBadPGOData(CGM, "pgo data file has unexpected number of counters");
64       return;
65     }
66     CurPtr = EndPtr;
67 
68     // There is one line for each counter; skip over those lines.
69     for (unsigned N = 0; N < NumCounters; ++N) {
70       CurPtr = strchr(++CurPtr, '\n');
71       if (!CurPtr) {
72         ReportBadPGOData(CGM, "pgo data file is missing some counter info");
73         return;
74       }
75     }
76 
77     // Skip over the blank line separating functions.
78     CurPtr += 2;
79 
80     DataOffsets[MangledName] = FuncName - BufferStart;
81   }
82 }
83 
84 bool PGOProfileData::getFunctionCounts(StringRef MangledName,
85                                        std::vector<uint64_t> &Counts) {
86   // Find the relevant section of the pgo-data file.
87   llvm::StringMap<unsigned>::const_iterator OffsetIter =
88     DataOffsets.find(MangledName);
89   if (OffsetIter == DataOffsets.end())
90     return true;
91   const char *CurPtr = DataBuffer->getBufferStart() + OffsetIter->getValue();
92 
93   // Skip over the function name.
94   CurPtr = strchr(CurPtr, ' ');
95   assert(CurPtr && "pgo-data has corrupted function entry");
96 
97   // Read the number of counters.
98   char *EndPtr;
99   unsigned NumCounters = strtol(++CurPtr, &EndPtr, 10);
100   assert(EndPtr != CurPtr && *EndPtr == '\n' && NumCounters > 0 &&
101          "pgo-data file has corrupted number of counters");
102   CurPtr = EndPtr;
103 
104   Counts.reserve(NumCounters);
105 
106   for (unsigned N = 0; N < NumCounters; ++N) {
107     // Read the count value.
108     uint64_t Count = strtoll(CurPtr, &EndPtr, 10);
109     if (EndPtr == CurPtr || *EndPtr != '\n') {
110       ReportBadPGOData(CGM, "pgo-data file has bad count value");
111       return true;
112     }
113     Counts.push_back(Count);
114     CurPtr = EndPtr + 1;
115   }
116 
117   // Make sure the number of counters matches up.
118   if (Counts.size() != NumCounters) {
119     ReportBadPGOData(CGM, "pgo-data file has inconsistent counters");
120     return true;
121   }
122 
123   return false;
124 }
125 
126 void CodeGenPGO::emitWriteoutFunction(GlobalDecl &GD) {
127   if (!CGM.getCodeGenOpts().ProfileInstrGenerate)
128     return;
129 
130   llvm::LLVMContext &Ctx = CGM.getLLVMContext();
131 
132   llvm::Type *Int32Ty = llvm::Type::getInt32Ty(Ctx);
133   llvm::Type *Int8PtrTy = llvm::Type::getInt8PtrTy(Ctx);
134 
135   llvm::Function *WriteoutF =
136     CGM.getModule().getFunction("__llvm_pgo_writeout");
137   if (!WriteoutF) {
138     llvm::FunctionType *WriteoutFTy =
139       llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), false);
140     WriteoutF = llvm::Function::Create(WriteoutFTy,
141                                        llvm::GlobalValue::InternalLinkage,
142                                        "__llvm_pgo_writeout", &CGM.getModule());
143   }
144   WriteoutF->setUnnamedAddr(true);
145   WriteoutF->addFnAttr(llvm::Attribute::NoInline);
146   if (CGM.getCodeGenOpts().DisableRedZone)
147     WriteoutF->addFnAttr(llvm::Attribute::NoRedZone);
148 
149   llvm::BasicBlock *BB = WriteoutF->empty() ?
150     llvm::BasicBlock::Create(Ctx, "", WriteoutF) : &WriteoutF->getEntryBlock();
151 
152   CGBuilderTy PGOBuilder(BB);
153 
154   llvm::Instruction *I = BB->getTerminator();
155   if (!I)
156     I = PGOBuilder.CreateRetVoid();
157   PGOBuilder.SetInsertPoint(I);
158 
159   llvm::Type *Int64PtrTy = llvm::Type::getInt64PtrTy(Ctx);
160   llvm::Type *Args[] = {
161     Int8PtrTy,                       // const char *MangledName
162     Int32Ty,                         // uint32_t NumCounters
163     Int64PtrTy                       // uint64_t *Counters
164   };
165   llvm::FunctionType *FTy =
166     llvm::FunctionType::get(PGOBuilder.getVoidTy(), Args, false);
167   llvm::Constant *EmitFunc =
168     CGM.getModule().getOrInsertFunction("llvm_pgo_emit", FTy);
169 
170   llvm::Constant *MangledName =
171     CGM.GetAddrOfConstantCString(CGM.getMangledName(GD), "__llvm_pgo_name");
172   MangledName = llvm::ConstantExpr::getBitCast(MangledName, Int8PtrTy);
173   PGOBuilder.CreateCall3(EmitFunc, MangledName,
174                          PGOBuilder.getInt32(NumRegionCounters),
175                          PGOBuilder.CreateBitCast(RegionCounters, Int64PtrTy));
176 }
177 
178 llvm::Function *CodeGenPGO::emitInitialization(CodeGenModule &CGM) {
179   llvm::Function *WriteoutF =
180     CGM.getModule().getFunction("__llvm_pgo_writeout");
181   if (!WriteoutF)
182     return NULL;
183 
184   // Create a small bit of code that registers the "__llvm_pgo_writeout" to
185   // be executed at exit.
186   llvm::Function *F = CGM.getModule().getFunction("__llvm_pgo_init");
187   if (F)
188     return NULL;
189 
190   llvm::LLVMContext &Ctx = CGM.getLLVMContext();
191   llvm::FunctionType *FTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx),
192                                                     false);
193   F = llvm::Function::Create(FTy, llvm::GlobalValue::InternalLinkage,
194                              "__llvm_pgo_init", &CGM.getModule());
195   F->setUnnamedAddr(true);
196   F->setLinkage(llvm::GlobalValue::InternalLinkage);
197   F->addFnAttr(llvm::Attribute::NoInline);
198   if (CGM.getCodeGenOpts().DisableRedZone)
199     F->addFnAttr(llvm::Attribute::NoRedZone);
200 
201   llvm::BasicBlock *BB = llvm::BasicBlock::Create(CGM.getLLVMContext(), "", F);
202   CGBuilderTy PGOBuilder(BB);
203 
204   FTy = llvm::FunctionType::get(PGOBuilder.getVoidTy(), false);
205   llvm::Type *Params[] = {
206     llvm::PointerType::get(FTy, 0)
207   };
208   FTy = llvm::FunctionType::get(PGOBuilder.getVoidTy(), Params, false);
209 
210   // Inialize the environment and register the local writeout function.
211   llvm::Constant *PGOInit =
212     CGM.getModule().getOrInsertFunction("llvm_pgo_init", FTy);
213   PGOBuilder.CreateCall(PGOInit, WriteoutF);
214   PGOBuilder.CreateRetVoid();
215 
216   return F;
217 }
218 
219 namespace {
220   /// A StmtVisitor that fills a map of statements to PGO counters.
221   struct MapRegionCounters : public ConstStmtVisitor<MapRegionCounters> {
222     /// The next counter value to assign.
223     unsigned NextCounter;
224     /// The map of statements to counters.
225     llvm::DenseMap<const Stmt*, unsigned> *CounterMap;
226 
227     MapRegionCounters(llvm::DenseMap<const Stmt*, unsigned> *CounterMap) :
228       NextCounter(0), CounterMap(CounterMap) {
229     }
230 
231     void VisitChildren(const Stmt *S) {
232       for (Stmt::const_child_range I = S->children(); I; ++I)
233         if (*I)
234          this->Visit(*I);
235     }
236     void VisitStmt(const Stmt *S) { VisitChildren(S); }
237 
238     /// Assign a counter to track entry to the function body.
239     void VisitFunctionDecl(const FunctionDecl *S) {
240       (*CounterMap)[S->getBody()] = NextCounter++;
241       Visit(S->getBody());
242     }
243     /// Assign a counter to track the block following a label.
244     void VisitLabelStmt(const LabelStmt *S) {
245       (*CounterMap)[S] = NextCounter++;
246       Visit(S->getSubStmt());
247     }
248     /// Assign three counters - one for the body of the loop, one for breaks
249     /// from the loop, and one for continues.
250     ///
251     /// The break and continue counters cover all such statements in this loop,
252     /// and are used in calculations to find the number of times the condition
253     /// and exit of the loop occur. They are needed so we can differentiate
254     /// these statements from non-local exits like return and goto.
255     void VisitWhileStmt(const WhileStmt *S) {
256       (*CounterMap)[S] = NextCounter;
257       NextCounter += 3;
258       Visit(S->getCond());
259       Visit(S->getBody());
260     }
261     /// Assign counters for the body of the loop, and for breaks and
262     /// continues. See VisitWhileStmt.
263     void VisitDoStmt(const DoStmt *S) {
264       (*CounterMap)[S] = NextCounter;
265       NextCounter += 3;
266       Visit(S->getBody());
267       Visit(S->getCond());
268     }
269     /// Assign counters for the body of the loop, and for breaks and
270     /// continues. See VisitWhileStmt.
271     void VisitForStmt(const ForStmt *S) {
272       (*CounterMap)[S] = NextCounter;
273       NextCounter += 3;
274       const Expr *E;
275       if ((E = S->getCond()))
276         Visit(E);
277       Visit(S->getBody());
278       if ((E = S->getInc()))
279         Visit(E);
280     }
281     /// Assign counters for the body of the loop, and for breaks and
282     /// continues. See VisitWhileStmt.
283     void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
284       (*CounterMap)[S] = NextCounter;
285       NextCounter += 3;
286       const Expr *E;
287       if ((E = S->getCond()))
288         Visit(E);
289       Visit(S->getBody());
290       if ((E = S->getInc()))
291         Visit(E);
292     }
293     /// Assign counters for the body of the loop, and for breaks and
294     /// continues. See VisitWhileStmt.
295     void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
296       (*CounterMap)[S] = NextCounter;
297       NextCounter += 3;
298       Visit(S->getElement());
299       Visit(S->getBody());
300     }
301     /// Assign a counter for the exit block of the switch statement.
302     void VisitSwitchStmt(const SwitchStmt *S) {
303       (*CounterMap)[S] = NextCounter++;
304       Visit(S->getCond());
305       Visit(S->getBody());
306     }
307     /// Assign a counter for a particular case in a switch. This counts jumps
308     /// from the switch header as well as fallthrough from the case before this
309     /// one.
310     void VisitCaseStmt(const CaseStmt *S) {
311       (*CounterMap)[S] = NextCounter++;
312       Visit(S->getSubStmt());
313     }
314     /// Assign a counter for the default case of a switch statement. The count
315     /// is the number of branches from the loop header to the default, and does
316     /// not include fallthrough from previous cases. If we have multiple
317     /// conditional branch blocks from the switch instruction to the default
318     /// block, as with large GNU case ranges, this is the counter for the last
319     /// edge in that series, rather than the first.
320     void VisitDefaultStmt(const DefaultStmt *S) {
321       (*CounterMap)[S] = NextCounter++;
322       Visit(S->getSubStmt());
323     }
324     /// Assign a counter for the "then" part of an if statement. The count for
325     /// the "else" part, if it exists, will be calculated from this counter.
326     void VisitIfStmt(const IfStmt *S) {
327       (*CounterMap)[S] = NextCounter++;
328       Visit(S->getCond());
329       Visit(S->getThen());
330       if (S->getElse())
331         Visit(S->getElse());
332     }
333     /// Assign a counter for the continuation block of a C++ try statement.
334     void VisitCXXTryStmt(const CXXTryStmt *S) {
335       (*CounterMap)[S] = NextCounter++;
336       Visit(S->getTryBlock());
337       for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
338         Visit(S->getHandler(I));
339     }
340     /// Assign a counter for a catch statement's handler block.
341     void VisitCXXCatchStmt(const CXXCatchStmt *S) {
342       (*CounterMap)[S] = NextCounter++;
343       Visit(S->getHandlerBlock());
344     }
345     /// Assign a counter for the "true" part of a conditional operator. The
346     /// count in the "false" part will be calculated from this counter.
347     void VisitConditionalOperator(const ConditionalOperator *E) {
348       (*CounterMap)[E] = NextCounter++;
349       Visit(E->getCond());
350       Visit(E->getTrueExpr());
351       Visit(E->getFalseExpr());
352     }
353     /// Assign a counter for the right hand side of a logical and operator.
354     void VisitBinLAnd(const BinaryOperator *E) {
355       (*CounterMap)[E] = NextCounter++;
356       Visit(E->getLHS());
357       Visit(E->getRHS());
358     }
359     /// Assign a counter for the right hand side of a logical or operator.
360     void VisitBinLOr(const BinaryOperator *E) {
361       (*CounterMap)[E] = NextCounter++;
362       Visit(E->getLHS());
363       Visit(E->getRHS());
364     }
365   };
366 }
367 
368 void CodeGenPGO::assignRegionCounters(GlobalDecl &GD) {
369   bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate;
370   PGOProfileData *PGOData = CGM.getPGOData();
371   if (!InstrumentRegions && !PGOData)
372     return;
373   const Decl *D = GD.getDecl();
374   if (!D)
375     return;
376   mapRegionCounters(D);
377   if (InstrumentRegions)
378     emitCounterVariables();
379   if (PGOData)
380     loadRegionCounts(GD, PGOData);
381 }
382 
383 void CodeGenPGO::mapRegionCounters(const Decl *D) {
384   RegionCounterMap = new llvm::DenseMap<const Stmt*, unsigned>();
385   MapRegionCounters Walker(RegionCounterMap);
386   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
387     Walker.VisitFunctionDecl(FD);
388   NumRegionCounters = Walker.NextCounter;
389 }
390 
391 void CodeGenPGO::emitCounterVariables() {
392   llvm::LLVMContext &Ctx = CGM.getLLVMContext();
393   llvm::ArrayType *CounterTy = llvm::ArrayType::get(llvm::Type::getInt64Ty(Ctx),
394                                                     NumRegionCounters);
395   RegionCounters =
396     new llvm::GlobalVariable(CGM.getModule(), CounterTy, false,
397                              llvm::GlobalVariable::PrivateLinkage,
398                              llvm::Constant::getNullValue(CounterTy),
399                              "__llvm_pgo_ctr");
400 }
401 
402 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, unsigned Counter) {
403   if (!CGM.getCodeGenOpts().ProfileInstrGenerate)
404     return;
405   llvm::Value *Addr =
406     Builder.CreateConstInBoundsGEP2_64(RegionCounters, 0, Counter);
407   llvm::Value *Count = Builder.CreateLoad(Addr, "pgocount");
408   Count = Builder.CreateAdd(Count, Builder.getInt64(1));
409   Builder.CreateStore(Count, Addr);
410 }
411 
412 void CodeGenPGO::loadRegionCounts(GlobalDecl &GD, PGOProfileData *PGOData) {
413   // For now, ignore the counts from the PGO data file only if the number of
414   // counters does not match. This could be tightened down in the future to
415   // ignore counts when the input changes in various ways, e.g., by comparing a
416   // hash value based on some characteristics of the input.
417   RegionCounts = new std::vector<uint64_t>();
418   if (PGOData->getFunctionCounts(CGM.getMangledName(GD), *RegionCounts) ||
419       RegionCounts->size() != NumRegionCounters) {
420     delete RegionCounts;
421     RegionCounts = 0;
422   }
423 }
424 
425 void CodeGenPGO::destroyRegionCounters() {
426   if (RegionCounterMap != 0)
427     delete RegionCounterMap;
428   if (RegionCounts != 0)
429     delete RegionCounts;
430 }
431 
432 llvm::MDNode *CodeGenPGO::createBranchWeights(uint64_t TrueCount,
433                                               uint64_t FalseCount) {
434   if (!TrueCount && !FalseCount)
435     return 0;
436 
437   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
438   // TODO: need to scale down to 32-bits
439   // According to Laplace's Rule of Succession, it is better to compute the
440   // weight based on the count plus 1.
441   return MDHelper.createBranchWeights(TrueCount + 1, FalseCount + 1);
442 }
443 
444 llvm::MDNode *
445 CodeGenPGO::createBranchWeights(ArrayRef<uint64_t> Weights) {
446   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
447   // TODO: need to scale down to 32-bits, instead of just truncating.
448   // According to Laplace's Rule of Succession, it is better to compute the
449   // weight based on the count plus 1.
450   SmallVector<uint32_t, 16> ScaledWeights;
451   ScaledWeights.reserve(Weights.size());
452   for (ArrayRef<uint64_t>::iterator WI = Weights.begin(), WE = Weights.end();
453        WI != WE; ++WI) {
454     ScaledWeights.push_back(*WI + 1);
455   }
456   return MDHelper.createBranchWeights(ScaledWeights);
457 }
458