1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Instrumentation-based profile-guided optimization
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "CodeGenPGO.h"
15 #include "CodeGenFunction.h"
16 #include "CoverageMappingGen.h"
17 #include "clang/AST/RecursiveASTVisitor.h"
18 #include "clang/AST/StmtVisitor.h"
19 #include "llvm/IR/MDBuilder.h"
20 #include "llvm/ProfileData/InstrProfReader.h"
21 #include "llvm/Support/Endian.h"
22 #include "llvm/Support/FileSystem.h"
23 #include "llvm/Support/MD5.h"
24 
25 using namespace clang;
26 using namespace CodeGen;
27 
28 void CodeGenPGO::setFuncName(StringRef Name,
29                              llvm::GlobalValue::LinkageTypes Linkage) {
30   RawFuncName = Name;
31 
32   // Function names may be prefixed with a binary '1' to indicate
33   // that the backend should not modify the symbols due to any platform
34   // naming convention. Do not include that '1' in the PGO profile name.
35   if (RawFuncName[0] == '\1')
36     RawFuncName = RawFuncName.substr(1);
37 
38   if (!llvm::GlobalValue::isLocalLinkage(Linkage)) {
39     PrefixedFuncName.reset(new std::string(RawFuncName));
40     return;
41   }
42 
43   // For local symbols, prepend the main file name to distinguish them.
44   // Do not include the full path in the file name since there's no guarantee
45   // that it will stay the same, e.g., if the files are checked out from
46   // version control in different locations.
47   PrefixedFuncName.reset(new std::string(CGM.getCodeGenOpts().MainFileName));
48   if (PrefixedFuncName->empty())
49     PrefixedFuncName->assign("<unknown>");
50   PrefixedFuncName->append(":");
51   PrefixedFuncName->append(RawFuncName);
52 }
53 
54 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
55   setFuncName(Fn->getName(), Fn->getLinkage());
56 }
57 
58 void CodeGenPGO::setVarLinkage(llvm::GlobalValue::LinkageTypes Linkage) {
59   // Set the linkage for variables based on the function linkage.  Usually, we
60   // want to match it, but available_externally and extern_weak both have the
61   // wrong semantics.
62   VarLinkage = Linkage;
63   switch (VarLinkage) {
64   case llvm::GlobalValue::ExternalWeakLinkage:
65     VarLinkage = llvm::GlobalValue::LinkOnceAnyLinkage;
66     break;
67   case llvm::GlobalValue::AvailableExternallyLinkage:
68     VarLinkage = llvm::GlobalValue::LinkOnceODRLinkage;
69     break;
70   default:
71     break;
72   }
73 }
74 
75 static llvm::Function *getRegisterFunc(CodeGenModule &CGM) {
76   return CGM.getModule().getFunction("__llvm_profile_register_functions");
77 }
78 
79 static llvm::BasicBlock *getOrInsertRegisterBB(CodeGenModule &CGM) {
80   // Don't do this for Darwin.  compiler-rt uses linker magic.
81   if (CGM.getTarget().getTriple().isOSDarwin())
82     return nullptr;
83 
84   // Only need to insert this once per module.
85   if (llvm::Function *RegisterF = getRegisterFunc(CGM))
86     return &RegisterF->getEntryBlock();
87 
88   // Construct the function.
89   auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext());
90   auto *RegisterFTy = llvm::FunctionType::get(VoidTy, false);
91   auto *RegisterF = llvm::Function::Create(RegisterFTy,
92                                            llvm::GlobalValue::InternalLinkage,
93                                            "__llvm_profile_register_functions",
94                                            &CGM.getModule());
95   RegisterF->setUnnamedAddr(true);
96   if (CGM.getCodeGenOpts().DisableRedZone)
97     RegisterF->addFnAttr(llvm::Attribute::NoRedZone);
98 
99   // Construct and return the entry block.
100   auto *BB = llvm::BasicBlock::Create(CGM.getLLVMContext(), "", RegisterF);
101   CGBuilderTy Builder(BB);
102   Builder.CreateRetVoid();
103   return BB;
104 }
105 
106 static llvm::Constant *getOrInsertRuntimeRegister(CodeGenModule &CGM) {
107   auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext());
108   auto *VoidPtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
109   auto *RuntimeRegisterTy = llvm::FunctionType::get(VoidTy, VoidPtrTy, false);
110   return CGM.getModule().getOrInsertFunction("__llvm_profile_register_function",
111                                              RuntimeRegisterTy);
112 }
113 
114 static bool isMachO(const CodeGenModule &CGM) {
115   return CGM.getTarget().getTriple().isOSBinFormatMachO();
116 }
117 
118 static StringRef getCountersSection(const CodeGenModule &CGM) {
119   return isMachO(CGM) ? "__DATA,__llvm_prf_cnts" : "__llvm_prf_cnts";
120 }
121 
122 static StringRef getNameSection(const CodeGenModule &CGM) {
123   return isMachO(CGM) ? "__DATA,__llvm_prf_names" : "__llvm_prf_names";
124 }
125 
126 static StringRef getDataSection(const CodeGenModule &CGM) {
127   return isMachO(CGM) ? "__DATA,__llvm_prf_data" : "__llvm_prf_data";
128 }
129 
130 llvm::GlobalVariable *CodeGenPGO::buildDataVar() {
131   // Create name variable.
132   llvm::LLVMContext &Ctx = CGM.getLLVMContext();
133   auto *VarName = llvm::ConstantDataArray::getString(Ctx, getFuncName(),
134                                                      false);
135   auto *Name = new llvm::GlobalVariable(CGM.getModule(), VarName->getType(),
136                                         true, VarLinkage, VarName,
137                                         getFuncVarName("name"));
138   Name->setSection(getNameSection(CGM));
139   Name->setAlignment(1);
140 
141   // Create data variable.
142   auto *Int32Ty = llvm::Type::getInt32Ty(Ctx);
143   auto *Int64Ty = llvm::Type::getInt64Ty(Ctx);
144   auto *Int8PtrTy = llvm::Type::getInt8PtrTy(Ctx);
145   auto *Int64PtrTy = llvm::Type::getInt64PtrTy(Ctx);
146   llvm::GlobalVariable *Data = nullptr;
147   if (RegionCounters) {
148     llvm::Type *DataTypes[] = {
149       Int32Ty, Int32Ty, Int64Ty, Int8PtrTy, Int64PtrTy
150     };
151     auto *DataTy = llvm::StructType::get(Ctx, makeArrayRef(DataTypes));
152     llvm::Constant *DataVals[] = {
153       llvm::ConstantInt::get(Int32Ty, getFuncName().size()),
154       llvm::ConstantInt::get(Int32Ty, NumRegionCounters),
155       llvm::ConstantInt::get(Int64Ty, FunctionHash),
156       llvm::ConstantExpr::getBitCast(Name, Int8PtrTy),
157       llvm::ConstantExpr::getBitCast(RegionCounters, Int64PtrTy)
158     };
159     Data =
160       new llvm::GlobalVariable(CGM.getModule(), DataTy, true, VarLinkage,
161                                llvm::ConstantStruct::get(DataTy, DataVals),
162                                getFuncVarName("data"));
163 
164     // All the data should be packed into an array in its own section.
165     Data->setSection(getDataSection(CGM));
166     Data->setAlignment(8);
167   }
168 
169   // Create coverage mapping data variable.
170   if (!CoverageMapping.empty())
171     CGM.getCoverageMapping()->addFunctionMappingRecord(Name, getFuncName(),
172                                                        FunctionHash,
173                                                        CoverageMapping);
174 
175   // Hide all these symbols so that we correctly get a copy for each
176   // executable.  The profile format expects names and counters to be
177   // contiguous, so references into shared objects would be invalid.
178   if (!llvm::GlobalValue::isLocalLinkage(VarLinkage)) {
179     Name->setVisibility(llvm::GlobalValue::HiddenVisibility);
180     if (Data) {
181       Data->setVisibility(llvm::GlobalValue::HiddenVisibility);
182       RegionCounters->setVisibility(llvm::GlobalValue::HiddenVisibility);
183     }
184   }
185 
186   // Make sure the data doesn't get deleted.
187   if (Data) CGM.addUsedGlobal(Data);
188   return Data;
189 }
190 
191 void CodeGenPGO::emitInstrumentationData() {
192   if (!RegionCounters)
193     return;
194 
195   // Build the data.
196   auto *Data = buildDataVar();
197 
198   // Register the data.
199   auto *RegisterBB = getOrInsertRegisterBB(CGM);
200   if (!RegisterBB)
201     return;
202   CGBuilderTy Builder(RegisterBB->getTerminator());
203   auto *VoidPtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
204   Builder.CreateCall(getOrInsertRuntimeRegister(CGM),
205                      Builder.CreateBitCast(Data, VoidPtrTy));
206 }
207 
208 llvm::Function *CodeGenPGO::emitInitialization(CodeGenModule &CGM) {
209   if (!CGM.getCodeGenOpts().ProfileInstrGenerate)
210     return nullptr;
211 
212   assert(CGM.getModule().getFunction("__llvm_profile_init") == nullptr &&
213          "profile initialization already emitted");
214 
215   // Get the function to call at initialization.
216   llvm::Constant *RegisterF = getRegisterFunc(CGM);
217   if (!RegisterF)
218     return nullptr;
219 
220   // Create the initialization function.
221   auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext());
222   auto *F = llvm::Function::Create(llvm::FunctionType::get(VoidTy, false),
223                                    llvm::GlobalValue::InternalLinkage,
224                                    "__llvm_profile_init", &CGM.getModule());
225   F->setUnnamedAddr(true);
226   F->addFnAttr(llvm::Attribute::NoInline);
227   if (CGM.getCodeGenOpts().DisableRedZone)
228     F->addFnAttr(llvm::Attribute::NoRedZone);
229 
230   // Add the basic block and the necessary calls.
231   CGBuilderTy Builder(llvm::BasicBlock::Create(CGM.getLLVMContext(), "", F));
232   Builder.CreateCall(RegisterF);
233   Builder.CreateRetVoid();
234 
235   return F;
236 }
237 
238 namespace {
239 /// \brief Stable hasher for PGO region counters.
240 ///
241 /// PGOHash produces a stable hash of a given function's control flow.
242 ///
243 /// Changing the output of this hash will invalidate all previously generated
244 /// profiles -- i.e., don't do it.
245 ///
246 /// \note  When this hash does eventually change (years?), we still need to
247 /// support old hashes.  We'll need to pull in the version number from the
248 /// profile data format and use the matching hash function.
249 class PGOHash {
250   uint64_t Working;
251   unsigned Count;
252   llvm::MD5 MD5;
253 
254   static const int NumBitsPerType = 6;
255   static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
256   static const unsigned TooBig = 1u << NumBitsPerType;
257 
258 public:
259   /// \brief Hash values for AST nodes.
260   ///
261   /// Distinct values for AST nodes that have region counters attached.
262   ///
263   /// These values must be stable.  All new members must be added at the end,
264   /// and no members should be removed.  Changing the enumeration value for an
265   /// AST node will affect the hash of every function that contains that node.
266   enum HashType : unsigned char {
267     None = 0,
268     LabelStmt = 1,
269     WhileStmt,
270     DoStmt,
271     ForStmt,
272     CXXForRangeStmt,
273     ObjCForCollectionStmt,
274     SwitchStmt,
275     CaseStmt,
276     DefaultStmt,
277     IfStmt,
278     CXXTryStmt,
279     CXXCatchStmt,
280     ConditionalOperator,
281     BinaryOperatorLAnd,
282     BinaryOperatorLOr,
283     BinaryConditionalOperator,
284 
285     // Keep this last.  It's for the static assert that follows.
286     LastHashType
287   };
288   static_assert(LastHashType <= TooBig, "Too many types in HashType");
289 
290   // TODO: When this format changes, take in a version number here, and use the
291   // old hash calculation for file formats that used the old hash.
292   PGOHash() : Working(0), Count(0) {}
293   void combine(HashType Type);
294   uint64_t finalize();
295 };
296 const int PGOHash::NumBitsPerType;
297 const unsigned PGOHash::NumTypesPerWord;
298 const unsigned PGOHash::TooBig;
299 
300   /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
301   struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
302     /// The next counter value to assign.
303     unsigned NextCounter;
304     /// The function hash.
305     PGOHash Hash;
306     /// The map of statements to counters.
307     llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
308 
309     MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
310         : NextCounter(0), CounterMap(CounterMap) {}
311 
312     // Blocks and lambdas are handled as separate functions, so we need not
313     // traverse them in the parent context.
314     bool TraverseBlockExpr(BlockExpr *BE) { return true; }
315     bool TraverseLambdaBody(LambdaExpr *LE) { return true; }
316     bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
317 
318     bool VisitDecl(const Decl *D) {
319       switch (D->getKind()) {
320       default:
321         break;
322       case Decl::Function:
323       case Decl::CXXMethod:
324       case Decl::CXXConstructor:
325       case Decl::CXXDestructor:
326       case Decl::CXXConversion:
327       case Decl::ObjCMethod:
328       case Decl::Block:
329       case Decl::Captured:
330         CounterMap[D->getBody()] = NextCounter++;
331         break;
332       }
333       return true;
334     }
335 
336     bool VisitStmt(const Stmt *S) {
337       auto Type = getHashType(S);
338       if (Type == PGOHash::None)
339         return true;
340 
341       CounterMap[S] = NextCounter++;
342       Hash.combine(Type);
343       return true;
344     }
345     PGOHash::HashType getHashType(const Stmt *S) {
346       switch (S->getStmtClass()) {
347       default:
348         break;
349       case Stmt::LabelStmtClass:
350         return PGOHash::LabelStmt;
351       case Stmt::WhileStmtClass:
352         return PGOHash::WhileStmt;
353       case Stmt::DoStmtClass:
354         return PGOHash::DoStmt;
355       case Stmt::ForStmtClass:
356         return PGOHash::ForStmt;
357       case Stmt::CXXForRangeStmtClass:
358         return PGOHash::CXXForRangeStmt;
359       case Stmt::ObjCForCollectionStmtClass:
360         return PGOHash::ObjCForCollectionStmt;
361       case Stmt::SwitchStmtClass:
362         return PGOHash::SwitchStmt;
363       case Stmt::CaseStmtClass:
364         return PGOHash::CaseStmt;
365       case Stmt::DefaultStmtClass:
366         return PGOHash::DefaultStmt;
367       case Stmt::IfStmtClass:
368         return PGOHash::IfStmt;
369       case Stmt::CXXTryStmtClass:
370         return PGOHash::CXXTryStmt;
371       case Stmt::CXXCatchStmtClass:
372         return PGOHash::CXXCatchStmt;
373       case Stmt::ConditionalOperatorClass:
374         return PGOHash::ConditionalOperator;
375       case Stmt::BinaryConditionalOperatorClass:
376         return PGOHash::BinaryConditionalOperator;
377       case Stmt::BinaryOperatorClass: {
378         const BinaryOperator *BO = cast<BinaryOperator>(S);
379         if (BO->getOpcode() == BO_LAnd)
380           return PGOHash::BinaryOperatorLAnd;
381         if (BO->getOpcode() == BO_LOr)
382           return PGOHash::BinaryOperatorLOr;
383         break;
384       }
385       }
386       return PGOHash::None;
387     }
388   };
389 
390   /// A StmtVisitor that propagates the raw counts through the AST and
391   /// records the count at statements where the value may change.
392   struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
393     /// PGO state.
394     CodeGenPGO &PGO;
395 
396     /// A flag that is set when the current count should be recorded on the
397     /// next statement, such as at the exit of a loop.
398     bool RecordNextStmtCount;
399 
400     /// The map of statements to count values.
401     llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
402 
403     /// BreakContinueStack - Keep counts of breaks and continues inside loops.
404     struct BreakContinue {
405       uint64_t BreakCount;
406       uint64_t ContinueCount;
407       BreakContinue() : BreakCount(0), ContinueCount(0) {}
408     };
409     SmallVector<BreakContinue, 8> BreakContinueStack;
410 
411     ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
412                         CodeGenPGO &PGO)
413         : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
414 
415     void RecordStmtCount(const Stmt *S) {
416       if (RecordNextStmtCount) {
417         CountMap[S] = PGO.getCurrentRegionCount();
418         RecordNextStmtCount = false;
419       }
420     }
421 
422     void VisitStmt(const Stmt *S) {
423       RecordStmtCount(S);
424       for (Stmt::const_child_range I = S->children(); I; ++I) {
425         if (*I)
426          this->Visit(*I);
427       }
428     }
429 
430     void VisitFunctionDecl(const FunctionDecl *D) {
431       // Counter tracks entry to the function body.
432       RegionCounter Cnt(PGO, D->getBody());
433       Cnt.beginRegion();
434       CountMap[D->getBody()] = PGO.getCurrentRegionCount();
435       Visit(D->getBody());
436     }
437 
438     // Skip lambda expressions. We visit these as FunctionDecls when we're
439     // generating them and aren't interested in the body when generating a
440     // parent context.
441     void VisitLambdaExpr(const LambdaExpr *LE) {}
442 
443     void VisitCapturedDecl(const CapturedDecl *D) {
444       // Counter tracks entry to the capture body.
445       RegionCounter Cnt(PGO, D->getBody());
446       Cnt.beginRegion();
447       CountMap[D->getBody()] = PGO.getCurrentRegionCount();
448       Visit(D->getBody());
449     }
450 
451     void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
452       // Counter tracks entry to the method body.
453       RegionCounter Cnt(PGO, D->getBody());
454       Cnt.beginRegion();
455       CountMap[D->getBody()] = PGO.getCurrentRegionCount();
456       Visit(D->getBody());
457     }
458 
459     void VisitBlockDecl(const BlockDecl *D) {
460       // Counter tracks entry to the block body.
461       RegionCounter Cnt(PGO, D->getBody());
462       Cnt.beginRegion();
463       CountMap[D->getBody()] = PGO.getCurrentRegionCount();
464       Visit(D->getBody());
465     }
466 
467     void VisitReturnStmt(const ReturnStmt *S) {
468       RecordStmtCount(S);
469       if (S->getRetValue())
470         Visit(S->getRetValue());
471       PGO.setCurrentRegionUnreachable();
472       RecordNextStmtCount = true;
473     }
474 
475     void VisitGotoStmt(const GotoStmt *S) {
476       RecordStmtCount(S);
477       PGO.setCurrentRegionUnreachable();
478       RecordNextStmtCount = true;
479     }
480 
481     void VisitLabelStmt(const LabelStmt *S) {
482       RecordNextStmtCount = false;
483       // Counter tracks the block following the label.
484       RegionCounter Cnt(PGO, S);
485       Cnt.beginRegion();
486       CountMap[S] = PGO.getCurrentRegionCount();
487       Visit(S->getSubStmt());
488     }
489 
490     void VisitBreakStmt(const BreakStmt *S) {
491       RecordStmtCount(S);
492       assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
493       BreakContinueStack.back().BreakCount += PGO.getCurrentRegionCount();
494       PGO.setCurrentRegionUnreachable();
495       RecordNextStmtCount = true;
496     }
497 
498     void VisitContinueStmt(const ContinueStmt *S) {
499       RecordStmtCount(S);
500       assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
501       BreakContinueStack.back().ContinueCount += PGO.getCurrentRegionCount();
502       PGO.setCurrentRegionUnreachable();
503       RecordNextStmtCount = true;
504     }
505 
506     void VisitWhileStmt(const WhileStmt *S) {
507       RecordStmtCount(S);
508       // Counter tracks the body of the loop.
509       RegionCounter Cnt(PGO, S);
510       BreakContinueStack.push_back(BreakContinue());
511       // Visit the body region first so the break/continue adjustments can be
512       // included when visiting the condition.
513       Cnt.beginRegion();
514       CountMap[S->getBody()] = PGO.getCurrentRegionCount();
515       Visit(S->getBody());
516       Cnt.adjustForControlFlow();
517 
518       // ...then go back and propagate counts through the condition. The count
519       // at the start of the condition is the sum of the incoming edges,
520       // the backedge from the end of the loop body, and the edges from
521       // continue statements.
522       BreakContinue BC = BreakContinueStack.pop_back_val();
523       Cnt.setCurrentRegionCount(Cnt.getParentCount() +
524                                 Cnt.getAdjustedCount() + BC.ContinueCount);
525       CountMap[S->getCond()] = PGO.getCurrentRegionCount();
526       Visit(S->getCond());
527       Cnt.adjustForControlFlow();
528       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
529       RecordNextStmtCount = true;
530     }
531 
532     void VisitDoStmt(const DoStmt *S) {
533       RecordStmtCount(S);
534       // Counter tracks the body of the loop.
535       RegionCounter Cnt(PGO, S);
536       BreakContinueStack.push_back(BreakContinue());
537       Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
538       CountMap[S->getBody()] = PGO.getCurrentRegionCount();
539       Visit(S->getBody());
540       Cnt.adjustForControlFlow();
541 
542       BreakContinue BC = BreakContinueStack.pop_back_val();
543       // The count at the start of the condition is equal to the count at the
544       // end of the body. The adjusted count does not include either the
545       // fall-through count coming into the loop or the continue count, so add
546       // both of those separately. This is coincidentally the same equation as
547       // with while loops but for different reasons.
548       Cnt.setCurrentRegionCount(Cnt.getParentCount() +
549                                 Cnt.getAdjustedCount() + BC.ContinueCount);
550       CountMap[S->getCond()] = PGO.getCurrentRegionCount();
551       Visit(S->getCond());
552       Cnt.adjustForControlFlow();
553       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
554       RecordNextStmtCount = true;
555     }
556 
557     void VisitForStmt(const ForStmt *S) {
558       RecordStmtCount(S);
559       if (S->getInit())
560         Visit(S->getInit());
561       // Counter tracks the body of the loop.
562       RegionCounter Cnt(PGO, S);
563       BreakContinueStack.push_back(BreakContinue());
564       // Visit the body region first. (This is basically the same as a while
565       // loop; see further comments in VisitWhileStmt.)
566       Cnt.beginRegion();
567       CountMap[S->getBody()] = PGO.getCurrentRegionCount();
568       Visit(S->getBody());
569       Cnt.adjustForControlFlow();
570 
571       // The increment is essentially part of the body but it needs to include
572       // the count for all the continue statements.
573       if (S->getInc()) {
574         Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
575                                   BreakContinueStack.back().ContinueCount);
576         CountMap[S->getInc()] = PGO.getCurrentRegionCount();
577         Visit(S->getInc());
578         Cnt.adjustForControlFlow();
579       }
580 
581       BreakContinue BC = BreakContinueStack.pop_back_val();
582 
583       // ...then go back and propagate counts through the condition.
584       if (S->getCond()) {
585         Cnt.setCurrentRegionCount(Cnt.getParentCount() +
586                                   Cnt.getAdjustedCount() +
587                                   BC.ContinueCount);
588         CountMap[S->getCond()] = PGO.getCurrentRegionCount();
589         Visit(S->getCond());
590         Cnt.adjustForControlFlow();
591       }
592       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
593       RecordNextStmtCount = true;
594     }
595 
596     void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
597       RecordStmtCount(S);
598       Visit(S->getRangeStmt());
599       Visit(S->getBeginEndStmt());
600       // Counter tracks the body of the loop.
601       RegionCounter Cnt(PGO, S);
602       BreakContinueStack.push_back(BreakContinue());
603       // Visit the body region first. (This is basically the same as a while
604       // loop; see further comments in VisitWhileStmt.)
605       Cnt.beginRegion();
606       CountMap[S->getLoopVarStmt()] = PGO.getCurrentRegionCount();
607       Visit(S->getLoopVarStmt());
608       Visit(S->getBody());
609       Cnt.adjustForControlFlow();
610 
611       // The increment is essentially part of the body but it needs to include
612       // the count for all the continue statements.
613       Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
614                                 BreakContinueStack.back().ContinueCount);
615       CountMap[S->getInc()] = PGO.getCurrentRegionCount();
616       Visit(S->getInc());
617       Cnt.adjustForControlFlow();
618 
619       BreakContinue BC = BreakContinueStack.pop_back_val();
620 
621       // ...then go back and propagate counts through the condition.
622       Cnt.setCurrentRegionCount(Cnt.getParentCount() +
623                                 Cnt.getAdjustedCount() +
624                                 BC.ContinueCount);
625       CountMap[S->getCond()] = PGO.getCurrentRegionCount();
626       Visit(S->getCond());
627       Cnt.adjustForControlFlow();
628       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
629       RecordNextStmtCount = true;
630     }
631 
632     void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
633       RecordStmtCount(S);
634       Visit(S->getElement());
635       // Counter tracks the body of the loop.
636       RegionCounter Cnt(PGO, S);
637       BreakContinueStack.push_back(BreakContinue());
638       Cnt.beginRegion();
639       CountMap[S->getBody()] = PGO.getCurrentRegionCount();
640       Visit(S->getBody());
641       BreakContinue BC = BreakContinueStack.pop_back_val();
642       Cnt.adjustForControlFlow();
643       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
644       RecordNextStmtCount = true;
645     }
646 
647     void VisitSwitchStmt(const SwitchStmt *S) {
648       RecordStmtCount(S);
649       Visit(S->getCond());
650       PGO.setCurrentRegionUnreachable();
651       BreakContinueStack.push_back(BreakContinue());
652       Visit(S->getBody());
653       // If the switch is inside a loop, add the continue counts.
654       BreakContinue BC = BreakContinueStack.pop_back_val();
655       if (!BreakContinueStack.empty())
656         BreakContinueStack.back().ContinueCount += BC.ContinueCount;
657       // Counter tracks the exit block of the switch.
658       RegionCounter ExitCnt(PGO, S);
659       ExitCnt.beginRegion();
660       RecordNextStmtCount = true;
661     }
662 
663     void VisitCaseStmt(const CaseStmt *S) {
664       RecordNextStmtCount = false;
665       // Counter for this particular case. This counts only jumps from the
666       // switch header and does not include fallthrough from the case before
667       // this one.
668       RegionCounter Cnt(PGO, S);
669       Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
670       CountMap[S] = Cnt.getCount();
671       RecordNextStmtCount = true;
672       Visit(S->getSubStmt());
673     }
674 
675     void VisitDefaultStmt(const DefaultStmt *S) {
676       RecordNextStmtCount = false;
677       // Counter for this default case. This does not include fallthrough from
678       // the previous case.
679       RegionCounter Cnt(PGO, S);
680       Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
681       CountMap[S] = Cnt.getCount();
682       RecordNextStmtCount = true;
683       Visit(S->getSubStmt());
684     }
685 
686     void VisitIfStmt(const IfStmt *S) {
687       RecordStmtCount(S);
688       // Counter tracks the "then" part of an if statement. The count for
689       // the "else" part, if it exists, will be calculated from this counter.
690       RegionCounter Cnt(PGO, S);
691       Visit(S->getCond());
692 
693       Cnt.beginRegion();
694       CountMap[S->getThen()] = PGO.getCurrentRegionCount();
695       Visit(S->getThen());
696       Cnt.adjustForControlFlow();
697 
698       if (S->getElse()) {
699         Cnt.beginElseRegion();
700         CountMap[S->getElse()] = PGO.getCurrentRegionCount();
701         Visit(S->getElse());
702         Cnt.adjustForControlFlow();
703       }
704       Cnt.applyAdjustmentsToRegion(0);
705       RecordNextStmtCount = true;
706     }
707 
708     void VisitCXXTryStmt(const CXXTryStmt *S) {
709       RecordStmtCount(S);
710       Visit(S->getTryBlock());
711       for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
712         Visit(S->getHandler(I));
713       // Counter tracks the continuation block of the try statement.
714       RegionCounter Cnt(PGO, S);
715       Cnt.beginRegion();
716       RecordNextStmtCount = true;
717     }
718 
719     void VisitCXXCatchStmt(const CXXCatchStmt *S) {
720       RecordNextStmtCount = false;
721       // Counter tracks the catch statement's handler block.
722       RegionCounter Cnt(PGO, S);
723       Cnt.beginRegion();
724       CountMap[S] = PGO.getCurrentRegionCount();
725       Visit(S->getHandlerBlock());
726     }
727 
728     void VisitAbstractConditionalOperator(
729         const AbstractConditionalOperator *E) {
730       RecordStmtCount(E);
731       // Counter tracks the "true" part of a conditional operator. The
732       // count in the "false" part will be calculated from this counter.
733       RegionCounter Cnt(PGO, E);
734       Visit(E->getCond());
735 
736       Cnt.beginRegion();
737       CountMap[E->getTrueExpr()] = PGO.getCurrentRegionCount();
738       Visit(E->getTrueExpr());
739       Cnt.adjustForControlFlow();
740 
741       Cnt.beginElseRegion();
742       CountMap[E->getFalseExpr()] = PGO.getCurrentRegionCount();
743       Visit(E->getFalseExpr());
744       Cnt.adjustForControlFlow();
745 
746       Cnt.applyAdjustmentsToRegion(0);
747       RecordNextStmtCount = true;
748     }
749 
750     void VisitBinLAnd(const BinaryOperator *E) {
751       RecordStmtCount(E);
752       // Counter tracks the right hand side of a logical and operator.
753       RegionCounter Cnt(PGO, E);
754       Visit(E->getLHS());
755       Cnt.beginRegion();
756       CountMap[E->getRHS()] = PGO.getCurrentRegionCount();
757       Visit(E->getRHS());
758       Cnt.adjustForControlFlow();
759       Cnt.applyAdjustmentsToRegion(0);
760       RecordNextStmtCount = true;
761     }
762 
763     void VisitBinLOr(const BinaryOperator *E) {
764       RecordStmtCount(E);
765       // Counter tracks the right hand side of a logical or operator.
766       RegionCounter Cnt(PGO, E);
767       Visit(E->getLHS());
768       Cnt.beginRegion();
769       CountMap[E->getRHS()] = PGO.getCurrentRegionCount();
770       Visit(E->getRHS());
771       Cnt.adjustForControlFlow();
772       Cnt.applyAdjustmentsToRegion(0);
773       RecordNextStmtCount = true;
774     }
775   };
776 }
777 
778 void PGOHash::combine(HashType Type) {
779   // Check that we never combine 0 and only have six bits.
780   assert(Type && "Hash is invalid: unexpected type 0");
781   assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
782 
783   // Pass through MD5 if enough work has built up.
784   if (Count && Count % NumTypesPerWord == 0) {
785     using namespace llvm::support;
786     uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
787     MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
788     Working = 0;
789   }
790 
791   // Accumulate the current type.
792   ++Count;
793   Working = Working << NumBitsPerType | Type;
794 }
795 
796 uint64_t PGOHash::finalize() {
797   // Use Working as the hash directly if we never used MD5.
798   if (Count <= NumTypesPerWord)
799     // No need to byte swap here, since none of the math was endian-dependent.
800     // This number will be byte-swapped as required on endianness transitions,
801     // so we will see the same value on the other side.
802     return Working;
803 
804   // Check for remaining work in Working.
805   if (Working)
806     MD5.update(Working);
807 
808   // Finalize the MD5 and return the hash.
809   llvm::MD5::MD5Result Result;
810   MD5.final(Result);
811   using namespace llvm::support;
812   return endian::read<uint64_t, little, unaligned>(Result);
813 }
814 
815 static void emitRuntimeHook(CodeGenModule &CGM) {
816   const char *const RuntimeVarName = "__llvm_profile_runtime";
817   const char *const RuntimeUserName = "__llvm_profile_runtime_user";
818   if (CGM.getModule().getGlobalVariable(RuntimeVarName))
819     return;
820 
821   // Declare the runtime hook.
822   llvm::LLVMContext &Ctx = CGM.getLLVMContext();
823   auto *Int32Ty = llvm::Type::getInt32Ty(Ctx);
824   auto *Var = new llvm::GlobalVariable(CGM.getModule(), Int32Ty, false,
825                                        llvm::GlobalValue::ExternalLinkage,
826                                        nullptr, RuntimeVarName);
827 
828   // Make a function that uses it.
829   auto *User = llvm::Function::Create(llvm::FunctionType::get(Int32Ty, false),
830                                       llvm::GlobalValue::LinkOnceODRLinkage,
831                                       RuntimeUserName, &CGM.getModule());
832   User->addFnAttr(llvm::Attribute::NoInline);
833   if (CGM.getCodeGenOpts().DisableRedZone)
834     User->addFnAttr(llvm::Attribute::NoRedZone);
835   CGBuilderTy Builder(llvm::BasicBlock::Create(CGM.getLLVMContext(), "", User));
836   auto *Load = Builder.CreateLoad(Var);
837   Builder.CreateRet(Load);
838 
839   // Create a use of the function.  Now the definition of the runtime variable
840   // should get pulled in, along with any static initializears.
841   CGM.addUsedGlobal(User);
842 }
843 
844 void CodeGenPGO::checkGlobalDecl(GlobalDecl GD) {
845   // Make sure we only emit coverage mapping for one constructor/destructor.
846   // Clang emits several functions for the constructor and the destructor of
847   // a class. Every function is instrumented, but we only want to provide
848   // coverage for one of them. Because of that we only emit the coverage mapping
849   // for the base constructor/destructor.
850   if ((isa<CXXConstructorDecl>(GD.getDecl()) &&
851        GD.getCtorType() != Ctor_Base) ||
852       (isa<CXXDestructorDecl>(GD.getDecl()) &&
853        GD.getDtorType() != Dtor_Base)) {
854     SkipCoverageMapping = true;
855   }
856 }
857 
858 void CodeGenPGO::assignRegionCounters(const Decl *D, llvm::Function *Fn) {
859   bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate;
860   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
861   if (!InstrumentRegions && !PGOReader)
862     return;
863   if (D->isImplicit())
864     return;
865   CGM.ClearUnusedCoverageMapping(D);
866   setFuncName(Fn);
867   setVarLinkage(Fn->getLinkage());
868 
869   mapRegionCounters(D);
870   if (InstrumentRegions) {
871     emitRuntimeHook(CGM);
872     emitCounterVariables();
873     if (CGM.getCodeGenOpts().CoverageMapping)
874       emitCounterRegionMapping(D);
875   }
876   if (PGOReader) {
877     SourceManager &SM = CGM.getContext().getSourceManager();
878     loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
879     computeRegionCounts(D);
880     applyFunctionAttributes(PGOReader, Fn);
881   }
882 }
883 
884 void CodeGenPGO::mapRegionCounters(const Decl *D) {
885   RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
886   MapRegionCounters Walker(*RegionCounterMap);
887   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
888     Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
889   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
890     Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
891   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
892     Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
893   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
894     Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
895   assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
896   NumRegionCounters = Walker.NextCounter;
897   FunctionHash = Walker.Hash.finalize();
898 }
899 
900 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
901   if (SkipCoverageMapping)
902     return;
903   // Don't map the functions inside the system headers
904   auto Loc = D->getBody()->getLocStart();
905   if (CGM.getContext().getSourceManager().isInSystemHeader(Loc))
906     return;
907 
908   llvm::raw_string_ostream OS(CoverageMapping);
909   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
910                                 CGM.getContext().getSourceManager(),
911                                 CGM.getLangOpts(), RegionCounterMap.get());
912   MappingGen.emitCounterMapping(D, OS);
913   OS.flush();
914 }
915 
916 void
917 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef FuncName,
918                                     llvm::GlobalValue::LinkageTypes Linkage) {
919   if (SkipCoverageMapping)
920     return;
921   setFuncName(FuncName, Linkage);
922   setVarLinkage(Linkage);
923 
924   // Don't map the functions inside the system headers
925   auto Loc = D->getBody()->getLocStart();
926   if (CGM.getContext().getSourceManager().isInSystemHeader(Loc))
927     return;
928 
929   llvm::raw_string_ostream OS(CoverageMapping);
930   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
931                                 CGM.getContext().getSourceManager(),
932                                 CGM.getLangOpts());
933   MappingGen.emitEmptyMapping(D, OS);
934   OS.flush();
935   buildDataVar();
936 }
937 
938 void CodeGenPGO::computeRegionCounts(const Decl *D) {
939   StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
940   ComputeRegionCounts Walker(*StmtCountMap, *this);
941   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
942     Walker.VisitFunctionDecl(FD);
943   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
944     Walker.VisitObjCMethodDecl(MD);
945   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
946     Walker.VisitBlockDecl(BD);
947   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
948     Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
949 }
950 
951 void
952 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
953                                     llvm::Function *Fn) {
954   if (!haveRegionCounts())
955     return;
956 
957   uint64_t MaxFunctionCount = PGOReader->getMaximumFunctionCount();
958   uint64_t FunctionCount = getRegionCount(0);
959   if (FunctionCount >= (uint64_t)(0.3 * (double)MaxFunctionCount))
960     // Turn on InlineHint attribute for hot functions.
961     // FIXME: 30% is from preliminary tuning on SPEC, it may not be optimal.
962     Fn->addFnAttr(llvm::Attribute::InlineHint);
963   else if (FunctionCount <= (uint64_t)(0.01 * (double)MaxFunctionCount))
964     // Turn on Cold attribute for cold functions.
965     // FIXME: 1% is from preliminary tuning on SPEC, it may not be optimal.
966     Fn->addFnAttr(llvm::Attribute::Cold);
967 }
968 
969 void CodeGenPGO::emitCounterVariables() {
970   llvm::LLVMContext &Ctx = CGM.getLLVMContext();
971   llvm::ArrayType *CounterTy = llvm::ArrayType::get(llvm::Type::getInt64Ty(Ctx),
972                                                     NumRegionCounters);
973   RegionCounters =
974     new llvm::GlobalVariable(CGM.getModule(), CounterTy, false, VarLinkage,
975                              llvm::Constant::getNullValue(CounterTy),
976                              getFuncVarName("counters"));
977   RegionCounters->setAlignment(8);
978   RegionCounters->setSection(getCountersSection(CGM));
979 }
980 
981 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, unsigned Counter) {
982   if (!RegionCounters)
983     return;
984   llvm::Value *Addr =
985     Builder.CreateConstInBoundsGEP2_64(RegionCounters, 0, Counter);
986   llvm::Value *Count = Builder.CreateLoad(Addr, "pgocount");
987   Count = Builder.CreateAdd(Count, Builder.getInt64(1));
988   Builder.CreateStore(Count, Addr);
989 }
990 
991 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
992                                   bool IsInMainFile) {
993   CGM.getPGOStats().addVisited(IsInMainFile);
994   RegionCounts.reset(new std::vector<uint64_t>);
995   if (std::error_code EC = PGOReader->getFunctionCounts(
996           getFuncName(), FunctionHash, *RegionCounts)) {
997     if (EC == llvm::instrprof_error::unknown_function)
998       CGM.getPGOStats().addMissing(IsInMainFile);
999     else if (EC == llvm::instrprof_error::hash_mismatch)
1000       CGM.getPGOStats().addMismatched(IsInMainFile);
1001     else if (EC == llvm::instrprof_error::malformed)
1002       // TODO: Consider a more specific warning for this case.
1003       CGM.getPGOStats().addMismatched(IsInMainFile);
1004     RegionCounts.reset();
1005   }
1006 }
1007 
1008 void CodeGenPGO::destroyRegionCounters() {
1009   RegionCounterMap.reset();
1010   StmtCountMap.reset();
1011   RegionCounts.reset();
1012   RegionCounters = nullptr;
1013 }
1014 
1015 /// \brief Calculate what to divide by to scale weights.
1016 ///
1017 /// Given the maximum weight, calculate a divisor that will scale all the
1018 /// weights to strictly less than UINT32_MAX.
1019 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
1020   return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
1021 }
1022 
1023 /// \brief Scale an individual branch weight (and add 1).
1024 ///
1025 /// Scale a 64-bit weight down to 32-bits using \c Scale.
1026 ///
1027 /// According to Laplace's Rule of Succession, it is better to compute the
1028 /// weight based on the count plus 1, so universally add 1 to the value.
1029 ///
1030 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1031 /// greater than \c Weight.
1032 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1033   assert(Scale && "scale by 0?");
1034   uint64_t Scaled = Weight / Scale + 1;
1035   assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1036   return Scaled;
1037 }
1038 
1039 llvm::MDNode *CodeGenPGO::createBranchWeights(uint64_t TrueCount,
1040                                               uint64_t FalseCount) {
1041   // Check for empty weights.
1042   if (!TrueCount && !FalseCount)
1043     return nullptr;
1044 
1045   // Calculate how to scale down to 32-bits.
1046   uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1047 
1048   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1049   return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1050                                       scaleBranchWeight(FalseCount, Scale));
1051 }
1052 
1053 llvm::MDNode *CodeGenPGO::createBranchWeights(ArrayRef<uint64_t> Weights) {
1054   // We need at least two elements to create meaningful weights.
1055   if (Weights.size() < 2)
1056     return nullptr;
1057 
1058   // Check for empty weights.
1059   uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1060   if (MaxWeight == 0)
1061     return nullptr;
1062 
1063   // Calculate how to scale down to 32-bits.
1064   uint64_t Scale = calculateWeightScale(MaxWeight);
1065 
1066   SmallVector<uint32_t, 16> ScaledWeights;
1067   ScaledWeights.reserve(Weights.size());
1068   for (uint64_t W : Weights)
1069     ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1070 
1071   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1072   return MDHelper.createBranchWeights(ScaledWeights);
1073 }
1074 
1075 llvm::MDNode *CodeGenPGO::createLoopWeights(const Stmt *Cond,
1076                                             RegionCounter &Cnt) {
1077   if (!haveRegionCounts())
1078     return nullptr;
1079   uint64_t LoopCount = Cnt.getCount();
1080   uint64_t CondCount = 0;
1081   bool Found = getStmtCount(Cond, CondCount);
1082   assert(Found && "missing expected loop condition count");
1083   (void)Found;
1084   if (CondCount == 0)
1085     return nullptr;
1086   return createBranchWeights(LoopCount,
1087                              std::max(CondCount, LoopCount) - LoopCount);
1088 }
1089