1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Instrumentation-based profile-guided optimization
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "CodeGenPGO.h"
15 #include "CodeGenFunction.h"
16 #include "clang/AST/RecursiveASTVisitor.h"
17 #include "clang/AST/StmtVisitor.h"
18 #include "llvm/Config/config.h" // for strtoull()/strtoll() define
19 #include "llvm/IR/MDBuilder.h"
20 #include "llvm/Support/FileSystem.h"
21 
22 using namespace clang;
23 using namespace CodeGen;
24 
25 static void ReportBadPGOData(CodeGenModule &CGM, const char *Message) {
26   DiagnosticsEngine &Diags = CGM.getDiags();
27   unsigned diagID = Diags.getCustomDiagID(DiagnosticsEngine::Error, "%0");
28   Diags.Report(diagID) << Message;
29 }
30 
31 PGOProfileData::PGOProfileData(CodeGenModule &CGM, std::string Path)
32   : CGM(CGM) {
33   if (llvm::MemoryBuffer::getFile(Path, DataBuffer)) {
34     ReportBadPGOData(CGM, "failed to open pgo data file");
35     return;
36   }
37 
38   if (DataBuffer->getBufferSize() > std::numeric_limits<unsigned>::max()) {
39     ReportBadPGOData(CGM, "pgo data file too big");
40     return;
41   }
42 
43   // Scan through the data file and map each function to the corresponding
44   // file offset where its counts are stored.
45   const char *BufferStart = DataBuffer->getBufferStart();
46   const char *BufferEnd = DataBuffer->getBufferEnd();
47   const char *CurPtr = BufferStart;
48   uint64_t MaxCount = 0;
49   while (CurPtr < BufferEnd) {
50     // Read the mangled function name.
51     const char *FuncName = CurPtr;
52     // FIXME: Something will need to be added to distinguish static functions.
53     CurPtr = strchr(CurPtr, ' ');
54     if (!CurPtr) {
55       ReportBadPGOData(CGM, "pgo data file has malformed function entry");
56       return;
57     }
58     StringRef MangledName(FuncName, CurPtr - FuncName);
59 
60     // Read the number of counters.
61     char *EndPtr;
62     unsigned NumCounters = strtol(++CurPtr, &EndPtr, 10);
63     if (EndPtr == CurPtr || *EndPtr != '\n' || NumCounters <= 0) {
64       ReportBadPGOData(CGM, "pgo data file has unexpected number of counters");
65       return;
66     }
67     CurPtr = EndPtr;
68 
69     // Read function count.
70     uint64_t Count = strtoll(CurPtr, &EndPtr, 10);
71     if (EndPtr == CurPtr || *EndPtr != '\n') {
72       ReportBadPGOData(CGM, "pgo-data file has bad count value");
73       return;
74     }
75     CurPtr = EndPtr; // Point to '\n'.
76     FunctionCounts[MangledName] = Count;
77     MaxCount = Count > MaxCount ? Count : MaxCount;
78 
79     // There is one line for each counter; skip over those lines.
80     // Since function count is already read, we start the loop from 1.
81     for (unsigned N = 1; N < NumCounters; ++N) {
82       CurPtr = strchr(++CurPtr, '\n');
83       if (!CurPtr) {
84         ReportBadPGOData(CGM, "pgo data file is missing some counter info");
85         return;
86       }
87     }
88 
89     // Skip over the blank line separating functions.
90     CurPtr += 2;
91 
92     DataOffsets[MangledName] = FuncName - BufferStart;
93   }
94   MaxFunctionCount = MaxCount;
95 }
96 
97 /// Return true if a function is hot. If we know nothing about the function,
98 /// return false.
99 bool PGOProfileData::isHotFunction(StringRef MangledName) {
100   llvm::StringMap<uint64_t>::const_iterator CountIter =
101     FunctionCounts.find(MangledName);
102   // If we know nothing about the function, return false.
103   if (CountIter == FunctionCounts.end())
104     return false;
105   // FIXME: functions with >= 30% of the maximal function count are
106   // treated as hot. This number is from preliminary tuning on SPEC.
107   return CountIter->getValue() >= (uint64_t)(0.3 * (double)MaxFunctionCount);
108 }
109 
110 /// Return true if a function is cold. If we know nothing about the function,
111 /// return false.
112 bool PGOProfileData::isColdFunction(StringRef MangledName) {
113   llvm::StringMap<uint64_t>::const_iterator CountIter =
114     FunctionCounts.find(MangledName);
115   // If we know nothing about the function, return false.
116   if (CountIter == FunctionCounts.end())
117     return false;
118   // FIXME: functions with <= 1% of the maximal function count are treated as
119   // cold. This number is from preliminary tuning on SPEC.
120   return CountIter->getValue() <= (uint64_t)(0.01 * (double)MaxFunctionCount);
121 }
122 
123 bool PGOProfileData::getFunctionCounts(StringRef MangledName,
124                                        std::vector<uint64_t> &Counts) {
125   // Find the relevant section of the pgo-data file.
126   llvm::StringMap<unsigned>::const_iterator OffsetIter =
127     DataOffsets.find(MangledName);
128   if (OffsetIter == DataOffsets.end())
129     return true;
130   const char *CurPtr = DataBuffer->getBufferStart() + OffsetIter->getValue();
131 
132   // Skip over the function name.
133   CurPtr = strchr(CurPtr, ' ');
134   assert(CurPtr && "pgo-data has corrupted function entry");
135 
136   // Read the number of counters.
137   char *EndPtr;
138   unsigned NumCounters = strtol(++CurPtr, &EndPtr, 10);
139   assert(EndPtr != CurPtr && *EndPtr == '\n' && NumCounters > 0 &&
140          "pgo-data file has corrupted number of counters");
141   CurPtr = EndPtr;
142 
143   Counts.reserve(NumCounters);
144 
145   for (unsigned N = 0; N < NumCounters; ++N) {
146     // Read the count value.
147     uint64_t Count = strtoll(CurPtr, &EndPtr, 10);
148     if (EndPtr == CurPtr || *EndPtr != '\n') {
149       ReportBadPGOData(CGM, "pgo-data file has bad count value");
150       return true;
151     }
152     Counts.push_back(Count);
153     CurPtr = EndPtr + 1;
154   }
155 
156   // Make sure the number of counters matches up.
157   if (Counts.size() != NumCounters) {
158     ReportBadPGOData(CGM, "pgo-data file has inconsistent counters");
159     return true;
160   }
161 
162   return false;
163 }
164 
165 void CodeGenPGO::emitWriteoutFunction(GlobalDecl &GD) {
166   if (!CGM.getCodeGenOpts().ProfileInstrGenerate)
167     return;
168 
169   llvm::LLVMContext &Ctx = CGM.getLLVMContext();
170 
171   llvm::Type *Int32Ty = llvm::Type::getInt32Ty(Ctx);
172   llvm::Type *Int8PtrTy = llvm::Type::getInt8PtrTy(Ctx);
173 
174   llvm::Function *WriteoutF =
175     CGM.getModule().getFunction("__llvm_pgo_writeout");
176   if (!WriteoutF) {
177     llvm::FunctionType *WriteoutFTy =
178       llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), false);
179     WriteoutF = llvm::Function::Create(WriteoutFTy,
180                                        llvm::GlobalValue::InternalLinkage,
181                                        "__llvm_pgo_writeout", &CGM.getModule());
182   }
183   WriteoutF->setUnnamedAddr(true);
184   WriteoutF->addFnAttr(llvm::Attribute::NoInline);
185   if (CGM.getCodeGenOpts().DisableRedZone)
186     WriteoutF->addFnAttr(llvm::Attribute::NoRedZone);
187 
188   llvm::BasicBlock *BB = WriteoutF->empty() ?
189     llvm::BasicBlock::Create(Ctx, "", WriteoutF) : &WriteoutF->getEntryBlock();
190 
191   CGBuilderTy PGOBuilder(BB);
192 
193   llvm::Instruction *I = BB->getTerminator();
194   if (!I)
195     I = PGOBuilder.CreateRetVoid();
196   PGOBuilder.SetInsertPoint(I);
197 
198   llvm::Type *Int64PtrTy = llvm::Type::getInt64PtrTy(Ctx);
199   llvm::Type *Args[] = {
200     Int8PtrTy,                       // const char *MangledName
201     Int32Ty,                         // uint32_t NumCounters
202     Int64PtrTy                       // uint64_t *Counters
203   };
204   llvm::FunctionType *FTy =
205     llvm::FunctionType::get(PGOBuilder.getVoidTy(), Args, false);
206   llvm::Constant *EmitFunc =
207     CGM.getModule().getOrInsertFunction("llvm_pgo_emit", FTy);
208 
209   llvm::Constant *MangledName =
210     CGM.GetAddrOfConstantCString(CGM.getMangledName(GD), "__llvm_pgo_name");
211   MangledName = llvm::ConstantExpr::getBitCast(MangledName, Int8PtrTy);
212   PGOBuilder.CreateCall3(EmitFunc, MangledName,
213                          PGOBuilder.getInt32(NumRegionCounters),
214                          PGOBuilder.CreateBitCast(RegionCounters, Int64PtrTy));
215 }
216 
217 llvm::Function *CodeGenPGO::emitInitialization(CodeGenModule &CGM) {
218   llvm::Function *WriteoutF =
219     CGM.getModule().getFunction("__llvm_pgo_writeout");
220   if (!WriteoutF)
221     return NULL;
222 
223   // Create a small bit of code that registers the "__llvm_pgo_writeout" to
224   // be executed at exit.
225   llvm::Function *F = CGM.getModule().getFunction("__llvm_pgo_init");
226   if (F)
227     return NULL;
228 
229   llvm::LLVMContext &Ctx = CGM.getLLVMContext();
230   llvm::FunctionType *FTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx),
231                                                     false);
232   F = llvm::Function::Create(FTy, llvm::GlobalValue::InternalLinkage,
233                              "__llvm_pgo_init", &CGM.getModule());
234   F->setUnnamedAddr(true);
235   F->setLinkage(llvm::GlobalValue::InternalLinkage);
236   F->addFnAttr(llvm::Attribute::NoInline);
237   if (CGM.getCodeGenOpts().DisableRedZone)
238     F->addFnAttr(llvm::Attribute::NoRedZone);
239 
240   llvm::BasicBlock *BB = llvm::BasicBlock::Create(CGM.getLLVMContext(), "", F);
241   CGBuilderTy PGOBuilder(BB);
242 
243   FTy = llvm::FunctionType::get(PGOBuilder.getVoidTy(), false);
244   llvm::Type *Params[] = {
245     llvm::PointerType::get(FTy, 0)
246   };
247   FTy = llvm::FunctionType::get(PGOBuilder.getVoidTy(), Params, false);
248 
249   // Inialize the environment and register the local writeout function.
250   llvm::Constant *PGOInit =
251     CGM.getModule().getOrInsertFunction("llvm_pgo_init", FTy);
252   PGOBuilder.CreateCall(PGOInit, WriteoutF);
253   PGOBuilder.CreateRetVoid();
254 
255   return F;
256 }
257 
258 namespace {
259   /// A StmtVisitor that fills a map of statements to PGO counters.
260   struct MapRegionCounters : public ConstStmtVisitor<MapRegionCounters> {
261     /// The next counter value to assign.
262     unsigned NextCounter;
263     /// The map of statements to counters.
264     llvm::DenseMap<const Stmt*, unsigned> *CounterMap;
265 
266     MapRegionCounters(llvm::DenseMap<const Stmt*, unsigned> *CounterMap) :
267       NextCounter(0), CounterMap(CounterMap) {
268     }
269 
270     void VisitChildren(const Stmt *S) {
271       for (Stmt::const_child_range I = S->children(); I; ++I)
272         if (*I)
273          this->Visit(*I);
274     }
275     void VisitStmt(const Stmt *S) { VisitChildren(S); }
276 
277     /// Assign a counter to track entry to the function body.
278     void VisitFunctionDecl(const FunctionDecl *S) {
279       (*CounterMap)[S->getBody()] = NextCounter++;
280       Visit(S->getBody());
281     }
282     /// Assign a counter to track the block following a label.
283     void VisitLabelStmt(const LabelStmt *S) {
284       (*CounterMap)[S] = NextCounter++;
285       Visit(S->getSubStmt());
286     }
287     /// Assign a counter for the body of a while loop.
288     void VisitWhileStmt(const WhileStmt *S) {
289       (*CounterMap)[S] = NextCounter++;
290       Visit(S->getCond());
291       Visit(S->getBody());
292     }
293     /// Assign a counter for the body of a do-while loop.
294     void VisitDoStmt(const DoStmt *S) {
295       (*CounterMap)[S] = NextCounter++;
296       Visit(S->getBody());
297       Visit(S->getCond());
298     }
299     /// Assign a counter for the body of a for loop.
300     void VisitForStmt(const ForStmt *S) {
301       (*CounterMap)[S] = NextCounter++;
302       if (S->getInit())
303         Visit(S->getInit());
304       const Expr *E;
305       if ((E = S->getCond()))
306         Visit(E);
307       if ((E = S->getInc()))
308         Visit(E);
309       Visit(S->getBody());
310     }
311     /// Assign a counter for the body of a for-range loop.
312     void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
313       (*CounterMap)[S] = NextCounter++;
314       Visit(S->getRangeStmt());
315       Visit(S->getBeginEndStmt());
316       Visit(S->getCond());
317       Visit(S->getLoopVarStmt());
318       Visit(S->getBody());
319       Visit(S->getInc());
320     }
321     /// Assign a counter for the body of a for-collection loop.
322     void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
323       (*CounterMap)[S] = NextCounter++;
324       Visit(S->getElement());
325       Visit(S->getBody());
326     }
327     /// Assign a counter for the exit block of the switch statement.
328     void VisitSwitchStmt(const SwitchStmt *S) {
329       (*CounterMap)[S] = NextCounter++;
330       Visit(S->getCond());
331       Visit(S->getBody());
332     }
333     /// Assign a counter for a particular case in a switch. This counts jumps
334     /// from the switch header as well as fallthrough from the case before this
335     /// one.
336     void VisitCaseStmt(const CaseStmt *S) {
337       (*CounterMap)[S] = NextCounter++;
338       Visit(S->getSubStmt());
339     }
340     /// Assign a counter for the default case of a switch statement. The count
341     /// is the number of branches from the loop header to the default, and does
342     /// not include fallthrough from previous cases. If we have multiple
343     /// conditional branch blocks from the switch instruction to the default
344     /// block, as with large GNU case ranges, this is the counter for the last
345     /// edge in that series, rather than the first.
346     void VisitDefaultStmt(const DefaultStmt *S) {
347       (*CounterMap)[S] = NextCounter++;
348       Visit(S->getSubStmt());
349     }
350     /// Assign a counter for the "then" part of an if statement. The count for
351     /// the "else" part, if it exists, will be calculated from this counter.
352     void VisitIfStmt(const IfStmt *S) {
353       (*CounterMap)[S] = NextCounter++;
354       Visit(S->getCond());
355       Visit(S->getThen());
356       if (S->getElse())
357         Visit(S->getElse());
358     }
359     /// Assign a counter for the continuation block of a C++ try statement.
360     void VisitCXXTryStmt(const CXXTryStmt *S) {
361       (*CounterMap)[S] = NextCounter++;
362       Visit(S->getTryBlock());
363       for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
364         Visit(S->getHandler(I));
365     }
366     /// Assign a counter for a catch statement's handler block.
367     void VisitCXXCatchStmt(const CXXCatchStmt *S) {
368       (*CounterMap)[S] = NextCounter++;
369       Visit(S->getHandlerBlock());
370     }
371     /// Assign a counter for the "true" part of a conditional operator. The
372     /// count in the "false" part will be calculated from this counter.
373     void VisitConditionalOperator(const ConditionalOperator *E) {
374       (*CounterMap)[E] = NextCounter++;
375       Visit(E->getCond());
376       Visit(E->getTrueExpr());
377       Visit(E->getFalseExpr());
378     }
379     /// Assign a counter for the right hand side of a logical and operator.
380     void VisitBinLAnd(const BinaryOperator *E) {
381       (*CounterMap)[E] = NextCounter++;
382       Visit(E->getLHS());
383       Visit(E->getRHS());
384     }
385     /// Assign a counter for the right hand side of a logical or operator.
386     void VisitBinLOr(const BinaryOperator *E) {
387       (*CounterMap)[E] = NextCounter++;
388       Visit(E->getLHS());
389       Visit(E->getRHS());
390     }
391   };
392 
393   /// A StmtVisitor that propagates the raw counts through the AST and
394   /// records the count at statements where the value may change.
395   struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
396     /// PGO state.
397     CodeGenPGO &PGO;
398 
399     /// A flag that is set when the current count should be recorded on the
400     /// next statement, such as at the exit of a loop.
401     bool RecordNextStmtCount;
402 
403     /// The map of statements to count values.
404     llvm::DenseMap<const Stmt*, uint64_t> *CountMap;
405 
406     /// BreakContinueStack - Keep counts of breaks and continues inside loops.
407     struct BreakContinue {
408       uint64_t BreakCount;
409       uint64_t ContinueCount;
410       BreakContinue() : BreakCount(0), ContinueCount(0) {}
411     };
412     SmallVector<BreakContinue, 8> BreakContinueStack;
413 
414     ComputeRegionCounts(llvm::DenseMap<const Stmt*, uint64_t> *CountMap,
415                         CodeGenPGO &PGO) :
416       PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {
417     }
418 
419     void RecordStmtCount(const Stmt *S) {
420       if (RecordNextStmtCount) {
421         (*CountMap)[S] = PGO.getCurrentRegionCount();
422         RecordNextStmtCount = false;
423       }
424     }
425 
426     void VisitStmt(const Stmt *S) {
427       RecordStmtCount(S);
428       for (Stmt::const_child_range I = S->children(); I; ++I) {
429         if (*I)
430          this->Visit(*I);
431       }
432     }
433 
434     void VisitFunctionDecl(const FunctionDecl *S) {
435       RegionCounter Cnt(PGO, S->getBody());
436       Cnt.beginRegion();
437       (*CountMap)[S->getBody()] = PGO.getCurrentRegionCount();
438       Visit(S->getBody());
439     }
440 
441     void VisitReturnStmt(const ReturnStmt *S) {
442       RecordStmtCount(S);
443       if (S->getRetValue())
444         Visit(S->getRetValue());
445       PGO.setCurrentRegionUnreachable();
446       RecordNextStmtCount = true;
447     }
448 
449     void VisitGotoStmt(const GotoStmt *S) {
450       RecordStmtCount(S);
451       PGO.setCurrentRegionUnreachable();
452       RecordNextStmtCount = true;
453     }
454 
455     void VisitLabelStmt(const LabelStmt *S) {
456       RecordNextStmtCount = false;
457       RegionCounter Cnt(PGO, S);
458       Cnt.beginRegion();
459       (*CountMap)[S] = PGO.getCurrentRegionCount();
460       Visit(S->getSubStmt());
461     }
462 
463     void VisitBreakStmt(const BreakStmt *S) {
464       RecordStmtCount(S);
465       assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
466       BreakContinueStack.back().BreakCount += PGO.getCurrentRegionCount();
467       PGO.setCurrentRegionUnreachable();
468       RecordNextStmtCount = true;
469     }
470 
471     void VisitContinueStmt(const ContinueStmt *S) {
472       RecordStmtCount(S);
473       assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
474       BreakContinueStack.back().ContinueCount += PGO.getCurrentRegionCount();
475       PGO.setCurrentRegionUnreachable();
476       RecordNextStmtCount = true;
477     }
478 
479     void VisitWhileStmt(const WhileStmt *S) {
480       RecordStmtCount(S);
481       RegionCounter Cnt(PGO, S);
482       BreakContinueStack.push_back(BreakContinue());
483       // Visit the body region first so the break/continue adjustments can be
484       // included when visiting the condition.
485       Cnt.beginRegion();
486       (*CountMap)[S->getBody()] = PGO.getCurrentRegionCount();
487       Visit(S->getBody());
488       Cnt.adjustForControlFlow();
489 
490       // ...then go back and propagate counts through the condition. The count
491       // at the start of the condition is the sum of the incoming edges,
492       // the backedge from the end of the loop body, and the edges from
493       // continue statements.
494       BreakContinue BC = BreakContinueStack.pop_back_val();
495       Cnt.setCurrentRegionCount(Cnt.getParentCount() +
496                                 Cnt.getAdjustedCount() + BC.ContinueCount);
497       (*CountMap)[S->getCond()] = PGO.getCurrentRegionCount();
498       Visit(S->getCond());
499       Cnt.adjustForControlFlow();
500       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
501       RecordNextStmtCount = true;
502     }
503 
504     void VisitDoStmt(const DoStmt *S) {
505       RecordStmtCount(S);
506       RegionCounter Cnt(PGO, S);
507       BreakContinueStack.push_back(BreakContinue());
508       Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
509       (*CountMap)[S->getBody()] = PGO.getCurrentRegionCount();
510       Visit(S->getBody());
511       Cnt.adjustForControlFlow();
512 
513       BreakContinue BC = BreakContinueStack.pop_back_val();
514       // The count at the start of the condition is equal to the count at the
515       // end of the body. The adjusted count does not include either the
516       // fall-through count coming into the loop or the continue count, so add
517       // both of those separately. This is coincidentally the same equation as
518       // with while loops but for different reasons.
519       Cnt.setCurrentRegionCount(Cnt.getParentCount() +
520                                 Cnt.getAdjustedCount() + BC.ContinueCount);
521       (*CountMap)[S->getCond()] = PGO.getCurrentRegionCount();
522       Visit(S->getCond());
523       Cnt.adjustForControlFlow();
524       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
525       RecordNextStmtCount = true;
526     }
527 
528     void VisitForStmt(const ForStmt *S) {
529       RecordStmtCount(S);
530       if (S->getInit())
531         Visit(S->getInit());
532       RegionCounter Cnt(PGO, S);
533       BreakContinueStack.push_back(BreakContinue());
534       // Visit the body region first. (This is basically the same as a while
535       // loop; see further comments in VisitWhileStmt.)
536       Cnt.beginRegion();
537       (*CountMap)[S->getBody()] = PGO.getCurrentRegionCount();
538       Visit(S->getBody());
539       Cnt.adjustForControlFlow();
540 
541       // The increment is essentially part of the body but it needs to include
542       // the count for all the continue statements.
543       if (S->getInc()) {
544         Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
545                                   BreakContinueStack.back().ContinueCount);
546         (*CountMap)[S->getInc()] = PGO.getCurrentRegionCount();
547         Visit(S->getInc());
548         Cnt.adjustForControlFlow();
549       }
550 
551       BreakContinue BC = BreakContinueStack.pop_back_val();
552 
553       // ...then go back and propagate counts through the condition.
554       if (S->getCond()) {
555         Cnt.setCurrentRegionCount(Cnt.getParentCount() +
556                                   Cnt.getAdjustedCount() +
557                                   BC.ContinueCount);
558         (*CountMap)[S->getCond()] = PGO.getCurrentRegionCount();
559         Visit(S->getCond());
560         Cnt.adjustForControlFlow();
561       }
562       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
563       RecordNextStmtCount = true;
564     }
565 
566     void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
567       RecordStmtCount(S);
568       Visit(S->getRangeStmt());
569       Visit(S->getBeginEndStmt());
570       RegionCounter Cnt(PGO, S);
571       BreakContinueStack.push_back(BreakContinue());
572       // Visit the body region first. (This is basically the same as a while
573       // loop; see further comments in VisitWhileStmt.)
574       Cnt.beginRegion();
575       (*CountMap)[S->getLoopVarStmt()] = PGO.getCurrentRegionCount();
576       Visit(S->getLoopVarStmt());
577       Visit(S->getBody());
578       Cnt.adjustForControlFlow();
579 
580       // The increment is essentially part of the body but it needs to include
581       // the count for all the continue statements.
582       Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
583                                 BreakContinueStack.back().ContinueCount);
584       (*CountMap)[S->getInc()] = PGO.getCurrentRegionCount();
585       Visit(S->getInc());
586       Cnt.adjustForControlFlow();
587 
588       BreakContinue BC = BreakContinueStack.pop_back_val();
589 
590       // ...then go back and propagate counts through the condition.
591       Cnt.setCurrentRegionCount(Cnt.getParentCount() +
592                                 Cnt.getAdjustedCount() +
593                                 BC.ContinueCount);
594       (*CountMap)[S->getCond()] = PGO.getCurrentRegionCount();
595       Visit(S->getCond());
596       Cnt.adjustForControlFlow();
597       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
598       RecordNextStmtCount = true;
599     }
600 
601     void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
602       RecordStmtCount(S);
603       Visit(S->getElement());
604       RegionCounter Cnt(PGO, S);
605       BreakContinueStack.push_back(BreakContinue());
606       Cnt.beginRegion();
607       (*CountMap)[S->getBody()] = PGO.getCurrentRegionCount();
608       Visit(S->getBody());
609       BreakContinue BC = BreakContinueStack.pop_back_val();
610       Cnt.adjustForControlFlow();
611       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
612       RecordNextStmtCount = true;
613     }
614 
615     void VisitSwitchStmt(const SwitchStmt *S) {
616       RecordStmtCount(S);
617       Visit(S->getCond());
618       PGO.setCurrentRegionUnreachable();
619       BreakContinueStack.push_back(BreakContinue());
620       Visit(S->getBody());
621       // If the switch is inside a loop, add the continue counts.
622       BreakContinue BC = BreakContinueStack.pop_back_val();
623       if (!BreakContinueStack.empty())
624         BreakContinueStack.back().ContinueCount += BC.ContinueCount;
625       RegionCounter ExitCnt(PGO, S);
626       ExitCnt.beginRegion();
627       RecordNextStmtCount = true;
628     }
629 
630     void VisitCaseStmt(const CaseStmt *S) {
631       RecordNextStmtCount = false;
632       RegionCounter Cnt(PGO, S);
633       Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
634       (*CountMap)[S] = Cnt.getCount();
635       RecordNextStmtCount = true;
636       Visit(S->getSubStmt());
637     }
638 
639     void VisitDefaultStmt(const DefaultStmt *S) {
640       RecordNextStmtCount = false;
641       RegionCounter Cnt(PGO, S);
642       Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
643       (*CountMap)[S] = Cnt.getCount();
644       RecordNextStmtCount = true;
645       Visit(S->getSubStmt());
646     }
647 
648     void VisitIfStmt(const IfStmt *S) {
649       RecordStmtCount(S);
650       RegionCounter Cnt(PGO, S);
651       Visit(S->getCond());
652 
653       Cnt.beginRegion();
654       (*CountMap)[S->getThen()] = PGO.getCurrentRegionCount();
655       Visit(S->getThen());
656       Cnt.adjustForControlFlow();
657 
658       if (S->getElse()) {
659         Cnt.beginElseRegion();
660         (*CountMap)[S->getElse()] = PGO.getCurrentRegionCount();
661         Visit(S->getElse());
662         Cnt.adjustForControlFlow();
663       }
664       Cnt.applyAdjustmentsToRegion(0);
665       RecordNextStmtCount = true;
666     }
667 
668     void VisitCXXTryStmt(const CXXTryStmt *S) {
669       RecordStmtCount(S);
670       Visit(S->getTryBlock());
671       for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
672         Visit(S->getHandler(I));
673       RegionCounter Cnt(PGO, S);
674       Cnt.beginRegion();
675       RecordNextStmtCount = true;
676     }
677 
678     void VisitCXXCatchStmt(const CXXCatchStmt *S) {
679       RecordNextStmtCount = false;
680       RegionCounter Cnt(PGO, S);
681       Cnt.beginRegion();
682       (*CountMap)[S] = PGO.getCurrentRegionCount();
683       Visit(S->getHandlerBlock());
684     }
685 
686     void VisitConditionalOperator(const ConditionalOperator *E) {
687       RecordStmtCount(E);
688       RegionCounter Cnt(PGO, E);
689       Visit(E->getCond());
690 
691       Cnt.beginRegion();
692       (*CountMap)[E->getTrueExpr()] = PGO.getCurrentRegionCount();
693       Visit(E->getTrueExpr());
694       Cnt.adjustForControlFlow();
695 
696       Cnt.beginElseRegion();
697       (*CountMap)[E->getFalseExpr()] = PGO.getCurrentRegionCount();
698       Visit(E->getFalseExpr());
699       Cnt.adjustForControlFlow();
700 
701       Cnt.applyAdjustmentsToRegion(0);
702       RecordNextStmtCount = true;
703     }
704 
705     void VisitBinLAnd(const BinaryOperator *E) {
706       RecordStmtCount(E);
707       RegionCounter Cnt(PGO, E);
708       Visit(E->getLHS());
709       Cnt.beginRegion();
710       (*CountMap)[E->getRHS()] = PGO.getCurrentRegionCount();
711       Visit(E->getRHS());
712       Cnt.adjustForControlFlow();
713       Cnt.applyAdjustmentsToRegion(0);
714       RecordNextStmtCount = true;
715     }
716 
717     void VisitBinLOr(const BinaryOperator *E) {
718       RecordStmtCount(E);
719       RegionCounter Cnt(PGO, E);
720       Visit(E->getLHS());
721       Cnt.beginRegion();
722       (*CountMap)[E->getRHS()] = PGO.getCurrentRegionCount();
723       Visit(E->getRHS());
724       Cnt.adjustForControlFlow();
725       Cnt.applyAdjustmentsToRegion(0);
726       RecordNextStmtCount = true;
727     }
728   };
729 }
730 
731 void CodeGenPGO::assignRegionCounters(GlobalDecl &GD) {
732   bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate;
733   PGOProfileData *PGOData = CGM.getPGOData();
734   if (!InstrumentRegions && !PGOData)
735     return;
736   const Decl *D = GD.getDecl();
737   if (!D)
738     return;
739   mapRegionCounters(D);
740   if (InstrumentRegions)
741     emitCounterVariables();
742   if (PGOData) {
743     loadRegionCounts(GD, PGOData);
744     computeRegionCounts(D);
745   }
746 }
747 
748 void CodeGenPGO::mapRegionCounters(const Decl *D) {
749   RegionCounterMap = new llvm::DenseMap<const Stmt*, unsigned>();
750   MapRegionCounters Walker(RegionCounterMap);
751   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
752     Walker.VisitFunctionDecl(FD);
753   NumRegionCounters = Walker.NextCounter;
754 }
755 
756 void CodeGenPGO::computeRegionCounts(const Decl *D) {
757   StmtCountMap = new llvm::DenseMap<const Stmt*, uint64_t>();
758   ComputeRegionCounts Walker(StmtCountMap, *this);
759   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
760     Walker.VisitFunctionDecl(FD);
761 }
762 
763 void CodeGenPGO::emitCounterVariables() {
764   llvm::LLVMContext &Ctx = CGM.getLLVMContext();
765   llvm::ArrayType *CounterTy = llvm::ArrayType::get(llvm::Type::getInt64Ty(Ctx),
766                                                     NumRegionCounters);
767   RegionCounters =
768     new llvm::GlobalVariable(CGM.getModule(), CounterTy, false,
769                              llvm::GlobalVariable::PrivateLinkage,
770                              llvm::Constant::getNullValue(CounterTy),
771                              "__llvm_pgo_ctr");
772 }
773 
774 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, unsigned Counter) {
775   if (!CGM.getCodeGenOpts().ProfileInstrGenerate)
776     return;
777   llvm::Value *Addr =
778     Builder.CreateConstInBoundsGEP2_64(RegionCounters, 0, Counter);
779   llvm::Value *Count = Builder.CreateLoad(Addr, "pgocount");
780   Count = Builder.CreateAdd(Count, Builder.getInt64(1));
781   Builder.CreateStore(Count, Addr);
782 }
783 
784 void CodeGenPGO::loadRegionCounts(GlobalDecl &GD, PGOProfileData *PGOData) {
785   // For now, ignore the counts from the PGO data file only if the number of
786   // counters does not match. This could be tightened down in the future to
787   // ignore counts when the input changes in various ways, e.g., by comparing a
788   // hash value based on some characteristics of the input.
789   RegionCounts = new std::vector<uint64_t>();
790   if (PGOData->getFunctionCounts(CGM.getMangledName(GD), *RegionCounts) ||
791       RegionCounts->size() != NumRegionCounters) {
792     delete RegionCounts;
793     RegionCounts = 0;
794   }
795 }
796 
797 void CodeGenPGO::destroyRegionCounters() {
798   if (RegionCounterMap != 0)
799     delete RegionCounterMap;
800   if (StmtCountMap != 0)
801     delete StmtCountMap;
802   if (RegionCounts != 0)
803     delete RegionCounts;
804 }
805 
806 llvm::MDNode *CodeGenPGO::createBranchWeights(uint64_t TrueCount,
807                                               uint64_t FalseCount) {
808   if (!TrueCount && !FalseCount)
809     return 0;
810 
811   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
812   // TODO: need to scale down to 32-bits
813   // According to Laplace's Rule of Succession, it is better to compute the
814   // weight based on the count plus 1.
815   return MDHelper.createBranchWeights(TrueCount + 1, FalseCount + 1);
816 }
817 
818 llvm::MDNode *CodeGenPGO::createBranchWeights(ArrayRef<uint64_t> Weights) {
819   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
820   // TODO: need to scale down to 32-bits, instead of just truncating.
821   // According to Laplace's Rule of Succession, it is better to compute the
822   // weight based on the count plus 1.
823   SmallVector<uint32_t, 16> ScaledWeights;
824   ScaledWeights.reserve(Weights.size());
825   for (ArrayRef<uint64_t>::iterator WI = Weights.begin(), WE = Weights.end();
826        WI != WE; ++WI) {
827     ScaledWeights.push_back(*WI + 1);
828   }
829   return MDHelper.createBranchWeights(ScaledWeights);
830 }
831 
832 llvm::MDNode *CodeGenPGO::createLoopWeights(const Stmt *Cond,
833                                             RegionCounter &Cnt) {
834   if (!haveRegionCounts())
835     return 0;
836   uint64_t LoopCount = Cnt.getCount();
837   uint64_t CondCount = 0;
838   bool Found = getStmtCount(Cond, CondCount);
839   assert(Found && "missing expected loop condition count");
840   (void)Found;
841   if (CondCount == 0)
842     return 0;
843   return createBranchWeights(LoopCount,
844                              std::max(CondCount, LoopCount) - LoopCount);
845 }
846