1ef512b99SJustin Bogner //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2ef512b99SJustin Bogner //
3ef512b99SJustin Bogner //                     The LLVM Compiler Infrastructure
4ef512b99SJustin Bogner //
5ef512b99SJustin Bogner // This file is distributed under the University of Illinois Open Source
6ef512b99SJustin Bogner // License. See LICENSE.TXT for details.
7ef512b99SJustin Bogner //
8ef512b99SJustin Bogner //===----------------------------------------------------------------------===//
9ef512b99SJustin Bogner //
10ef512b99SJustin Bogner // Instrumentation-based profile-guided optimization
11ef512b99SJustin Bogner //
12ef512b99SJustin Bogner //===----------------------------------------------------------------------===//
13ef512b99SJustin Bogner 
14ef512b99SJustin Bogner #include "CodeGenPGO.h"
15ef512b99SJustin Bogner #include "CodeGenFunction.h"
16ef512b99SJustin Bogner #include "clang/AST/RecursiveASTVisitor.h"
17ef512b99SJustin Bogner #include "clang/AST/StmtVisitor.h"
18529f6dd8SJustin Bogner #include "llvm/Config/config.h" // for strtoull()/strtoll() define
19ef512b99SJustin Bogner #include "llvm/IR/MDBuilder.h"
20ef512b99SJustin Bogner #include "llvm/Support/FileSystem.h"
21ef512b99SJustin Bogner 
22ef512b99SJustin Bogner using namespace clang;
23ef512b99SJustin Bogner using namespace CodeGen;
24ef512b99SJustin Bogner 
25ef512b99SJustin Bogner static void ReportBadPGOData(CodeGenModule &CGM, const char *Message) {
26ef512b99SJustin Bogner   DiagnosticsEngine &Diags = CGM.getDiags();
2729cb66baSAlp Toker   unsigned diagID = Diags.getCustomDiagID(DiagnosticsEngine::Error, "%0");
2829cb66baSAlp Toker   Diags.Report(diagID) << Message;
29ef512b99SJustin Bogner }
30ef512b99SJustin Bogner 
31ef512b99SJustin Bogner PGOProfileData::PGOProfileData(CodeGenModule &CGM, std::string Path)
32ef512b99SJustin Bogner   : CGM(CGM) {
33ef512b99SJustin Bogner   if (llvm::MemoryBuffer::getFile(Path, DataBuffer)) {
34ef512b99SJustin Bogner     ReportBadPGOData(CGM, "failed to open pgo data file");
35ef512b99SJustin Bogner     return;
36ef512b99SJustin Bogner   }
37ef512b99SJustin Bogner 
38ef512b99SJustin Bogner   if (DataBuffer->getBufferSize() > std::numeric_limits<unsigned>::max()) {
39ef512b99SJustin Bogner     ReportBadPGOData(CGM, "pgo data file too big");
40ef512b99SJustin Bogner     return;
41ef512b99SJustin Bogner   }
42ef512b99SJustin Bogner 
43ef512b99SJustin Bogner   // Scan through the data file and map each function to the corresponding
44ef512b99SJustin Bogner   // file offset where its counts are stored.
45ef512b99SJustin Bogner   const char *BufferStart = DataBuffer->getBufferStart();
46ef512b99SJustin Bogner   const char *BufferEnd = DataBuffer->getBufferEnd();
47ef512b99SJustin Bogner   const char *CurPtr = BufferStart;
4867a28136SManman Ren   uint64_t MaxCount = 0;
49ef512b99SJustin Bogner   while (CurPtr < BufferEnd) {
50d0b7824eSBob Wilson     // Read the function name.
51d0b7824eSBob Wilson     const char *FuncStart = CurPtr;
525ec8fe19SBob Wilson     // For Objective-C methods, the name may include whitespace, so search
535ec8fe19SBob Wilson     // backward from the end of the line to find the space that separates the
545ec8fe19SBob Wilson     // name from the number of counters. (This is a temporary hack since we are
555ec8fe19SBob Wilson     // going to completely replace this file format in the near future.)
565ec8fe19SBob Wilson     CurPtr = strchr(CurPtr, '\n');
57ef512b99SJustin Bogner     if (!CurPtr) {
58ef512b99SJustin Bogner       ReportBadPGOData(CGM, "pgo data file has malformed function entry");
59ef512b99SJustin Bogner       return;
60ef512b99SJustin Bogner     }
615ec8fe19SBob Wilson     while (*--CurPtr != ' ')
625ec8fe19SBob Wilson       ;
63d0b7824eSBob Wilson     StringRef FuncName(FuncStart, CurPtr - FuncStart);
64ef512b99SJustin Bogner 
65ef512b99SJustin Bogner     // Read the number of counters.
66ef512b99SJustin Bogner     char *EndPtr;
67ef512b99SJustin Bogner     unsigned NumCounters = strtol(++CurPtr, &EndPtr, 10);
68ef512b99SJustin Bogner     if (EndPtr == CurPtr || *EndPtr != '\n' || NumCounters <= 0) {
69ef512b99SJustin Bogner       ReportBadPGOData(CGM, "pgo data file has unexpected number of counters");
70ef512b99SJustin Bogner       return;
71ef512b99SJustin Bogner     }
72ef512b99SJustin Bogner     CurPtr = EndPtr;
73ef512b99SJustin Bogner 
7467a28136SManman Ren     // Read function count.
7567a28136SManman Ren     uint64_t Count = strtoll(CurPtr, &EndPtr, 10);
7667a28136SManman Ren     if (EndPtr == CurPtr || *EndPtr != '\n') {
7767a28136SManman Ren       ReportBadPGOData(CGM, "pgo-data file has bad count value");
7867a28136SManman Ren       return;
7967a28136SManman Ren     }
80f1a6a2d9SManman Ren     CurPtr = EndPtr; // Point to '\n'.
81d0b7824eSBob Wilson     FunctionCounts[FuncName] = Count;
8267a28136SManman Ren     MaxCount = Count > MaxCount ? Count : MaxCount;
8367a28136SManman Ren 
84ef512b99SJustin Bogner     // There is one line for each counter; skip over those lines.
8567a28136SManman Ren     // Since function count is already read, we start the loop from 1.
8667a28136SManman Ren     for (unsigned N = 1; N < NumCounters; ++N) {
87ef512b99SJustin Bogner       CurPtr = strchr(++CurPtr, '\n');
88ef512b99SJustin Bogner       if (!CurPtr) {
89ef512b99SJustin Bogner         ReportBadPGOData(CGM, "pgo data file is missing some counter info");
90ef512b99SJustin Bogner         return;
91ef512b99SJustin Bogner       }
92ef512b99SJustin Bogner     }
93ef512b99SJustin Bogner 
94ef512b99SJustin Bogner     // Skip over the blank line separating functions.
95ef512b99SJustin Bogner     CurPtr += 2;
96ef512b99SJustin Bogner 
97d0b7824eSBob Wilson     DataOffsets[FuncName] = FuncStart - BufferStart;
98ef512b99SJustin Bogner   }
9967a28136SManman Ren   MaxFunctionCount = MaxCount;
10067a28136SManman Ren }
10167a28136SManman Ren 
10267a28136SManman Ren /// Return true if a function is hot. If we know nothing about the function,
10367a28136SManman Ren /// return false.
104d0b7824eSBob Wilson bool PGOProfileData::isHotFunction(StringRef FuncName) {
10567a28136SManman Ren   llvm::StringMap<uint64_t>::const_iterator CountIter =
106d0b7824eSBob Wilson     FunctionCounts.find(FuncName);
10767a28136SManman Ren   // If we know nothing about the function, return false.
10867a28136SManman Ren   if (CountIter == FunctionCounts.end())
10967a28136SManman Ren     return false;
11067a28136SManman Ren   // FIXME: functions with >= 30% of the maximal function count are
11167a28136SManman Ren   // treated as hot. This number is from preliminary tuning on SPEC.
11267a28136SManman Ren   return CountIter->getValue() >= (uint64_t)(0.3 * (double)MaxFunctionCount);
11367a28136SManman Ren }
11467a28136SManman Ren 
11567a28136SManman Ren /// Return true if a function is cold. If we know nothing about the function,
11667a28136SManman Ren /// return false.
117d0b7824eSBob Wilson bool PGOProfileData::isColdFunction(StringRef FuncName) {
11867a28136SManman Ren   llvm::StringMap<uint64_t>::const_iterator CountIter =
119d0b7824eSBob Wilson     FunctionCounts.find(FuncName);
12067a28136SManman Ren   // If we know nothing about the function, return false.
12167a28136SManman Ren   if (CountIter == FunctionCounts.end())
12267a28136SManman Ren     return false;
12367a28136SManman Ren   // FIXME: functions with <= 1% of the maximal function count are treated as
12467a28136SManman Ren   // cold. This number is from preliminary tuning on SPEC.
12567a28136SManman Ren   return CountIter->getValue() <= (uint64_t)(0.01 * (double)MaxFunctionCount);
126ef512b99SJustin Bogner }
127ef512b99SJustin Bogner 
128d0b7824eSBob Wilson bool PGOProfileData::getFunctionCounts(StringRef FuncName,
129ef512b99SJustin Bogner                                        std::vector<uint64_t> &Counts) {
130ef512b99SJustin Bogner   // Find the relevant section of the pgo-data file.
131ef512b99SJustin Bogner   llvm::StringMap<unsigned>::const_iterator OffsetIter =
132d0b7824eSBob Wilson     DataOffsets.find(FuncName);
133ef512b99SJustin Bogner   if (OffsetIter == DataOffsets.end())
134ef512b99SJustin Bogner     return true;
135ef512b99SJustin Bogner   const char *CurPtr = DataBuffer->getBufferStart() + OffsetIter->getValue();
136ef512b99SJustin Bogner 
137ef512b99SJustin Bogner   // Skip over the function name.
1385ec8fe19SBob Wilson   CurPtr = strchr(CurPtr, '\n');
139ef512b99SJustin Bogner   assert(CurPtr && "pgo-data has corrupted function entry");
1405ec8fe19SBob Wilson   while (*--CurPtr != ' ')
1415ec8fe19SBob Wilson     ;
142ef512b99SJustin Bogner 
143ef512b99SJustin Bogner   // Read the number of counters.
144ef512b99SJustin Bogner   char *EndPtr;
145ef512b99SJustin Bogner   unsigned NumCounters = strtol(++CurPtr, &EndPtr, 10);
146ef512b99SJustin Bogner   assert(EndPtr != CurPtr && *EndPtr == '\n' && NumCounters > 0 &&
147ef512b99SJustin Bogner          "pgo-data file has corrupted number of counters");
148ef512b99SJustin Bogner   CurPtr = EndPtr;
149ef512b99SJustin Bogner 
150ef512b99SJustin Bogner   Counts.reserve(NumCounters);
151ef512b99SJustin Bogner 
152ef512b99SJustin Bogner   for (unsigned N = 0; N < NumCounters; ++N) {
153ef512b99SJustin Bogner     // Read the count value.
154ef512b99SJustin Bogner     uint64_t Count = strtoll(CurPtr, &EndPtr, 10);
155ef512b99SJustin Bogner     if (EndPtr == CurPtr || *EndPtr != '\n') {
156ef512b99SJustin Bogner       ReportBadPGOData(CGM, "pgo-data file has bad count value");
157ef512b99SJustin Bogner       return true;
158ef512b99SJustin Bogner     }
159ef512b99SJustin Bogner     Counts.push_back(Count);
160ef512b99SJustin Bogner     CurPtr = EndPtr + 1;
161ef512b99SJustin Bogner   }
162ef512b99SJustin Bogner 
163ef512b99SJustin Bogner   // Make sure the number of counters matches up.
164ef512b99SJustin Bogner   if (Counts.size() != NumCounters) {
165ef512b99SJustin Bogner     ReportBadPGOData(CGM, "pgo-data file has inconsistent counters");
166ef512b99SJustin Bogner     return true;
167ef512b99SJustin Bogner   }
168ef512b99SJustin Bogner 
169ef512b99SJustin Bogner   return false;
170ef512b99SJustin Bogner }
171ef512b99SJustin Bogner 
172da1ebedeSBob Wilson void CodeGenPGO::setFuncName(llvm::Function *Fn) {
173da1ebedeSBob Wilson   StringRef Func = Fn->getName();
174da1ebedeSBob Wilson 
175da1ebedeSBob Wilson   // Function names may be prefixed with a binary '1' to indicate
176da1ebedeSBob Wilson   // that the backend should not modify the symbols due to any platform
177da1ebedeSBob Wilson   // naming convention. Do not include that '1' in the PGO profile name.
178da1ebedeSBob Wilson   if (Func[0] == '\1')
179da1ebedeSBob Wilson     Func = Func.substr(1);
180da1ebedeSBob Wilson 
181da1ebedeSBob Wilson   if (!Fn->hasLocalLinkage()) {
182da1ebedeSBob Wilson     FuncName = new std::string(Func);
183da1ebedeSBob Wilson     return;
184da1ebedeSBob Wilson   }
185da1ebedeSBob Wilson 
186da1ebedeSBob Wilson   // For local symbols, prepend the main file name to distinguish them.
187da1ebedeSBob Wilson   // Do not include the full path in the file name since there's no guarantee
188da1ebedeSBob Wilson   // that it will stay the same, e.g., if the files are checked out from
189da1ebedeSBob Wilson   // version control in different locations.
190da1ebedeSBob Wilson   FuncName = new std::string(CGM.getCodeGenOpts().MainFileName);
191da1ebedeSBob Wilson   if (FuncName->empty())
192da1ebedeSBob Wilson     FuncName->assign("<unknown>");
193da1ebedeSBob Wilson   FuncName->append(":");
194da1ebedeSBob Wilson   FuncName->append(Func);
195da1ebedeSBob Wilson }
196da1ebedeSBob Wilson 
197da1ebedeSBob Wilson void CodeGenPGO::emitWriteoutFunction() {
198ef512b99SJustin Bogner   if (!CGM.getCodeGenOpts().ProfileInstrGenerate)
199ef512b99SJustin Bogner     return;
200ef512b99SJustin Bogner 
201ef512b99SJustin Bogner   llvm::LLVMContext &Ctx = CGM.getLLVMContext();
202ef512b99SJustin Bogner 
203ef512b99SJustin Bogner   llvm::Type *Int32Ty = llvm::Type::getInt32Ty(Ctx);
204ef512b99SJustin Bogner   llvm::Type *Int8PtrTy = llvm::Type::getInt8PtrTy(Ctx);
205ef512b99SJustin Bogner 
206ef512b99SJustin Bogner   llvm::Function *WriteoutF =
207ef512b99SJustin Bogner     CGM.getModule().getFunction("__llvm_pgo_writeout");
208ef512b99SJustin Bogner   if (!WriteoutF) {
209ef512b99SJustin Bogner     llvm::FunctionType *WriteoutFTy =
210ef512b99SJustin Bogner       llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), false);
211ef512b99SJustin Bogner     WriteoutF = llvm::Function::Create(WriteoutFTy,
212ef512b99SJustin Bogner                                        llvm::GlobalValue::InternalLinkage,
213ef512b99SJustin Bogner                                        "__llvm_pgo_writeout", &CGM.getModule());
214ef512b99SJustin Bogner   }
215ef512b99SJustin Bogner   WriteoutF->setUnnamedAddr(true);
216ef512b99SJustin Bogner   WriteoutF->addFnAttr(llvm::Attribute::NoInline);
217ef512b99SJustin Bogner   if (CGM.getCodeGenOpts().DisableRedZone)
218ef512b99SJustin Bogner     WriteoutF->addFnAttr(llvm::Attribute::NoRedZone);
219ef512b99SJustin Bogner 
220ef512b99SJustin Bogner   llvm::BasicBlock *BB = WriteoutF->empty() ?
221ef512b99SJustin Bogner     llvm::BasicBlock::Create(Ctx, "", WriteoutF) : &WriteoutF->getEntryBlock();
222ef512b99SJustin Bogner 
223ef512b99SJustin Bogner   CGBuilderTy PGOBuilder(BB);
224ef512b99SJustin Bogner 
225ef512b99SJustin Bogner   llvm::Instruction *I = BB->getTerminator();
226ef512b99SJustin Bogner   if (!I)
227ef512b99SJustin Bogner     I = PGOBuilder.CreateRetVoid();
228ef512b99SJustin Bogner   PGOBuilder.SetInsertPoint(I);
229ef512b99SJustin Bogner 
230ef512b99SJustin Bogner   llvm::Type *Int64PtrTy = llvm::Type::getInt64PtrTy(Ctx);
231ef512b99SJustin Bogner   llvm::Type *Args[] = {
232d0b7824eSBob Wilson     Int8PtrTy,                       // const char *FuncName
233ef512b99SJustin Bogner     Int32Ty,                         // uint32_t NumCounters
234ef512b99SJustin Bogner     Int64PtrTy                       // uint64_t *Counters
235ef512b99SJustin Bogner   };
236ef512b99SJustin Bogner   llvm::FunctionType *FTy =
237ef512b99SJustin Bogner     llvm::FunctionType::get(PGOBuilder.getVoidTy(), Args, false);
238ef512b99SJustin Bogner   llvm::Constant *EmitFunc =
239ef512b99SJustin Bogner     CGM.getModule().getOrInsertFunction("llvm_pgo_emit", FTy);
240ef512b99SJustin Bogner 
241d0b7824eSBob Wilson   llvm::Constant *NameString =
242da1ebedeSBob Wilson     CGM.GetAddrOfConstantCString(getFuncName(), "__llvm_pgo_name");
243d0b7824eSBob Wilson   NameString = llvm::ConstantExpr::getBitCast(NameString, Int8PtrTy);
244d0b7824eSBob Wilson   PGOBuilder.CreateCall3(EmitFunc, NameString,
245ef512b99SJustin Bogner                          PGOBuilder.getInt32(NumRegionCounters),
246ef512b99SJustin Bogner                          PGOBuilder.CreateBitCast(RegionCounters, Int64PtrTy));
247ef512b99SJustin Bogner }
248ef512b99SJustin Bogner 
249ef512b99SJustin Bogner llvm::Function *CodeGenPGO::emitInitialization(CodeGenModule &CGM) {
250ef512b99SJustin Bogner   llvm::Function *WriteoutF =
251ef512b99SJustin Bogner     CGM.getModule().getFunction("__llvm_pgo_writeout");
252ef512b99SJustin Bogner   if (!WriteoutF)
253ef512b99SJustin Bogner     return NULL;
254ef512b99SJustin Bogner 
255ef512b99SJustin Bogner   // Create a small bit of code that registers the "__llvm_pgo_writeout" to
256ef512b99SJustin Bogner   // be executed at exit.
257ef512b99SJustin Bogner   llvm::Function *F = CGM.getModule().getFunction("__llvm_pgo_init");
258ef512b99SJustin Bogner   if (F)
259ef512b99SJustin Bogner     return NULL;
260ef512b99SJustin Bogner 
261ef512b99SJustin Bogner   llvm::LLVMContext &Ctx = CGM.getLLVMContext();
262ef512b99SJustin Bogner   llvm::FunctionType *FTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx),
263ef512b99SJustin Bogner                                                     false);
264ef512b99SJustin Bogner   F = llvm::Function::Create(FTy, llvm::GlobalValue::InternalLinkage,
265ef512b99SJustin Bogner                              "__llvm_pgo_init", &CGM.getModule());
266ef512b99SJustin Bogner   F->setUnnamedAddr(true);
267ef512b99SJustin Bogner   F->setLinkage(llvm::GlobalValue::InternalLinkage);
268ef512b99SJustin Bogner   F->addFnAttr(llvm::Attribute::NoInline);
269ef512b99SJustin Bogner   if (CGM.getCodeGenOpts().DisableRedZone)
270ef512b99SJustin Bogner     F->addFnAttr(llvm::Attribute::NoRedZone);
271ef512b99SJustin Bogner 
272ef512b99SJustin Bogner   llvm::BasicBlock *BB = llvm::BasicBlock::Create(CGM.getLLVMContext(), "", F);
273ef512b99SJustin Bogner   CGBuilderTy PGOBuilder(BB);
274ef512b99SJustin Bogner 
275ef512b99SJustin Bogner   FTy = llvm::FunctionType::get(PGOBuilder.getVoidTy(), false);
276ef512b99SJustin Bogner   llvm::Type *Params[] = {
277ef512b99SJustin Bogner     llvm::PointerType::get(FTy, 0)
278ef512b99SJustin Bogner   };
279ef512b99SJustin Bogner   FTy = llvm::FunctionType::get(PGOBuilder.getVoidTy(), Params, false);
280ef512b99SJustin Bogner 
281ef512b99SJustin Bogner   // Inialize the environment and register the local writeout function.
282ef512b99SJustin Bogner   llvm::Constant *PGOInit =
283ef512b99SJustin Bogner     CGM.getModule().getOrInsertFunction("llvm_pgo_init", FTy);
284ef512b99SJustin Bogner   PGOBuilder.CreateCall(PGOInit, WriteoutF);
285ef512b99SJustin Bogner   PGOBuilder.CreateRetVoid();
286ef512b99SJustin Bogner 
287ef512b99SJustin Bogner   return F;
288ef512b99SJustin Bogner }
289ef512b99SJustin Bogner 
290ef512b99SJustin Bogner namespace {
291ef512b99SJustin Bogner   /// A StmtVisitor that fills a map of statements to PGO counters.
292ef512b99SJustin Bogner   struct MapRegionCounters : public ConstStmtVisitor<MapRegionCounters> {
293ef512b99SJustin Bogner     /// The next counter value to assign.
294ef512b99SJustin Bogner     unsigned NextCounter;
295ef512b99SJustin Bogner     /// The map of statements to counters.
296ef512b99SJustin Bogner     llvm::DenseMap<const Stmt*, unsigned> *CounterMap;
297ef512b99SJustin Bogner 
298ef512b99SJustin Bogner     MapRegionCounters(llvm::DenseMap<const Stmt*, unsigned> *CounterMap) :
299ef512b99SJustin Bogner       NextCounter(0), CounterMap(CounterMap) {
300ef512b99SJustin Bogner     }
301ef512b99SJustin Bogner 
302ef512b99SJustin Bogner     void VisitChildren(const Stmt *S) {
303ef512b99SJustin Bogner       for (Stmt::const_child_range I = S->children(); I; ++I)
304ef512b99SJustin Bogner         if (*I)
305ef512b99SJustin Bogner          this->Visit(*I);
306ef512b99SJustin Bogner     }
307ef512b99SJustin Bogner     void VisitStmt(const Stmt *S) { VisitChildren(S); }
308ef512b99SJustin Bogner 
309ea278c32SJustin Bogner     /// Assign a counter to track entry to the function body.
310ef512b99SJustin Bogner     void VisitFunctionDecl(const FunctionDecl *S) {
311ef512b99SJustin Bogner       (*CounterMap)[S->getBody()] = NextCounter++;
312ef512b99SJustin Bogner       Visit(S->getBody());
313ef512b99SJustin Bogner     }
3145ec8fe19SBob Wilson     void VisitObjCMethodDecl(const ObjCMethodDecl *S) {
3155ec8fe19SBob Wilson       (*CounterMap)[S->getBody()] = NextCounter++;
3165ec8fe19SBob Wilson       Visit(S->getBody());
3175ec8fe19SBob Wilson     }
318*c845c00aSBob Wilson     void VisitBlockDecl(const BlockDecl *S) {
319*c845c00aSBob Wilson       (*CounterMap)[S->getBody()] = NextCounter++;
320*c845c00aSBob Wilson       Visit(S->getBody());
321*c845c00aSBob Wilson     }
322ea278c32SJustin Bogner     /// Assign a counter to track the block following a label.
323ef512b99SJustin Bogner     void VisitLabelStmt(const LabelStmt *S) {
324ef512b99SJustin Bogner       (*CounterMap)[S] = NextCounter++;
325ef512b99SJustin Bogner       Visit(S->getSubStmt());
326ef512b99SJustin Bogner     }
327bf854f0fSBob Wilson     /// Assign a counter for the body of a while loop.
328ef512b99SJustin Bogner     void VisitWhileStmt(const WhileStmt *S) {
329bf854f0fSBob Wilson       (*CounterMap)[S] = NextCounter++;
330ef512b99SJustin Bogner       Visit(S->getCond());
331ef512b99SJustin Bogner       Visit(S->getBody());
332ef512b99SJustin Bogner     }
333bf854f0fSBob Wilson     /// Assign a counter for the body of a do-while loop.
334ef512b99SJustin Bogner     void VisitDoStmt(const DoStmt *S) {
335bf854f0fSBob Wilson       (*CounterMap)[S] = NextCounter++;
336ef512b99SJustin Bogner       Visit(S->getBody());
337ef512b99SJustin Bogner       Visit(S->getCond());
338ef512b99SJustin Bogner     }
339bf854f0fSBob Wilson     /// Assign a counter for the body of a for loop.
340ef512b99SJustin Bogner     void VisitForStmt(const ForStmt *S) {
341bf854f0fSBob Wilson       (*CounterMap)[S] = NextCounter++;
342bf854f0fSBob Wilson       if (S->getInit())
343bf854f0fSBob Wilson         Visit(S->getInit());
344ef512b99SJustin Bogner       const Expr *E;
345ef512b99SJustin Bogner       if ((E = S->getCond()))
346ef512b99SJustin Bogner         Visit(E);
347ef512b99SJustin Bogner       if ((E = S->getInc()))
348ef512b99SJustin Bogner         Visit(E);
349bf854f0fSBob Wilson       Visit(S->getBody());
350ef512b99SJustin Bogner     }
351bf854f0fSBob Wilson     /// Assign a counter for the body of a for-range loop.
352ef512b99SJustin Bogner     void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
353bf854f0fSBob Wilson       (*CounterMap)[S] = NextCounter++;
354bf854f0fSBob Wilson       Visit(S->getRangeStmt());
355bf854f0fSBob Wilson       Visit(S->getBeginEndStmt());
356bf854f0fSBob Wilson       Visit(S->getCond());
357bf854f0fSBob Wilson       Visit(S->getLoopVarStmt());
358ef512b99SJustin Bogner       Visit(S->getBody());
359bf854f0fSBob Wilson       Visit(S->getInc());
360ef512b99SJustin Bogner     }
361bf854f0fSBob Wilson     /// Assign a counter for the body of a for-collection loop.
362ef512b99SJustin Bogner     void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
363bf854f0fSBob Wilson       (*CounterMap)[S] = NextCounter++;
364ef512b99SJustin Bogner       Visit(S->getElement());
365ef512b99SJustin Bogner       Visit(S->getBody());
366ef512b99SJustin Bogner     }
367ef512b99SJustin Bogner     /// Assign a counter for the exit block of the switch statement.
368ef512b99SJustin Bogner     void VisitSwitchStmt(const SwitchStmt *S) {
369ef512b99SJustin Bogner       (*CounterMap)[S] = NextCounter++;
370ef512b99SJustin Bogner       Visit(S->getCond());
371ef512b99SJustin Bogner       Visit(S->getBody());
372ef512b99SJustin Bogner     }
373ef512b99SJustin Bogner     /// Assign a counter for a particular case in a switch. This counts jumps
374ef512b99SJustin Bogner     /// from the switch header as well as fallthrough from the case before this
375ef512b99SJustin Bogner     /// one.
376ef512b99SJustin Bogner     void VisitCaseStmt(const CaseStmt *S) {
377ef512b99SJustin Bogner       (*CounterMap)[S] = NextCounter++;
378ef512b99SJustin Bogner       Visit(S->getSubStmt());
379ef512b99SJustin Bogner     }
380ef512b99SJustin Bogner     /// Assign a counter for the default case of a switch statement. The count
381ef512b99SJustin Bogner     /// is the number of branches from the loop header to the default, and does
382ef512b99SJustin Bogner     /// not include fallthrough from previous cases. If we have multiple
383ef512b99SJustin Bogner     /// conditional branch blocks from the switch instruction to the default
384ef512b99SJustin Bogner     /// block, as with large GNU case ranges, this is the counter for the last
385ef512b99SJustin Bogner     /// edge in that series, rather than the first.
386ef512b99SJustin Bogner     void VisitDefaultStmt(const DefaultStmt *S) {
387ef512b99SJustin Bogner       (*CounterMap)[S] = NextCounter++;
388ef512b99SJustin Bogner       Visit(S->getSubStmt());
389ef512b99SJustin Bogner     }
390ef512b99SJustin Bogner     /// Assign a counter for the "then" part of an if statement. The count for
391ef512b99SJustin Bogner     /// the "else" part, if it exists, will be calculated from this counter.
392ef512b99SJustin Bogner     void VisitIfStmt(const IfStmt *S) {
393ef512b99SJustin Bogner       (*CounterMap)[S] = NextCounter++;
394ef512b99SJustin Bogner       Visit(S->getCond());
395ef512b99SJustin Bogner       Visit(S->getThen());
396ef512b99SJustin Bogner       if (S->getElse())
397ef512b99SJustin Bogner         Visit(S->getElse());
398ef512b99SJustin Bogner     }
399ef512b99SJustin Bogner     /// Assign a counter for the continuation block of a C++ try statement.
400ef512b99SJustin Bogner     void VisitCXXTryStmt(const CXXTryStmt *S) {
401ef512b99SJustin Bogner       (*CounterMap)[S] = NextCounter++;
402ef512b99SJustin Bogner       Visit(S->getTryBlock());
403ef512b99SJustin Bogner       for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
404ef512b99SJustin Bogner         Visit(S->getHandler(I));
405ef512b99SJustin Bogner     }
406ef512b99SJustin Bogner     /// Assign a counter for a catch statement's handler block.
407ef512b99SJustin Bogner     void VisitCXXCatchStmt(const CXXCatchStmt *S) {
408ef512b99SJustin Bogner       (*CounterMap)[S] = NextCounter++;
409ef512b99SJustin Bogner       Visit(S->getHandlerBlock());
410ef512b99SJustin Bogner     }
411ef512b99SJustin Bogner     /// Assign a counter for the "true" part of a conditional operator. The
412ef512b99SJustin Bogner     /// count in the "false" part will be calculated from this counter.
413ef512b99SJustin Bogner     void VisitConditionalOperator(const ConditionalOperator *E) {
414ef512b99SJustin Bogner       (*CounterMap)[E] = NextCounter++;
415ef512b99SJustin Bogner       Visit(E->getCond());
416ef512b99SJustin Bogner       Visit(E->getTrueExpr());
417ef512b99SJustin Bogner       Visit(E->getFalseExpr());
418ef512b99SJustin Bogner     }
419ef512b99SJustin Bogner     /// Assign a counter for the right hand side of a logical and operator.
420ef512b99SJustin Bogner     void VisitBinLAnd(const BinaryOperator *E) {
421ef512b99SJustin Bogner       (*CounterMap)[E] = NextCounter++;
422ef512b99SJustin Bogner       Visit(E->getLHS());
423ef512b99SJustin Bogner       Visit(E->getRHS());
424ef512b99SJustin Bogner     }
425ef512b99SJustin Bogner     /// Assign a counter for the right hand side of a logical or operator.
426ef512b99SJustin Bogner     void VisitBinLOr(const BinaryOperator *E) {
427ef512b99SJustin Bogner       (*CounterMap)[E] = NextCounter++;
428ef512b99SJustin Bogner       Visit(E->getLHS());
429ef512b99SJustin Bogner       Visit(E->getRHS());
430ef512b99SJustin Bogner     }
431ef512b99SJustin Bogner   };
432bf854f0fSBob Wilson 
433bf854f0fSBob Wilson   /// A StmtVisitor that propagates the raw counts through the AST and
434bf854f0fSBob Wilson   /// records the count at statements where the value may change.
435bf854f0fSBob Wilson   struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
436bf854f0fSBob Wilson     /// PGO state.
437bf854f0fSBob Wilson     CodeGenPGO &PGO;
438bf854f0fSBob Wilson 
439bf854f0fSBob Wilson     /// A flag that is set when the current count should be recorded on the
440bf854f0fSBob Wilson     /// next statement, such as at the exit of a loop.
441bf854f0fSBob Wilson     bool RecordNextStmtCount;
442bf854f0fSBob Wilson 
443bf854f0fSBob Wilson     /// The map of statements to count values.
444bf854f0fSBob Wilson     llvm::DenseMap<const Stmt*, uint64_t> *CountMap;
445bf854f0fSBob Wilson 
446bf854f0fSBob Wilson     /// BreakContinueStack - Keep counts of breaks and continues inside loops.
447bf854f0fSBob Wilson     struct BreakContinue {
448bf854f0fSBob Wilson       uint64_t BreakCount;
449bf854f0fSBob Wilson       uint64_t ContinueCount;
450bf854f0fSBob Wilson       BreakContinue() : BreakCount(0), ContinueCount(0) {}
451bf854f0fSBob Wilson     };
452bf854f0fSBob Wilson     SmallVector<BreakContinue, 8> BreakContinueStack;
453bf854f0fSBob Wilson 
454bf854f0fSBob Wilson     ComputeRegionCounts(llvm::DenseMap<const Stmt*, uint64_t> *CountMap,
455bf854f0fSBob Wilson                         CodeGenPGO &PGO) :
456bf854f0fSBob Wilson       PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {
457bf854f0fSBob Wilson     }
458bf854f0fSBob Wilson 
459bf854f0fSBob Wilson     void RecordStmtCount(const Stmt *S) {
460bf854f0fSBob Wilson       if (RecordNextStmtCount) {
461bf854f0fSBob Wilson         (*CountMap)[S] = PGO.getCurrentRegionCount();
462bf854f0fSBob Wilson         RecordNextStmtCount = false;
463bf854f0fSBob Wilson       }
464bf854f0fSBob Wilson     }
465bf854f0fSBob Wilson 
466bf854f0fSBob Wilson     void VisitStmt(const Stmt *S) {
467bf854f0fSBob Wilson       RecordStmtCount(S);
468bf854f0fSBob Wilson       for (Stmt::const_child_range I = S->children(); I; ++I) {
469bf854f0fSBob Wilson         if (*I)
470bf854f0fSBob Wilson          this->Visit(*I);
471bf854f0fSBob Wilson       }
472bf854f0fSBob Wilson     }
473bf854f0fSBob Wilson 
474bf854f0fSBob Wilson     void VisitFunctionDecl(const FunctionDecl *S) {
475bf854f0fSBob Wilson       RegionCounter Cnt(PGO, S->getBody());
476bf854f0fSBob Wilson       Cnt.beginRegion();
477bf854f0fSBob Wilson       (*CountMap)[S->getBody()] = PGO.getCurrentRegionCount();
478bf854f0fSBob Wilson       Visit(S->getBody());
479bf854f0fSBob Wilson     }
480bf854f0fSBob Wilson 
4815ec8fe19SBob Wilson     void VisitObjCMethodDecl(const ObjCMethodDecl *S) {
4825ec8fe19SBob Wilson       RegionCounter Cnt(PGO, S->getBody());
4835ec8fe19SBob Wilson       Cnt.beginRegion();
4845ec8fe19SBob Wilson       (*CountMap)[S->getBody()] = PGO.getCurrentRegionCount();
4855ec8fe19SBob Wilson       Visit(S->getBody());
4865ec8fe19SBob Wilson     }
4875ec8fe19SBob Wilson 
488*c845c00aSBob Wilson     void VisitBlockDecl(const BlockDecl *S) {
489*c845c00aSBob Wilson       RegionCounter Cnt(PGO, S->getBody());
490*c845c00aSBob Wilson       Cnt.beginRegion();
491*c845c00aSBob Wilson       (*CountMap)[S->getBody()] = PGO.getCurrentRegionCount();
492*c845c00aSBob Wilson       Visit(S->getBody());
493*c845c00aSBob Wilson     }
494*c845c00aSBob Wilson 
495bf854f0fSBob Wilson     void VisitReturnStmt(const ReturnStmt *S) {
496bf854f0fSBob Wilson       RecordStmtCount(S);
497bf854f0fSBob Wilson       if (S->getRetValue())
498bf854f0fSBob Wilson         Visit(S->getRetValue());
499bf854f0fSBob Wilson       PGO.setCurrentRegionUnreachable();
500bf854f0fSBob Wilson       RecordNextStmtCount = true;
501bf854f0fSBob Wilson     }
502bf854f0fSBob Wilson 
503bf854f0fSBob Wilson     void VisitGotoStmt(const GotoStmt *S) {
504bf854f0fSBob Wilson       RecordStmtCount(S);
505bf854f0fSBob Wilson       PGO.setCurrentRegionUnreachable();
506bf854f0fSBob Wilson       RecordNextStmtCount = true;
507bf854f0fSBob Wilson     }
508bf854f0fSBob Wilson 
509bf854f0fSBob Wilson     void VisitLabelStmt(const LabelStmt *S) {
510bf854f0fSBob Wilson       RecordNextStmtCount = false;
511bf854f0fSBob Wilson       RegionCounter Cnt(PGO, S);
512bf854f0fSBob Wilson       Cnt.beginRegion();
513bf854f0fSBob Wilson       (*CountMap)[S] = PGO.getCurrentRegionCount();
514bf854f0fSBob Wilson       Visit(S->getSubStmt());
515bf854f0fSBob Wilson     }
516bf854f0fSBob Wilson 
517bf854f0fSBob Wilson     void VisitBreakStmt(const BreakStmt *S) {
518bf854f0fSBob Wilson       RecordStmtCount(S);
519bf854f0fSBob Wilson       assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
520bf854f0fSBob Wilson       BreakContinueStack.back().BreakCount += PGO.getCurrentRegionCount();
521bf854f0fSBob Wilson       PGO.setCurrentRegionUnreachable();
522bf854f0fSBob Wilson       RecordNextStmtCount = true;
523bf854f0fSBob Wilson     }
524bf854f0fSBob Wilson 
525bf854f0fSBob Wilson     void VisitContinueStmt(const ContinueStmt *S) {
526bf854f0fSBob Wilson       RecordStmtCount(S);
527bf854f0fSBob Wilson       assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
528bf854f0fSBob Wilson       BreakContinueStack.back().ContinueCount += PGO.getCurrentRegionCount();
529bf854f0fSBob Wilson       PGO.setCurrentRegionUnreachable();
530bf854f0fSBob Wilson       RecordNextStmtCount = true;
531bf854f0fSBob Wilson     }
532bf854f0fSBob Wilson 
533bf854f0fSBob Wilson     void VisitWhileStmt(const WhileStmt *S) {
534bf854f0fSBob Wilson       RecordStmtCount(S);
535bf854f0fSBob Wilson       RegionCounter Cnt(PGO, S);
536bf854f0fSBob Wilson       BreakContinueStack.push_back(BreakContinue());
537bf854f0fSBob Wilson       // Visit the body region first so the break/continue adjustments can be
538bf854f0fSBob Wilson       // included when visiting the condition.
539bf854f0fSBob Wilson       Cnt.beginRegion();
540bf854f0fSBob Wilson       (*CountMap)[S->getBody()] = PGO.getCurrentRegionCount();
541bf854f0fSBob Wilson       Visit(S->getBody());
542bf854f0fSBob Wilson       Cnt.adjustForControlFlow();
543bf854f0fSBob Wilson 
544bf854f0fSBob Wilson       // ...then go back and propagate counts through the condition. The count
545bf854f0fSBob Wilson       // at the start of the condition is the sum of the incoming edges,
546bf854f0fSBob Wilson       // the backedge from the end of the loop body, and the edges from
547bf854f0fSBob Wilson       // continue statements.
548bf854f0fSBob Wilson       BreakContinue BC = BreakContinueStack.pop_back_val();
549bf854f0fSBob Wilson       Cnt.setCurrentRegionCount(Cnt.getParentCount() +
550bf854f0fSBob Wilson                                 Cnt.getAdjustedCount() + BC.ContinueCount);
551bf854f0fSBob Wilson       (*CountMap)[S->getCond()] = PGO.getCurrentRegionCount();
552bf854f0fSBob Wilson       Visit(S->getCond());
553bf854f0fSBob Wilson       Cnt.adjustForControlFlow();
554bf854f0fSBob Wilson       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
555bf854f0fSBob Wilson       RecordNextStmtCount = true;
556bf854f0fSBob Wilson     }
557bf854f0fSBob Wilson 
558bf854f0fSBob Wilson     void VisitDoStmt(const DoStmt *S) {
559bf854f0fSBob Wilson       RecordStmtCount(S);
560bf854f0fSBob Wilson       RegionCounter Cnt(PGO, S);
561bf854f0fSBob Wilson       BreakContinueStack.push_back(BreakContinue());
562bf854f0fSBob Wilson       Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
563bf854f0fSBob Wilson       (*CountMap)[S->getBody()] = PGO.getCurrentRegionCount();
564bf854f0fSBob Wilson       Visit(S->getBody());
565bf854f0fSBob Wilson       Cnt.adjustForControlFlow();
566bf854f0fSBob Wilson 
567bf854f0fSBob Wilson       BreakContinue BC = BreakContinueStack.pop_back_val();
568bf854f0fSBob Wilson       // The count at the start of the condition is equal to the count at the
569bf854f0fSBob Wilson       // end of the body. The adjusted count does not include either the
570bf854f0fSBob Wilson       // fall-through count coming into the loop or the continue count, so add
571bf854f0fSBob Wilson       // both of those separately. This is coincidentally the same equation as
572bf854f0fSBob Wilson       // with while loops but for different reasons.
573bf854f0fSBob Wilson       Cnt.setCurrentRegionCount(Cnt.getParentCount() +
574bf854f0fSBob Wilson                                 Cnt.getAdjustedCount() + BC.ContinueCount);
575bf854f0fSBob Wilson       (*CountMap)[S->getCond()] = PGO.getCurrentRegionCount();
576bf854f0fSBob Wilson       Visit(S->getCond());
577bf854f0fSBob Wilson       Cnt.adjustForControlFlow();
578bf854f0fSBob Wilson       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
579bf854f0fSBob Wilson       RecordNextStmtCount = true;
580bf854f0fSBob Wilson     }
581bf854f0fSBob Wilson 
582bf854f0fSBob Wilson     void VisitForStmt(const ForStmt *S) {
583bf854f0fSBob Wilson       RecordStmtCount(S);
584bf854f0fSBob Wilson       if (S->getInit())
585bf854f0fSBob Wilson         Visit(S->getInit());
586bf854f0fSBob Wilson       RegionCounter Cnt(PGO, S);
587bf854f0fSBob Wilson       BreakContinueStack.push_back(BreakContinue());
588bf854f0fSBob Wilson       // Visit the body region first. (This is basically the same as a while
589bf854f0fSBob Wilson       // loop; see further comments in VisitWhileStmt.)
590bf854f0fSBob Wilson       Cnt.beginRegion();
591bf854f0fSBob Wilson       (*CountMap)[S->getBody()] = PGO.getCurrentRegionCount();
592bf854f0fSBob Wilson       Visit(S->getBody());
593bf854f0fSBob Wilson       Cnt.adjustForControlFlow();
594bf854f0fSBob Wilson 
595bf854f0fSBob Wilson       // The increment is essentially part of the body but it needs to include
596bf854f0fSBob Wilson       // the count for all the continue statements.
597bf854f0fSBob Wilson       if (S->getInc()) {
598bf854f0fSBob Wilson         Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
599bf854f0fSBob Wilson                                   BreakContinueStack.back().ContinueCount);
600bf854f0fSBob Wilson         (*CountMap)[S->getInc()] = PGO.getCurrentRegionCount();
601bf854f0fSBob Wilson         Visit(S->getInc());
602bf854f0fSBob Wilson         Cnt.adjustForControlFlow();
603bf854f0fSBob Wilson       }
604bf854f0fSBob Wilson 
605bf854f0fSBob Wilson       BreakContinue BC = BreakContinueStack.pop_back_val();
606bf854f0fSBob Wilson 
607bf854f0fSBob Wilson       // ...then go back and propagate counts through the condition.
608bf854f0fSBob Wilson       if (S->getCond()) {
609bf854f0fSBob Wilson         Cnt.setCurrentRegionCount(Cnt.getParentCount() +
610bf854f0fSBob Wilson                                   Cnt.getAdjustedCount() +
611bf854f0fSBob Wilson                                   BC.ContinueCount);
612bf854f0fSBob Wilson         (*CountMap)[S->getCond()] = PGO.getCurrentRegionCount();
613bf854f0fSBob Wilson         Visit(S->getCond());
614bf854f0fSBob Wilson         Cnt.adjustForControlFlow();
615bf854f0fSBob Wilson       }
616bf854f0fSBob Wilson       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
617bf854f0fSBob Wilson       RecordNextStmtCount = true;
618bf854f0fSBob Wilson     }
619bf854f0fSBob Wilson 
620bf854f0fSBob Wilson     void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
621bf854f0fSBob Wilson       RecordStmtCount(S);
622bf854f0fSBob Wilson       Visit(S->getRangeStmt());
623bf854f0fSBob Wilson       Visit(S->getBeginEndStmt());
624bf854f0fSBob Wilson       RegionCounter Cnt(PGO, S);
625bf854f0fSBob Wilson       BreakContinueStack.push_back(BreakContinue());
626bf854f0fSBob Wilson       // Visit the body region first. (This is basically the same as a while
627bf854f0fSBob Wilson       // loop; see further comments in VisitWhileStmt.)
628bf854f0fSBob Wilson       Cnt.beginRegion();
629bf854f0fSBob Wilson       (*CountMap)[S->getLoopVarStmt()] = PGO.getCurrentRegionCount();
630bf854f0fSBob Wilson       Visit(S->getLoopVarStmt());
631bf854f0fSBob Wilson       Visit(S->getBody());
632bf854f0fSBob Wilson       Cnt.adjustForControlFlow();
633bf854f0fSBob Wilson 
634bf854f0fSBob Wilson       // The increment is essentially part of the body but it needs to include
635bf854f0fSBob Wilson       // the count for all the continue statements.
636bf854f0fSBob Wilson       Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
637bf854f0fSBob Wilson                                 BreakContinueStack.back().ContinueCount);
638bf854f0fSBob Wilson       (*CountMap)[S->getInc()] = PGO.getCurrentRegionCount();
639bf854f0fSBob Wilson       Visit(S->getInc());
640bf854f0fSBob Wilson       Cnt.adjustForControlFlow();
641bf854f0fSBob Wilson 
642bf854f0fSBob Wilson       BreakContinue BC = BreakContinueStack.pop_back_val();
643bf854f0fSBob Wilson 
644bf854f0fSBob Wilson       // ...then go back and propagate counts through the condition.
645bf854f0fSBob Wilson       Cnt.setCurrentRegionCount(Cnt.getParentCount() +
646bf854f0fSBob Wilson                                 Cnt.getAdjustedCount() +
647bf854f0fSBob Wilson                                 BC.ContinueCount);
648bf854f0fSBob Wilson       (*CountMap)[S->getCond()] = PGO.getCurrentRegionCount();
649bf854f0fSBob Wilson       Visit(S->getCond());
650bf854f0fSBob Wilson       Cnt.adjustForControlFlow();
651bf854f0fSBob Wilson       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
652bf854f0fSBob Wilson       RecordNextStmtCount = true;
653bf854f0fSBob Wilson     }
654bf854f0fSBob Wilson 
655bf854f0fSBob Wilson     void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
656bf854f0fSBob Wilson       RecordStmtCount(S);
657bf854f0fSBob Wilson       Visit(S->getElement());
658bf854f0fSBob Wilson       RegionCounter Cnt(PGO, S);
659bf854f0fSBob Wilson       BreakContinueStack.push_back(BreakContinue());
660bf854f0fSBob Wilson       Cnt.beginRegion();
661bf854f0fSBob Wilson       (*CountMap)[S->getBody()] = PGO.getCurrentRegionCount();
662bf854f0fSBob Wilson       Visit(S->getBody());
663bf854f0fSBob Wilson       BreakContinue BC = BreakContinueStack.pop_back_val();
664bf854f0fSBob Wilson       Cnt.adjustForControlFlow();
665bf854f0fSBob Wilson       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
666bf854f0fSBob Wilson       RecordNextStmtCount = true;
667bf854f0fSBob Wilson     }
668bf854f0fSBob Wilson 
669bf854f0fSBob Wilson     void VisitSwitchStmt(const SwitchStmt *S) {
670bf854f0fSBob Wilson       RecordStmtCount(S);
671bf854f0fSBob Wilson       Visit(S->getCond());
672bf854f0fSBob Wilson       PGO.setCurrentRegionUnreachable();
673bf854f0fSBob Wilson       BreakContinueStack.push_back(BreakContinue());
674bf854f0fSBob Wilson       Visit(S->getBody());
675bf854f0fSBob Wilson       // If the switch is inside a loop, add the continue counts.
676bf854f0fSBob Wilson       BreakContinue BC = BreakContinueStack.pop_back_val();
677bf854f0fSBob Wilson       if (!BreakContinueStack.empty())
678bf854f0fSBob Wilson         BreakContinueStack.back().ContinueCount += BC.ContinueCount;
679bf854f0fSBob Wilson       RegionCounter ExitCnt(PGO, S);
680bf854f0fSBob Wilson       ExitCnt.beginRegion();
681bf854f0fSBob Wilson       RecordNextStmtCount = true;
682bf854f0fSBob Wilson     }
683bf854f0fSBob Wilson 
684bf854f0fSBob Wilson     void VisitCaseStmt(const CaseStmt *S) {
685bf854f0fSBob Wilson       RecordNextStmtCount = false;
686bf854f0fSBob Wilson       RegionCounter Cnt(PGO, S);
687bf854f0fSBob Wilson       Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
688bf854f0fSBob Wilson       (*CountMap)[S] = Cnt.getCount();
689bf854f0fSBob Wilson       RecordNextStmtCount = true;
690bf854f0fSBob Wilson       Visit(S->getSubStmt());
691bf854f0fSBob Wilson     }
692bf854f0fSBob Wilson 
693bf854f0fSBob Wilson     void VisitDefaultStmt(const DefaultStmt *S) {
694bf854f0fSBob Wilson       RecordNextStmtCount = false;
695bf854f0fSBob Wilson       RegionCounter Cnt(PGO, S);
696bf854f0fSBob Wilson       Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
697bf854f0fSBob Wilson       (*CountMap)[S] = Cnt.getCount();
698bf854f0fSBob Wilson       RecordNextStmtCount = true;
699bf854f0fSBob Wilson       Visit(S->getSubStmt());
700bf854f0fSBob Wilson     }
701bf854f0fSBob Wilson 
702bf854f0fSBob Wilson     void VisitIfStmt(const IfStmt *S) {
703bf854f0fSBob Wilson       RecordStmtCount(S);
704bf854f0fSBob Wilson       RegionCounter Cnt(PGO, S);
705bf854f0fSBob Wilson       Visit(S->getCond());
706bf854f0fSBob Wilson 
707bf854f0fSBob Wilson       Cnt.beginRegion();
708bf854f0fSBob Wilson       (*CountMap)[S->getThen()] = PGO.getCurrentRegionCount();
709bf854f0fSBob Wilson       Visit(S->getThen());
710bf854f0fSBob Wilson       Cnt.adjustForControlFlow();
711bf854f0fSBob Wilson 
712bf854f0fSBob Wilson       if (S->getElse()) {
713bf854f0fSBob Wilson         Cnt.beginElseRegion();
714bf854f0fSBob Wilson         (*CountMap)[S->getElse()] = PGO.getCurrentRegionCount();
715bf854f0fSBob Wilson         Visit(S->getElse());
716bf854f0fSBob Wilson         Cnt.adjustForControlFlow();
717bf854f0fSBob Wilson       }
718bf854f0fSBob Wilson       Cnt.applyAdjustmentsToRegion(0);
719bf854f0fSBob Wilson       RecordNextStmtCount = true;
720bf854f0fSBob Wilson     }
721bf854f0fSBob Wilson 
722bf854f0fSBob Wilson     void VisitCXXTryStmt(const CXXTryStmt *S) {
723bf854f0fSBob Wilson       RecordStmtCount(S);
724bf854f0fSBob Wilson       Visit(S->getTryBlock());
725bf854f0fSBob Wilson       for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
726bf854f0fSBob Wilson         Visit(S->getHandler(I));
727bf854f0fSBob Wilson       RegionCounter Cnt(PGO, S);
728bf854f0fSBob Wilson       Cnt.beginRegion();
729bf854f0fSBob Wilson       RecordNextStmtCount = true;
730bf854f0fSBob Wilson     }
731bf854f0fSBob Wilson 
732bf854f0fSBob Wilson     void VisitCXXCatchStmt(const CXXCatchStmt *S) {
733bf854f0fSBob Wilson       RecordNextStmtCount = false;
734bf854f0fSBob Wilson       RegionCounter Cnt(PGO, S);
735bf854f0fSBob Wilson       Cnt.beginRegion();
736bf854f0fSBob Wilson       (*CountMap)[S] = PGO.getCurrentRegionCount();
737bf854f0fSBob Wilson       Visit(S->getHandlerBlock());
738bf854f0fSBob Wilson     }
739bf854f0fSBob Wilson 
740bf854f0fSBob Wilson     void VisitConditionalOperator(const ConditionalOperator *E) {
741bf854f0fSBob Wilson       RecordStmtCount(E);
742bf854f0fSBob Wilson       RegionCounter Cnt(PGO, E);
743bf854f0fSBob Wilson       Visit(E->getCond());
744bf854f0fSBob Wilson 
745bf854f0fSBob Wilson       Cnt.beginRegion();
746bf854f0fSBob Wilson       (*CountMap)[E->getTrueExpr()] = PGO.getCurrentRegionCount();
747bf854f0fSBob Wilson       Visit(E->getTrueExpr());
748bf854f0fSBob Wilson       Cnt.adjustForControlFlow();
749bf854f0fSBob Wilson 
750bf854f0fSBob Wilson       Cnt.beginElseRegion();
751bf854f0fSBob Wilson       (*CountMap)[E->getFalseExpr()] = PGO.getCurrentRegionCount();
752bf854f0fSBob Wilson       Visit(E->getFalseExpr());
753bf854f0fSBob Wilson       Cnt.adjustForControlFlow();
754bf854f0fSBob Wilson 
755bf854f0fSBob Wilson       Cnt.applyAdjustmentsToRegion(0);
756bf854f0fSBob Wilson       RecordNextStmtCount = true;
757bf854f0fSBob Wilson     }
758bf854f0fSBob Wilson 
759bf854f0fSBob Wilson     void VisitBinLAnd(const BinaryOperator *E) {
760bf854f0fSBob Wilson       RecordStmtCount(E);
761bf854f0fSBob Wilson       RegionCounter Cnt(PGO, E);
762bf854f0fSBob Wilson       Visit(E->getLHS());
763bf854f0fSBob Wilson       Cnt.beginRegion();
764bf854f0fSBob Wilson       (*CountMap)[E->getRHS()] = PGO.getCurrentRegionCount();
765bf854f0fSBob Wilson       Visit(E->getRHS());
766bf854f0fSBob Wilson       Cnt.adjustForControlFlow();
767bf854f0fSBob Wilson       Cnt.applyAdjustmentsToRegion(0);
768bf854f0fSBob Wilson       RecordNextStmtCount = true;
769bf854f0fSBob Wilson     }
770bf854f0fSBob Wilson 
771bf854f0fSBob Wilson     void VisitBinLOr(const BinaryOperator *E) {
772bf854f0fSBob Wilson       RecordStmtCount(E);
773bf854f0fSBob Wilson       RegionCounter Cnt(PGO, E);
774bf854f0fSBob Wilson       Visit(E->getLHS());
775bf854f0fSBob Wilson       Cnt.beginRegion();
776bf854f0fSBob Wilson       (*CountMap)[E->getRHS()] = PGO.getCurrentRegionCount();
777bf854f0fSBob Wilson       Visit(E->getRHS());
778bf854f0fSBob Wilson       Cnt.adjustForControlFlow();
779bf854f0fSBob Wilson       Cnt.applyAdjustmentsToRegion(0);
780bf854f0fSBob Wilson       RecordNextStmtCount = true;
781bf854f0fSBob Wilson     }
782bf854f0fSBob Wilson   };
783ef512b99SJustin Bogner }
784ef512b99SJustin Bogner 
785da1ebedeSBob Wilson void CodeGenPGO::assignRegionCounters(const Decl *D, llvm::Function *Fn) {
786ef512b99SJustin Bogner   bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate;
787ef512b99SJustin Bogner   PGOProfileData *PGOData = CGM.getPGOData();
788ef512b99SJustin Bogner   if (!InstrumentRegions && !PGOData)
789ef512b99SJustin Bogner     return;
790ef512b99SJustin Bogner   if (!D)
791ef512b99SJustin Bogner     return;
792da1ebedeSBob Wilson   setFuncName(Fn);
793ef512b99SJustin Bogner   mapRegionCounters(D);
794ef512b99SJustin Bogner   if (InstrumentRegions)
795ef512b99SJustin Bogner     emitCounterVariables();
796bf854f0fSBob Wilson   if (PGOData) {
797da1ebedeSBob Wilson     loadRegionCounts(PGOData);
798bf854f0fSBob Wilson     computeRegionCounts(D);
799da1ebedeSBob Wilson 
800da1ebedeSBob Wilson     // Turn on InlineHint attribute for hot functions.
801da1ebedeSBob Wilson     if (PGOData->isHotFunction(getFuncName()))
802da1ebedeSBob Wilson       Fn->addFnAttr(llvm::Attribute::InlineHint);
803da1ebedeSBob Wilson     // Turn on Cold attribute for cold functions.
804da1ebedeSBob Wilson     else if (PGOData->isColdFunction(getFuncName()))
805da1ebedeSBob Wilson       Fn->addFnAttr(llvm::Attribute::Cold);
806bf854f0fSBob Wilson   }
807ef512b99SJustin Bogner }
808ef512b99SJustin Bogner 
809ef512b99SJustin Bogner void CodeGenPGO::mapRegionCounters(const Decl *D) {
810ef512b99SJustin Bogner   RegionCounterMap = new llvm::DenseMap<const Stmt*, unsigned>();
811ef512b99SJustin Bogner   MapRegionCounters Walker(RegionCounterMap);
812ef512b99SJustin Bogner   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
813ef512b99SJustin Bogner     Walker.VisitFunctionDecl(FD);
8145ec8fe19SBob Wilson   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
8155ec8fe19SBob Wilson     Walker.VisitObjCMethodDecl(MD);
816*c845c00aSBob Wilson   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
817*c845c00aSBob Wilson     Walker.VisitBlockDecl(BD);
818ef512b99SJustin Bogner   NumRegionCounters = Walker.NextCounter;
819ef512b99SJustin Bogner }
820ef512b99SJustin Bogner 
821bf854f0fSBob Wilson void CodeGenPGO::computeRegionCounts(const Decl *D) {
822bf854f0fSBob Wilson   StmtCountMap = new llvm::DenseMap<const Stmt*, uint64_t>();
823bf854f0fSBob Wilson   ComputeRegionCounts Walker(StmtCountMap, *this);
824bf854f0fSBob Wilson   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
825bf854f0fSBob Wilson     Walker.VisitFunctionDecl(FD);
8265ec8fe19SBob Wilson   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
8275ec8fe19SBob Wilson     Walker.VisitObjCMethodDecl(MD);
828*c845c00aSBob Wilson   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
829*c845c00aSBob Wilson     Walker.VisitBlockDecl(BD);
830bf854f0fSBob Wilson }
831bf854f0fSBob Wilson 
832ef512b99SJustin Bogner void CodeGenPGO::emitCounterVariables() {
833ef512b99SJustin Bogner   llvm::LLVMContext &Ctx = CGM.getLLVMContext();
834ef512b99SJustin Bogner   llvm::ArrayType *CounterTy = llvm::ArrayType::get(llvm::Type::getInt64Ty(Ctx),
835ef512b99SJustin Bogner                                                     NumRegionCounters);
836ef512b99SJustin Bogner   RegionCounters =
837ef512b99SJustin Bogner     new llvm::GlobalVariable(CGM.getModule(), CounterTy, false,
838ef512b99SJustin Bogner                              llvm::GlobalVariable::PrivateLinkage,
839ef512b99SJustin Bogner                              llvm::Constant::getNullValue(CounterTy),
840ef512b99SJustin Bogner                              "__llvm_pgo_ctr");
841ef512b99SJustin Bogner }
842ef512b99SJustin Bogner 
843ef512b99SJustin Bogner void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, unsigned Counter) {
844749ebc7fSBob Wilson   if (!RegionCounters)
845ef512b99SJustin Bogner     return;
846ef512b99SJustin Bogner   llvm::Value *Addr =
847ef512b99SJustin Bogner     Builder.CreateConstInBoundsGEP2_64(RegionCounters, 0, Counter);
848ef512b99SJustin Bogner   llvm::Value *Count = Builder.CreateLoad(Addr, "pgocount");
849ef512b99SJustin Bogner   Count = Builder.CreateAdd(Count, Builder.getInt64(1));
850ef512b99SJustin Bogner   Builder.CreateStore(Count, Addr);
851ef512b99SJustin Bogner }
852ef512b99SJustin Bogner 
853da1ebedeSBob Wilson void CodeGenPGO::loadRegionCounts(PGOProfileData *PGOData) {
854ef512b99SJustin Bogner   // For now, ignore the counts from the PGO data file only if the number of
855ef512b99SJustin Bogner   // counters does not match. This could be tightened down in the future to
856ef512b99SJustin Bogner   // ignore counts when the input changes in various ways, e.g., by comparing a
857ef512b99SJustin Bogner   // hash value based on some characteristics of the input.
858ef512b99SJustin Bogner   RegionCounts = new std::vector<uint64_t>();
859da1ebedeSBob Wilson   if (PGOData->getFunctionCounts(getFuncName(), *RegionCounts) ||
860ef512b99SJustin Bogner       RegionCounts->size() != NumRegionCounters) {
861ef512b99SJustin Bogner     delete RegionCounts;
862ef512b99SJustin Bogner     RegionCounts = 0;
863ef512b99SJustin Bogner   }
864ef512b99SJustin Bogner }
865ef512b99SJustin Bogner 
866ef512b99SJustin Bogner void CodeGenPGO::destroyRegionCounters() {
867ef512b99SJustin Bogner   if (RegionCounterMap != 0)
868ef512b99SJustin Bogner     delete RegionCounterMap;
869bf854f0fSBob Wilson   if (StmtCountMap != 0)
870bf854f0fSBob Wilson     delete StmtCountMap;
871ef512b99SJustin Bogner   if (RegionCounts != 0)
872ef512b99SJustin Bogner     delete RegionCounts;
873ef512b99SJustin Bogner }
874ef512b99SJustin Bogner 
875ef512b99SJustin Bogner llvm::MDNode *CodeGenPGO::createBranchWeights(uint64_t TrueCount,
876ef512b99SJustin Bogner                                               uint64_t FalseCount) {
877ef512b99SJustin Bogner   if (!TrueCount && !FalseCount)
878ef512b99SJustin Bogner     return 0;
879ef512b99SJustin Bogner 
880ef512b99SJustin Bogner   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
881ef512b99SJustin Bogner   // TODO: need to scale down to 32-bits
882ef512b99SJustin Bogner   // According to Laplace's Rule of Succession, it is better to compute the
883ef512b99SJustin Bogner   // weight based on the count plus 1.
884ef512b99SJustin Bogner   return MDHelper.createBranchWeights(TrueCount + 1, FalseCount + 1);
885ef512b99SJustin Bogner }
886ef512b99SJustin Bogner 
88795a27b0eSBob Wilson llvm::MDNode *CodeGenPGO::createBranchWeights(ArrayRef<uint64_t> Weights) {
888ef512b99SJustin Bogner   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
889ef512b99SJustin Bogner   // TODO: need to scale down to 32-bits, instead of just truncating.
890ef512b99SJustin Bogner   // According to Laplace's Rule of Succession, it is better to compute the
891ef512b99SJustin Bogner   // weight based on the count plus 1.
892ef512b99SJustin Bogner   SmallVector<uint32_t, 16> ScaledWeights;
893ef512b99SJustin Bogner   ScaledWeights.reserve(Weights.size());
894ef512b99SJustin Bogner   for (ArrayRef<uint64_t>::iterator WI = Weights.begin(), WE = Weights.end();
895ef512b99SJustin Bogner        WI != WE; ++WI) {
896ef512b99SJustin Bogner     ScaledWeights.push_back(*WI + 1);
897ef512b99SJustin Bogner   }
898ef512b99SJustin Bogner   return MDHelper.createBranchWeights(ScaledWeights);
899ef512b99SJustin Bogner }
900bf854f0fSBob Wilson 
901bf854f0fSBob Wilson llvm::MDNode *CodeGenPGO::createLoopWeights(const Stmt *Cond,
902bf854f0fSBob Wilson                                             RegionCounter &Cnt) {
903bf854f0fSBob Wilson   if (!haveRegionCounts())
904bf854f0fSBob Wilson     return 0;
905bf854f0fSBob Wilson   uint64_t LoopCount = Cnt.getCount();
906bf854f0fSBob Wilson   uint64_t CondCount = 0;
907bf854f0fSBob Wilson   bool Found = getStmtCount(Cond, CondCount);
908bf854f0fSBob Wilson   assert(Found && "missing expected loop condition count");
909bf854f0fSBob Wilson   (void)Found;
910bf854f0fSBob Wilson   if (CondCount == 0)
911bf854f0fSBob Wilson     return 0;
912bf854f0fSBob Wilson   return createBranchWeights(LoopCount,
913bf854f0fSBob Wilson                              std::max(CondCount, LoopCount) - LoopCount);
914bf854f0fSBob Wilson }
915