1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Instrumentation-based profile-guided optimization
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "CodeGenPGO.h"
15 #include "CodeGenFunction.h"
16 #include "clang/AST/RecursiveASTVisitor.h"
17 #include "clang/AST/StmtVisitor.h"
18 #include "llvm/IR/MDBuilder.h"
19 #include "llvm/ProfileData/InstrProfReader.h"
20 #include "llvm/Support/Endian.h"
21 #include "llvm/Support/FileSystem.h"
22 #include "llvm/Support/MD5.h"
23 
24 using namespace clang;
25 using namespace CodeGen;
26 
27 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
28   RawFuncName = Fn->getName();
29 
30   // Function names may be prefixed with a binary '1' to indicate
31   // that the backend should not modify the symbols due to any platform
32   // naming convention. Do not include that '1' in the PGO profile name.
33   if (RawFuncName[0] == '\1')
34     RawFuncName = RawFuncName.substr(1);
35 
36   if (!Fn->hasLocalLinkage()) {
37     PrefixedFuncName.reset(new std::string(RawFuncName));
38     return;
39   }
40 
41   // For local symbols, prepend the main file name to distinguish them.
42   // Do not include the full path in the file name since there's no guarantee
43   // that it will stay the same, e.g., if the files are checked out from
44   // version control in different locations.
45   PrefixedFuncName.reset(new std::string(CGM.getCodeGenOpts().MainFileName));
46   if (PrefixedFuncName->empty())
47     PrefixedFuncName->assign("<unknown>");
48   PrefixedFuncName->append(":");
49   PrefixedFuncName->append(RawFuncName);
50 }
51 
52 static llvm::Function *getRegisterFunc(CodeGenModule &CGM) {
53   return CGM.getModule().getFunction("__llvm_profile_register_functions");
54 }
55 
56 static llvm::BasicBlock *getOrInsertRegisterBB(CodeGenModule &CGM) {
57   // Don't do this for Darwin.  compiler-rt uses linker magic.
58   if (CGM.getTarget().getTriple().isOSDarwin())
59     return nullptr;
60 
61   // Only need to insert this once per module.
62   if (llvm::Function *RegisterF = getRegisterFunc(CGM))
63     return &RegisterF->getEntryBlock();
64 
65   // Construct the function.
66   auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext());
67   auto *RegisterFTy = llvm::FunctionType::get(VoidTy, false);
68   auto *RegisterF = llvm::Function::Create(RegisterFTy,
69                                            llvm::GlobalValue::InternalLinkage,
70                                            "__llvm_profile_register_functions",
71                                            &CGM.getModule());
72   RegisterF->setUnnamedAddr(true);
73   if (CGM.getCodeGenOpts().DisableRedZone)
74     RegisterF->addFnAttr(llvm::Attribute::NoRedZone);
75 
76   // Construct and return the entry block.
77   auto *BB = llvm::BasicBlock::Create(CGM.getLLVMContext(), "", RegisterF);
78   CGBuilderTy Builder(BB);
79   Builder.CreateRetVoid();
80   return BB;
81 }
82 
83 static llvm::Constant *getOrInsertRuntimeRegister(CodeGenModule &CGM) {
84   auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext());
85   auto *VoidPtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
86   auto *RuntimeRegisterTy = llvm::FunctionType::get(VoidTy, VoidPtrTy, false);
87   return CGM.getModule().getOrInsertFunction("__llvm_profile_register_function",
88                                              RuntimeRegisterTy);
89 }
90 
91 static bool isMachO(const CodeGenModule &CGM) {
92   return CGM.getTarget().getTriple().isOSBinFormatMachO();
93 }
94 
95 static StringRef getCountersSection(const CodeGenModule &CGM) {
96   return isMachO(CGM) ? "__DATA,__llvm_prf_cnts" : "__llvm_prf_cnts";
97 }
98 
99 static StringRef getNameSection(const CodeGenModule &CGM) {
100   return isMachO(CGM) ? "__DATA,__llvm_prf_names" : "__llvm_prf_names";
101 }
102 
103 static StringRef getDataSection(const CodeGenModule &CGM) {
104   return isMachO(CGM) ? "__DATA,__llvm_prf_data" : "__llvm_prf_data";
105 }
106 
107 llvm::GlobalVariable *CodeGenPGO::buildDataVar() {
108   // Create name variable.
109   llvm::LLVMContext &Ctx = CGM.getLLVMContext();
110   auto *VarName = llvm::ConstantDataArray::getString(Ctx, getFuncName(),
111                                                      false);
112   auto *Name = new llvm::GlobalVariable(CGM.getModule(), VarName->getType(),
113                                         true, VarLinkage, VarName,
114                                         getFuncVarName("name"));
115   Name->setSection(getNameSection(CGM));
116   Name->setAlignment(1);
117 
118   // Create data variable.
119   auto *Int32Ty = llvm::Type::getInt32Ty(Ctx);
120   auto *Int64Ty = llvm::Type::getInt64Ty(Ctx);
121   auto *Int8PtrTy = llvm::Type::getInt8PtrTy(Ctx);
122   auto *Int64PtrTy = llvm::Type::getInt64PtrTy(Ctx);
123   llvm::Type *DataTypes[] = {
124     Int32Ty, Int32Ty, Int64Ty, Int8PtrTy, Int64PtrTy
125   };
126   auto *DataTy = llvm::StructType::get(Ctx, makeArrayRef(DataTypes));
127   llvm::Constant *DataVals[] = {
128     llvm::ConstantInt::get(Int32Ty, getFuncName().size()),
129     llvm::ConstantInt::get(Int32Ty, NumRegionCounters),
130     llvm::ConstantInt::get(Int64Ty, FunctionHash),
131     llvm::ConstantExpr::getBitCast(Name, Int8PtrTy),
132     llvm::ConstantExpr::getBitCast(RegionCounters, Int64PtrTy)
133   };
134   auto *Data =
135     new llvm::GlobalVariable(CGM.getModule(), DataTy, true, VarLinkage,
136                              llvm::ConstantStruct::get(DataTy, DataVals),
137                              getFuncVarName("data"));
138 
139   // All the data should be packed into an array in its own section.
140   Data->setSection(getDataSection(CGM));
141   Data->setAlignment(8);
142 
143   // Make sure the data doesn't get deleted.
144   CGM.addUsedGlobal(Data);
145   return Data;
146 }
147 
148 void CodeGenPGO::emitInstrumentationData() {
149   if (!RegionCounters)
150     return;
151 
152   // Build the data.
153   auto *Data = buildDataVar();
154 
155   // Register the data.
156   auto *RegisterBB = getOrInsertRegisterBB(CGM);
157   if (!RegisterBB)
158     return;
159   CGBuilderTy Builder(RegisterBB->getTerminator());
160   auto *VoidPtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
161   Builder.CreateCall(getOrInsertRuntimeRegister(CGM),
162                      Builder.CreateBitCast(Data, VoidPtrTy));
163 }
164 
165 llvm::Function *CodeGenPGO::emitInitialization(CodeGenModule &CGM) {
166   if (!CGM.getCodeGenOpts().ProfileInstrGenerate)
167     return nullptr;
168 
169   assert(CGM.getModule().getFunction("__llvm_profile_init") == nullptr &&
170          "profile initialization already emitted");
171 
172   // Get the function to call at initialization.
173   llvm::Constant *RegisterF = getRegisterFunc(CGM);
174   if (!RegisterF)
175     return nullptr;
176 
177   // Create the initialization function.
178   auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext());
179   auto *F = llvm::Function::Create(llvm::FunctionType::get(VoidTy, false),
180                                    llvm::GlobalValue::InternalLinkage,
181                                    "__llvm_profile_init", &CGM.getModule());
182   F->setUnnamedAddr(true);
183   F->addFnAttr(llvm::Attribute::NoInline);
184   if (CGM.getCodeGenOpts().DisableRedZone)
185     F->addFnAttr(llvm::Attribute::NoRedZone);
186 
187   // Add the basic block and the necessary calls.
188   CGBuilderTy Builder(llvm::BasicBlock::Create(CGM.getLLVMContext(), "", F));
189   Builder.CreateCall(RegisterF);
190   Builder.CreateRetVoid();
191 
192   return F;
193 }
194 
195 namespace {
196 /// \brief Stable hasher for PGO region counters.
197 ///
198 /// PGOHash produces a stable hash of a given function's control flow.
199 ///
200 /// Changing the output of this hash will invalidate all previously generated
201 /// profiles -- i.e., don't do it.
202 ///
203 /// \note  When this hash does eventually change (years?), we still need to
204 /// support old hashes.  We'll need to pull in the version number from the
205 /// profile data format and use the matching hash function.
206 class PGOHash {
207   uint64_t Working;
208   unsigned Count;
209   llvm::MD5 MD5;
210 
211   static const int NumBitsPerType = 6;
212   static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
213   static const unsigned TooBig = 1u << NumBitsPerType;
214 
215 public:
216   /// \brief Hash values for AST nodes.
217   ///
218   /// Distinct values for AST nodes that have region counters attached.
219   ///
220   /// These values must be stable.  All new members must be added at the end,
221   /// and no members should be removed.  Changing the enumeration value for an
222   /// AST node will affect the hash of every function that contains that node.
223   enum HashType : unsigned char {
224     None = 0,
225     LabelStmt = 1,
226     WhileStmt,
227     DoStmt,
228     ForStmt,
229     CXXForRangeStmt,
230     ObjCForCollectionStmt,
231     SwitchStmt,
232     CaseStmt,
233     DefaultStmt,
234     IfStmt,
235     CXXTryStmt,
236     CXXCatchStmt,
237     ConditionalOperator,
238     BinaryOperatorLAnd,
239     BinaryOperatorLOr,
240     BinaryConditionalOperator,
241 
242     // Keep this last.  It's for the static assert that follows.
243     LastHashType
244   };
245   static_assert(LastHashType <= TooBig, "Too many types in HashType");
246 
247   // TODO: When this format changes, take in a version number here, and use the
248   // old hash calculation for file formats that used the old hash.
249   PGOHash() : Working(0), Count(0) {}
250   void combine(HashType Type);
251   uint64_t finalize();
252 };
253 const int PGOHash::NumBitsPerType;
254 const unsigned PGOHash::NumTypesPerWord;
255 const unsigned PGOHash::TooBig;
256 
257   /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
258   struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
259     /// The next counter value to assign.
260     unsigned NextCounter;
261     /// The function hash.
262     PGOHash Hash;
263     /// The map of statements to counters.
264     llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
265 
266     MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
267         : NextCounter(0), CounterMap(CounterMap) {}
268 
269     // Blocks and lambdas are handled as separate functions, so we need not
270     // traverse them in the parent context.
271     bool TraverseBlockExpr(BlockExpr *BE) { return true; }
272     bool TraverseLambdaBody(LambdaExpr *LE) { return true; }
273     bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
274 
275     bool VisitDecl(const Decl *D) {
276       switch (D->getKind()) {
277       default:
278         break;
279       case Decl::Function:
280       case Decl::CXXMethod:
281       case Decl::CXXConstructor:
282       case Decl::CXXDestructor:
283       case Decl::CXXConversion:
284       case Decl::ObjCMethod:
285       case Decl::Block:
286       case Decl::Captured:
287         CounterMap[D->getBody()] = NextCounter++;
288         break;
289       }
290       return true;
291     }
292 
293     bool VisitStmt(const Stmt *S) {
294       auto Type = getHashType(S);
295       if (Type == PGOHash::None)
296         return true;
297 
298       CounterMap[S] = NextCounter++;
299       Hash.combine(Type);
300       return true;
301     }
302     PGOHash::HashType getHashType(const Stmt *S) {
303       switch (S->getStmtClass()) {
304       default:
305         break;
306       case Stmt::LabelStmtClass:
307         return PGOHash::LabelStmt;
308       case Stmt::WhileStmtClass:
309         return PGOHash::WhileStmt;
310       case Stmt::DoStmtClass:
311         return PGOHash::DoStmt;
312       case Stmt::ForStmtClass:
313         return PGOHash::ForStmt;
314       case Stmt::CXXForRangeStmtClass:
315         return PGOHash::CXXForRangeStmt;
316       case Stmt::ObjCForCollectionStmtClass:
317         return PGOHash::ObjCForCollectionStmt;
318       case Stmt::SwitchStmtClass:
319         return PGOHash::SwitchStmt;
320       case Stmt::CaseStmtClass:
321         return PGOHash::CaseStmt;
322       case Stmt::DefaultStmtClass:
323         return PGOHash::DefaultStmt;
324       case Stmt::IfStmtClass:
325         return PGOHash::IfStmt;
326       case Stmt::CXXTryStmtClass:
327         return PGOHash::CXXTryStmt;
328       case Stmt::CXXCatchStmtClass:
329         return PGOHash::CXXCatchStmt;
330       case Stmt::ConditionalOperatorClass:
331         return PGOHash::ConditionalOperator;
332       case Stmt::BinaryConditionalOperatorClass:
333         return PGOHash::BinaryConditionalOperator;
334       case Stmt::BinaryOperatorClass: {
335         const BinaryOperator *BO = cast<BinaryOperator>(S);
336         if (BO->getOpcode() == BO_LAnd)
337           return PGOHash::BinaryOperatorLAnd;
338         if (BO->getOpcode() == BO_LOr)
339           return PGOHash::BinaryOperatorLOr;
340         break;
341       }
342       }
343       return PGOHash::None;
344     }
345   };
346 
347   /// A StmtVisitor that propagates the raw counts through the AST and
348   /// records the count at statements where the value may change.
349   struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
350     /// PGO state.
351     CodeGenPGO &PGO;
352 
353     /// A flag that is set when the current count should be recorded on the
354     /// next statement, such as at the exit of a loop.
355     bool RecordNextStmtCount;
356 
357     /// The map of statements to count values.
358     llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
359 
360     /// BreakContinueStack - Keep counts of breaks and continues inside loops.
361     struct BreakContinue {
362       uint64_t BreakCount;
363       uint64_t ContinueCount;
364       BreakContinue() : BreakCount(0), ContinueCount(0) {}
365     };
366     SmallVector<BreakContinue, 8> BreakContinueStack;
367 
368     ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
369                         CodeGenPGO &PGO)
370         : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
371 
372     void RecordStmtCount(const Stmt *S) {
373       if (RecordNextStmtCount) {
374         CountMap[S] = PGO.getCurrentRegionCount();
375         RecordNextStmtCount = false;
376       }
377     }
378 
379     void VisitStmt(const Stmt *S) {
380       RecordStmtCount(S);
381       for (Stmt::const_child_range I = S->children(); I; ++I) {
382         if (*I)
383          this->Visit(*I);
384       }
385     }
386 
387     void VisitFunctionDecl(const FunctionDecl *D) {
388       // Counter tracks entry to the function body.
389       RegionCounter Cnt(PGO, D->getBody());
390       Cnt.beginRegion();
391       CountMap[D->getBody()] = PGO.getCurrentRegionCount();
392       Visit(D->getBody());
393     }
394 
395     // Skip lambda expressions. We visit these as FunctionDecls when we're
396     // generating them and aren't interested in the body when generating a
397     // parent context.
398     void VisitLambdaExpr(const LambdaExpr *LE) {}
399 
400     void VisitCapturedDecl(const CapturedDecl *D) {
401       // Counter tracks entry to the capture body.
402       RegionCounter Cnt(PGO, D->getBody());
403       Cnt.beginRegion();
404       CountMap[D->getBody()] = PGO.getCurrentRegionCount();
405       Visit(D->getBody());
406     }
407 
408     void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
409       // Counter tracks entry to the method body.
410       RegionCounter Cnt(PGO, D->getBody());
411       Cnt.beginRegion();
412       CountMap[D->getBody()] = PGO.getCurrentRegionCount();
413       Visit(D->getBody());
414     }
415 
416     void VisitBlockDecl(const BlockDecl *D) {
417       // Counter tracks entry to the block body.
418       RegionCounter Cnt(PGO, D->getBody());
419       Cnt.beginRegion();
420       CountMap[D->getBody()] = PGO.getCurrentRegionCount();
421       Visit(D->getBody());
422     }
423 
424     void VisitReturnStmt(const ReturnStmt *S) {
425       RecordStmtCount(S);
426       if (S->getRetValue())
427         Visit(S->getRetValue());
428       PGO.setCurrentRegionUnreachable();
429       RecordNextStmtCount = true;
430     }
431 
432     void VisitGotoStmt(const GotoStmt *S) {
433       RecordStmtCount(S);
434       PGO.setCurrentRegionUnreachable();
435       RecordNextStmtCount = true;
436     }
437 
438     void VisitLabelStmt(const LabelStmt *S) {
439       RecordNextStmtCount = false;
440       // Counter tracks the block following the label.
441       RegionCounter Cnt(PGO, S);
442       Cnt.beginRegion();
443       CountMap[S] = PGO.getCurrentRegionCount();
444       Visit(S->getSubStmt());
445     }
446 
447     void VisitBreakStmt(const BreakStmt *S) {
448       RecordStmtCount(S);
449       assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
450       BreakContinueStack.back().BreakCount += PGO.getCurrentRegionCount();
451       PGO.setCurrentRegionUnreachable();
452       RecordNextStmtCount = true;
453     }
454 
455     void VisitContinueStmt(const ContinueStmt *S) {
456       RecordStmtCount(S);
457       assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
458       BreakContinueStack.back().ContinueCount += PGO.getCurrentRegionCount();
459       PGO.setCurrentRegionUnreachable();
460       RecordNextStmtCount = true;
461     }
462 
463     void VisitWhileStmt(const WhileStmt *S) {
464       RecordStmtCount(S);
465       // Counter tracks the body of the loop.
466       RegionCounter Cnt(PGO, S);
467       BreakContinueStack.push_back(BreakContinue());
468       // Visit the body region first so the break/continue adjustments can be
469       // included when visiting the condition.
470       Cnt.beginRegion();
471       CountMap[S->getBody()] = PGO.getCurrentRegionCount();
472       Visit(S->getBody());
473       Cnt.adjustForControlFlow();
474 
475       // ...then go back and propagate counts through the condition. The count
476       // at the start of the condition is the sum of the incoming edges,
477       // the backedge from the end of the loop body, and the edges from
478       // continue statements.
479       BreakContinue BC = BreakContinueStack.pop_back_val();
480       Cnt.setCurrentRegionCount(Cnt.getParentCount() +
481                                 Cnt.getAdjustedCount() + BC.ContinueCount);
482       CountMap[S->getCond()] = PGO.getCurrentRegionCount();
483       Visit(S->getCond());
484       Cnt.adjustForControlFlow();
485       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
486       RecordNextStmtCount = true;
487     }
488 
489     void VisitDoStmt(const DoStmt *S) {
490       RecordStmtCount(S);
491       // Counter tracks the body of the loop.
492       RegionCounter Cnt(PGO, S);
493       BreakContinueStack.push_back(BreakContinue());
494       Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
495       CountMap[S->getBody()] = PGO.getCurrentRegionCount();
496       Visit(S->getBody());
497       Cnt.adjustForControlFlow();
498 
499       BreakContinue BC = BreakContinueStack.pop_back_val();
500       // The count at the start of the condition is equal to the count at the
501       // end of the body. The adjusted count does not include either the
502       // fall-through count coming into the loop or the continue count, so add
503       // both of those separately. This is coincidentally the same equation as
504       // with while loops but for different reasons.
505       Cnt.setCurrentRegionCount(Cnt.getParentCount() +
506                                 Cnt.getAdjustedCount() + BC.ContinueCount);
507       CountMap[S->getCond()] = PGO.getCurrentRegionCount();
508       Visit(S->getCond());
509       Cnt.adjustForControlFlow();
510       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
511       RecordNextStmtCount = true;
512     }
513 
514     void VisitForStmt(const ForStmt *S) {
515       RecordStmtCount(S);
516       if (S->getInit())
517         Visit(S->getInit());
518       // Counter tracks the body of the loop.
519       RegionCounter Cnt(PGO, S);
520       BreakContinueStack.push_back(BreakContinue());
521       // Visit the body region first. (This is basically the same as a while
522       // loop; see further comments in VisitWhileStmt.)
523       Cnt.beginRegion();
524       CountMap[S->getBody()] = PGO.getCurrentRegionCount();
525       Visit(S->getBody());
526       Cnt.adjustForControlFlow();
527 
528       // The increment is essentially part of the body but it needs to include
529       // the count for all the continue statements.
530       if (S->getInc()) {
531         Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
532                                   BreakContinueStack.back().ContinueCount);
533         CountMap[S->getInc()] = PGO.getCurrentRegionCount();
534         Visit(S->getInc());
535         Cnt.adjustForControlFlow();
536       }
537 
538       BreakContinue BC = BreakContinueStack.pop_back_val();
539 
540       // ...then go back and propagate counts through the condition.
541       if (S->getCond()) {
542         Cnt.setCurrentRegionCount(Cnt.getParentCount() +
543                                   Cnt.getAdjustedCount() +
544                                   BC.ContinueCount);
545         CountMap[S->getCond()] = PGO.getCurrentRegionCount();
546         Visit(S->getCond());
547         Cnt.adjustForControlFlow();
548       }
549       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
550       RecordNextStmtCount = true;
551     }
552 
553     void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
554       RecordStmtCount(S);
555       Visit(S->getRangeStmt());
556       Visit(S->getBeginEndStmt());
557       // Counter tracks the body of the loop.
558       RegionCounter Cnt(PGO, S);
559       BreakContinueStack.push_back(BreakContinue());
560       // Visit the body region first. (This is basically the same as a while
561       // loop; see further comments in VisitWhileStmt.)
562       Cnt.beginRegion();
563       CountMap[S->getLoopVarStmt()] = PGO.getCurrentRegionCount();
564       Visit(S->getLoopVarStmt());
565       Visit(S->getBody());
566       Cnt.adjustForControlFlow();
567 
568       // The increment is essentially part of the body but it needs to include
569       // the count for all the continue statements.
570       Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
571                                 BreakContinueStack.back().ContinueCount);
572       CountMap[S->getInc()] = PGO.getCurrentRegionCount();
573       Visit(S->getInc());
574       Cnt.adjustForControlFlow();
575 
576       BreakContinue BC = BreakContinueStack.pop_back_val();
577 
578       // ...then go back and propagate counts through the condition.
579       Cnt.setCurrentRegionCount(Cnt.getParentCount() +
580                                 Cnt.getAdjustedCount() +
581                                 BC.ContinueCount);
582       CountMap[S->getCond()] = PGO.getCurrentRegionCount();
583       Visit(S->getCond());
584       Cnt.adjustForControlFlow();
585       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
586       RecordNextStmtCount = true;
587     }
588 
589     void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
590       RecordStmtCount(S);
591       Visit(S->getElement());
592       // Counter tracks the body of the loop.
593       RegionCounter Cnt(PGO, S);
594       BreakContinueStack.push_back(BreakContinue());
595       Cnt.beginRegion();
596       CountMap[S->getBody()] = PGO.getCurrentRegionCount();
597       Visit(S->getBody());
598       BreakContinue BC = BreakContinueStack.pop_back_val();
599       Cnt.adjustForControlFlow();
600       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
601       RecordNextStmtCount = true;
602     }
603 
604     void VisitSwitchStmt(const SwitchStmt *S) {
605       RecordStmtCount(S);
606       Visit(S->getCond());
607       PGO.setCurrentRegionUnreachable();
608       BreakContinueStack.push_back(BreakContinue());
609       Visit(S->getBody());
610       // If the switch is inside a loop, add the continue counts.
611       BreakContinue BC = BreakContinueStack.pop_back_val();
612       if (!BreakContinueStack.empty())
613         BreakContinueStack.back().ContinueCount += BC.ContinueCount;
614       // Counter tracks the exit block of the switch.
615       RegionCounter ExitCnt(PGO, S);
616       ExitCnt.beginRegion();
617       RecordNextStmtCount = true;
618     }
619 
620     void VisitCaseStmt(const CaseStmt *S) {
621       RecordNextStmtCount = false;
622       // Counter for this particular case. This counts only jumps from the
623       // switch header and does not include fallthrough from the case before
624       // this one.
625       RegionCounter Cnt(PGO, S);
626       Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
627       CountMap[S] = Cnt.getCount();
628       RecordNextStmtCount = true;
629       Visit(S->getSubStmt());
630     }
631 
632     void VisitDefaultStmt(const DefaultStmt *S) {
633       RecordNextStmtCount = false;
634       // Counter for this default case. This does not include fallthrough from
635       // the previous case.
636       RegionCounter Cnt(PGO, S);
637       Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
638       CountMap[S] = Cnt.getCount();
639       RecordNextStmtCount = true;
640       Visit(S->getSubStmt());
641     }
642 
643     void VisitIfStmt(const IfStmt *S) {
644       RecordStmtCount(S);
645       // Counter tracks the "then" part of an if statement. The count for
646       // the "else" part, if it exists, will be calculated from this counter.
647       RegionCounter Cnt(PGO, S);
648       Visit(S->getCond());
649 
650       Cnt.beginRegion();
651       CountMap[S->getThen()] = PGO.getCurrentRegionCount();
652       Visit(S->getThen());
653       Cnt.adjustForControlFlow();
654 
655       if (S->getElse()) {
656         Cnt.beginElseRegion();
657         CountMap[S->getElse()] = PGO.getCurrentRegionCount();
658         Visit(S->getElse());
659         Cnt.adjustForControlFlow();
660       }
661       Cnt.applyAdjustmentsToRegion(0);
662       RecordNextStmtCount = true;
663     }
664 
665     void VisitCXXTryStmt(const CXXTryStmt *S) {
666       RecordStmtCount(S);
667       Visit(S->getTryBlock());
668       for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
669         Visit(S->getHandler(I));
670       // Counter tracks the continuation block of the try statement.
671       RegionCounter Cnt(PGO, S);
672       Cnt.beginRegion();
673       RecordNextStmtCount = true;
674     }
675 
676     void VisitCXXCatchStmt(const CXXCatchStmt *S) {
677       RecordNextStmtCount = false;
678       // Counter tracks the catch statement's handler block.
679       RegionCounter Cnt(PGO, S);
680       Cnt.beginRegion();
681       CountMap[S] = PGO.getCurrentRegionCount();
682       Visit(S->getHandlerBlock());
683     }
684 
685     void VisitAbstractConditionalOperator(
686         const AbstractConditionalOperator *E) {
687       RecordStmtCount(E);
688       // Counter tracks the "true" part of a conditional operator. The
689       // count in the "false" part will be calculated from this counter.
690       RegionCounter Cnt(PGO, E);
691       Visit(E->getCond());
692 
693       Cnt.beginRegion();
694       CountMap[E->getTrueExpr()] = PGO.getCurrentRegionCount();
695       Visit(E->getTrueExpr());
696       Cnt.adjustForControlFlow();
697 
698       Cnt.beginElseRegion();
699       CountMap[E->getFalseExpr()] = PGO.getCurrentRegionCount();
700       Visit(E->getFalseExpr());
701       Cnt.adjustForControlFlow();
702 
703       Cnt.applyAdjustmentsToRegion(0);
704       RecordNextStmtCount = true;
705     }
706 
707     void VisitBinLAnd(const BinaryOperator *E) {
708       RecordStmtCount(E);
709       // Counter tracks the right hand side of a logical and operator.
710       RegionCounter Cnt(PGO, E);
711       Visit(E->getLHS());
712       Cnt.beginRegion();
713       CountMap[E->getRHS()] = PGO.getCurrentRegionCount();
714       Visit(E->getRHS());
715       Cnt.adjustForControlFlow();
716       Cnt.applyAdjustmentsToRegion(0);
717       RecordNextStmtCount = true;
718     }
719 
720     void VisitBinLOr(const BinaryOperator *E) {
721       RecordStmtCount(E);
722       // Counter tracks the right hand side of a logical or operator.
723       RegionCounter Cnt(PGO, E);
724       Visit(E->getLHS());
725       Cnt.beginRegion();
726       CountMap[E->getRHS()] = PGO.getCurrentRegionCount();
727       Visit(E->getRHS());
728       Cnt.adjustForControlFlow();
729       Cnt.applyAdjustmentsToRegion(0);
730       RecordNextStmtCount = true;
731     }
732   };
733 }
734 
735 void PGOHash::combine(HashType Type) {
736   // Check that we never combine 0 and only have six bits.
737   assert(Type && "Hash is invalid: unexpected type 0");
738   assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
739 
740   // Pass through MD5 if enough work has built up.
741   if (Count && Count % NumTypesPerWord == 0) {
742     using namespace llvm::support;
743     uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
744     MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
745     Working = 0;
746   }
747 
748   // Accumulate the current type.
749   ++Count;
750   Working = Working << NumBitsPerType | Type;
751 }
752 
753 uint64_t PGOHash::finalize() {
754   // Use Working as the hash directly if we never used MD5.
755   if (Count <= NumTypesPerWord)
756     // No need to byte swap here, since none of the math was endian-dependent.
757     // This number will be byte-swapped as required on endianness transitions,
758     // so we will see the same value on the other side.
759     return Working;
760 
761   // Check for remaining work in Working.
762   if (Working)
763     MD5.update(Working);
764 
765   // Finalize the MD5 and return the hash.
766   llvm::MD5::MD5Result Result;
767   MD5.final(Result);
768   using namespace llvm::support;
769   return endian::read<uint64_t, little, unaligned>(Result);
770 }
771 
772 static void emitRuntimeHook(CodeGenModule &CGM) {
773   const char *const RuntimeVarName = "__llvm_profile_runtime";
774   const char *const RuntimeUserName = "__llvm_profile_runtime_user";
775   if (CGM.getModule().getGlobalVariable(RuntimeVarName))
776     return;
777 
778   // Declare the runtime hook.
779   llvm::LLVMContext &Ctx = CGM.getLLVMContext();
780   auto *Int32Ty = llvm::Type::getInt32Ty(Ctx);
781   auto *Var = new llvm::GlobalVariable(CGM.getModule(), Int32Ty, false,
782                                        llvm::GlobalValue::ExternalLinkage,
783                                        nullptr, RuntimeVarName);
784 
785   // Make a function that uses it.
786   auto *User = llvm::Function::Create(llvm::FunctionType::get(Int32Ty, false),
787                                       llvm::GlobalValue::LinkOnceODRLinkage,
788                                       RuntimeUserName, &CGM.getModule());
789   User->addFnAttr(llvm::Attribute::NoInline);
790   if (CGM.getCodeGenOpts().DisableRedZone)
791     User->addFnAttr(llvm::Attribute::NoRedZone);
792   CGBuilderTy Builder(llvm::BasicBlock::Create(CGM.getLLVMContext(), "", User));
793   auto *Load = Builder.CreateLoad(Var);
794   Builder.CreateRet(Load);
795 
796   // Create a use of the function.  Now the definition of the runtime variable
797   // should get pulled in, along with any static initializears.
798   CGM.addUsedGlobal(User);
799 }
800 
801 void CodeGenPGO::assignRegionCounters(const Decl *D, llvm::Function *Fn) {
802   bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate;
803   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
804   if (!InstrumentRegions && !PGOReader)
805     return;
806   if (D->isImplicit())
807     return;
808   setFuncName(Fn);
809 
810   // Set the linkage for variables based on the function linkage.  Usually, we
811   // want to match it, but available_externally and extern_weak both have the
812   // wrong semantics.
813   VarLinkage = Fn->getLinkage();
814   switch (VarLinkage) {
815   case llvm::GlobalValue::ExternalWeakLinkage:
816     VarLinkage = llvm::GlobalValue::LinkOnceAnyLinkage;
817     break;
818   case llvm::GlobalValue::AvailableExternallyLinkage:
819     VarLinkage = llvm::GlobalValue::LinkOnceODRLinkage;
820     break;
821   default:
822     break;
823   }
824 
825   mapRegionCounters(D);
826   if (InstrumentRegions) {
827     emitRuntimeHook(CGM);
828     emitCounterVariables();
829   }
830   if (PGOReader) {
831     loadRegionCounts(PGOReader);
832     computeRegionCounts(D);
833     applyFunctionAttributes(PGOReader, Fn);
834   }
835 }
836 
837 void CodeGenPGO::mapRegionCounters(const Decl *D) {
838   RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
839   MapRegionCounters Walker(*RegionCounterMap);
840   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
841     Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
842   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
843     Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
844   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
845     Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
846   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
847     Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
848   assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
849   NumRegionCounters = Walker.NextCounter;
850   FunctionHash = Walker.Hash.finalize();
851 }
852 
853 void CodeGenPGO::computeRegionCounts(const Decl *D) {
854   StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
855   ComputeRegionCounts Walker(*StmtCountMap, *this);
856   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
857     Walker.VisitFunctionDecl(FD);
858   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
859     Walker.VisitObjCMethodDecl(MD);
860   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
861     Walker.VisitBlockDecl(BD);
862   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
863     Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
864 }
865 
866 void
867 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
868                                     llvm::Function *Fn) {
869   if (!haveRegionCounts())
870     return;
871 
872   uint64_t MaxFunctionCount = PGOReader->getMaximumFunctionCount();
873   uint64_t FunctionCount = getRegionCount(0);
874   if (FunctionCount >= (uint64_t)(0.3 * (double)MaxFunctionCount))
875     // Turn on InlineHint attribute for hot functions.
876     // FIXME: 30% is from preliminary tuning on SPEC, it may not be optimal.
877     Fn->addFnAttr(llvm::Attribute::InlineHint);
878   else if (FunctionCount <= (uint64_t)(0.01 * (double)MaxFunctionCount))
879     // Turn on Cold attribute for cold functions.
880     // FIXME: 1% is from preliminary tuning on SPEC, it may not be optimal.
881     Fn->addFnAttr(llvm::Attribute::Cold);
882 }
883 
884 void CodeGenPGO::emitCounterVariables() {
885   llvm::LLVMContext &Ctx = CGM.getLLVMContext();
886   llvm::ArrayType *CounterTy = llvm::ArrayType::get(llvm::Type::getInt64Ty(Ctx),
887                                                     NumRegionCounters);
888   RegionCounters =
889     new llvm::GlobalVariable(CGM.getModule(), CounterTy, false, VarLinkage,
890                              llvm::Constant::getNullValue(CounterTy),
891                              getFuncVarName("counters"));
892   RegionCounters->setAlignment(8);
893   RegionCounters->setSection(getCountersSection(CGM));
894 }
895 
896 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, unsigned Counter) {
897   if (!RegionCounters)
898     return;
899   llvm::Value *Addr =
900     Builder.CreateConstInBoundsGEP2_64(RegionCounters, 0, Counter);
901   llvm::Value *Count = Builder.CreateLoad(Addr, "pgocount");
902   Count = Builder.CreateAdd(Count, Builder.getInt64(1));
903   Builder.CreateStore(Count, Addr);
904 }
905 
906 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader) {
907   CGM.getPGOStats().Visited++;
908   RegionCounts.reset(new std::vector<uint64_t>);
909   uint64_t Hash;
910   if (PGOReader->getFunctionCounts(getFuncName(), Hash, *RegionCounts)) {
911     CGM.getPGOStats().Missing++;
912     RegionCounts.reset();
913   } else if (Hash != FunctionHash ||
914              RegionCounts->size() != NumRegionCounters) {
915     CGM.getPGOStats().Mismatched++;
916     RegionCounts.reset();
917   }
918 }
919 
920 void CodeGenPGO::destroyRegionCounters() {
921   RegionCounterMap.reset();
922   StmtCountMap.reset();
923   RegionCounts.reset();
924   RegionCounters = nullptr;
925 }
926 
927 /// \brief Calculate what to divide by to scale weights.
928 ///
929 /// Given the maximum weight, calculate a divisor that will scale all the
930 /// weights to strictly less than UINT32_MAX.
931 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
932   return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
933 }
934 
935 /// \brief Scale an individual branch weight (and add 1).
936 ///
937 /// Scale a 64-bit weight down to 32-bits using \c Scale.
938 ///
939 /// According to Laplace's Rule of Succession, it is better to compute the
940 /// weight based on the count plus 1, so universally add 1 to the value.
941 ///
942 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
943 /// greater than \c Weight.
944 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
945   assert(Scale && "scale by 0?");
946   uint64_t Scaled = Weight / Scale + 1;
947   assert(Scaled <= UINT32_MAX && "overflow 32-bits");
948   return Scaled;
949 }
950 
951 llvm::MDNode *CodeGenPGO::createBranchWeights(uint64_t TrueCount,
952                                               uint64_t FalseCount) {
953   // Check for empty weights.
954   if (!TrueCount && !FalseCount)
955     return nullptr;
956 
957   // Calculate how to scale down to 32-bits.
958   uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
959 
960   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
961   return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
962                                       scaleBranchWeight(FalseCount, Scale));
963 }
964 
965 llvm::MDNode *CodeGenPGO::createBranchWeights(ArrayRef<uint64_t> Weights) {
966   // We need at least two elements to create meaningful weights.
967   if (Weights.size() < 2)
968     return nullptr;
969 
970   // Check for empty weights.
971   uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
972   if (MaxWeight == 0)
973     return nullptr;
974 
975   // Calculate how to scale down to 32-bits.
976   uint64_t Scale = calculateWeightScale(MaxWeight);
977 
978   SmallVector<uint32_t, 16> ScaledWeights;
979   ScaledWeights.reserve(Weights.size());
980   for (uint64_t W : Weights)
981     ScaledWeights.push_back(scaleBranchWeight(W, Scale));
982 
983   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
984   return MDHelper.createBranchWeights(ScaledWeights);
985 }
986 
987 llvm::MDNode *CodeGenPGO::createLoopWeights(const Stmt *Cond,
988                                             RegionCounter &Cnt) {
989   if (!haveRegionCounts())
990     return nullptr;
991   uint64_t LoopCount = Cnt.getCount();
992   uint64_t CondCount = 0;
993   bool Found = getStmtCount(Cond, CondCount);
994   assert(Found && "missing expected loop condition count");
995   (void)Found;
996   if (CondCount == 0)
997     return nullptr;
998   return createBranchWeights(LoopCount,
999                              std::max(CondCount, LoopCount) - LoopCount);
1000 }
1001