1 //===- BranchProbabilityInfo.cpp - Branch Probability Analysis ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Loops should be simplified before this analysis.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "llvm/Analysis/BranchProbabilityInfo.h"
14 #include "llvm/ADT/PostOrderIterator.h"
15 #include "llvm/ADT/SCCIterator.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/Analysis/LoopInfo.h"
19 #include "llvm/Analysis/PostDominators.h"
20 #include "llvm/Analysis/TargetLibraryInfo.h"
21 #include "llvm/IR/Attributes.h"
22 #include "llvm/IR/BasicBlock.h"
23 #include "llvm/IR/CFG.h"
24 #include "llvm/IR/Constants.h"
25 #include "llvm/IR/Dominators.h"
26 #include "llvm/IR/Function.h"
27 #include "llvm/IR/InstrTypes.h"
28 #include "llvm/IR/Instruction.h"
29 #include "llvm/IR/Instructions.h"
30 #include "llvm/IR/LLVMContext.h"
31 #include "llvm/IR/Metadata.h"
32 #include "llvm/IR/PassManager.h"
33 #include "llvm/IR/Type.h"
34 #include "llvm/IR/Value.h"
35 #include "llvm/InitializePasses.h"
36 #include "llvm/Pass.h"
37 #include "llvm/Support/BranchProbability.h"
38 #include "llvm/Support/Casting.h"
39 #include "llvm/Support/CommandLine.h"
40 #include "llvm/Support/Debug.h"
41 #include "llvm/Support/raw_ostream.h"
42 #include <cassert>
43 #include <cstdint>
44 #include <iterator>
45 #include <utility>
46 
47 using namespace llvm;
48 
49 #define DEBUG_TYPE "branch-prob"
50 
51 static cl::opt<bool> PrintBranchProb(
52     "print-bpi", cl::init(false), cl::Hidden,
53     cl::desc("Print the branch probability info."));
54 
55 cl::opt<std::string> PrintBranchProbFuncName(
56     "print-bpi-func-name", cl::Hidden,
57     cl::desc("The option to specify the name of the function "
58              "whose branch probability info is printed."));
59 
60 INITIALIZE_PASS_BEGIN(BranchProbabilityInfoWrapperPass, "branch-prob",
61                       "Branch Probability Analysis", false, true)
62 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
63 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
64 INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)
65 INITIALIZE_PASS_END(BranchProbabilityInfoWrapperPass, "branch-prob",
66                     "Branch Probability Analysis", false, true)
67 
68 BranchProbabilityInfoWrapperPass::BranchProbabilityInfoWrapperPass()
69     : FunctionPass(ID) {
70   initializeBranchProbabilityInfoWrapperPassPass(
71       *PassRegistry::getPassRegistry());
72 }
73 
74 char BranchProbabilityInfoWrapperPass::ID = 0;
75 
76 // Weights are for internal use only. They are used by heuristics to help to
77 // estimate edges' probability. Example:
78 //
79 // Using "Loop Branch Heuristics" we predict weights of edges for the
80 // block BB2.
81 //         ...
82 //          |
83 //          V
84 //         BB1<-+
85 //          |   |
86 //          |   | (Weight = 124)
87 //          V   |
88 //         BB2--+
89 //          |
90 //          | (Weight = 4)
91 //          V
92 //         BB3
93 //
94 // Probability of the edge BB2->BB1 = 124 / (124 + 4) = 0.96875
95 // Probability of the edge BB2->BB3 = 4 / (124 + 4) = 0.03125
96 static const uint32_t LBH_TAKEN_WEIGHT = 124;
97 static const uint32_t LBH_NONTAKEN_WEIGHT = 4;
98 // Unlikely edges within a loop are half as likely as other edges
99 static const uint32_t LBH_UNLIKELY_WEIGHT = 62;
100 
101 /// Unreachable-terminating branch taken probability.
102 ///
103 /// This is the probability for a branch being taken to a block that terminates
104 /// (eventually) in unreachable. These are predicted as unlikely as possible.
105 /// All reachable probability will proportionally share the remaining part.
106 static const BranchProbability UR_TAKEN_PROB = BranchProbability::getRaw(1);
107 
108 /// Weight for a branch taken going into a cold block.
109 ///
110 /// This is the weight for a branch taken toward a block marked
111 /// cold.  A block is marked cold if it's postdominated by a
112 /// block containing a call to a cold function.  Cold functions
113 /// are those marked with attribute 'cold'.
114 static const uint32_t CC_TAKEN_WEIGHT = 4;
115 
116 /// Weight for a branch not-taken into a cold block.
117 ///
118 /// This is the weight for a branch not taken toward a block marked
119 /// cold.
120 static const uint32_t CC_NONTAKEN_WEIGHT = 64;
121 
122 static const uint32_t PH_TAKEN_WEIGHT = 20;
123 static const uint32_t PH_NONTAKEN_WEIGHT = 12;
124 
125 static const uint32_t ZH_TAKEN_WEIGHT = 20;
126 static const uint32_t ZH_NONTAKEN_WEIGHT = 12;
127 
128 static const uint32_t FPH_TAKEN_WEIGHT = 20;
129 static const uint32_t FPH_NONTAKEN_WEIGHT = 12;
130 
131 /// This is the probability for an ordered floating point comparison.
132 static const uint32_t FPH_ORD_WEIGHT = 1024 * 1024 - 1;
133 /// This is the probability for an unordered floating point comparison, it means
134 /// one or two of the operands are NaN. Usually it is used to test for an
135 /// exceptional case, so the result is unlikely.
136 static const uint32_t FPH_UNO_WEIGHT = 1;
137 
138 /// Invoke-terminating normal branch taken weight
139 ///
140 /// This is the weight for branching to the normal destination of an invoke
141 /// instruction. We expect this to happen most of the time. Set the weight to an
142 /// absurdly high value so that nested loops subsume it.
143 static const uint32_t IH_TAKEN_WEIGHT = 1024 * 1024 - 1;
144 
145 /// Invoke-terminating normal branch not-taken weight.
146 ///
147 /// This is the weight for branching to the unwind destination of an invoke
148 /// instruction. This is essentially never taken.
149 static const uint32_t IH_NONTAKEN_WEIGHT = 1;
150 
151 BranchProbabilityInfo::SccInfo::SccInfo(const Function &F) {
152   // Record SCC numbers of blocks in the CFG to identify irreducible loops.
153   // FIXME: We could only calculate this if the CFG is known to be irreducible
154   // (perhaps cache this info in LoopInfo if we can easily calculate it there?).
155   int SccNum = 0;
156   for (scc_iterator<const Function *> It = scc_begin(&F); !It.isAtEnd();
157        ++It, ++SccNum) {
158     // Ignore single-block SCCs since they either aren't loops or LoopInfo will
159     // catch them.
160     const std::vector<const BasicBlock *> &Scc = *It;
161     if (Scc.size() == 1)
162       continue;
163 
164     LLVM_DEBUG(dbgs() << "BPI: SCC " << SccNum << ":");
165     for (const auto *BB : Scc) {
166       LLVM_DEBUG(dbgs() << " " << BB->getName());
167       SccNums[BB] = SccNum;
168       calculateSccBlockType(BB, SccNum);
169     }
170     LLVM_DEBUG(dbgs() << "\n");
171   }
172 }
173 
174 int BranchProbabilityInfo::SccInfo::getSCCNum(const BasicBlock *BB) const {
175   auto SccIt = SccNums.find(BB);
176   if (SccIt == SccNums.end())
177     return -1;
178   return SccIt->second;
179 }
180 
181 void BranchProbabilityInfo::SccInfo::getSccEnterBlocks(
182     int SccNum, SmallVectorImpl<BasicBlock *> &Enters) const {
183 
184   for (auto MapIt : SccBlocks[SccNum]) {
185     const auto *BB = MapIt.first;
186     if (isSCCHeader(BB, SccNum))
187       for (const auto *Pred : predecessors(BB))
188         if (getSCCNum(Pred) != SccNum)
189           Enters.push_back(const_cast<BasicBlock *>(BB));
190   }
191 }
192 
193 void BranchProbabilityInfo::SccInfo::getSccExitBlocks(
194     int SccNum, SmallVectorImpl<BasicBlock *> &Exits) const {
195   for (auto MapIt : SccBlocks[SccNum]) {
196     const auto *BB = MapIt.first;
197     if (isSCCExitingBlock(BB, SccNum))
198       for (const auto *Succ : successors(BB))
199         if (getSCCNum(Succ) != SccNum)
200           Exits.push_back(const_cast<BasicBlock *>(BB));
201   }
202 }
203 
204 uint32_t BranchProbabilityInfo::SccInfo::getSccBlockType(const BasicBlock *BB,
205                                                          int SccNum) const {
206   assert(getSCCNum(BB) == SccNum);
207 
208   assert(SccBlocks.size() > static_cast<unsigned>(SccNum) && "Unknown SCC");
209   const auto &SccBlockTypes = SccBlocks[SccNum];
210 
211   auto It = SccBlockTypes.find(BB);
212   if (It != SccBlockTypes.end()) {
213     return It->second;
214   }
215   return Inner;
216 }
217 
218 void BranchProbabilityInfo::SccInfo::calculateSccBlockType(const BasicBlock *BB,
219                                                            int SccNum) {
220   assert(getSCCNum(BB) == SccNum);
221   uint32_t BlockType = Inner;
222 
223   if (llvm::any_of(predecessors(BB), [&](const BasicBlock *Pred) {
224         // Consider any block that is an entry point to the SCC as
225         // a header.
226         return getSCCNum(Pred) != SccNum;
227       }))
228     BlockType |= Header;
229 
230   if (llvm::any_of(successors(BB), [&](const BasicBlock *Succ) {
231         return getSCCNum(Succ) != SccNum;
232       }))
233     BlockType |= Exiting;
234 
235   // Lazily compute the set of headers for a given SCC and cache the results
236   // in the SccHeaderMap.
237   if (SccBlocks.size() <= static_cast<unsigned>(SccNum))
238     SccBlocks.resize(SccNum + 1);
239   auto &SccBlockTypes = SccBlocks[SccNum];
240 
241   if (BlockType != Inner) {
242     bool IsInserted;
243     std::tie(std::ignore, IsInserted) =
244         SccBlockTypes.insert(std::make_pair(BB, BlockType));
245     assert(IsInserted && "Duplicated block in SCC");
246   }
247 }
248 
249 BranchProbabilityInfo::LoopBlock::LoopBlock(const BasicBlock *BB,
250                                             const LoopInfo &LI,
251                                             const SccInfo &SccI)
252     : BB(BB) {
253   LD.first = LI.getLoopFor(BB);
254   if (!LD.first) {
255     LD.second = SccI.getSCCNum(BB);
256   }
257 }
258 
259 bool BranchProbabilityInfo::isLoopEnteringEdge(const LoopEdge &Edge) const {
260   const auto &SrcBlock = Edge.first;
261   const auto &DstBlock = Edge.second;
262   return (DstBlock.getLoop() &&
263           !DstBlock.getLoop()->contains(SrcBlock.getLoop())) ||
264          // Assume that SCCs can't be nested.
265          (DstBlock.getSccNum() != -1 &&
266           SrcBlock.getSccNum() != DstBlock.getSccNum());
267 }
268 
269 bool BranchProbabilityInfo::isLoopExitingEdge(const LoopEdge &Edge) const {
270   return isLoopEnteringEdge({Edge.second, Edge.first});
271 }
272 
273 bool BranchProbabilityInfo::isLoopEnteringExitingEdge(
274     const LoopEdge &Edge) const {
275   return isLoopEnteringEdge(Edge) || isLoopExitingEdge(Edge);
276 }
277 
278 bool BranchProbabilityInfo::isLoopBackEdge(const LoopEdge &Edge) const {
279   const auto &SrcBlock = Edge.first;
280   const auto &DstBlock = Edge.second;
281   return SrcBlock.belongsToSameLoop(DstBlock) &&
282          ((DstBlock.getLoop() &&
283            DstBlock.getLoop()->getHeader() == DstBlock.getBlock()) ||
284           (DstBlock.getSccNum() != -1 &&
285            SccI->isSCCHeader(DstBlock.getBlock(), DstBlock.getSccNum())));
286 }
287 
288 void BranchProbabilityInfo::getLoopEnterBlocks(
289     const LoopBlock &LB, SmallVectorImpl<BasicBlock *> &Enters) const {
290   if (LB.getLoop()) {
291     auto *Header = LB.getLoop()->getHeader();
292     Enters.append(pred_begin(Header), pred_end(Header));
293   } else {
294     assert(LB.getSccNum() != -1 && "LB doesn't belong to any loop?");
295     SccI->getSccEnterBlocks(LB.getSccNum(), Enters);
296   }
297 }
298 
299 void BranchProbabilityInfo::getLoopExitBlocks(
300     const LoopBlock &LB, SmallVectorImpl<BasicBlock *> &Exits) const {
301   if (LB.getLoop()) {
302     LB.getLoop()->getExitBlocks(Exits);
303   } else {
304     assert(LB.getSccNum() != -1 && "LB doesn't belong to any loop?");
305     SccI->getSccExitBlocks(LB.getSccNum(), Exits);
306   }
307 }
308 
309 static void UpdatePDTWorklist(const BasicBlock *BB, PostDominatorTree *PDT,
310                               SmallVectorImpl<const BasicBlock *> &WorkList,
311                               SmallPtrSetImpl<const BasicBlock *> &TargetSet) {
312   SmallVector<BasicBlock *, 8> Descendants;
313   SmallPtrSet<const BasicBlock *, 16> NewItems;
314 
315   PDT->getDescendants(const_cast<BasicBlock *>(BB), Descendants);
316   for (auto *BB : Descendants)
317     if (TargetSet.insert(BB).second)
318       for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI)
319         if (!TargetSet.count(*PI))
320           NewItems.insert(*PI);
321   WorkList.insert(WorkList.end(), NewItems.begin(), NewItems.end());
322 }
323 
324 /// Compute a set of basic blocks that are post-dominated by unreachables.
325 void BranchProbabilityInfo::computePostDominatedByUnreachable(
326     const Function &F, PostDominatorTree *PDT) {
327   SmallVector<const BasicBlock *, 8> WorkList;
328   for (auto &BB : F) {
329     const Instruction *TI = BB.getTerminator();
330     if (TI->getNumSuccessors() == 0) {
331       if (isa<UnreachableInst>(TI) ||
332           // If this block is terminated by a call to
333           // @llvm.experimental.deoptimize then treat it like an unreachable
334           // since the @llvm.experimental.deoptimize call is expected to
335           // practically never execute.
336           BB.getTerminatingDeoptimizeCall())
337         UpdatePDTWorklist(&BB, PDT, WorkList, PostDominatedByUnreachable);
338     }
339   }
340 
341   while (!WorkList.empty()) {
342     const BasicBlock *BB = WorkList.pop_back_val();
343     if (PostDominatedByUnreachable.count(BB))
344       continue;
345     // If the terminator is an InvokeInst, check only the normal destination
346     // block as the unwind edge of InvokeInst is also very unlikely taken.
347     if (auto *II = dyn_cast<InvokeInst>(BB->getTerminator())) {
348       if (PostDominatedByUnreachable.count(II->getNormalDest()))
349         UpdatePDTWorklist(BB, PDT, WorkList, PostDominatedByUnreachable);
350     }
351     // If all the successors are unreachable, BB is unreachable as well.
352     else if (!successors(BB).empty() &&
353              llvm::all_of(successors(BB), [this](const BasicBlock *Succ) {
354                return PostDominatedByUnreachable.count(Succ);
355              }))
356       UpdatePDTWorklist(BB, PDT, WorkList, PostDominatedByUnreachable);
357   }
358 }
359 
360 /// compute a set of basic blocks that are post-dominated by ColdCalls.
361 void BranchProbabilityInfo::computePostDominatedByColdCall(
362     const Function &F, PostDominatorTree *PDT) {
363   SmallVector<const BasicBlock *, 8> WorkList;
364   for (auto &BB : F)
365     for (auto &I : BB)
366       if (const CallInst *CI = dyn_cast<CallInst>(&I))
367         if (CI->hasFnAttr(Attribute::Cold))
368           UpdatePDTWorklist(&BB, PDT, WorkList, PostDominatedByColdCall);
369 
370   while (!WorkList.empty()) {
371     const BasicBlock *BB = WorkList.pop_back_val();
372 
373     // If the terminator is an InvokeInst, check only the normal destination
374     // block as the unwind edge of InvokeInst is also very unlikely taken.
375     if (auto *II = dyn_cast<InvokeInst>(BB->getTerminator())) {
376       if (PostDominatedByColdCall.count(II->getNormalDest()))
377         UpdatePDTWorklist(BB, PDT, WorkList, PostDominatedByColdCall);
378     }
379     // If all of successor are post dominated then BB is also done.
380     else if (!successors(BB).empty() &&
381              llvm::all_of(successors(BB), [this](const BasicBlock *Succ) {
382                return PostDominatedByColdCall.count(Succ);
383              }))
384       UpdatePDTWorklist(BB, PDT, WorkList, PostDominatedByColdCall);
385   }
386 }
387 
388 /// Calculate edge weights for successors lead to unreachable.
389 ///
390 /// Predict that a successor which leads necessarily to an
391 /// unreachable-terminated block as extremely unlikely.
392 bool BranchProbabilityInfo::calcUnreachableHeuristics(const BasicBlock *BB) {
393   const Instruction *TI = BB->getTerminator();
394   (void) TI;
395   assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
396   assert(!isa<InvokeInst>(TI) &&
397          "Invokes should have already been handled by calcInvokeHeuristics");
398 
399   SmallVector<unsigned, 4> UnreachableEdges;
400   SmallVector<unsigned, 4> ReachableEdges;
401 
402   for (const_succ_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I)
403     if (PostDominatedByUnreachable.count(*I))
404       UnreachableEdges.push_back(I.getSuccessorIndex());
405     else
406       ReachableEdges.push_back(I.getSuccessorIndex());
407 
408   // Skip probabilities if all were reachable.
409   if (UnreachableEdges.empty())
410     return false;
411 
412   SmallVector<BranchProbability, 4> EdgeProbabilities(
413       BB->getTerminator()->getNumSuccessors(), BranchProbability::getUnknown());
414   if (ReachableEdges.empty()) {
415     BranchProbability Prob(1, UnreachableEdges.size());
416     for (unsigned SuccIdx : UnreachableEdges)
417       EdgeProbabilities[SuccIdx] = Prob;
418     setEdgeProbability(BB, EdgeProbabilities);
419     return true;
420   }
421 
422   auto UnreachableProb = UR_TAKEN_PROB;
423   auto ReachableProb =
424       (BranchProbability::getOne() - UR_TAKEN_PROB * UnreachableEdges.size()) /
425       ReachableEdges.size();
426 
427   for (unsigned SuccIdx : UnreachableEdges)
428     EdgeProbabilities[SuccIdx] = UnreachableProb;
429   for (unsigned SuccIdx : ReachableEdges)
430     EdgeProbabilities[SuccIdx] = ReachableProb;
431 
432   setEdgeProbability(BB, EdgeProbabilities);
433   return true;
434 }
435 
436 // Propagate existing explicit probabilities from either profile data or
437 // 'expect' intrinsic processing. Examine metadata against unreachable
438 // heuristic. The probability of the edge coming to unreachable block is
439 // set to min of metadata and unreachable heuristic.
440 bool BranchProbabilityInfo::calcMetadataWeights(const BasicBlock *BB) {
441   const Instruction *TI = BB->getTerminator();
442   assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
443   if (!(isa<BranchInst>(TI) || isa<SwitchInst>(TI) || isa<IndirectBrInst>(TI) ||
444         isa<InvokeInst>(TI)))
445     return false;
446 
447   MDNode *WeightsNode = TI->getMetadata(LLVMContext::MD_prof);
448   if (!WeightsNode)
449     return false;
450 
451   // Check that the number of successors is manageable.
452   assert(TI->getNumSuccessors() < UINT32_MAX && "Too many successors");
453 
454   // Ensure there are weights for all of the successors. Note that the first
455   // operand to the metadata node is a name, not a weight.
456   if (WeightsNode->getNumOperands() != TI->getNumSuccessors() + 1)
457     return false;
458 
459   // Build up the final weights that will be used in a temporary buffer.
460   // Compute the sum of all weights to later decide whether they need to
461   // be scaled to fit in 32 bits.
462   uint64_t WeightSum = 0;
463   SmallVector<uint32_t, 2> Weights;
464   SmallVector<unsigned, 2> UnreachableIdxs;
465   SmallVector<unsigned, 2> ReachableIdxs;
466   Weights.reserve(TI->getNumSuccessors());
467   for (unsigned I = 1, E = WeightsNode->getNumOperands(); I != E; ++I) {
468     ConstantInt *Weight =
469         mdconst::dyn_extract<ConstantInt>(WeightsNode->getOperand(I));
470     if (!Weight)
471       return false;
472     assert(Weight->getValue().getActiveBits() <= 32 &&
473            "Too many bits for uint32_t");
474     Weights.push_back(Weight->getZExtValue());
475     WeightSum += Weights.back();
476     if (PostDominatedByUnreachable.count(TI->getSuccessor(I - 1)))
477       UnreachableIdxs.push_back(I - 1);
478     else
479       ReachableIdxs.push_back(I - 1);
480   }
481   assert(Weights.size() == TI->getNumSuccessors() && "Checked above");
482 
483   // If the sum of weights does not fit in 32 bits, scale every weight down
484   // accordingly.
485   uint64_t ScalingFactor =
486       (WeightSum > UINT32_MAX) ? WeightSum / UINT32_MAX + 1 : 1;
487 
488   if (ScalingFactor > 1) {
489     WeightSum = 0;
490     for (unsigned I = 0, E = TI->getNumSuccessors(); I != E; ++I) {
491       Weights[I] /= ScalingFactor;
492       WeightSum += Weights[I];
493     }
494   }
495   assert(WeightSum <= UINT32_MAX &&
496          "Expected weights to scale down to 32 bits");
497 
498   if (WeightSum == 0 || ReachableIdxs.size() == 0) {
499     for (unsigned I = 0, E = TI->getNumSuccessors(); I != E; ++I)
500       Weights[I] = 1;
501     WeightSum = TI->getNumSuccessors();
502   }
503 
504   // Set the probability.
505   SmallVector<BranchProbability, 2> BP;
506   for (unsigned I = 0, E = TI->getNumSuccessors(); I != E; ++I)
507     BP.push_back({ Weights[I], static_cast<uint32_t>(WeightSum) });
508 
509   // Examine the metadata against unreachable heuristic.
510   // If the unreachable heuristic is more strong then we use it for this edge.
511   if (UnreachableIdxs.size() == 0 || ReachableIdxs.size() == 0) {
512     setEdgeProbability(BB, BP);
513     return true;
514   }
515 
516   auto UnreachableProb = UR_TAKEN_PROB;
517   for (auto I : UnreachableIdxs)
518     if (UnreachableProb < BP[I]) {
519       BP[I] = UnreachableProb;
520     }
521 
522   // Sum of all edge probabilities must be 1.0. If we modified the probability
523   // of some edges then we must distribute the introduced difference over the
524   // reachable blocks.
525   //
526   // Proportional distribution: the relation between probabilities of the
527   // reachable edges is kept unchanged. That is for any reachable edges i and j:
528   //   newBP[i] / newBP[j] == oldBP[i] / oldBP[j] =>
529   //   newBP[i] / oldBP[i] == newBP[j] / oldBP[j] == K
530   // Where K is independent of i,j.
531   //   newBP[i] == oldBP[i] * K
532   // We need to find K.
533   // Make sum of all reachables of the left and right parts:
534   //   sum_of_reachable(newBP) == K * sum_of_reachable(oldBP)
535   // Sum of newBP must be equal to 1.0:
536   //   sum_of_reachable(newBP) + sum_of_unreachable(newBP) == 1.0 =>
537   //   sum_of_reachable(newBP) = 1.0 - sum_of_unreachable(newBP)
538   // Where sum_of_unreachable(newBP) is what has been just changed.
539   // Finally:
540   //   K == sum_of_reachable(newBP) / sum_of_reachable(oldBP) =>
541   //   K == (1.0 - sum_of_unreachable(newBP)) / sum_of_reachable(oldBP)
542   BranchProbability NewUnreachableSum = BranchProbability::getZero();
543   for (auto I : UnreachableIdxs)
544     NewUnreachableSum += BP[I];
545 
546   BranchProbability NewReachableSum =
547       BranchProbability::getOne() - NewUnreachableSum;
548 
549   BranchProbability OldReachableSum = BranchProbability::getZero();
550   for (auto I : ReachableIdxs)
551     OldReachableSum += BP[I];
552 
553   if (OldReachableSum != NewReachableSum) { // Anything to dsitribute?
554     if (OldReachableSum.isZero()) {
555       // If all oldBP[i] are zeroes then the proportional distribution results
556       // in all zero probabilities and the error stays big. In this case we
557       // evenly spread NewReachableSum over the reachable edges.
558       BranchProbability PerEdge = NewReachableSum / ReachableIdxs.size();
559       for (auto I : ReachableIdxs)
560         BP[I] = PerEdge;
561     } else {
562       for (auto I : ReachableIdxs) {
563         // We use uint64_t to avoid double rounding error of the following
564         // calculation: BP[i] = BP[i] * NewReachableSum / OldReachableSum
565         // The formula is taken from the private constructor
566         // BranchProbability(uint32_t Numerator, uint32_t Denominator)
567         uint64_t Mul = static_cast<uint64_t>(NewReachableSum.getNumerator()) *
568                        BP[I].getNumerator();
569         uint32_t Div = static_cast<uint32_t>(
570             divideNearest(Mul, OldReachableSum.getNumerator()));
571         BP[I] = BranchProbability::getRaw(Div);
572       }
573     }
574   }
575 
576   setEdgeProbability(BB, BP);
577 
578   return true;
579 }
580 
581 /// Calculate edge weights for edges leading to cold blocks.
582 ///
583 /// A cold block is one post-dominated by  a block with a call to a
584 /// cold function.  Those edges are unlikely to be taken, so we give
585 /// them relatively low weight.
586 ///
587 /// Return true if we could compute the weights for cold edges.
588 /// Return false, otherwise.
589 bool BranchProbabilityInfo::calcColdCallHeuristics(const BasicBlock *BB) {
590   const Instruction *TI = BB->getTerminator();
591   (void) TI;
592   assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
593   assert(!isa<InvokeInst>(TI) &&
594          "Invokes should have already been handled by calcInvokeHeuristics");
595 
596   // Determine which successors are post-dominated by a cold block.
597   SmallVector<unsigned, 4> ColdEdges;
598   SmallVector<unsigned, 4> NormalEdges;
599   for (const_succ_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I)
600     if (PostDominatedByColdCall.count(*I))
601       ColdEdges.push_back(I.getSuccessorIndex());
602     else
603       NormalEdges.push_back(I.getSuccessorIndex());
604 
605   // Skip probabilities if no cold edges.
606   if (ColdEdges.empty())
607     return false;
608 
609   SmallVector<BranchProbability, 4> EdgeProbabilities(
610       BB->getTerminator()->getNumSuccessors(), BranchProbability::getUnknown());
611   if (NormalEdges.empty()) {
612     BranchProbability Prob(1, ColdEdges.size());
613     for (unsigned SuccIdx : ColdEdges)
614       EdgeProbabilities[SuccIdx] = Prob;
615     setEdgeProbability(BB, EdgeProbabilities);
616     return true;
617   }
618 
619   auto ColdProb = BranchProbability::getBranchProbability(
620       CC_TAKEN_WEIGHT,
621       (CC_TAKEN_WEIGHT + CC_NONTAKEN_WEIGHT) * uint64_t(ColdEdges.size()));
622   auto NormalProb = BranchProbability::getBranchProbability(
623       CC_NONTAKEN_WEIGHT,
624       (CC_TAKEN_WEIGHT + CC_NONTAKEN_WEIGHT) * uint64_t(NormalEdges.size()));
625 
626   for (unsigned SuccIdx : ColdEdges)
627     EdgeProbabilities[SuccIdx] = ColdProb;
628   for (unsigned SuccIdx : NormalEdges)
629     EdgeProbabilities[SuccIdx] = NormalProb;
630 
631   setEdgeProbability(BB, EdgeProbabilities);
632   return true;
633 }
634 
635 // Calculate Edge Weights using "Pointer Heuristics". Predict a comparison
636 // between two pointer or pointer and NULL will fail.
637 bool BranchProbabilityInfo::calcPointerHeuristics(const BasicBlock *BB) {
638   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
639   if (!BI || !BI->isConditional())
640     return false;
641 
642   Value *Cond = BI->getCondition();
643   ICmpInst *CI = dyn_cast<ICmpInst>(Cond);
644   if (!CI || !CI->isEquality())
645     return false;
646 
647   Value *LHS = CI->getOperand(0);
648 
649   if (!LHS->getType()->isPointerTy())
650     return false;
651 
652   assert(CI->getOperand(1)->getType()->isPointerTy());
653 
654   BranchProbability TakenProb(PH_TAKEN_WEIGHT,
655                               PH_TAKEN_WEIGHT + PH_NONTAKEN_WEIGHT);
656   BranchProbability UntakenProb(PH_NONTAKEN_WEIGHT,
657                                 PH_TAKEN_WEIGHT + PH_NONTAKEN_WEIGHT);
658 
659   // p != 0   ->   isProb = true
660   // p == 0   ->   isProb = false
661   // p != q   ->   isProb = true
662   // p == q   ->   isProb = false;
663   bool isProb = CI->getPredicate() == ICmpInst::ICMP_NE;
664   if (!isProb)
665     std::swap(TakenProb, UntakenProb);
666 
667   setEdgeProbability(
668       BB, SmallVector<BranchProbability, 2>({TakenProb, UntakenProb}));
669   return true;
670 }
671 
672 // Compute the unlikely successors to the block BB in the loop L, specifically
673 // those that are unlikely because this is a loop, and add them to the
674 // UnlikelyBlocks set.
675 static void
676 computeUnlikelySuccessors(const BasicBlock *BB, Loop *L,
677                           SmallPtrSetImpl<const BasicBlock*> &UnlikelyBlocks) {
678   // Sometimes in a loop we have a branch whose condition is made false by
679   // taking it. This is typically something like
680   //  int n = 0;
681   //  while (...) {
682   //    if (++n >= MAX) {
683   //      n = 0;
684   //    }
685   //  }
686   // In this sort of situation taking the branch means that at the very least it
687   // won't be taken again in the next iteration of the loop, so we should
688   // consider it less likely than a typical branch.
689   //
690   // We detect this by looking back through the graph of PHI nodes that sets the
691   // value that the condition depends on, and seeing if we can reach a successor
692   // block which can be determined to make the condition false.
693   //
694   // FIXME: We currently consider unlikely blocks to be half as likely as other
695   // blocks, but if we consider the example above the likelyhood is actually
696   // 1/MAX. We could therefore be more precise in how unlikely we consider
697   // blocks to be, but it would require more careful examination of the form
698   // of the comparison expression.
699   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
700   if (!BI || !BI->isConditional())
701     return;
702 
703   // Check if the branch is based on an instruction compared with a constant
704   CmpInst *CI = dyn_cast<CmpInst>(BI->getCondition());
705   if (!CI || !isa<Instruction>(CI->getOperand(0)) ||
706       !isa<Constant>(CI->getOperand(1)))
707     return;
708 
709   // Either the instruction must be a PHI, or a chain of operations involving
710   // constants that ends in a PHI which we can then collapse into a single value
711   // if the PHI value is known.
712   Instruction *CmpLHS = dyn_cast<Instruction>(CI->getOperand(0));
713   PHINode *CmpPHI = dyn_cast<PHINode>(CmpLHS);
714   Constant *CmpConst = dyn_cast<Constant>(CI->getOperand(1));
715   // Collect the instructions until we hit a PHI
716   SmallVector<BinaryOperator *, 1> InstChain;
717   while (!CmpPHI && CmpLHS && isa<BinaryOperator>(CmpLHS) &&
718          isa<Constant>(CmpLHS->getOperand(1))) {
719     // Stop if the chain extends outside of the loop
720     if (!L->contains(CmpLHS))
721       return;
722     InstChain.push_back(cast<BinaryOperator>(CmpLHS));
723     CmpLHS = dyn_cast<Instruction>(CmpLHS->getOperand(0));
724     if (CmpLHS)
725       CmpPHI = dyn_cast<PHINode>(CmpLHS);
726   }
727   if (!CmpPHI || !L->contains(CmpPHI))
728     return;
729 
730   // Trace the phi node to find all values that come from successors of BB
731   SmallPtrSet<PHINode*, 8> VisitedInsts;
732   SmallVector<PHINode*, 8> WorkList;
733   WorkList.push_back(CmpPHI);
734   VisitedInsts.insert(CmpPHI);
735   while (!WorkList.empty()) {
736     PHINode *P = WorkList.back();
737     WorkList.pop_back();
738     for (BasicBlock *B : P->blocks()) {
739       // Skip blocks that aren't part of the loop
740       if (!L->contains(B))
741         continue;
742       Value *V = P->getIncomingValueForBlock(B);
743       // If the source is a PHI add it to the work list if we haven't
744       // already visited it.
745       if (PHINode *PN = dyn_cast<PHINode>(V)) {
746         if (VisitedInsts.insert(PN).second)
747           WorkList.push_back(PN);
748         continue;
749       }
750       // If this incoming value is a constant and B is a successor of BB, then
751       // we can constant-evaluate the compare to see if it makes the branch be
752       // taken or not.
753       Constant *CmpLHSConst = dyn_cast<Constant>(V);
754       if (!CmpLHSConst || !llvm::is_contained(successors(BB), B))
755         continue;
756       // First collapse InstChain
757       for (Instruction *I : llvm::reverse(InstChain)) {
758         CmpLHSConst = ConstantExpr::get(I->getOpcode(), CmpLHSConst,
759                                         cast<Constant>(I->getOperand(1)), true);
760         if (!CmpLHSConst)
761           break;
762       }
763       if (!CmpLHSConst)
764         continue;
765       // Now constant-evaluate the compare
766       Constant *Result = ConstantExpr::getCompare(CI->getPredicate(),
767                                                   CmpLHSConst, CmpConst, true);
768       // If the result means we don't branch to the block then that block is
769       // unlikely.
770       if (Result &&
771           ((Result->isZeroValue() && B == BI->getSuccessor(0)) ||
772            (Result->isOneValue() && B == BI->getSuccessor(1))))
773         UnlikelyBlocks.insert(B);
774     }
775   }
776 }
777 
778 // Calculate Edge Weights using "Loop Branch Heuristics". Predict backedges
779 // as taken, exiting edges as not-taken.
780 bool BranchProbabilityInfo::calcLoopBranchHeuristics(const BasicBlock *BB,
781                                                      const LoopInfo &LI) {
782   LoopBlock LB(BB, LI, *SccI.get());
783   if (!LB.belongsToLoop())
784     return false;
785 
786   SmallPtrSet<const BasicBlock*, 8> UnlikelyBlocks;
787   if (LB.getLoop())
788     computeUnlikelySuccessors(BB, LB.getLoop(), UnlikelyBlocks);
789 
790   SmallVector<unsigned, 8> BackEdges;
791   SmallVector<unsigned, 8> ExitingEdges;
792   SmallVector<unsigned, 8> InEdges; // Edges from header to the loop.
793   SmallVector<unsigned, 8> UnlikelyEdges;
794 
795   for (const_succ_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) {
796     LoopBlock SuccLB(*I, LI, *SccI.get());
797     LoopEdge Edge(LB, SuccLB);
798     bool IsUnlikelyEdge =
799         LB.getLoop() && (UnlikelyBlocks.find(*I) != UnlikelyBlocks.end());
800 
801     if (IsUnlikelyEdge)
802       UnlikelyEdges.push_back(I.getSuccessorIndex());
803     else if (isLoopExitingEdge(Edge))
804       ExitingEdges.push_back(I.getSuccessorIndex());
805     else if (isLoopBackEdge(Edge))
806       BackEdges.push_back(I.getSuccessorIndex());
807     else {
808       InEdges.push_back(I.getSuccessorIndex());
809     }
810   }
811 
812   if (BackEdges.empty() && ExitingEdges.empty() && UnlikelyEdges.empty())
813     return false;
814 
815   // Collect the sum of probabilities of back-edges/in-edges/exiting-edges, and
816   // normalize them so that they sum up to one.
817   unsigned Denom = (BackEdges.empty() ? 0 : LBH_TAKEN_WEIGHT) +
818                    (InEdges.empty() ? 0 : LBH_TAKEN_WEIGHT) +
819                    (UnlikelyEdges.empty() ? 0 : LBH_UNLIKELY_WEIGHT) +
820                    (ExitingEdges.empty() ? 0 : LBH_NONTAKEN_WEIGHT);
821 
822   SmallVector<BranchProbability, 4> EdgeProbabilities(
823       BB->getTerminator()->getNumSuccessors(), BranchProbability::getUnknown());
824   if (uint32_t numBackEdges = BackEdges.size()) {
825     BranchProbability TakenProb = BranchProbability(LBH_TAKEN_WEIGHT, Denom);
826     auto Prob = TakenProb / numBackEdges;
827     for (unsigned SuccIdx : BackEdges)
828       EdgeProbabilities[SuccIdx] = Prob;
829   }
830 
831   if (uint32_t numInEdges = InEdges.size()) {
832     BranchProbability TakenProb = BranchProbability(LBH_TAKEN_WEIGHT, Denom);
833     auto Prob = TakenProb / numInEdges;
834     for (unsigned SuccIdx : InEdges)
835       EdgeProbabilities[SuccIdx] = Prob;
836   }
837 
838   if (uint32_t numExitingEdges = ExitingEdges.size()) {
839     BranchProbability NotTakenProb = BranchProbability(LBH_NONTAKEN_WEIGHT,
840                                                        Denom);
841     auto Prob = NotTakenProb / numExitingEdges;
842     for (unsigned SuccIdx : ExitingEdges)
843       EdgeProbabilities[SuccIdx] = Prob;
844   }
845 
846   if (uint32_t numUnlikelyEdges = UnlikelyEdges.size()) {
847     BranchProbability UnlikelyProb = BranchProbability(LBH_UNLIKELY_WEIGHT,
848                                                        Denom);
849     auto Prob = UnlikelyProb / numUnlikelyEdges;
850     for (unsigned SuccIdx : UnlikelyEdges)
851       EdgeProbabilities[SuccIdx] = Prob;
852   }
853 
854   setEdgeProbability(BB, EdgeProbabilities);
855   return true;
856 }
857 
858 bool BranchProbabilityInfo::calcZeroHeuristics(const BasicBlock *BB,
859                                                const TargetLibraryInfo *TLI) {
860   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
861   if (!BI || !BI->isConditional())
862     return false;
863 
864   Value *Cond = BI->getCondition();
865   ICmpInst *CI = dyn_cast<ICmpInst>(Cond);
866   if (!CI)
867     return false;
868 
869   auto GetConstantInt = [](Value *V) {
870     if (auto *I = dyn_cast<BitCastInst>(V))
871       return dyn_cast<ConstantInt>(I->getOperand(0));
872     return dyn_cast<ConstantInt>(V);
873   };
874 
875   Value *RHS = CI->getOperand(1);
876   ConstantInt *CV = GetConstantInt(RHS);
877   if (!CV)
878     return false;
879 
880   // If the LHS is the result of AND'ing a value with a single bit bitmask,
881   // we don't have information about probabilities.
882   if (Instruction *LHS = dyn_cast<Instruction>(CI->getOperand(0)))
883     if (LHS->getOpcode() == Instruction::And)
884       if (ConstantInt *AndRHS = GetConstantInt(LHS->getOperand(1)))
885         if (AndRHS->getValue().isPowerOf2())
886           return false;
887 
888   // Check if the LHS is the return value of a library function
889   LibFunc Func = NumLibFuncs;
890   if (TLI)
891     if (CallInst *Call = dyn_cast<CallInst>(CI->getOperand(0)))
892       if (Function *CalledFn = Call->getCalledFunction())
893         TLI->getLibFunc(*CalledFn, Func);
894 
895   bool isProb;
896   if (Func == LibFunc_strcasecmp ||
897       Func == LibFunc_strcmp ||
898       Func == LibFunc_strncasecmp ||
899       Func == LibFunc_strncmp ||
900       Func == LibFunc_memcmp ||
901       Func == LibFunc_bcmp) {
902     // strcmp and similar functions return zero, negative, or positive, if the
903     // first string is equal, less, or greater than the second. We consider it
904     // likely that the strings are not equal, so a comparison with zero is
905     // probably false, but also a comparison with any other number is also
906     // probably false given that what exactly is returned for nonzero values is
907     // not specified. Any kind of comparison other than equality we know
908     // nothing about.
909     switch (CI->getPredicate()) {
910     case CmpInst::ICMP_EQ:
911       isProb = false;
912       break;
913     case CmpInst::ICMP_NE:
914       isProb = true;
915       break;
916     default:
917       return false;
918     }
919   } else if (CV->isZero()) {
920     switch (CI->getPredicate()) {
921     case CmpInst::ICMP_EQ:
922       // X == 0   ->  Unlikely
923       isProb = false;
924       break;
925     case CmpInst::ICMP_NE:
926       // X != 0   ->  Likely
927       isProb = true;
928       break;
929     case CmpInst::ICMP_SLT:
930       // X < 0   ->  Unlikely
931       isProb = false;
932       break;
933     case CmpInst::ICMP_SGT:
934       // X > 0   ->  Likely
935       isProb = true;
936       break;
937     default:
938       return false;
939     }
940   } else if (CV->isOne() && CI->getPredicate() == CmpInst::ICMP_SLT) {
941     // InstCombine canonicalizes X <= 0 into X < 1.
942     // X <= 0   ->  Unlikely
943     isProb = false;
944   } else if (CV->isMinusOne()) {
945     switch (CI->getPredicate()) {
946     case CmpInst::ICMP_EQ:
947       // X == -1  ->  Unlikely
948       isProb = false;
949       break;
950     case CmpInst::ICMP_NE:
951       // X != -1  ->  Likely
952       isProb = true;
953       break;
954     case CmpInst::ICMP_SGT:
955       // InstCombine canonicalizes X >= 0 into X > -1.
956       // X >= 0   ->  Likely
957       isProb = true;
958       break;
959     default:
960       return false;
961     }
962   } else {
963     return false;
964   }
965 
966   BranchProbability TakenProb(ZH_TAKEN_WEIGHT,
967                               ZH_TAKEN_WEIGHT + ZH_NONTAKEN_WEIGHT);
968   BranchProbability UntakenProb(ZH_NONTAKEN_WEIGHT,
969                                 ZH_TAKEN_WEIGHT + ZH_NONTAKEN_WEIGHT);
970   if (!isProb)
971     std::swap(TakenProb, UntakenProb);
972 
973   setEdgeProbability(
974       BB, SmallVector<BranchProbability, 2>({TakenProb, UntakenProb}));
975   return true;
976 }
977 
978 bool BranchProbabilityInfo::calcFloatingPointHeuristics(const BasicBlock *BB) {
979   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
980   if (!BI || !BI->isConditional())
981     return false;
982 
983   Value *Cond = BI->getCondition();
984   FCmpInst *FCmp = dyn_cast<FCmpInst>(Cond);
985   if (!FCmp)
986     return false;
987 
988   uint32_t TakenWeight = FPH_TAKEN_WEIGHT;
989   uint32_t NontakenWeight = FPH_NONTAKEN_WEIGHT;
990   bool isProb;
991   if (FCmp->isEquality()) {
992     // f1 == f2 -> Unlikely
993     // f1 != f2 -> Likely
994     isProb = !FCmp->isTrueWhenEqual();
995   } else if (FCmp->getPredicate() == FCmpInst::FCMP_ORD) {
996     // !isnan -> Likely
997     isProb = true;
998     TakenWeight = FPH_ORD_WEIGHT;
999     NontakenWeight = FPH_UNO_WEIGHT;
1000   } else if (FCmp->getPredicate() == FCmpInst::FCMP_UNO) {
1001     // isnan -> Unlikely
1002     isProb = false;
1003     TakenWeight = FPH_ORD_WEIGHT;
1004     NontakenWeight = FPH_UNO_WEIGHT;
1005   } else {
1006     return false;
1007   }
1008 
1009   BranchProbability TakenProb(TakenWeight, TakenWeight + NontakenWeight);
1010   BranchProbability UntakenProb(NontakenWeight, TakenWeight + NontakenWeight);
1011   if (!isProb)
1012     std::swap(TakenProb, UntakenProb);
1013 
1014   setEdgeProbability(
1015       BB, SmallVector<BranchProbability, 2>({TakenProb, UntakenProb}));
1016   return true;
1017 }
1018 
1019 bool BranchProbabilityInfo::calcInvokeHeuristics(const BasicBlock *BB) {
1020   const InvokeInst *II = dyn_cast<InvokeInst>(BB->getTerminator());
1021   if (!II)
1022     return false;
1023 
1024   BranchProbability TakenProb(IH_TAKEN_WEIGHT,
1025                               IH_TAKEN_WEIGHT + IH_NONTAKEN_WEIGHT);
1026   setEdgeProbability(
1027       BB, SmallVector<BranchProbability, 2>({TakenProb, TakenProb.getCompl()}));
1028   return true;
1029 }
1030 
1031 void BranchProbabilityInfo::releaseMemory() {
1032   Probs.clear();
1033   Handles.clear();
1034 }
1035 
1036 bool BranchProbabilityInfo::invalidate(Function &, const PreservedAnalyses &PA,
1037                                        FunctionAnalysisManager::Invalidator &) {
1038   // Check whether the analysis, all analyses on functions, or the function's
1039   // CFG have been preserved.
1040   auto PAC = PA.getChecker<BranchProbabilityAnalysis>();
1041   return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>() ||
1042            PAC.preservedSet<CFGAnalyses>());
1043 }
1044 
1045 void BranchProbabilityInfo::print(raw_ostream &OS) const {
1046   OS << "---- Branch Probabilities ----\n";
1047   // We print the probabilities from the last function the analysis ran over,
1048   // or the function it is currently running over.
1049   assert(LastF && "Cannot print prior to running over a function");
1050   for (const auto &BI : *LastF) {
1051     for (const_succ_iterator SI = succ_begin(&BI), SE = succ_end(&BI); SI != SE;
1052          ++SI) {
1053       printEdgeProbability(OS << "  ", &BI, *SI);
1054     }
1055   }
1056 }
1057 
1058 bool BranchProbabilityInfo::
1059 isEdgeHot(const BasicBlock *Src, const BasicBlock *Dst) const {
1060   // Hot probability is at least 4/5 = 80%
1061   // FIXME: Compare against a static "hot" BranchProbability.
1062   return getEdgeProbability(Src, Dst) > BranchProbability(4, 5);
1063 }
1064 
1065 const BasicBlock *
1066 BranchProbabilityInfo::getHotSucc(const BasicBlock *BB) const {
1067   auto MaxProb = BranchProbability::getZero();
1068   const BasicBlock *MaxSucc = nullptr;
1069 
1070   for (const auto *Succ : successors(BB)) {
1071     auto Prob = getEdgeProbability(BB, Succ);
1072     if (Prob > MaxProb) {
1073       MaxProb = Prob;
1074       MaxSucc = Succ;
1075     }
1076   }
1077 
1078   // Hot probability is at least 4/5 = 80%
1079   if (MaxProb > BranchProbability(4, 5))
1080     return MaxSucc;
1081 
1082   return nullptr;
1083 }
1084 
1085 /// Get the raw edge probability for the edge. If can't find it, return a
1086 /// default probability 1/N where N is the number of successors. Here an edge is
1087 /// specified using PredBlock and an
1088 /// index to the successors.
1089 BranchProbability
1090 BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
1091                                           unsigned IndexInSuccessors) const {
1092   auto I = Probs.find(std::make_pair(Src, IndexInSuccessors));
1093   assert((Probs.end() == Probs.find(std::make_pair(Src, 0))) ==
1094              (Probs.end() == I) &&
1095          "Probability for I-th successor must always be defined along with the "
1096          "probability for the first successor");
1097 
1098   if (I != Probs.end())
1099     return I->second;
1100 
1101   return {1, static_cast<uint32_t>(succ_size(Src))};
1102 }
1103 
1104 BranchProbability
1105 BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
1106                                           const_succ_iterator Dst) const {
1107   return getEdgeProbability(Src, Dst.getSuccessorIndex());
1108 }
1109 
1110 /// Get the raw edge probability calculated for the block pair. This returns the
1111 /// sum of all raw edge probabilities from Src to Dst.
1112 BranchProbability
1113 BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
1114                                           const BasicBlock *Dst) const {
1115   if (!Probs.count(std::make_pair(Src, 0)))
1116     return BranchProbability(llvm::count(successors(Src), Dst), succ_size(Src));
1117 
1118   auto Prob = BranchProbability::getZero();
1119   for (const_succ_iterator I = succ_begin(Src), E = succ_end(Src); I != E; ++I)
1120     if (*I == Dst)
1121       Prob += Probs.find(std::make_pair(Src, I.getSuccessorIndex()))->second;
1122 
1123   return Prob;
1124 }
1125 
1126 /// Set the edge probability for all edges at once.
1127 void BranchProbabilityInfo::setEdgeProbability(
1128     const BasicBlock *Src, const SmallVectorImpl<BranchProbability> &Probs) {
1129   assert(Src->getTerminator()->getNumSuccessors() == Probs.size());
1130   eraseBlock(Src); // Erase stale data if any.
1131   if (Probs.size() == 0)
1132     return; // Nothing to set.
1133 
1134   Handles.insert(BasicBlockCallbackVH(Src, this));
1135   uint64_t TotalNumerator = 0;
1136   for (unsigned SuccIdx = 0; SuccIdx < Probs.size(); ++SuccIdx) {
1137     this->Probs[std::make_pair(Src, SuccIdx)] = Probs[SuccIdx];
1138     LLVM_DEBUG(dbgs() << "set edge " << Src->getName() << " -> " << SuccIdx
1139                       << " successor probability to " << Probs[SuccIdx]
1140                       << "\n");
1141     TotalNumerator += Probs[SuccIdx].getNumerator();
1142   }
1143 
1144   // Because of rounding errors the total probability cannot be checked to be
1145   // 1.0 exactly. That is TotalNumerator == BranchProbability::getDenominator.
1146   // Instead, every single probability in Probs must be as accurate as possible.
1147   // This results in error 1/denominator at most, thus the total absolute error
1148   // should be within Probs.size / BranchProbability::getDenominator.
1149   assert(TotalNumerator <= BranchProbability::getDenominator() + Probs.size());
1150   assert(TotalNumerator >= BranchProbability::getDenominator() - Probs.size());
1151 }
1152 
1153 void BranchProbabilityInfo::copyEdgeProbabilities(BasicBlock *Src,
1154                                                   BasicBlock *Dst) {
1155   eraseBlock(Dst); // Erase stale data if any.
1156   unsigned NumSuccessors = Src->getTerminator()->getNumSuccessors();
1157   assert(NumSuccessors == Dst->getTerminator()->getNumSuccessors());
1158   if (NumSuccessors == 0)
1159     return; // Nothing to set.
1160   if (this->Probs.find(std::make_pair(Src, 0)) == this->Probs.end())
1161     return; // No probability is set for edges from Src. Keep the same for Dst.
1162 
1163   Handles.insert(BasicBlockCallbackVH(Dst, this));
1164   for (unsigned SuccIdx = 0; SuccIdx < NumSuccessors; ++SuccIdx) {
1165     auto Prob = this->Probs[std::make_pair(Src, SuccIdx)];
1166     this->Probs[std::make_pair(Dst, SuccIdx)] = Prob;
1167     LLVM_DEBUG(dbgs() << "set edge " << Dst->getName() << " -> " << SuccIdx
1168                       << " successor probability to " << Prob << "\n");
1169   }
1170 }
1171 
1172 raw_ostream &
1173 BranchProbabilityInfo::printEdgeProbability(raw_ostream &OS,
1174                                             const BasicBlock *Src,
1175                                             const BasicBlock *Dst) const {
1176   const BranchProbability Prob = getEdgeProbability(Src, Dst);
1177   OS << "edge " << Src->getName() << " -> " << Dst->getName()
1178      << " probability is " << Prob
1179      << (isEdgeHot(Src, Dst) ? " [HOT edge]\n" : "\n");
1180 
1181   return OS;
1182 }
1183 
1184 void BranchProbabilityInfo::eraseBlock(const BasicBlock *BB) {
1185   LLVM_DEBUG(dbgs() << "eraseBlock " << BB->getName() << "\n");
1186 
1187   // Note that we cannot use successors of BB because the terminator of BB may
1188   // have changed when eraseBlock is called as a BasicBlockCallbackVH callback.
1189   // Instead we remove prob data for the block by iterating successors by their
1190   // indices from 0 till the last which exists. There could not be prob data for
1191   // a pair (BB, N) if there is no data for (BB, N-1) because the data is always
1192   // set for all successors from 0 to M at once by the method
1193   // setEdgeProbability().
1194   Handles.erase(BasicBlockCallbackVH(BB, this));
1195   for (unsigned I = 0;; ++I) {
1196     auto MapI = Probs.find(std::make_pair(BB, I));
1197     if (MapI == Probs.end()) {
1198       assert(Probs.count(std::make_pair(BB, I + 1)) == 0 &&
1199              "Must be no more successors");
1200       return;
1201     }
1202     Probs.erase(MapI);
1203   }
1204 }
1205 
1206 void BranchProbabilityInfo::calculate(const Function &F, const LoopInfo &LI,
1207                                       const TargetLibraryInfo *TLI,
1208                                       PostDominatorTree *PDT) {
1209   LLVM_DEBUG(dbgs() << "---- Branch Probability Info : " << F.getName()
1210                     << " ----\n\n");
1211   LastF = &F; // Store the last function we ran on for printing.
1212   assert(PostDominatedByUnreachable.empty());
1213   assert(PostDominatedByColdCall.empty());
1214 
1215   SccI = std::make_unique<SccInfo>(F);
1216 
1217   std::unique_ptr<PostDominatorTree> PDTPtr;
1218 
1219   if (!PDT) {
1220     PDTPtr = std::make_unique<PostDominatorTree>(const_cast<Function &>(F));
1221     PDT = PDTPtr.get();
1222   }
1223 
1224   computePostDominatedByUnreachable(F, PDT);
1225   computePostDominatedByColdCall(F, PDT);
1226 
1227   // Walk the basic blocks in post-order so that we can build up state about
1228   // the successors of a block iteratively.
1229   for (auto BB : post_order(&F.getEntryBlock())) {
1230     LLVM_DEBUG(dbgs() << "Computing probabilities for " << BB->getName()
1231                       << "\n");
1232     // If there is no at least two successors, no sense to set probability.
1233     if (BB->getTerminator()->getNumSuccessors() < 2)
1234       continue;
1235     if (calcMetadataWeights(BB))
1236       continue;
1237     if (calcInvokeHeuristics(BB))
1238       continue;
1239     if (calcUnreachableHeuristics(BB))
1240       continue;
1241     if (calcColdCallHeuristics(BB))
1242       continue;
1243     if (calcLoopBranchHeuristics(BB, LI))
1244       continue;
1245     if (calcPointerHeuristics(BB))
1246       continue;
1247     if (calcZeroHeuristics(BB, TLI))
1248       continue;
1249     if (calcFloatingPointHeuristics(BB))
1250       continue;
1251   }
1252 
1253   PostDominatedByUnreachable.clear();
1254   PostDominatedByColdCall.clear();
1255   SccI.reset();
1256 
1257   if (PrintBranchProb &&
1258       (PrintBranchProbFuncName.empty() ||
1259        F.getName().equals(PrintBranchProbFuncName))) {
1260     print(dbgs());
1261   }
1262 }
1263 
1264 void BranchProbabilityInfoWrapperPass::getAnalysisUsage(
1265     AnalysisUsage &AU) const {
1266   // We require DT so it's available when LI is available. The LI updating code
1267   // asserts that DT is also present so if we don't make sure that we have DT
1268   // here, that assert will trigger.
1269   AU.addRequired<DominatorTreeWrapperPass>();
1270   AU.addRequired<LoopInfoWrapperPass>();
1271   AU.addRequired<TargetLibraryInfoWrapperPass>();
1272   AU.addRequired<PostDominatorTreeWrapperPass>();
1273   AU.setPreservesAll();
1274 }
1275 
1276 bool BranchProbabilityInfoWrapperPass::runOnFunction(Function &F) {
1277   const LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
1278   const TargetLibraryInfo &TLI =
1279       getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
1280   PostDominatorTree &PDT =
1281       getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();
1282   BPI.calculate(F, LI, &TLI, &PDT);
1283   return false;
1284 }
1285 
1286 void BranchProbabilityInfoWrapperPass::releaseMemory() { BPI.releaseMemory(); }
1287 
1288 void BranchProbabilityInfoWrapperPass::print(raw_ostream &OS,
1289                                              const Module *) const {
1290   BPI.print(OS);
1291 }
1292 
1293 AnalysisKey BranchProbabilityAnalysis::Key;
1294 BranchProbabilityInfo
1295 BranchProbabilityAnalysis::run(Function &F, FunctionAnalysisManager &AM) {
1296   BranchProbabilityInfo BPI;
1297   BPI.calculate(F, AM.getResult<LoopAnalysis>(F),
1298                 &AM.getResult<TargetLibraryAnalysis>(F),
1299                 &AM.getResult<PostDominatorTreeAnalysis>(F));
1300   return BPI;
1301 }
1302 
1303 PreservedAnalyses
1304 BranchProbabilityPrinterPass::run(Function &F, FunctionAnalysisManager &AM) {
1305   OS << "Printing analysis results of BPI for function "
1306      << "'" << F.getName() << "':"
1307      << "\n";
1308   AM.getResult<BranchProbabilityAnalysis>(F).print(OS);
1309   return PreservedAnalyses::all();
1310 }
1311