1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Instrumentation-based profile-guided optimization
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "CodeGenPGO.h"
15 #include "CodeGenFunction.h"
16 #include "clang/AST/RecursiveASTVisitor.h"
17 #include "clang/AST/StmtVisitor.h"
18 #include "llvm/Config/config.h" // for strtoull()/strtoll() define
19 #include "llvm/IR/MDBuilder.h"
20 #include "llvm/Support/FileSystem.h"
21 
22 using namespace clang;
23 using namespace CodeGen;
24 
25 static void ReportBadPGOData(CodeGenModule &CGM, const char *Message) {
26   DiagnosticsEngine &Diags = CGM.getDiags();
27   unsigned diagID = Diags.getCustomDiagID(DiagnosticsEngine::Error, "%0");
28   Diags.Report(diagID) << Message;
29 }
30 
31 PGOProfileData::PGOProfileData(CodeGenModule &CGM, std::string Path)
32   : CGM(CGM) {
33   if (llvm::MemoryBuffer::getFile(Path, DataBuffer)) {
34     ReportBadPGOData(CGM, "failed to open pgo data file");
35     return;
36   }
37 
38   if (DataBuffer->getBufferSize() > std::numeric_limits<unsigned>::max()) {
39     ReportBadPGOData(CGM, "pgo data file too big");
40     return;
41   }
42 
43   // Scan through the data file and map each function to the corresponding
44   // file offset where its counts are stored.
45   const char *BufferStart = DataBuffer->getBufferStart();
46   const char *BufferEnd = DataBuffer->getBufferEnd();
47   const char *CurPtr = BufferStart;
48   uint64_t MaxCount = 0;
49   while (CurPtr < BufferEnd) {
50     // Read the function name.
51     const char *FuncStart = CurPtr;
52     // For Objective-C methods, the name may include whitespace, so search
53     // backward from the end of the line to find the space that separates the
54     // name from the number of counters. (This is a temporary hack since we are
55     // going to completely replace this file format in the near future.)
56     CurPtr = strchr(CurPtr, '\n');
57     if (!CurPtr) {
58       ReportBadPGOData(CGM, "pgo data file has malformed function entry");
59       return;
60     }
61     StringRef FuncName(FuncStart, CurPtr - FuncStart);
62 
63     // Skip over the function hash.
64     CurPtr = strchr(++CurPtr, '\n');
65     if (!CurPtr) {
66       ReportBadPGOData(CGM, "pgo data file is missing the function hash");
67       return;
68     }
69 
70     // Read the number of counters.
71     char *EndPtr;
72     unsigned NumCounters = strtol(++CurPtr, &EndPtr, 10);
73     if (EndPtr == CurPtr || *EndPtr != '\n' || NumCounters <= 0) {
74       ReportBadPGOData(CGM, "pgo data file has unexpected number of counters");
75       return;
76     }
77     CurPtr = EndPtr;
78 
79     // Read function count.
80     uint64_t Count = strtoll(CurPtr, &EndPtr, 10);
81     if (EndPtr == CurPtr || *EndPtr != '\n') {
82       ReportBadPGOData(CGM, "pgo-data file has bad count value");
83       return;
84     }
85     CurPtr = EndPtr; // Point to '\n'.
86     FunctionCounts[FuncName] = Count;
87     MaxCount = Count > MaxCount ? Count : MaxCount;
88 
89     // There is one line for each counter; skip over those lines.
90     // Since function count is already read, we start the loop from 1.
91     for (unsigned N = 1; N < NumCounters; ++N) {
92       CurPtr = strchr(++CurPtr, '\n');
93       if (!CurPtr) {
94         ReportBadPGOData(CGM, "pgo data file is missing some counter info");
95         return;
96       }
97     }
98 
99     // Skip over the blank line separating functions.
100     CurPtr += 2;
101 
102     DataOffsets[FuncName] = FuncStart - BufferStart;
103   }
104   MaxFunctionCount = MaxCount;
105 }
106 
107 bool PGOProfileData::getFunctionCounts(StringRef FuncName, uint64_t &FuncHash,
108                                        std::vector<uint64_t> &Counts) {
109   // Find the relevant section of the pgo-data file.
110   llvm::StringMap<unsigned>::const_iterator OffsetIter =
111     DataOffsets.find(FuncName);
112   if (OffsetIter == DataOffsets.end())
113     return true;
114   const char *CurPtr = DataBuffer->getBufferStart() + OffsetIter->getValue();
115 
116   // Skip over the function name.
117   CurPtr = strchr(CurPtr, '\n');
118   assert(CurPtr && "pgo-data has corrupted function entry");
119 
120   char *EndPtr;
121   // Read the function hash.
122   FuncHash = strtoll(++CurPtr, &EndPtr, 10);
123   assert(EndPtr != CurPtr && *EndPtr == '\n' &&
124          "pgo-data file has corrupted function hash");
125   CurPtr = EndPtr;
126 
127   // Read the number of counters.
128   unsigned NumCounters = strtol(++CurPtr, &EndPtr, 10);
129   assert(EndPtr != CurPtr && *EndPtr == '\n' && NumCounters > 0 &&
130          "pgo-data file has corrupted number of counters");
131   CurPtr = EndPtr;
132 
133   Counts.reserve(NumCounters);
134 
135   for (unsigned N = 0; N < NumCounters; ++N) {
136     // Read the count value.
137     uint64_t Count = strtoll(CurPtr, &EndPtr, 10);
138     if (EndPtr == CurPtr || *EndPtr != '\n') {
139       ReportBadPGOData(CGM, "pgo-data file has bad count value");
140       return true;
141     }
142     Counts.push_back(Count);
143     CurPtr = EndPtr + 1;
144   }
145 
146   // Make sure the number of counters matches up.
147   if (Counts.size() != NumCounters) {
148     ReportBadPGOData(CGM, "pgo-data file has inconsistent counters");
149     return true;
150   }
151 
152   return false;
153 }
154 
155 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
156   RawFuncName = Fn->getName();
157 
158   // Function names may be prefixed with a binary '1' to indicate
159   // that the backend should not modify the symbols due to any platform
160   // naming convention. Do not include that '1' in the PGO profile name.
161   if (RawFuncName[0] == '\1')
162     RawFuncName = RawFuncName.substr(1);
163 
164   if (!Fn->hasLocalLinkage()) {
165     PrefixedFuncName = new std::string(RawFuncName);
166     return;
167   }
168 
169   // For local symbols, prepend the main file name to distinguish them.
170   // Do not include the full path in the file name since there's no guarantee
171   // that it will stay the same, e.g., if the files are checked out from
172   // version control in different locations.
173   PrefixedFuncName = new std::string(CGM.getCodeGenOpts().MainFileName);
174   if (PrefixedFuncName->empty())
175     PrefixedFuncName->assign("<unknown>");
176   PrefixedFuncName->append(":");
177   PrefixedFuncName->append(RawFuncName);
178 }
179 
180 static llvm::Function *getRegisterFunc(CodeGenModule &CGM) {
181   return CGM.getModule().getFunction("__llvm_pgo_register_functions");
182 }
183 
184 static llvm::BasicBlock *getOrInsertRegisterBB(CodeGenModule &CGM) {
185   // Only need to insert this once per module.
186   if (llvm::Function *RegisterF = getRegisterFunc(CGM))
187     return &RegisterF->getEntryBlock();
188 
189   // Construct the function.
190   auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext());
191   auto *RegisterFTy = llvm::FunctionType::get(VoidTy, false);
192   auto *RegisterF = llvm::Function::Create(RegisterFTy,
193                                            llvm::GlobalValue::InternalLinkage,
194                                            "__llvm_pgo_register_functions",
195                                            &CGM.getModule());
196   RegisterF->setUnnamedAddr(true);
197   RegisterF->addFnAttr(llvm::Attribute::NoInline);
198   if (CGM.getCodeGenOpts().DisableRedZone)
199     RegisterF->addFnAttr(llvm::Attribute::NoRedZone);
200 
201   // Construct and return the entry block.
202   auto *BB = llvm::BasicBlock::Create(CGM.getLLVMContext(), "", RegisterF);
203   CGBuilderTy Builder(BB);
204   Builder.CreateRetVoid();
205   return BB;
206 }
207 
208 static llvm::Constant *getOrInsertRuntimeRegister(CodeGenModule &CGM) {
209   auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext());
210   auto *VoidPtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
211   auto *RuntimeRegisterTy = llvm::FunctionType::get(VoidTy, VoidPtrTy, false);
212   return CGM.getModule().getOrInsertFunction("__llvm_pgo_register_function",
213                                              RuntimeRegisterTy);
214 }
215 
216 static llvm::Constant *getOrInsertRuntimeWriteAtExit(CodeGenModule &CGM) {
217   // TODO: make this depend on a command-line option.
218   auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext());
219   auto *WriteAtExitTy = llvm::FunctionType::get(VoidTy, false);
220   return CGM.getModule().getOrInsertFunction("__llvm_pgo_register_write_atexit",
221                                              WriteAtExitTy);
222 }
223 
224 static StringRef getCountersSection(const CodeGenModule &CGM) {
225   if (CGM.getTarget().getTriple().isOSBinFormatMachO())
226     return "__DATA,__llvm_pgo_cnts";
227   else
228     return "__llvm_pgo_cnts";
229 }
230 
231 static StringRef getNameSection(const CodeGenModule &CGM) {
232   if (CGM.getTarget().getTriple().isOSBinFormatMachO())
233     return "__DATA,__llvm_pgo_names";
234   else
235     return "__llvm_pgo_names";
236 }
237 
238 static StringRef getDataSection(const CodeGenModule &CGM) {
239   if (CGM.getTarget().getTriple().isOSBinFormatMachO())
240     return "__DATA,__llvm_pgo_data";
241   else
242     return "__llvm_pgo_data";
243 }
244 
245 llvm::GlobalVariable *CodeGenPGO::buildDataVar() {
246   // Create name variable.
247   llvm::LLVMContext &Ctx = CGM.getLLVMContext();
248   auto *VarName = llvm::ConstantDataArray::getString(Ctx, getFuncName(),
249                                                      false);
250   auto *Name = new llvm::GlobalVariable(CGM.getModule(), VarName->getType(),
251                                         true, FuncLinkage, VarName,
252                                         getFuncVarName("name"));
253   Name->setSection(getNameSection(CGM));
254   Name->setAlignment(1);
255 
256   // Create data variable.
257   auto *Int32Ty = llvm::Type::getInt32Ty(Ctx);
258   auto *Int64Ty = llvm::Type::getInt64Ty(Ctx);
259   auto *Int8PtrTy = llvm::Type::getInt8PtrTy(Ctx);
260   auto *Int64PtrTy = llvm::Type::getInt64PtrTy(Ctx);
261   llvm::Type *DataTypes[] = {
262     Int32Ty, Int32Ty, Int64Ty, Int8PtrTy, Int64PtrTy
263   };
264   auto *DataTy = llvm::StructType::get(Ctx, makeArrayRef(DataTypes));
265   llvm::Constant *DataVals[] = {
266     llvm::ConstantInt::get(Int32Ty, getFuncName().size()),
267     llvm::ConstantInt::get(Int32Ty, NumRegionCounters),
268     llvm::ConstantInt::get(Int64Ty, FunctionHash),
269     llvm::ConstantExpr::getBitCast(Name, Int8PtrTy),
270     llvm::ConstantExpr::getBitCast(RegionCounters, Int64PtrTy)
271   };
272   auto *Data =
273     new llvm::GlobalVariable(CGM.getModule(), DataTy, true, FuncLinkage,
274                              llvm::ConstantStruct::get(DataTy, DataVals),
275                              getFuncVarName("data"));
276 
277   // All the data should be packed into an array in its own section.
278   Data->setSection(getDataSection(CGM));
279   Data->setAlignment(8);
280 
281   // Make sure the data doesn't get deleted.
282   CGM.addUsedGlobal(Data);
283   return Data;
284 }
285 
286 void CodeGenPGO::emitInstrumentationData() {
287   if (!CGM.getCodeGenOpts().ProfileInstrGenerate)
288     return;
289 
290   // Build the data.
291   auto *Data = buildDataVar();
292 
293   // Register the data.
294   //
295   // TODO: only register when static initialization is required.
296   CGBuilderTy Builder(getOrInsertRegisterBB(CGM)->getTerminator());
297   auto *VoidPtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
298   Builder.CreateCall(getOrInsertRuntimeRegister(CGM),
299                      Builder.CreateBitCast(Data, VoidPtrTy));
300 }
301 
302 llvm::Function *CodeGenPGO::emitInitialization(CodeGenModule &CGM) {
303   if (!CGM.getCodeGenOpts().ProfileInstrGenerate)
304     return 0;
305 
306   // Only need to create this once per module.
307   if (CGM.getModule().getFunction("__llvm_pgo_init"))
308     return 0;
309 
310   // Get the functions to call at initialization.
311   llvm::Constant *RegisterF = getRegisterFunc(CGM);
312   llvm::Constant *WriteAtExitF = getOrInsertRuntimeWriteAtExit(CGM);
313   if (!RegisterF && !WriteAtExitF)
314     return 0;
315 
316   // Create the initialization function.
317   auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext());
318   auto *F = llvm::Function::Create(llvm::FunctionType::get(VoidTy, false),
319                                    llvm::GlobalValue::InternalLinkage,
320                                    "__llvm_pgo_init", &CGM.getModule());
321   F->setUnnamedAddr(true);
322   F->addFnAttr(llvm::Attribute::NoInline);
323   if (CGM.getCodeGenOpts().DisableRedZone)
324     F->addFnAttr(llvm::Attribute::NoRedZone);
325 
326   // Add the basic block and the necessary calls.
327   CGBuilderTy Builder(llvm::BasicBlock::Create(CGM.getLLVMContext(), "", F));
328   if (RegisterF)
329     Builder.CreateCall(RegisterF);
330   if (WriteAtExitF)
331     Builder.CreateCall(WriteAtExitF);
332   Builder.CreateRetVoid();
333 
334   return F;
335 }
336 
337 namespace {
338   /// A StmtVisitor that fills a map of statements to PGO counters.
339   struct MapRegionCounters : public ConstStmtVisitor<MapRegionCounters> {
340     /// The next counter value to assign.
341     unsigned NextCounter;
342     /// The map of statements to counters.
343     llvm::DenseMap<const Stmt*, unsigned> *CounterMap;
344 
345     MapRegionCounters(llvm::DenseMap<const Stmt*, unsigned> *CounterMap) :
346       NextCounter(0), CounterMap(CounterMap) {
347     }
348 
349     void VisitChildren(const Stmt *S) {
350       for (Stmt::const_child_range I = S->children(); I; ++I)
351         if (*I)
352          this->Visit(*I);
353     }
354     void VisitStmt(const Stmt *S) { VisitChildren(S); }
355 
356     /// Assign a counter to track entry to the function body.
357     void VisitFunctionDecl(const FunctionDecl *S) {
358       (*CounterMap)[S->getBody()] = NextCounter++;
359       Visit(S->getBody());
360     }
361     void VisitObjCMethodDecl(const ObjCMethodDecl *S) {
362       (*CounterMap)[S->getBody()] = NextCounter++;
363       Visit(S->getBody());
364     }
365     void VisitBlockDecl(const BlockDecl *S) {
366       (*CounterMap)[S->getBody()] = NextCounter++;
367       Visit(S->getBody());
368     }
369     /// Assign a counter to track the block following a label.
370     void VisitLabelStmt(const LabelStmt *S) {
371       (*CounterMap)[S] = NextCounter++;
372       Visit(S->getSubStmt());
373     }
374     /// Assign a counter for the body of a while loop.
375     void VisitWhileStmt(const WhileStmt *S) {
376       (*CounterMap)[S] = NextCounter++;
377       Visit(S->getCond());
378       Visit(S->getBody());
379     }
380     /// Assign a counter for the body of a do-while loop.
381     void VisitDoStmt(const DoStmt *S) {
382       (*CounterMap)[S] = NextCounter++;
383       Visit(S->getBody());
384       Visit(S->getCond());
385     }
386     /// Assign a counter for the body of a for loop.
387     void VisitForStmt(const ForStmt *S) {
388       (*CounterMap)[S] = NextCounter++;
389       if (S->getInit())
390         Visit(S->getInit());
391       const Expr *E;
392       if ((E = S->getCond()))
393         Visit(E);
394       if ((E = S->getInc()))
395         Visit(E);
396       Visit(S->getBody());
397     }
398     /// Assign a counter for the body of a for-range loop.
399     void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
400       (*CounterMap)[S] = NextCounter++;
401       Visit(S->getRangeStmt());
402       Visit(S->getBeginEndStmt());
403       Visit(S->getCond());
404       Visit(S->getLoopVarStmt());
405       Visit(S->getBody());
406       Visit(S->getInc());
407     }
408     /// Assign a counter for the body of a for-collection loop.
409     void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
410       (*CounterMap)[S] = NextCounter++;
411       Visit(S->getElement());
412       Visit(S->getBody());
413     }
414     /// Assign a counter for the exit block of the switch statement.
415     void VisitSwitchStmt(const SwitchStmt *S) {
416       (*CounterMap)[S] = NextCounter++;
417       Visit(S->getCond());
418       Visit(S->getBody());
419     }
420     /// Assign a counter for a particular case in a switch. This counts jumps
421     /// from the switch header as well as fallthrough from the case before this
422     /// one.
423     void VisitCaseStmt(const CaseStmt *S) {
424       (*CounterMap)[S] = NextCounter++;
425       Visit(S->getSubStmt());
426     }
427     /// Assign a counter for the default case of a switch statement. The count
428     /// is the number of branches from the loop header to the default, and does
429     /// not include fallthrough from previous cases. If we have multiple
430     /// conditional branch blocks from the switch instruction to the default
431     /// block, as with large GNU case ranges, this is the counter for the last
432     /// edge in that series, rather than the first.
433     void VisitDefaultStmt(const DefaultStmt *S) {
434       (*CounterMap)[S] = NextCounter++;
435       Visit(S->getSubStmt());
436     }
437     /// Assign a counter for the "then" part of an if statement. The count for
438     /// the "else" part, if it exists, will be calculated from this counter.
439     void VisitIfStmt(const IfStmt *S) {
440       (*CounterMap)[S] = NextCounter++;
441       Visit(S->getCond());
442       Visit(S->getThen());
443       if (S->getElse())
444         Visit(S->getElse());
445     }
446     /// Assign a counter for the continuation block of a C++ try statement.
447     void VisitCXXTryStmt(const CXXTryStmt *S) {
448       (*CounterMap)[S] = NextCounter++;
449       Visit(S->getTryBlock());
450       for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
451         Visit(S->getHandler(I));
452     }
453     /// Assign a counter for a catch statement's handler block.
454     void VisitCXXCatchStmt(const CXXCatchStmt *S) {
455       (*CounterMap)[S] = NextCounter++;
456       Visit(S->getHandlerBlock());
457     }
458     /// Assign a counter for the "true" part of a conditional operator. The
459     /// count in the "false" part will be calculated from this counter.
460     void VisitConditionalOperator(const ConditionalOperator *E) {
461       (*CounterMap)[E] = NextCounter++;
462       Visit(E->getCond());
463       Visit(E->getTrueExpr());
464       Visit(E->getFalseExpr());
465     }
466     /// Assign a counter for the right hand side of a logical and operator.
467     void VisitBinLAnd(const BinaryOperator *E) {
468       (*CounterMap)[E] = NextCounter++;
469       Visit(E->getLHS());
470       Visit(E->getRHS());
471     }
472     /// Assign a counter for the right hand side of a logical or operator.
473     void VisitBinLOr(const BinaryOperator *E) {
474       (*CounterMap)[E] = NextCounter++;
475       Visit(E->getLHS());
476       Visit(E->getRHS());
477     }
478   };
479 
480   /// A StmtVisitor that propagates the raw counts through the AST and
481   /// records the count at statements where the value may change.
482   struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
483     /// PGO state.
484     CodeGenPGO &PGO;
485 
486     /// A flag that is set when the current count should be recorded on the
487     /// next statement, such as at the exit of a loop.
488     bool RecordNextStmtCount;
489 
490     /// The map of statements to count values.
491     llvm::DenseMap<const Stmt*, uint64_t> *CountMap;
492 
493     /// BreakContinueStack - Keep counts of breaks and continues inside loops.
494     struct BreakContinue {
495       uint64_t BreakCount;
496       uint64_t ContinueCount;
497       BreakContinue() : BreakCount(0), ContinueCount(0) {}
498     };
499     SmallVector<BreakContinue, 8> BreakContinueStack;
500 
501     ComputeRegionCounts(llvm::DenseMap<const Stmt*, uint64_t> *CountMap,
502                         CodeGenPGO &PGO) :
503       PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {
504     }
505 
506     void RecordStmtCount(const Stmt *S) {
507       if (RecordNextStmtCount) {
508         (*CountMap)[S] = PGO.getCurrentRegionCount();
509         RecordNextStmtCount = false;
510       }
511     }
512 
513     void VisitStmt(const Stmt *S) {
514       RecordStmtCount(S);
515       for (Stmt::const_child_range I = S->children(); I; ++I) {
516         if (*I)
517          this->Visit(*I);
518       }
519     }
520 
521     void VisitFunctionDecl(const FunctionDecl *S) {
522       RegionCounter Cnt(PGO, S->getBody());
523       Cnt.beginRegion();
524       (*CountMap)[S->getBody()] = PGO.getCurrentRegionCount();
525       Visit(S->getBody());
526     }
527 
528     void VisitObjCMethodDecl(const ObjCMethodDecl *S) {
529       RegionCounter Cnt(PGO, S->getBody());
530       Cnt.beginRegion();
531       (*CountMap)[S->getBody()] = PGO.getCurrentRegionCount();
532       Visit(S->getBody());
533     }
534 
535     void VisitBlockDecl(const BlockDecl *S) {
536       RegionCounter Cnt(PGO, S->getBody());
537       Cnt.beginRegion();
538       (*CountMap)[S->getBody()] = PGO.getCurrentRegionCount();
539       Visit(S->getBody());
540     }
541 
542     void VisitReturnStmt(const ReturnStmt *S) {
543       RecordStmtCount(S);
544       if (S->getRetValue())
545         Visit(S->getRetValue());
546       PGO.setCurrentRegionUnreachable();
547       RecordNextStmtCount = true;
548     }
549 
550     void VisitGotoStmt(const GotoStmt *S) {
551       RecordStmtCount(S);
552       PGO.setCurrentRegionUnreachable();
553       RecordNextStmtCount = true;
554     }
555 
556     void VisitLabelStmt(const LabelStmt *S) {
557       RecordNextStmtCount = false;
558       RegionCounter Cnt(PGO, S);
559       Cnt.beginRegion();
560       (*CountMap)[S] = PGO.getCurrentRegionCount();
561       Visit(S->getSubStmt());
562     }
563 
564     void VisitBreakStmt(const BreakStmt *S) {
565       RecordStmtCount(S);
566       assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
567       BreakContinueStack.back().BreakCount += PGO.getCurrentRegionCount();
568       PGO.setCurrentRegionUnreachable();
569       RecordNextStmtCount = true;
570     }
571 
572     void VisitContinueStmt(const ContinueStmt *S) {
573       RecordStmtCount(S);
574       assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
575       BreakContinueStack.back().ContinueCount += PGO.getCurrentRegionCount();
576       PGO.setCurrentRegionUnreachable();
577       RecordNextStmtCount = true;
578     }
579 
580     void VisitWhileStmt(const WhileStmt *S) {
581       RecordStmtCount(S);
582       RegionCounter Cnt(PGO, S);
583       BreakContinueStack.push_back(BreakContinue());
584       // Visit the body region first so the break/continue adjustments can be
585       // included when visiting the condition.
586       Cnt.beginRegion();
587       (*CountMap)[S->getBody()] = PGO.getCurrentRegionCount();
588       Visit(S->getBody());
589       Cnt.adjustForControlFlow();
590 
591       // ...then go back and propagate counts through the condition. The count
592       // at the start of the condition is the sum of the incoming edges,
593       // the backedge from the end of the loop body, and the edges from
594       // continue statements.
595       BreakContinue BC = BreakContinueStack.pop_back_val();
596       Cnt.setCurrentRegionCount(Cnt.getParentCount() +
597                                 Cnt.getAdjustedCount() + BC.ContinueCount);
598       (*CountMap)[S->getCond()] = PGO.getCurrentRegionCount();
599       Visit(S->getCond());
600       Cnt.adjustForControlFlow();
601       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
602       RecordNextStmtCount = true;
603     }
604 
605     void VisitDoStmt(const DoStmt *S) {
606       RecordStmtCount(S);
607       RegionCounter Cnt(PGO, S);
608       BreakContinueStack.push_back(BreakContinue());
609       Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
610       (*CountMap)[S->getBody()] = PGO.getCurrentRegionCount();
611       Visit(S->getBody());
612       Cnt.adjustForControlFlow();
613 
614       BreakContinue BC = BreakContinueStack.pop_back_val();
615       // The count at the start of the condition is equal to the count at the
616       // end of the body. The adjusted count does not include either the
617       // fall-through count coming into the loop or the continue count, so add
618       // both of those separately. This is coincidentally the same equation as
619       // with while loops but for different reasons.
620       Cnt.setCurrentRegionCount(Cnt.getParentCount() +
621                                 Cnt.getAdjustedCount() + BC.ContinueCount);
622       (*CountMap)[S->getCond()] = PGO.getCurrentRegionCount();
623       Visit(S->getCond());
624       Cnt.adjustForControlFlow();
625       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
626       RecordNextStmtCount = true;
627     }
628 
629     void VisitForStmt(const ForStmt *S) {
630       RecordStmtCount(S);
631       if (S->getInit())
632         Visit(S->getInit());
633       RegionCounter Cnt(PGO, S);
634       BreakContinueStack.push_back(BreakContinue());
635       // Visit the body region first. (This is basically the same as a while
636       // loop; see further comments in VisitWhileStmt.)
637       Cnt.beginRegion();
638       (*CountMap)[S->getBody()] = PGO.getCurrentRegionCount();
639       Visit(S->getBody());
640       Cnt.adjustForControlFlow();
641 
642       // The increment is essentially part of the body but it needs to include
643       // the count for all the continue statements.
644       if (S->getInc()) {
645         Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
646                                   BreakContinueStack.back().ContinueCount);
647         (*CountMap)[S->getInc()] = PGO.getCurrentRegionCount();
648         Visit(S->getInc());
649         Cnt.adjustForControlFlow();
650       }
651 
652       BreakContinue BC = BreakContinueStack.pop_back_val();
653 
654       // ...then go back and propagate counts through the condition.
655       if (S->getCond()) {
656         Cnt.setCurrentRegionCount(Cnt.getParentCount() +
657                                   Cnt.getAdjustedCount() +
658                                   BC.ContinueCount);
659         (*CountMap)[S->getCond()] = PGO.getCurrentRegionCount();
660         Visit(S->getCond());
661         Cnt.adjustForControlFlow();
662       }
663       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
664       RecordNextStmtCount = true;
665     }
666 
667     void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
668       RecordStmtCount(S);
669       Visit(S->getRangeStmt());
670       Visit(S->getBeginEndStmt());
671       RegionCounter Cnt(PGO, S);
672       BreakContinueStack.push_back(BreakContinue());
673       // Visit the body region first. (This is basically the same as a while
674       // loop; see further comments in VisitWhileStmt.)
675       Cnt.beginRegion();
676       (*CountMap)[S->getLoopVarStmt()] = PGO.getCurrentRegionCount();
677       Visit(S->getLoopVarStmt());
678       Visit(S->getBody());
679       Cnt.adjustForControlFlow();
680 
681       // The increment is essentially part of the body but it needs to include
682       // the count for all the continue statements.
683       Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
684                                 BreakContinueStack.back().ContinueCount);
685       (*CountMap)[S->getInc()] = PGO.getCurrentRegionCount();
686       Visit(S->getInc());
687       Cnt.adjustForControlFlow();
688 
689       BreakContinue BC = BreakContinueStack.pop_back_val();
690 
691       // ...then go back and propagate counts through the condition.
692       Cnt.setCurrentRegionCount(Cnt.getParentCount() +
693                                 Cnt.getAdjustedCount() +
694                                 BC.ContinueCount);
695       (*CountMap)[S->getCond()] = PGO.getCurrentRegionCount();
696       Visit(S->getCond());
697       Cnt.adjustForControlFlow();
698       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
699       RecordNextStmtCount = true;
700     }
701 
702     void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
703       RecordStmtCount(S);
704       Visit(S->getElement());
705       RegionCounter Cnt(PGO, S);
706       BreakContinueStack.push_back(BreakContinue());
707       Cnt.beginRegion();
708       (*CountMap)[S->getBody()] = PGO.getCurrentRegionCount();
709       Visit(S->getBody());
710       BreakContinue BC = BreakContinueStack.pop_back_val();
711       Cnt.adjustForControlFlow();
712       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
713       RecordNextStmtCount = true;
714     }
715 
716     void VisitSwitchStmt(const SwitchStmt *S) {
717       RecordStmtCount(S);
718       Visit(S->getCond());
719       PGO.setCurrentRegionUnreachable();
720       BreakContinueStack.push_back(BreakContinue());
721       Visit(S->getBody());
722       // If the switch is inside a loop, add the continue counts.
723       BreakContinue BC = BreakContinueStack.pop_back_val();
724       if (!BreakContinueStack.empty())
725         BreakContinueStack.back().ContinueCount += BC.ContinueCount;
726       RegionCounter ExitCnt(PGO, S);
727       ExitCnt.beginRegion();
728       RecordNextStmtCount = true;
729     }
730 
731     void VisitCaseStmt(const CaseStmt *S) {
732       RecordNextStmtCount = false;
733       RegionCounter Cnt(PGO, S);
734       Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
735       (*CountMap)[S] = Cnt.getCount();
736       RecordNextStmtCount = true;
737       Visit(S->getSubStmt());
738     }
739 
740     void VisitDefaultStmt(const DefaultStmt *S) {
741       RecordNextStmtCount = false;
742       RegionCounter Cnt(PGO, S);
743       Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
744       (*CountMap)[S] = Cnt.getCount();
745       RecordNextStmtCount = true;
746       Visit(S->getSubStmt());
747     }
748 
749     void VisitIfStmt(const IfStmt *S) {
750       RecordStmtCount(S);
751       RegionCounter Cnt(PGO, S);
752       Visit(S->getCond());
753 
754       Cnt.beginRegion();
755       (*CountMap)[S->getThen()] = PGO.getCurrentRegionCount();
756       Visit(S->getThen());
757       Cnt.adjustForControlFlow();
758 
759       if (S->getElse()) {
760         Cnt.beginElseRegion();
761         (*CountMap)[S->getElse()] = PGO.getCurrentRegionCount();
762         Visit(S->getElse());
763         Cnt.adjustForControlFlow();
764       }
765       Cnt.applyAdjustmentsToRegion(0);
766       RecordNextStmtCount = true;
767     }
768 
769     void VisitCXXTryStmt(const CXXTryStmt *S) {
770       RecordStmtCount(S);
771       Visit(S->getTryBlock());
772       for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
773         Visit(S->getHandler(I));
774       RegionCounter Cnt(PGO, S);
775       Cnt.beginRegion();
776       RecordNextStmtCount = true;
777     }
778 
779     void VisitCXXCatchStmt(const CXXCatchStmt *S) {
780       RecordNextStmtCount = false;
781       RegionCounter Cnt(PGO, S);
782       Cnt.beginRegion();
783       (*CountMap)[S] = PGO.getCurrentRegionCount();
784       Visit(S->getHandlerBlock());
785     }
786 
787     void VisitConditionalOperator(const ConditionalOperator *E) {
788       RecordStmtCount(E);
789       RegionCounter Cnt(PGO, E);
790       Visit(E->getCond());
791 
792       Cnt.beginRegion();
793       (*CountMap)[E->getTrueExpr()] = PGO.getCurrentRegionCount();
794       Visit(E->getTrueExpr());
795       Cnt.adjustForControlFlow();
796 
797       Cnt.beginElseRegion();
798       (*CountMap)[E->getFalseExpr()] = PGO.getCurrentRegionCount();
799       Visit(E->getFalseExpr());
800       Cnt.adjustForControlFlow();
801 
802       Cnt.applyAdjustmentsToRegion(0);
803       RecordNextStmtCount = true;
804     }
805 
806     void VisitBinLAnd(const BinaryOperator *E) {
807       RecordStmtCount(E);
808       RegionCounter Cnt(PGO, E);
809       Visit(E->getLHS());
810       Cnt.beginRegion();
811       (*CountMap)[E->getRHS()] = PGO.getCurrentRegionCount();
812       Visit(E->getRHS());
813       Cnt.adjustForControlFlow();
814       Cnt.applyAdjustmentsToRegion(0);
815       RecordNextStmtCount = true;
816     }
817 
818     void VisitBinLOr(const BinaryOperator *E) {
819       RecordStmtCount(E);
820       RegionCounter Cnt(PGO, E);
821       Visit(E->getLHS());
822       Cnt.beginRegion();
823       (*CountMap)[E->getRHS()] = PGO.getCurrentRegionCount();
824       Visit(E->getRHS());
825       Cnt.adjustForControlFlow();
826       Cnt.applyAdjustmentsToRegion(0);
827       RecordNextStmtCount = true;
828     }
829   };
830 }
831 
832 void CodeGenPGO::assignRegionCounters(const Decl *D, llvm::Function *Fn) {
833   bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate;
834   PGOProfileData *PGOData = CGM.getPGOData();
835   if (!InstrumentRegions && !PGOData)
836     return;
837   if (!D)
838     return;
839   setFuncName(Fn);
840   FuncLinkage = Fn->getLinkage();
841   mapRegionCounters(D);
842   if (InstrumentRegions)
843     emitCounterVariables();
844   if (PGOData) {
845     loadRegionCounts(PGOData);
846     computeRegionCounts(D);
847     applyFunctionAttributes(PGOData, Fn);
848   }
849 }
850 
851 void CodeGenPGO::mapRegionCounters(const Decl *D) {
852   RegionCounterMap = new llvm::DenseMap<const Stmt*, unsigned>();
853   MapRegionCounters Walker(RegionCounterMap);
854   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
855     Walker.VisitFunctionDecl(FD);
856   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
857     Walker.VisitObjCMethodDecl(MD);
858   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
859     Walker.VisitBlockDecl(BD);
860   NumRegionCounters = Walker.NextCounter;
861   // FIXME: The number of counters isn't sufficient for the hash
862   FunctionHash = NumRegionCounters;
863 }
864 
865 void CodeGenPGO::computeRegionCounts(const Decl *D) {
866   StmtCountMap = new llvm::DenseMap<const Stmt*, uint64_t>();
867   ComputeRegionCounts Walker(StmtCountMap, *this);
868   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
869     Walker.VisitFunctionDecl(FD);
870   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
871     Walker.VisitObjCMethodDecl(MD);
872   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
873     Walker.VisitBlockDecl(BD);
874 }
875 
876 void CodeGenPGO::applyFunctionAttributes(PGOProfileData *PGOData,
877                                          llvm::Function *Fn) {
878   if (!haveRegionCounts())
879     return;
880 
881   uint64_t MaxFunctionCount = PGOData->getMaximumFunctionCount();
882   uint64_t FunctionCount = getRegionCount(0);
883   if (FunctionCount >= (uint64_t)(0.3 * (double)MaxFunctionCount))
884     // Turn on InlineHint attribute for hot functions.
885     // FIXME: 30% is from preliminary tuning on SPEC, it may not be optimal.
886     Fn->addFnAttr(llvm::Attribute::InlineHint);
887   else if (FunctionCount <= (uint64_t)(0.01 * (double)MaxFunctionCount))
888     // Turn on Cold attribute for cold functions.
889     // FIXME: 1% is from preliminary tuning on SPEC, it may not be optimal.
890     Fn->addFnAttr(llvm::Attribute::Cold);
891 }
892 
893 void CodeGenPGO::emitCounterVariables() {
894   llvm::LLVMContext &Ctx = CGM.getLLVMContext();
895   llvm::ArrayType *CounterTy = llvm::ArrayType::get(llvm::Type::getInt64Ty(Ctx),
896                                                     NumRegionCounters);
897   RegionCounters =
898     new llvm::GlobalVariable(CGM.getModule(), CounterTy, false, FuncLinkage,
899                              llvm::Constant::getNullValue(CounterTy),
900                              getFuncVarName("counters"));
901   RegionCounters->setAlignment(8);
902   RegionCounters->setSection(getCountersSection(CGM));
903 }
904 
905 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, unsigned Counter) {
906   if (!RegionCounters)
907     return;
908   llvm::Value *Addr =
909     Builder.CreateConstInBoundsGEP2_64(RegionCounters, 0, Counter);
910   llvm::Value *Count = Builder.CreateLoad(Addr, "pgocount");
911   Count = Builder.CreateAdd(Count, Builder.getInt64(1));
912   Builder.CreateStore(Count, Addr);
913 }
914 
915 void CodeGenPGO::loadRegionCounts(PGOProfileData *PGOData) {
916   // For now, ignore the counts from the PGO data file only if the number of
917   // counters does not match. This could be tightened down in the future to
918   // ignore counts when the input changes in various ways, e.g., by comparing a
919   // hash value based on some characteristics of the input.
920   RegionCounts = new std::vector<uint64_t>();
921   uint64_t Hash;
922   if (PGOData->getFunctionCounts(getFuncName(), Hash, *RegionCounts) ||
923       Hash != FunctionHash || RegionCounts->size() != NumRegionCounters) {
924     delete RegionCounts;
925     RegionCounts = 0;
926   }
927 }
928 
929 void CodeGenPGO::destroyRegionCounters() {
930   if (RegionCounterMap != 0)
931     delete RegionCounterMap;
932   if (StmtCountMap != 0)
933     delete StmtCountMap;
934   if (RegionCounts != 0)
935     delete RegionCounts;
936 }
937 
938 /// \brief Calculate what to divide by to scale weights.
939 ///
940 /// Given the maximum weight, calculate a divisor that will scale all the
941 /// weights to strictly less than UINT32_MAX.
942 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
943   return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
944 }
945 
946 /// \brief Scale an individual branch weight (and add 1).
947 ///
948 /// Scale a 64-bit weight down to 32-bits using \c Scale.
949 ///
950 /// According to Laplace's Rule of Succession, it is better to compute the
951 /// weight based on the count plus 1, so universally add 1 to the value.
952 ///
953 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
954 /// greater than \c Weight.
955 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
956   assert(Scale && "scale by 0?");
957   uint64_t Scaled = Weight / Scale + 1;
958   assert(Scaled <= UINT32_MAX && "overflow 32-bits");
959   return Scaled;
960 }
961 
962 llvm::MDNode *CodeGenPGO::createBranchWeights(uint64_t TrueCount,
963                                               uint64_t FalseCount) {
964   // Check for empty weights.
965   if (!TrueCount && !FalseCount)
966     return 0;
967 
968   // Calculate how to scale down to 32-bits.
969   uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
970 
971   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
972   return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
973                                       scaleBranchWeight(FalseCount, Scale));
974 }
975 
976 llvm::MDNode *CodeGenPGO::createBranchWeights(ArrayRef<uint64_t> Weights) {
977   // We need at least two elements to create meaningful weights.
978   if (Weights.size() < 2)
979     return 0;
980 
981   // Calculate how to scale down to 32-bits.
982   uint64_t Scale = calculateWeightScale(*std::max_element(Weights.begin(),
983                                                           Weights.end()));
984 
985   SmallVector<uint32_t, 16> ScaledWeights;
986   ScaledWeights.reserve(Weights.size());
987   for (uint64_t W : Weights)
988     ScaledWeights.push_back(scaleBranchWeight(W, Scale));
989 
990   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
991   return MDHelper.createBranchWeights(ScaledWeights);
992 }
993 
994 llvm::MDNode *CodeGenPGO::createLoopWeights(const Stmt *Cond,
995                                             RegionCounter &Cnt) {
996   if (!haveRegionCounts())
997     return 0;
998   uint64_t LoopCount = Cnt.getCount();
999   uint64_t CondCount = 0;
1000   bool Found = getStmtCount(Cond, CondCount);
1001   assert(Found && "missing expected loop condition count");
1002   (void)Found;
1003   if (CondCount == 0)
1004     return 0;
1005   return createBranchWeights(LoopCount,
1006                              std::max(CondCount, LoopCount) - LoopCount);
1007 }
1008