1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Instrumentation-based profile-guided optimization
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "CodeGenPGO.h"
15 #include "CodeGenFunction.h"
16 #include "clang/AST/RecursiveASTVisitor.h"
17 #include "clang/AST/StmtVisitor.h"
18 #include "llvm/Config/config.h" // for strtoull()/strtoul() define
19 #include "llvm/IR/MDBuilder.h"
20 #include "llvm/ProfileData/InstrProfReader.h"
21 #include "llvm/Support/Endian.h"
22 #include "llvm/Support/FileSystem.h"
23 #include "llvm/Support/MD5.h"
24 
25 using namespace clang;
26 using namespace CodeGen;
27 
28 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
29   RawFuncName = Fn->getName();
30 
31   // Function names may be prefixed with a binary '1' to indicate
32   // that the backend should not modify the symbols due to any platform
33   // naming convention. Do not include that '1' in the PGO profile name.
34   if (RawFuncName[0] == '\1')
35     RawFuncName = RawFuncName.substr(1);
36 
37   if (!Fn->hasLocalLinkage()) {
38     PrefixedFuncName.reset(new std::string(RawFuncName));
39     return;
40   }
41 
42   // For local symbols, prepend the main file name to distinguish them.
43   // Do not include the full path in the file name since there's no guarantee
44   // that it will stay the same, e.g., if the files are checked out from
45   // version control in different locations.
46   PrefixedFuncName.reset(new std::string(CGM.getCodeGenOpts().MainFileName));
47   if (PrefixedFuncName->empty())
48     PrefixedFuncName->assign("<unknown>");
49   PrefixedFuncName->append(":");
50   PrefixedFuncName->append(RawFuncName);
51 }
52 
53 static llvm::Function *getRegisterFunc(CodeGenModule &CGM) {
54   return CGM.getModule().getFunction("__llvm_profile_register_functions");
55 }
56 
57 static llvm::BasicBlock *getOrInsertRegisterBB(CodeGenModule &CGM) {
58   // Don't do this for Darwin.  compiler-rt uses linker magic.
59   if (CGM.getTarget().getTriple().isOSDarwin())
60     return nullptr;
61 
62   // Only need to insert this once per module.
63   if (llvm::Function *RegisterF = getRegisterFunc(CGM))
64     return &RegisterF->getEntryBlock();
65 
66   // Construct the function.
67   auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext());
68   auto *RegisterFTy = llvm::FunctionType::get(VoidTy, false);
69   auto *RegisterF = llvm::Function::Create(RegisterFTy,
70                                            llvm::GlobalValue::InternalLinkage,
71                                            "__llvm_profile_register_functions",
72                                            &CGM.getModule());
73   RegisterF->setUnnamedAddr(true);
74   if (CGM.getCodeGenOpts().DisableRedZone)
75     RegisterF->addFnAttr(llvm::Attribute::NoRedZone);
76 
77   // Construct and return the entry block.
78   auto *BB = llvm::BasicBlock::Create(CGM.getLLVMContext(), "", RegisterF);
79   CGBuilderTy Builder(BB);
80   Builder.CreateRetVoid();
81   return BB;
82 }
83 
84 static llvm::Constant *getOrInsertRuntimeRegister(CodeGenModule &CGM) {
85   auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext());
86   auto *VoidPtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
87   auto *RuntimeRegisterTy = llvm::FunctionType::get(VoidTy, VoidPtrTy, false);
88   return CGM.getModule().getOrInsertFunction("__llvm_profile_register_function",
89                                              RuntimeRegisterTy);
90 }
91 
92 static bool isMachO(const CodeGenModule &CGM) {
93   return CGM.getTarget().getTriple().isOSBinFormatMachO();
94 }
95 
96 static StringRef getCountersSection(const CodeGenModule &CGM) {
97   return isMachO(CGM) ? "__DATA,__llvm_prf_cnts" : "__llvm_prf_cnts";
98 }
99 
100 static StringRef getNameSection(const CodeGenModule &CGM) {
101   return isMachO(CGM) ? "__DATA,__llvm_prf_names" : "__llvm_prf_names";
102 }
103 
104 static StringRef getDataSection(const CodeGenModule &CGM) {
105   return isMachO(CGM) ? "__DATA,__llvm_prf_data" : "__llvm_prf_data";
106 }
107 
108 llvm::GlobalVariable *CodeGenPGO::buildDataVar() {
109   // Create name variable.
110   llvm::LLVMContext &Ctx = CGM.getLLVMContext();
111   auto *VarName = llvm::ConstantDataArray::getString(Ctx, getFuncName(),
112                                                      false);
113   auto *Name = new llvm::GlobalVariable(CGM.getModule(), VarName->getType(),
114                                         true, VarLinkage, VarName,
115                                         getFuncVarName("name"));
116   Name->setSection(getNameSection(CGM));
117   Name->setAlignment(1);
118 
119   // Create data variable.
120   auto *Int32Ty = llvm::Type::getInt32Ty(Ctx);
121   auto *Int64Ty = llvm::Type::getInt64Ty(Ctx);
122   auto *Int8PtrTy = llvm::Type::getInt8PtrTy(Ctx);
123   auto *Int64PtrTy = llvm::Type::getInt64PtrTy(Ctx);
124   llvm::Type *DataTypes[] = {
125     Int32Ty, Int32Ty, Int64Ty, Int8PtrTy, Int64PtrTy
126   };
127   auto *DataTy = llvm::StructType::get(Ctx, makeArrayRef(DataTypes));
128   llvm::Constant *DataVals[] = {
129     llvm::ConstantInt::get(Int32Ty, getFuncName().size()),
130     llvm::ConstantInt::get(Int32Ty, NumRegionCounters),
131     llvm::ConstantInt::get(Int64Ty, FunctionHash),
132     llvm::ConstantExpr::getBitCast(Name, Int8PtrTy),
133     llvm::ConstantExpr::getBitCast(RegionCounters, Int64PtrTy)
134   };
135   auto *Data =
136     new llvm::GlobalVariable(CGM.getModule(), DataTy, true, VarLinkage,
137                              llvm::ConstantStruct::get(DataTy, DataVals),
138                              getFuncVarName("data"));
139 
140   // All the data should be packed into an array in its own section.
141   Data->setSection(getDataSection(CGM));
142   Data->setAlignment(8);
143 
144   // Make sure the data doesn't get deleted.
145   CGM.addUsedGlobal(Data);
146   return Data;
147 }
148 
149 void CodeGenPGO::emitInstrumentationData() {
150   if (!CGM.getCodeGenOpts().ProfileInstrGenerate)
151     return;
152 
153   // Build the data.
154   auto *Data = buildDataVar();
155 
156   // Register the data.
157   auto *RegisterBB = getOrInsertRegisterBB(CGM);
158   if (!RegisterBB)
159     return;
160   CGBuilderTy Builder(RegisterBB->getTerminator());
161   auto *VoidPtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
162   Builder.CreateCall(getOrInsertRuntimeRegister(CGM),
163                      Builder.CreateBitCast(Data, VoidPtrTy));
164 }
165 
166 llvm::Function *CodeGenPGO::emitInitialization(CodeGenModule &CGM) {
167   if (!CGM.getCodeGenOpts().ProfileInstrGenerate)
168     return nullptr;
169 
170   assert(CGM.getModule().getFunction("__llvm_profile_init") == nullptr &&
171          "profile initialization already emitted");
172 
173   // Get the function to call at initialization.
174   llvm::Constant *RegisterF = getRegisterFunc(CGM);
175   if (!RegisterF)
176     return nullptr;
177 
178   // Create the initialization function.
179   auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext());
180   auto *F = llvm::Function::Create(llvm::FunctionType::get(VoidTy, false),
181                                    llvm::GlobalValue::InternalLinkage,
182                                    "__llvm_profile_init", &CGM.getModule());
183   F->setUnnamedAddr(true);
184   F->addFnAttr(llvm::Attribute::NoInline);
185   if (CGM.getCodeGenOpts().DisableRedZone)
186     F->addFnAttr(llvm::Attribute::NoRedZone);
187 
188   // Add the basic block and the necessary calls.
189   CGBuilderTy Builder(llvm::BasicBlock::Create(CGM.getLLVMContext(), "", F));
190   Builder.CreateCall(RegisterF);
191   Builder.CreateRetVoid();
192 
193   return F;
194 }
195 
196 namespace {
197 /// \brief Stable hasher for PGO region counters.
198 ///
199 /// PGOHash produces a stable hash of a given function's control flow.
200 ///
201 /// Changing the output of this hash will invalidate all previously generated
202 /// profiles -- i.e., don't do it.
203 ///
204 /// \note  When this hash does eventually change (years?), we still need to
205 /// support old hashes.  We'll need to pull in the version number from the
206 /// profile data format and use the matching hash function.
207 class PGOHash {
208   uint64_t Working;
209   unsigned Count;
210   llvm::MD5 MD5;
211 
212   static const int NumBitsPerType = 6;
213   static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
214   static const unsigned TooBig = 1u << NumBitsPerType;
215 
216 public:
217   /// \brief Hash values for AST nodes.
218   ///
219   /// Distinct values for AST nodes that have region counters attached.
220   ///
221   /// These values must be stable.  All new members must be added at the end,
222   /// and no members should be removed.  Changing the enumeration value for an
223   /// AST node will affect the hash of every function that contains that node.
224   enum HashType : unsigned char {
225     None = 0,
226     LabelStmt = 1,
227     WhileStmt,
228     DoStmt,
229     ForStmt,
230     CXXForRangeStmt,
231     ObjCForCollectionStmt,
232     SwitchStmt,
233     CaseStmt,
234     DefaultStmt,
235     IfStmt,
236     CXXTryStmt,
237     CXXCatchStmt,
238     ConditionalOperator,
239     BinaryOperatorLAnd,
240     BinaryOperatorLOr,
241     BinaryConditionalOperator,
242 
243     // Keep this last.  It's for the static assert that follows.
244     LastHashType
245   };
246   static_assert(LastHashType <= TooBig, "Too many types in HashType");
247 
248   // TODO: When this format changes, take in a version number here, and use the
249   // old hash calculation for file formats that used the old hash.
250   PGOHash() : Working(0), Count(0) {}
251   void combine(HashType Type);
252   uint64_t finalize();
253 };
254 const int PGOHash::NumBitsPerType;
255 const unsigned PGOHash::NumTypesPerWord;
256 const unsigned PGOHash::TooBig;
257 
258   /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
259   struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
260     /// The next counter value to assign.
261     unsigned NextCounter;
262     /// The function hash.
263     PGOHash Hash;
264     /// The map of statements to counters.
265     llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
266 
267     MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
268         : NextCounter(0), CounterMap(CounterMap) {}
269 
270     // Blocks and lambdas are handled as separate functions, so we need not
271     // traverse them in the parent context.
272     bool TraverseBlockExpr(BlockExpr *BE) { return true; }
273     bool TraverseLambdaBody(LambdaExpr *LE) { return true; }
274     bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
275 
276     bool VisitDecl(const Decl *D) {
277       switch (D->getKind()) {
278       default:
279         break;
280       case Decl::Function:
281       case Decl::CXXMethod:
282       case Decl::CXXConstructor:
283       case Decl::CXXDestructor:
284       case Decl::CXXConversion:
285       case Decl::ObjCMethod:
286       case Decl::Block:
287       case Decl::Captured:
288         CounterMap[D->getBody()] = NextCounter++;
289         break;
290       }
291       return true;
292     }
293 
294     bool VisitStmt(const Stmt *S) {
295       auto Type = getHashType(S);
296       if (Type == PGOHash::None)
297         return true;
298 
299       CounterMap[S] = NextCounter++;
300       Hash.combine(Type);
301       return true;
302     }
303     PGOHash::HashType getHashType(const Stmt *S) {
304       switch (S->getStmtClass()) {
305       default:
306         break;
307       case Stmt::LabelStmtClass:
308         return PGOHash::LabelStmt;
309       case Stmt::WhileStmtClass:
310         return PGOHash::WhileStmt;
311       case Stmt::DoStmtClass:
312         return PGOHash::DoStmt;
313       case Stmt::ForStmtClass:
314         return PGOHash::ForStmt;
315       case Stmt::CXXForRangeStmtClass:
316         return PGOHash::CXXForRangeStmt;
317       case Stmt::ObjCForCollectionStmtClass:
318         return PGOHash::ObjCForCollectionStmt;
319       case Stmt::SwitchStmtClass:
320         return PGOHash::SwitchStmt;
321       case Stmt::CaseStmtClass:
322         return PGOHash::CaseStmt;
323       case Stmt::DefaultStmtClass:
324         return PGOHash::DefaultStmt;
325       case Stmt::IfStmtClass:
326         return PGOHash::IfStmt;
327       case Stmt::CXXTryStmtClass:
328         return PGOHash::CXXTryStmt;
329       case Stmt::CXXCatchStmtClass:
330         return PGOHash::CXXCatchStmt;
331       case Stmt::ConditionalOperatorClass:
332         return PGOHash::ConditionalOperator;
333       case Stmt::BinaryConditionalOperatorClass:
334         return PGOHash::BinaryConditionalOperator;
335       case Stmt::BinaryOperatorClass: {
336         const BinaryOperator *BO = cast<BinaryOperator>(S);
337         if (BO->getOpcode() == BO_LAnd)
338           return PGOHash::BinaryOperatorLAnd;
339         if (BO->getOpcode() == BO_LOr)
340           return PGOHash::BinaryOperatorLOr;
341         break;
342       }
343       }
344       return PGOHash::None;
345     }
346   };
347 
348   /// A StmtVisitor that propagates the raw counts through the AST and
349   /// records the count at statements where the value may change.
350   struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
351     /// PGO state.
352     CodeGenPGO &PGO;
353 
354     /// A flag that is set when the current count should be recorded on the
355     /// next statement, such as at the exit of a loop.
356     bool RecordNextStmtCount;
357 
358     /// The map of statements to count values.
359     llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
360 
361     /// BreakContinueStack - Keep counts of breaks and continues inside loops.
362     struct BreakContinue {
363       uint64_t BreakCount;
364       uint64_t ContinueCount;
365       BreakContinue() : BreakCount(0), ContinueCount(0) {}
366     };
367     SmallVector<BreakContinue, 8> BreakContinueStack;
368 
369     ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
370                         CodeGenPGO &PGO)
371         : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
372 
373     void RecordStmtCount(const Stmt *S) {
374       if (RecordNextStmtCount) {
375         CountMap[S] = PGO.getCurrentRegionCount();
376         RecordNextStmtCount = false;
377       }
378     }
379 
380     void VisitStmt(const Stmt *S) {
381       RecordStmtCount(S);
382       for (Stmt::const_child_range I = S->children(); I; ++I) {
383         if (*I)
384          this->Visit(*I);
385       }
386     }
387 
388     void VisitFunctionDecl(const FunctionDecl *D) {
389       // Counter tracks entry to the function body.
390       RegionCounter Cnt(PGO, D->getBody());
391       Cnt.beginRegion();
392       CountMap[D->getBody()] = PGO.getCurrentRegionCount();
393       Visit(D->getBody());
394     }
395 
396     // Skip lambda expressions. We visit these as FunctionDecls when we're
397     // generating them and aren't interested in the body when generating a
398     // parent context.
399     void VisitLambdaExpr(const LambdaExpr *LE) {}
400 
401     void VisitCapturedDecl(const CapturedDecl *D) {
402       // Counter tracks entry to the capture body.
403       RegionCounter Cnt(PGO, D->getBody());
404       Cnt.beginRegion();
405       CountMap[D->getBody()] = PGO.getCurrentRegionCount();
406       Visit(D->getBody());
407     }
408 
409     void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
410       // Counter tracks entry to the method body.
411       RegionCounter Cnt(PGO, D->getBody());
412       Cnt.beginRegion();
413       CountMap[D->getBody()] = PGO.getCurrentRegionCount();
414       Visit(D->getBody());
415     }
416 
417     void VisitBlockDecl(const BlockDecl *D) {
418       // Counter tracks entry to the block body.
419       RegionCounter Cnt(PGO, D->getBody());
420       Cnt.beginRegion();
421       CountMap[D->getBody()] = PGO.getCurrentRegionCount();
422       Visit(D->getBody());
423     }
424 
425     void VisitReturnStmt(const ReturnStmt *S) {
426       RecordStmtCount(S);
427       if (S->getRetValue())
428         Visit(S->getRetValue());
429       PGO.setCurrentRegionUnreachable();
430       RecordNextStmtCount = true;
431     }
432 
433     void VisitGotoStmt(const GotoStmt *S) {
434       RecordStmtCount(S);
435       PGO.setCurrentRegionUnreachable();
436       RecordNextStmtCount = true;
437     }
438 
439     void VisitLabelStmt(const LabelStmt *S) {
440       RecordNextStmtCount = false;
441       // Counter tracks the block following the label.
442       RegionCounter Cnt(PGO, S);
443       Cnt.beginRegion();
444       CountMap[S] = PGO.getCurrentRegionCount();
445       Visit(S->getSubStmt());
446     }
447 
448     void VisitBreakStmt(const BreakStmt *S) {
449       RecordStmtCount(S);
450       assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
451       BreakContinueStack.back().BreakCount += PGO.getCurrentRegionCount();
452       PGO.setCurrentRegionUnreachable();
453       RecordNextStmtCount = true;
454     }
455 
456     void VisitContinueStmt(const ContinueStmt *S) {
457       RecordStmtCount(S);
458       assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
459       BreakContinueStack.back().ContinueCount += PGO.getCurrentRegionCount();
460       PGO.setCurrentRegionUnreachable();
461       RecordNextStmtCount = true;
462     }
463 
464     void VisitWhileStmt(const WhileStmt *S) {
465       RecordStmtCount(S);
466       // Counter tracks the body of the loop.
467       RegionCounter Cnt(PGO, S);
468       BreakContinueStack.push_back(BreakContinue());
469       // Visit the body region first so the break/continue adjustments can be
470       // included when visiting the condition.
471       Cnt.beginRegion();
472       CountMap[S->getBody()] = PGO.getCurrentRegionCount();
473       Visit(S->getBody());
474       Cnt.adjustForControlFlow();
475 
476       // ...then go back and propagate counts through the condition. The count
477       // at the start of the condition is the sum of the incoming edges,
478       // the backedge from the end of the loop body, and the edges from
479       // continue statements.
480       BreakContinue BC = BreakContinueStack.pop_back_val();
481       Cnt.setCurrentRegionCount(Cnt.getParentCount() +
482                                 Cnt.getAdjustedCount() + BC.ContinueCount);
483       CountMap[S->getCond()] = PGO.getCurrentRegionCount();
484       Visit(S->getCond());
485       Cnt.adjustForControlFlow();
486       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
487       RecordNextStmtCount = true;
488     }
489 
490     void VisitDoStmt(const DoStmt *S) {
491       RecordStmtCount(S);
492       // Counter tracks the body of the loop.
493       RegionCounter Cnt(PGO, S);
494       BreakContinueStack.push_back(BreakContinue());
495       Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
496       CountMap[S->getBody()] = PGO.getCurrentRegionCount();
497       Visit(S->getBody());
498       Cnt.adjustForControlFlow();
499 
500       BreakContinue BC = BreakContinueStack.pop_back_val();
501       // The count at the start of the condition is equal to the count at the
502       // end of the body. The adjusted count does not include either the
503       // fall-through count coming into the loop or the continue count, so add
504       // both of those separately. This is coincidentally the same equation as
505       // with while loops but for different reasons.
506       Cnt.setCurrentRegionCount(Cnt.getParentCount() +
507                                 Cnt.getAdjustedCount() + BC.ContinueCount);
508       CountMap[S->getCond()] = PGO.getCurrentRegionCount();
509       Visit(S->getCond());
510       Cnt.adjustForControlFlow();
511       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
512       RecordNextStmtCount = true;
513     }
514 
515     void VisitForStmt(const ForStmt *S) {
516       RecordStmtCount(S);
517       if (S->getInit())
518         Visit(S->getInit());
519       // Counter tracks the body of the loop.
520       RegionCounter Cnt(PGO, S);
521       BreakContinueStack.push_back(BreakContinue());
522       // Visit the body region first. (This is basically the same as a while
523       // loop; see further comments in VisitWhileStmt.)
524       Cnt.beginRegion();
525       CountMap[S->getBody()] = PGO.getCurrentRegionCount();
526       Visit(S->getBody());
527       Cnt.adjustForControlFlow();
528 
529       // The increment is essentially part of the body but it needs to include
530       // the count for all the continue statements.
531       if (S->getInc()) {
532         Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
533                                   BreakContinueStack.back().ContinueCount);
534         CountMap[S->getInc()] = PGO.getCurrentRegionCount();
535         Visit(S->getInc());
536         Cnt.adjustForControlFlow();
537       }
538 
539       BreakContinue BC = BreakContinueStack.pop_back_val();
540 
541       // ...then go back and propagate counts through the condition.
542       if (S->getCond()) {
543         Cnt.setCurrentRegionCount(Cnt.getParentCount() +
544                                   Cnt.getAdjustedCount() +
545                                   BC.ContinueCount);
546         CountMap[S->getCond()] = PGO.getCurrentRegionCount();
547         Visit(S->getCond());
548         Cnt.adjustForControlFlow();
549       }
550       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
551       RecordNextStmtCount = true;
552     }
553 
554     void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
555       RecordStmtCount(S);
556       Visit(S->getRangeStmt());
557       Visit(S->getBeginEndStmt());
558       // Counter tracks the body of the loop.
559       RegionCounter Cnt(PGO, S);
560       BreakContinueStack.push_back(BreakContinue());
561       // Visit the body region first. (This is basically the same as a while
562       // loop; see further comments in VisitWhileStmt.)
563       Cnt.beginRegion();
564       CountMap[S->getLoopVarStmt()] = PGO.getCurrentRegionCount();
565       Visit(S->getLoopVarStmt());
566       Visit(S->getBody());
567       Cnt.adjustForControlFlow();
568 
569       // The increment is essentially part of the body but it needs to include
570       // the count for all the continue statements.
571       Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
572                                 BreakContinueStack.back().ContinueCount);
573       CountMap[S->getInc()] = PGO.getCurrentRegionCount();
574       Visit(S->getInc());
575       Cnt.adjustForControlFlow();
576 
577       BreakContinue BC = BreakContinueStack.pop_back_val();
578 
579       // ...then go back and propagate counts through the condition.
580       Cnt.setCurrentRegionCount(Cnt.getParentCount() +
581                                 Cnt.getAdjustedCount() +
582                                 BC.ContinueCount);
583       CountMap[S->getCond()] = PGO.getCurrentRegionCount();
584       Visit(S->getCond());
585       Cnt.adjustForControlFlow();
586       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
587       RecordNextStmtCount = true;
588     }
589 
590     void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
591       RecordStmtCount(S);
592       Visit(S->getElement());
593       // Counter tracks the body of the loop.
594       RegionCounter Cnt(PGO, S);
595       BreakContinueStack.push_back(BreakContinue());
596       Cnt.beginRegion();
597       CountMap[S->getBody()] = PGO.getCurrentRegionCount();
598       Visit(S->getBody());
599       BreakContinue BC = BreakContinueStack.pop_back_val();
600       Cnt.adjustForControlFlow();
601       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
602       RecordNextStmtCount = true;
603     }
604 
605     void VisitSwitchStmt(const SwitchStmt *S) {
606       RecordStmtCount(S);
607       Visit(S->getCond());
608       PGO.setCurrentRegionUnreachable();
609       BreakContinueStack.push_back(BreakContinue());
610       Visit(S->getBody());
611       // If the switch is inside a loop, add the continue counts.
612       BreakContinue BC = BreakContinueStack.pop_back_val();
613       if (!BreakContinueStack.empty())
614         BreakContinueStack.back().ContinueCount += BC.ContinueCount;
615       // Counter tracks the exit block of the switch.
616       RegionCounter ExitCnt(PGO, S);
617       ExitCnt.beginRegion();
618       RecordNextStmtCount = true;
619     }
620 
621     void VisitCaseStmt(const CaseStmt *S) {
622       RecordNextStmtCount = false;
623       // Counter for this particular case. This counts only jumps from the
624       // switch header and does not include fallthrough from the case before
625       // this one.
626       RegionCounter Cnt(PGO, S);
627       Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
628       CountMap[S] = Cnt.getCount();
629       RecordNextStmtCount = true;
630       Visit(S->getSubStmt());
631     }
632 
633     void VisitDefaultStmt(const DefaultStmt *S) {
634       RecordNextStmtCount = false;
635       // Counter for this default case. This does not include fallthrough from
636       // the previous case.
637       RegionCounter Cnt(PGO, S);
638       Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
639       CountMap[S] = Cnt.getCount();
640       RecordNextStmtCount = true;
641       Visit(S->getSubStmt());
642     }
643 
644     void VisitIfStmt(const IfStmt *S) {
645       RecordStmtCount(S);
646       // Counter tracks the "then" part of an if statement. The count for
647       // the "else" part, if it exists, will be calculated from this counter.
648       RegionCounter Cnt(PGO, S);
649       Visit(S->getCond());
650 
651       Cnt.beginRegion();
652       CountMap[S->getThen()] = PGO.getCurrentRegionCount();
653       Visit(S->getThen());
654       Cnt.adjustForControlFlow();
655 
656       if (S->getElse()) {
657         Cnt.beginElseRegion();
658         CountMap[S->getElse()] = PGO.getCurrentRegionCount();
659         Visit(S->getElse());
660         Cnt.adjustForControlFlow();
661       }
662       Cnt.applyAdjustmentsToRegion(0);
663       RecordNextStmtCount = true;
664     }
665 
666     void VisitCXXTryStmt(const CXXTryStmt *S) {
667       RecordStmtCount(S);
668       Visit(S->getTryBlock());
669       for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
670         Visit(S->getHandler(I));
671       // Counter tracks the continuation block of the try statement.
672       RegionCounter Cnt(PGO, S);
673       Cnt.beginRegion();
674       RecordNextStmtCount = true;
675     }
676 
677     void VisitCXXCatchStmt(const CXXCatchStmt *S) {
678       RecordNextStmtCount = false;
679       // Counter tracks the catch statement's handler block.
680       RegionCounter Cnt(PGO, S);
681       Cnt.beginRegion();
682       CountMap[S] = PGO.getCurrentRegionCount();
683       Visit(S->getHandlerBlock());
684     }
685 
686     void VisitAbstractConditionalOperator(
687         const AbstractConditionalOperator *E) {
688       RecordStmtCount(E);
689       // Counter tracks the "true" part of a conditional operator. The
690       // count in the "false" part will be calculated from this counter.
691       RegionCounter Cnt(PGO, E);
692       Visit(E->getCond());
693 
694       Cnt.beginRegion();
695       CountMap[E->getTrueExpr()] = PGO.getCurrentRegionCount();
696       Visit(E->getTrueExpr());
697       Cnt.adjustForControlFlow();
698 
699       Cnt.beginElseRegion();
700       CountMap[E->getFalseExpr()] = PGO.getCurrentRegionCount();
701       Visit(E->getFalseExpr());
702       Cnt.adjustForControlFlow();
703 
704       Cnt.applyAdjustmentsToRegion(0);
705       RecordNextStmtCount = true;
706     }
707 
708     void VisitBinLAnd(const BinaryOperator *E) {
709       RecordStmtCount(E);
710       // Counter tracks the right hand side of a logical and operator.
711       RegionCounter Cnt(PGO, E);
712       Visit(E->getLHS());
713       Cnt.beginRegion();
714       CountMap[E->getRHS()] = PGO.getCurrentRegionCount();
715       Visit(E->getRHS());
716       Cnt.adjustForControlFlow();
717       Cnt.applyAdjustmentsToRegion(0);
718       RecordNextStmtCount = true;
719     }
720 
721     void VisitBinLOr(const BinaryOperator *E) {
722       RecordStmtCount(E);
723       // Counter tracks the right hand side of a logical or operator.
724       RegionCounter Cnt(PGO, E);
725       Visit(E->getLHS());
726       Cnt.beginRegion();
727       CountMap[E->getRHS()] = PGO.getCurrentRegionCount();
728       Visit(E->getRHS());
729       Cnt.adjustForControlFlow();
730       Cnt.applyAdjustmentsToRegion(0);
731       RecordNextStmtCount = true;
732     }
733   };
734 }
735 
736 void PGOHash::combine(HashType Type) {
737   // Check that we never combine 0 and only have six bits.
738   assert(Type && "Hash is invalid: unexpected type 0");
739   assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
740 
741   // Pass through MD5 if enough work has built up.
742   if (Count && Count % NumTypesPerWord == 0) {
743     using namespace llvm::support;
744     uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
745     MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
746     Working = 0;
747   }
748 
749   // Accumulate the current type.
750   ++Count;
751   Working = Working << NumBitsPerType | Type;
752 }
753 
754 uint64_t PGOHash::finalize() {
755   // Use Working as the hash directly if we never used MD5.
756   if (Count <= NumTypesPerWord)
757     // No need to byte swap here, since none of the math was endian-dependent.
758     // This number will be byte-swapped as required on endianness transitions,
759     // so we will see the same value on the other side.
760     return Working;
761 
762   // Check for remaining work in Working.
763   if (Working)
764     MD5.update(Working);
765 
766   // Finalize the MD5 and return the hash.
767   llvm::MD5::MD5Result Result;
768   MD5.final(Result);
769   using namespace llvm::support;
770   return endian::read<uint64_t, little, unaligned>(Result);
771 }
772 
773 static void emitRuntimeHook(CodeGenModule &CGM) {
774   const char *const RuntimeVarName = "__llvm_profile_runtime";
775   const char *const RuntimeUserName = "__llvm_profile_runtime_user";
776   if (CGM.getModule().getGlobalVariable(RuntimeVarName))
777     return;
778 
779   // Declare the runtime hook.
780   llvm::LLVMContext &Ctx = CGM.getLLVMContext();
781   auto *Int32Ty = llvm::Type::getInt32Ty(Ctx);
782   auto *Var = new llvm::GlobalVariable(CGM.getModule(), Int32Ty, false,
783                                        llvm::GlobalValue::ExternalLinkage,
784                                        nullptr, RuntimeVarName);
785 
786   // Make a function that uses it.
787   auto *User = llvm::Function::Create(llvm::FunctionType::get(Int32Ty, false),
788                                       llvm::GlobalValue::LinkOnceODRLinkage,
789                                       RuntimeUserName, &CGM.getModule());
790   User->addFnAttr(llvm::Attribute::NoInline);
791   if (CGM.getCodeGenOpts().DisableRedZone)
792     User->addFnAttr(llvm::Attribute::NoRedZone);
793   CGBuilderTy Builder(llvm::BasicBlock::Create(CGM.getLLVMContext(), "", User));
794   auto *Load = Builder.CreateLoad(Var);
795   Builder.CreateRet(Load);
796 
797   // Create a use of the function.  Now the definition of the runtime variable
798   // should get pulled in, along with any static initializears.
799   CGM.addUsedGlobal(User);
800 }
801 
802 void CodeGenPGO::assignRegionCounters(const Decl *D, llvm::Function *Fn) {
803   bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate;
804   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
805   if (!InstrumentRegions && !PGOReader)
806     return;
807   if (!D)
808     return;
809   setFuncName(Fn);
810 
811   // Set the linkage for variables based on the function linkage.  Usually, we
812   // want to match it, but available_externally and extern_weak both have the
813   // wrong semantics.
814   VarLinkage = Fn->getLinkage();
815   switch (VarLinkage) {
816   case llvm::GlobalValue::ExternalWeakLinkage:
817     VarLinkage = llvm::GlobalValue::LinkOnceAnyLinkage;
818     break;
819   case llvm::GlobalValue::AvailableExternallyLinkage:
820     VarLinkage = llvm::GlobalValue::LinkOnceODRLinkage;
821     break;
822   default:
823     break;
824   }
825 
826   mapRegionCounters(D);
827   if (InstrumentRegions) {
828     emitRuntimeHook(CGM);
829     emitCounterVariables();
830   }
831   if (PGOReader) {
832     loadRegionCounts(PGOReader);
833     computeRegionCounts(D);
834     applyFunctionAttributes(PGOReader, Fn);
835   }
836 }
837 
838 void CodeGenPGO::mapRegionCounters(const Decl *D) {
839   RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
840   MapRegionCounters Walker(*RegionCounterMap);
841   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
842     Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
843   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
844     Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
845   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
846     Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
847   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
848     Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
849   NumRegionCounters = Walker.NextCounter;
850   FunctionHash = Walker.Hash.finalize();
851 }
852 
853 void CodeGenPGO::computeRegionCounts(const Decl *D) {
854   StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
855   ComputeRegionCounts Walker(*StmtCountMap, *this);
856   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
857     Walker.VisitFunctionDecl(FD);
858   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
859     Walker.VisitObjCMethodDecl(MD);
860   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
861     Walker.VisitBlockDecl(BD);
862   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
863     Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
864 }
865 
866 void
867 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
868                                     llvm::Function *Fn) {
869   if (!haveRegionCounts())
870     return;
871 
872   uint64_t MaxFunctionCount = PGOReader->getMaximumFunctionCount();
873   uint64_t FunctionCount = getRegionCount(0);
874   if (FunctionCount >= (uint64_t)(0.3 * (double)MaxFunctionCount))
875     // Turn on InlineHint attribute for hot functions.
876     // FIXME: 30% is from preliminary tuning on SPEC, it may not be optimal.
877     Fn->addFnAttr(llvm::Attribute::InlineHint);
878   else if (FunctionCount <= (uint64_t)(0.01 * (double)MaxFunctionCount))
879     // Turn on Cold attribute for cold functions.
880     // FIXME: 1% is from preliminary tuning on SPEC, it may not be optimal.
881     Fn->addFnAttr(llvm::Attribute::Cold);
882 }
883 
884 void CodeGenPGO::emitCounterVariables() {
885   llvm::LLVMContext &Ctx = CGM.getLLVMContext();
886   llvm::ArrayType *CounterTy = llvm::ArrayType::get(llvm::Type::getInt64Ty(Ctx),
887                                                     NumRegionCounters);
888   RegionCounters =
889     new llvm::GlobalVariable(CGM.getModule(), CounterTy, false, VarLinkage,
890                              llvm::Constant::getNullValue(CounterTy),
891                              getFuncVarName("counters"));
892   RegionCounters->setAlignment(8);
893   RegionCounters->setSection(getCountersSection(CGM));
894 }
895 
896 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, unsigned Counter) {
897   if (!RegionCounters)
898     return;
899   llvm::Value *Addr =
900     Builder.CreateConstInBoundsGEP2_64(RegionCounters, 0, Counter);
901   llvm::Value *Count = Builder.CreateLoad(Addr, "pgocount");
902   Count = Builder.CreateAdd(Count, Builder.getInt64(1));
903   Builder.CreateStore(Count, Addr);
904 }
905 
906 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader) {
907   CGM.getPGOStats().Visited++;
908   RegionCounts.reset(new std::vector<uint64_t>);
909   uint64_t Hash;
910   if (PGOReader->getFunctionCounts(getFuncName(), Hash, *RegionCounts)) {
911     CGM.getPGOStats().Missing++;
912     RegionCounts.reset();
913   } else if (Hash != FunctionHash ||
914              RegionCounts->size() != NumRegionCounters) {
915     CGM.getPGOStats().Mismatched++;
916     RegionCounts.reset();
917   }
918 }
919 
920 void CodeGenPGO::destroyRegionCounters() {
921   RegionCounterMap.reset();
922   StmtCountMap.reset();
923   RegionCounts.reset();
924 }
925 
926 /// \brief Calculate what to divide by to scale weights.
927 ///
928 /// Given the maximum weight, calculate a divisor that will scale all the
929 /// weights to strictly less than UINT32_MAX.
930 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
931   return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
932 }
933 
934 /// \brief Scale an individual branch weight (and add 1).
935 ///
936 /// Scale a 64-bit weight down to 32-bits using \c Scale.
937 ///
938 /// According to Laplace's Rule of Succession, it is better to compute the
939 /// weight based on the count plus 1, so universally add 1 to the value.
940 ///
941 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
942 /// greater than \c Weight.
943 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
944   assert(Scale && "scale by 0?");
945   uint64_t Scaled = Weight / Scale + 1;
946   assert(Scaled <= UINT32_MAX && "overflow 32-bits");
947   return Scaled;
948 }
949 
950 llvm::MDNode *CodeGenPGO::createBranchWeights(uint64_t TrueCount,
951                                               uint64_t FalseCount) {
952   // Check for empty weights.
953   if (!TrueCount && !FalseCount)
954     return nullptr;
955 
956   // Calculate how to scale down to 32-bits.
957   uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
958 
959   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
960   return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
961                                       scaleBranchWeight(FalseCount, Scale));
962 }
963 
964 llvm::MDNode *CodeGenPGO::createBranchWeights(ArrayRef<uint64_t> Weights) {
965   // We need at least two elements to create meaningful weights.
966   if (Weights.size() < 2)
967     return nullptr;
968 
969   // Check for empty weights.
970   uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
971   if (MaxWeight == 0)
972     return nullptr;
973 
974   // Calculate how to scale down to 32-bits.
975   uint64_t Scale = calculateWeightScale(MaxWeight);
976 
977   SmallVector<uint32_t, 16> ScaledWeights;
978   ScaledWeights.reserve(Weights.size());
979   for (uint64_t W : Weights)
980     ScaledWeights.push_back(scaleBranchWeight(W, Scale));
981 
982   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
983   return MDHelper.createBranchWeights(ScaledWeights);
984 }
985 
986 llvm::MDNode *CodeGenPGO::createLoopWeights(const Stmt *Cond,
987                                             RegionCounter &Cnt) {
988   if (!haveRegionCounts())
989     return nullptr;
990   uint64_t LoopCount = Cnt.getCount();
991   uint64_t CondCount = 0;
992   bool Found = getStmtCount(Cond, CondCount);
993   assert(Found && "missing expected loop condition count");
994   (void)Found;
995   if (CondCount == 0)
996     return nullptr;
997   return createBranchWeights(LoopCount,
998                              std::max(CondCount, LoopCount) - LoopCount);
999 }
1000