1 //===- BranchProbabilityInfo.cpp - Branch Probability Analysis ------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Loops should be simplified before this analysis.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/Analysis/BranchProbabilityInfo.h"
15 #include "llvm/ADT/PostOrderIterator.h"
16 #include "llvm/ADT/SCCIterator.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/Analysis/LoopInfo.h"
20 #include "llvm/Analysis/TargetLibraryInfo.h"
21 #include "llvm/IR/Attributes.h"
22 #include "llvm/IR/BasicBlock.h"
23 #include "llvm/IR/CFG.h"
24 #include "llvm/IR/Constants.h"
25 #include "llvm/IR/Function.h"
26 #include "llvm/IR/InstrTypes.h"
27 #include "llvm/IR/Instruction.h"
28 #include "llvm/IR/Instructions.h"
29 #include "llvm/IR/LLVMContext.h"
30 #include "llvm/IR/Metadata.h"
31 #include "llvm/IR/PassManager.h"
32 #include "llvm/IR/Type.h"
33 #include "llvm/IR/Value.h"
34 #include "llvm/Pass.h"
35 #include "llvm/Support/BranchProbability.h"
36 #include "llvm/Support/Casting.h"
37 #include "llvm/Support/Debug.h"
38 #include "llvm/Support/raw_ostream.h"
39 #include <cassert>
40 #include <cstdint>
41 #include <iterator>
42 #include <utility>
43 
44 using namespace llvm;
45 
46 #define DEBUG_TYPE "branch-prob"
47 
48 static cl::opt<bool> PrintBranchProb(
49     "print-bpi", cl::init(false), cl::Hidden,
50     cl::desc("Print the branch probability info."));
51 
52 cl::opt<std::string> PrintBranchProbFuncName(
53     "print-bpi-func-name", cl::Hidden,
54     cl::desc("The option to specify the name of the function "
55              "whose branch probability info is printed."));
56 
57 INITIALIZE_PASS_BEGIN(BranchProbabilityInfoWrapperPass, "branch-prob",
58                       "Branch Probability Analysis", false, true)
59 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
60 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
61 INITIALIZE_PASS_END(BranchProbabilityInfoWrapperPass, "branch-prob",
62                     "Branch Probability Analysis", false, true)
63 
64 char BranchProbabilityInfoWrapperPass::ID = 0;
65 
66 // Weights are for internal use only. They are used by heuristics to help to
67 // estimate edges' probability. Example:
68 //
69 // Using "Loop Branch Heuristics" we predict weights of edges for the
70 // block BB2.
71 //         ...
72 //          |
73 //          V
74 //         BB1<-+
75 //          |   |
76 //          |   | (Weight = 124)
77 //          V   |
78 //         BB2--+
79 //          |
80 //          | (Weight = 4)
81 //          V
82 //         BB3
83 //
84 // Probability of the edge BB2->BB1 = 124 / (124 + 4) = 0.96875
85 // Probability of the edge BB2->BB3 = 4 / (124 + 4) = 0.03125
86 static const uint32_t LBH_TAKEN_WEIGHT = 124;
87 static const uint32_t LBH_NONTAKEN_WEIGHT = 4;
88 // Unlikely edges within a loop are half as likely as other edges
89 static const uint32_t LBH_UNLIKELY_WEIGHT = 62;
90 
91 /// \brief Unreachable-terminating branch taken probability.
92 ///
93 /// This is the probability for a branch being taken to a block that terminates
94 /// (eventually) in unreachable. These are predicted as unlikely as possible.
95 /// All reachable probability will equally share the remaining part.
96 static const BranchProbability UR_TAKEN_PROB = BranchProbability::getRaw(1);
97 
98 /// \brief Weight for a branch taken going into a cold block.
99 ///
100 /// This is the weight for a branch taken toward a block marked
101 /// cold.  A block is marked cold if it's postdominated by a
102 /// block containing a call to a cold function.  Cold functions
103 /// are those marked with attribute 'cold'.
104 static const uint32_t CC_TAKEN_WEIGHT = 4;
105 
106 /// \brief Weight for a branch not-taken into a cold block.
107 ///
108 /// This is the weight for a branch not taken toward a block marked
109 /// cold.
110 static const uint32_t CC_NONTAKEN_WEIGHT = 64;
111 
112 static const uint32_t PH_TAKEN_WEIGHT = 20;
113 static const uint32_t PH_NONTAKEN_WEIGHT = 12;
114 
115 static const uint32_t ZH_TAKEN_WEIGHT = 20;
116 static const uint32_t ZH_NONTAKEN_WEIGHT = 12;
117 
118 static const uint32_t FPH_TAKEN_WEIGHT = 20;
119 static const uint32_t FPH_NONTAKEN_WEIGHT = 12;
120 
121 /// \brief Invoke-terminating normal branch taken weight
122 ///
123 /// This is the weight for branching to the normal destination of an invoke
124 /// instruction. We expect this to happen most of the time. Set the weight to an
125 /// absurdly high value so that nested loops subsume it.
126 static const uint32_t IH_TAKEN_WEIGHT = 1024 * 1024 - 1;
127 
128 /// \brief Invoke-terminating normal branch not-taken weight.
129 ///
130 /// This is the weight for branching to the unwind destination of an invoke
131 /// instruction. This is essentially never taken.
132 static const uint32_t IH_NONTAKEN_WEIGHT = 1;
133 
134 /// \brief Add \p BB to PostDominatedByUnreachable set if applicable.
135 void
136 BranchProbabilityInfo::updatePostDominatedByUnreachable(const BasicBlock *BB) {
137   const TerminatorInst *TI = BB->getTerminator();
138   if (TI->getNumSuccessors() == 0) {
139     if (isa<UnreachableInst>(TI) ||
140         // If this block is terminated by a call to
141         // @llvm.experimental.deoptimize then treat it like an unreachable since
142         // the @llvm.experimental.deoptimize call is expected to practically
143         // never execute.
144         BB->getTerminatingDeoptimizeCall())
145       PostDominatedByUnreachable.insert(BB);
146     return;
147   }
148 
149   // If the terminator is an InvokeInst, check only the normal destination block
150   // as the unwind edge of InvokeInst is also very unlikely taken.
151   if (auto *II = dyn_cast<InvokeInst>(TI)) {
152     if (PostDominatedByUnreachable.count(II->getNormalDest()))
153       PostDominatedByUnreachable.insert(BB);
154     return;
155   }
156 
157   for (auto *I : successors(BB))
158     // If any of successor is not post dominated then BB is also not.
159     if (!PostDominatedByUnreachable.count(I))
160       return;
161 
162   PostDominatedByUnreachable.insert(BB);
163 }
164 
165 /// \brief Add \p BB to PostDominatedByColdCall set if applicable.
166 void
167 BranchProbabilityInfo::updatePostDominatedByColdCall(const BasicBlock *BB) {
168   assert(!PostDominatedByColdCall.count(BB));
169   const TerminatorInst *TI = BB->getTerminator();
170   if (TI->getNumSuccessors() == 0)
171     return;
172 
173   // If all of successor are post dominated then BB is also done.
174   if (llvm::all_of(successors(BB), [&](const BasicBlock *SuccBB) {
175         return PostDominatedByColdCall.count(SuccBB);
176       })) {
177     PostDominatedByColdCall.insert(BB);
178     return;
179   }
180 
181   // If the terminator is an InvokeInst, check only the normal destination
182   // block as the unwind edge of InvokeInst is also very unlikely taken.
183   if (auto *II = dyn_cast<InvokeInst>(TI))
184     if (PostDominatedByColdCall.count(II->getNormalDest())) {
185       PostDominatedByColdCall.insert(BB);
186       return;
187     }
188 
189   // Otherwise, if the block itself contains a cold function, add it to the
190   // set of blocks post-dominated by a cold call.
191   for (auto &I : *BB)
192     if (const CallInst *CI = dyn_cast<CallInst>(&I))
193       if (CI->hasFnAttr(Attribute::Cold)) {
194         PostDominatedByColdCall.insert(BB);
195         return;
196       }
197 }
198 
199 /// \brief Calculate edge weights for successors lead to unreachable.
200 ///
201 /// Predict that a successor which leads necessarily to an
202 /// unreachable-terminated block as extremely unlikely.
203 bool BranchProbabilityInfo::calcUnreachableHeuristics(const BasicBlock *BB) {
204   const TerminatorInst *TI = BB->getTerminator();
205   assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
206 
207   // Return false here so that edge weights for InvokeInst could be decided
208   // in calcInvokeHeuristics().
209   if (isa<InvokeInst>(TI))
210     return false;
211 
212   SmallVector<unsigned, 4> UnreachableEdges;
213   SmallVector<unsigned, 4> ReachableEdges;
214 
215   for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I)
216     if (PostDominatedByUnreachable.count(*I))
217       UnreachableEdges.push_back(I.getSuccessorIndex());
218     else
219       ReachableEdges.push_back(I.getSuccessorIndex());
220 
221   // Skip probabilities if all were reachable.
222   if (UnreachableEdges.empty())
223     return false;
224 
225   if (ReachableEdges.empty()) {
226     BranchProbability Prob(1, UnreachableEdges.size());
227     for (unsigned SuccIdx : UnreachableEdges)
228       setEdgeProbability(BB, SuccIdx, Prob);
229     return true;
230   }
231 
232   auto UnreachableProb = UR_TAKEN_PROB;
233   auto ReachableProb =
234       (BranchProbability::getOne() - UR_TAKEN_PROB * UnreachableEdges.size()) /
235       ReachableEdges.size();
236 
237   for (unsigned SuccIdx : UnreachableEdges)
238     setEdgeProbability(BB, SuccIdx, UnreachableProb);
239   for (unsigned SuccIdx : ReachableEdges)
240     setEdgeProbability(BB, SuccIdx, ReachableProb);
241 
242   return true;
243 }
244 
245 // Propagate existing explicit probabilities from either profile data or
246 // 'expect' intrinsic processing. Examine metadata against unreachable
247 // heuristic. The probability of the edge coming to unreachable block is
248 // set to min of metadata and unreachable heuristic.
249 bool BranchProbabilityInfo::calcMetadataWeights(const BasicBlock *BB) {
250   const TerminatorInst *TI = BB->getTerminator();
251   assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
252   if (!(isa<BranchInst>(TI) || isa<SwitchInst>(TI) || isa<IndirectBrInst>(TI)))
253     return false;
254 
255   MDNode *WeightsNode = TI->getMetadata(LLVMContext::MD_prof);
256   if (!WeightsNode)
257     return false;
258 
259   // Check that the number of successors is manageable.
260   assert(TI->getNumSuccessors() < UINT32_MAX && "Too many successors");
261 
262   // Ensure there are weights for all of the successors. Note that the first
263   // operand to the metadata node is a name, not a weight.
264   if (WeightsNode->getNumOperands() != TI->getNumSuccessors() + 1)
265     return false;
266 
267   // Build up the final weights that will be used in a temporary buffer.
268   // Compute the sum of all weights to later decide whether they need to
269   // be scaled to fit in 32 bits.
270   uint64_t WeightSum = 0;
271   SmallVector<uint32_t, 2> Weights;
272   SmallVector<unsigned, 2> UnreachableIdxs;
273   SmallVector<unsigned, 2> ReachableIdxs;
274   Weights.reserve(TI->getNumSuccessors());
275   for (unsigned i = 1, e = WeightsNode->getNumOperands(); i != e; ++i) {
276     ConstantInt *Weight =
277         mdconst::dyn_extract<ConstantInt>(WeightsNode->getOperand(i));
278     if (!Weight)
279       return false;
280     assert(Weight->getValue().getActiveBits() <= 32 &&
281            "Too many bits for uint32_t");
282     Weights.push_back(Weight->getZExtValue());
283     WeightSum += Weights.back();
284     if (PostDominatedByUnreachable.count(TI->getSuccessor(i - 1)))
285       UnreachableIdxs.push_back(i - 1);
286     else
287       ReachableIdxs.push_back(i - 1);
288   }
289   assert(Weights.size() == TI->getNumSuccessors() && "Checked above");
290 
291   // If the sum of weights does not fit in 32 bits, scale every weight down
292   // accordingly.
293   uint64_t ScalingFactor =
294       (WeightSum > UINT32_MAX) ? WeightSum / UINT32_MAX + 1 : 1;
295 
296   if (ScalingFactor > 1) {
297     WeightSum = 0;
298     for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) {
299       Weights[i] /= ScalingFactor;
300       WeightSum += Weights[i];
301     }
302   }
303   assert(WeightSum <= UINT32_MAX &&
304          "Expected weights to scale down to 32 bits");
305 
306   if (WeightSum == 0 || ReachableIdxs.size() == 0) {
307     for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
308       Weights[i] = 1;
309     WeightSum = TI->getNumSuccessors();
310   }
311 
312   // Set the probability.
313   SmallVector<BranchProbability, 2> BP;
314   for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
315     BP.push_back({ Weights[i], static_cast<uint32_t>(WeightSum) });
316 
317   // Examine the metadata against unreachable heuristic.
318   // If the unreachable heuristic is more strong then we use it for this edge.
319   if (UnreachableIdxs.size() > 0 && ReachableIdxs.size() > 0) {
320     auto ToDistribute = BranchProbability::getZero();
321     auto UnreachableProb = UR_TAKEN_PROB;
322     for (auto i : UnreachableIdxs)
323       if (UnreachableProb < BP[i]) {
324         ToDistribute += BP[i] - UnreachableProb;
325         BP[i] = UnreachableProb;
326       }
327 
328     // If we modified the probability of some edges then we must distribute
329     // the difference between reachable blocks.
330     if (ToDistribute > BranchProbability::getZero()) {
331       BranchProbability PerEdge = ToDistribute / ReachableIdxs.size();
332       for (auto i : ReachableIdxs)
333         BP[i] += PerEdge;
334     }
335   }
336 
337   for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
338     setEdgeProbability(BB, i, BP[i]);
339 
340   return true;
341 }
342 
343 /// \brief Calculate edge weights for edges leading to cold blocks.
344 ///
345 /// A cold block is one post-dominated by  a block with a call to a
346 /// cold function.  Those edges are unlikely to be taken, so we give
347 /// them relatively low weight.
348 ///
349 /// Return true if we could compute the weights for cold edges.
350 /// Return false, otherwise.
351 bool BranchProbabilityInfo::calcColdCallHeuristics(const BasicBlock *BB) {
352   const TerminatorInst *TI = BB->getTerminator();
353   assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
354 
355   // Return false here so that edge weights for InvokeInst could be decided
356   // in calcInvokeHeuristics().
357   if (isa<InvokeInst>(TI))
358     return false;
359 
360   // Determine which successors are post-dominated by a cold block.
361   SmallVector<unsigned, 4> ColdEdges;
362   SmallVector<unsigned, 4> NormalEdges;
363   for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I)
364     if (PostDominatedByColdCall.count(*I))
365       ColdEdges.push_back(I.getSuccessorIndex());
366     else
367       NormalEdges.push_back(I.getSuccessorIndex());
368 
369   // Skip probabilities if no cold edges.
370   if (ColdEdges.empty())
371     return false;
372 
373   if (NormalEdges.empty()) {
374     BranchProbability Prob(1, ColdEdges.size());
375     for (unsigned SuccIdx : ColdEdges)
376       setEdgeProbability(BB, SuccIdx, Prob);
377     return true;
378   }
379 
380   auto ColdProb = BranchProbability::getBranchProbability(
381       CC_TAKEN_WEIGHT,
382       (CC_TAKEN_WEIGHT + CC_NONTAKEN_WEIGHT) * uint64_t(ColdEdges.size()));
383   auto NormalProb = BranchProbability::getBranchProbability(
384       CC_NONTAKEN_WEIGHT,
385       (CC_TAKEN_WEIGHT + CC_NONTAKEN_WEIGHT) * uint64_t(NormalEdges.size()));
386 
387   for (unsigned SuccIdx : ColdEdges)
388     setEdgeProbability(BB, SuccIdx, ColdProb);
389   for (unsigned SuccIdx : NormalEdges)
390     setEdgeProbability(BB, SuccIdx, NormalProb);
391 
392   return true;
393 }
394 
395 // Calculate Edge Weights using "Pointer Heuristics". Predict a comparison
396 // between two pointer or pointer and NULL will fail.
397 bool BranchProbabilityInfo::calcPointerHeuristics(const BasicBlock *BB) {
398   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
399   if (!BI || !BI->isConditional())
400     return false;
401 
402   Value *Cond = BI->getCondition();
403   ICmpInst *CI = dyn_cast<ICmpInst>(Cond);
404   if (!CI || !CI->isEquality())
405     return false;
406 
407   Value *LHS = CI->getOperand(0);
408 
409   if (!LHS->getType()->isPointerTy())
410     return false;
411 
412   assert(CI->getOperand(1)->getType()->isPointerTy());
413 
414   // p != 0   ->   isProb = true
415   // p == 0   ->   isProb = false
416   // p != q   ->   isProb = true
417   // p == q   ->   isProb = false;
418   unsigned TakenIdx = 0, NonTakenIdx = 1;
419   bool isProb = CI->getPredicate() == ICmpInst::ICMP_NE;
420   if (!isProb)
421     std::swap(TakenIdx, NonTakenIdx);
422 
423   BranchProbability TakenProb(PH_TAKEN_WEIGHT,
424                               PH_TAKEN_WEIGHT + PH_NONTAKEN_WEIGHT);
425   setEdgeProbability(BB, TakenIdx, TakenProb);
426   setEdgeProbability(BB, NonTakenIdx, TakenProb.getCompl());
427   return true;
428 }
429 
430 static int getSCCNum(const BasicBlock *BB,
431                      const BranchProbabilityInfo::SccInfo &SccI) {
432   auto SccIt = SccI.SccNums.find(BB);
433   if (SccIt == SccI.SccNums.end())
434     return -1;
435   return SccIt->second;
436 }
437 
438 // Consider any block that is an entry point to the SCC as a header.
439 static bool isSCCHeader(const BasicBlock *BB, int SccNum,
440                         BranchProbabilityInfo::SccInfo &SccI) {
441   assert(getSCCNum(BB, SccI) == SccNum);
442 
443   // Lazily compute the set of headers for a given SCC and cache the results
444   // in the SccHeaderMap.
445   if (SccI.SccHeaders.size() <= static_cast<unsigned>(SccNum))
446     SccI.SccHeaders.resize(SccNum + 1);
447   auto &HeaderMap = SccI.SccHeaders[SccNum];
448   bool Inserted;
449   BranchProbabilityInfo::SccHeaderMap::iterator HeaderMapIt;
450   std::tie(HeaderMapIt, Inserted) = HeaderMap.insert(std::make_pair(BB, false));
451   if (Inserted) {
452     bool IsHeader = llvm::any_of(make_range(pred_begin(BB), pred_end(BB)),
453                                  [&](const BasicBlock *Pred) {
454                                    return getSCCNum(Pred, SccI) != SccNum;
455                                  });
456     HeaderMapIt->second = IsHeader;
457     return IsHeader;
458   } else
459     return HeaderMapIt->second;
460 }
461 
462 // Compute the unlikely successors to the block BB in the loop L, specifically
463 // those that are unlikely because this is a loop, and add them to the
464 // UnlikelyBlocks set.
465 static void
466 computeUnlikelySuccessors(const BasicBlock *BB, Loop *L,
467                           SmallPtrSetImpl<const BasicBlock*> &UnlikelyBlocks) {
468   // Sometimes in a loop we have a branch whose condition is made false by
469   // taking it. This is typically something like
470   //  int n = 0;
471   //  while (...) {
472   //    if (++n >= MAX) {
473   //      n = 0;
474   //    }
475   //  }
476   // In this sort of situation taking the branch means that at the very least it
477   // won't be taken again in the next iteration of the loop, so we should
478   // consider it less likely than a typical branch.
479   //
480   // We detect this by looking back through the graph of PHI nodes that sets the
481   // value that the condition depends on, and seeing if we can reach a successor
482   // block which can be determined to make the condition false.
483   //
484   // FIXME: We currently consider unlikely blocks to be half as likely as other
485   // blocks, but if we consider the example above the likelyhood is actually
486   // 1/MAX. We could therefore be more precise in how unlikely we consider
487   // blocks to be, but it would require more careful examination of the form
488   // of the comparison expression.
489   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
490   if (!BI || !BI->isConditional())
491     return;
492 
493   // Check if the branch is based on an instruction compared with a constant
494   CmpInst *CI = dyn_cast<CmpInst>(BI->getCondition());
495   if (!CI || !isa<Instruction>(CI->getOperand(0)) ||
496       !isa<Constant>(CI->getOperand(1)))
497     return;
498 
499   // Either the instruction must be a PHI, or a chain of operations involving
500   // constants that ends in a PHI which we can then collapse into a single value
501   // if the PHI value is known.
502   Instruction *CmpLHS = dyn_cast<Instruction>(CI->getOperand(0));
503   PHINode *CmpPHI = dyn_cast<PHINode>(CmpLHS);
504   Constant *CmpConst = dyn_cast<Constant>(CI->getOperand(1));
505   // Collect the instructions until we hit a PHI
506   std::list<BinaryOperator*> InstChain;
507   while (!CmpPHI && CmpLHS && isa<BinaryOperator>(CmpLHS) &&
508          isa<Constant>(CmpLHS->getOperand(1))) {
509     // Stop if the chain extends outside of the loop
510     if (!L->contains(CmpLHS))
511       return;
512     InstChain.push_front(dyn_cast<BinaryOperator>(CmpLHS));
513     CmpLHS = dyn_cast<Instruction>(CmpLHS->getOperand(0));
514     if (CmpLHS)
515       CmpPHI = dyn_cast<PHINode>(CmpLHS);
516   }
517   if (!CmpPHI || !L->contains(CmpPHI))
518     return;
519 
520   // Trace the phi node to find all values that come from successors of BB
521   SmallPtrSet<PHINode*, 8> VisitedInsts;
522   SmallVector<PHINode*, 8> WorkList;
523   WorkList.push_back(CmpPHI);
524   VisitedInsts.insert(CmpPHI);
525   while (!WorkList.empty()) {
526     PHINode *P = WorkList.back();
527     WorkList.pop_back();
528     for (BasicBlock *B : P->blocks()) {
529       // Skip blocks that aren't part of the loop
530       if (!L->contains(B))
531         continue;
532       Value *V = P->getIncomingValueForBlock(B);
533       // If the source is a PHI add it to the work list if we haven't
534       // already visited it.
535       if (PHINode *PN = dyn_cast<PHINode>(V)) {
536         if (VisitedInsts.insert(PN).second)
537           WorkList.push_back(PN);
538         continue;
539       }
540       // If this incoming value is a constant and B is a successor of BB, then
541       // we can constant-evaluate the compare to see if it makes the branch be
542       // taken or not.
543       Constant *CmpLHSConst = dyn_cast<Constant>(V);
544       if (!CmpLHSConst ||
545           std::find(succ_begin(BB), succ_end(BB), B) == succ_end(BB))
546         continue;
547       // First collapse InstChain
548       for (Instruction *I : InstChain) {
549         CmpLHSConst = ConstantExpr::get(I->getOpcode(), CmpLHSConst,
550                                         dyn_cast<Constant>(I->getOperand(1)),
551                                         true);
552         if (!CmpLHSConst)
553           break;
554       }
555       if (!CmpLHSConst)
556         continue;
557       // Now constant-evaluate the compare
558       Constant *Result = ConstantExpr::getCompare(CI->getPredicate(),
559                                                   CmpLHSConst, CmpConst, true);
560       // If the result means we don't branch to the block then that block is
561       // unlikely.
562       if (Result &&
563           ((Result->isZeroValue() && B == BI->getSuccessor(0)) ||
564            (Result->isOneValue() && B == BI->getSuccessor(1))))
565         UnlikelyBlocks.insert(B);
566     }
567   }
568 }
569 
570 // Calculate Edge Weights using "Loop Branch Heuristics". Predict backedges
571 // as taken, exiting edges as not-taken.
572 bool BranchProbabilityInfo::calcLoopBranchHeuristics(const BasicBlock *BB,
573                                                      const LoopInfo &LI,
574                                                      SccInfo &SccI) {
575   int SccNum;
576   Loop *L = LI.getLoopFor(BB);
577   if (!L) {
578     SccNum = getSCCNum(BB, SccI);
579     if (SccNum < 0)
580       return false;
581   }
582 
583   SmallPtrSet<const BasicBlock*, 8> UnlikelyBlocks;
584   if (L)
585     computeUnlikelySuccessors(BB, L, UnlikelyBlocks);
586 
587   SmallVector<unsigned, 8> BackEdges;
588   SmallVector<unsigned, 8> ExitingEdges;
589   SmallVector<unsigned, 8> InEdges; // Edges from header to the loop.
590   SmallVector<unsigned, 8> UnlikelyEdges;
591 
592   for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) {
593     // Use LoopInfo if we have it, otherwise fall-back to SCC info to catch
594     // irreducible loops.
595     if (L) {
596       if (UnlikelyBlocks.count(*I) != 0)
597         UnlikelyEdges.push_back(I.getSuccessorIndex());
598       else if (!L->contains(*I))
599         ExitingEdges.push_back(I.getSuccessorIndex());
600       else if (L->getHeader() == *I)
601         BackEdges.push_back(I.getSuccessorIndex());
602       else
603         InEdges.push_back(I.getSuccessorIndex());
604     } else {
605       if (getSCCNum(*I, SccI) != SccNum)
606         ExitingEdges.push_back(I.getSuccessorIndex());
607       else if (isSCCHeader(*I, SccNum, SccI))
608         BackEdges.push_back(I.getSuccessorIndex());
609       else
610         InEdges.push_back(I.getSuccessorIndex());
611     }
612   }
613 
614   if (BackEdges.empty() && ExitingEdges.empty() && UnlikelyEdges.empty())
615     return false;
616 
617   // Collect the sum of probabilities of back-edges/in-edges/exiting-edges, and
618   // normalize them so that they sum up to one.
619   unsigned Denom = (BackEdges.empty() ? 0 : LBH_TAKEN_WEIGHT) +
620                    (InEdges.empty() ? 0 : LBH_TAKEN_WEIGHT) +
621                    (UnlikelyEdges.empty() ? 0 : LBH_UNLIKELY_WEIGHT) +
622                    (ExitingEdges.empty() ? 0 : LBH_NONTAKEN_WEIGHT);
623 
624   if (uint32_t numBackEdges = BackEdges.size()) {
625     BranchProbability TakenProb = BranchProbability(LBH_TAKEN_WEIGHT, Denom);
626     auto Prob = TakenProb / numBackEdges;
627     for (unsigned SuccIdx : BackEdges)
628       setEdgeProbability(BB, SuccIdx, Prob);
629   }
630 
631   if (uint32_t numInEdges = InEdges.size()) {
632     BranchProbability TakenProb = BranchProbability(LBH_TAKEN_WEIGHT, Denom);
633     auto Prob = TakenProb / numInEdges;
634     for (unsigned SuccIdx : InEdges)
635       setEdgeProbability(BB, SuccIdx, Prob);
636   }
637 
638   if (uint32_t numExitingEdges = ExitingEdges.size()) {
639     BranchProbability NotTakenProb = BranchProbability(LBH_NONTAKEN_WEIGHT,
640                                                        Denom);
641     auto Prob = NotTakenProb / numExitingEdges;
642     for (unsigned SuccIdx : ExitingEdges)
643       setEdgeProbability(BB, SuccIdx, Prob);
644   }
645 
646   if (uint32_t numUnlikelyEdges = UnlikelyEdges.size()) {
647     BranchProbability UnlikelyProb = BranchProbability(LBH_UNLIKELY_WEIGHT,
648                                                        Denom);
649     auto Prob = UnlikelyProb / numUnlikelyEdges;
650     for (unsigned SuccIdx : UnlikelyEdges)
651       setEdgeProbability(BB, SuccIdx, Prob);
652   }
653 
654   return true;
655 }
656 
657 bool BranchProbabilityInfo::calcZeroHeuristics(const BasicBlock *BB,
658                                                const TargetLibraryInfo *TLI) {
659   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
660   if (!BI || !BI->isConditional())
661     return false;
662 
663   Value *Cond = BI->getCondition();
664   ICmpInst *CI = dyn_cast<ICmpInst>(Cond);
665   if (!CI)
666     return false;
667 
668   Value *RHS = CI->getOperand(1);
669   ConstantInt *CV = dyn_cast<ConstantInt>(RHS);
670   if (!CV)
671     return false;
672 
673   // If the LHS is the result of AND'ing a value with a single bit bitmask,
674   // we don't have information about probabilities.
675   if (Instruction *LHS = dyn_cast<Instruction>(CI->getOperand(0)))
676     if (LHS->getOpcode() == Instruction::And)
677       if (ConstantInt *AndRHS = dyn_cast<ConstantInt>(LHS->getOperand(1)))
678         if (AndRHS->getValue().isPowerOf2())
679           return false;
680 
681   // Check if the LHS is the return value of a library function
682   LibFunc Func = NumLibFuncs;
683   if (TLI)
684     if (CallInst *Call = dyn_cast<CallInst>(CI->getOperand(0)))
685       if (Function *CalledFn = Call->getCalledFunction())
686         TLI->getLibFunc(*CalledFn, Func);
687 
688   bool isProb;
689   if (Func == LibFunc_strcasecmp ||
690       Func == LibFunc_strcmp ||
691       Func == LibFunc_strncasecmp ||
692       Func == LibFunc_strncmp ||
693       Func == LibFunc_memcmp) {
694     // strcmp and similar functions return zero, negative, or positive, if the
695     // first string is equal, less, or greater than the second. We consider it
696     // likely that the strings are not equal, so a comparison with zero is
697     // probably false, but also a comparison with any other number is also
698     // probably false given that what exactly is returned for nonzero values is
699     // not specified. Any kind of comparison other than equality we know
700     // nothing about.
701     switch (CI->getPredicate()) {
702     case CmpInst::ICMP_EQ:
703       isProb = false;
704       break;
705     case CmpInst::ICMP_NE:
706       isProb = true;
707       break;
708     default:
709       return false;
710     }
711   } else if (CV->isZero()) {
712     switch (CI->getPredicate()) {
713     case CmpInst::ICMP_EQ:
714       // X == 0   ->  Unlikely
715       isProb = false;
716       break;
717     case CmpInst::ICMP_NE:
718       // X != 0   ->  Likely
719       isProb = true;
720       break;
721     case CmpInst::ICMP_SLT:
722       // X < 0   ->  Unlikely
723       isProb = false;
724       break;
725     case CmpInst::ICMP_SGT:
726       // X > 0   ->  Likely
727       isProb = true;
728       break;
729     default:
730       return false;
731     }
732   } else if (CV->isOne() && CI->getPredicate() == CmpInst::ICMP_SLT) {
733     // InstCombine canonicalizes X <= 0 into X < 1.
734     // X <= 0   ->  Unlikely
735     isProb = false;
736   } else if (CV->isMinusOne()) {
737     switch (CI->getPredicate()) {
738     case CmpInst::ICMP_EQ:
739       // X == -1  ->  Unlikely
740       isProb = false;
741       break;
742     case CmpInst::ICMP_NE:
743       // X != -1  ->  Likely
744       isProb = true;
745       break;
746     case CmpInst::ICMP_SGT:
747       // InstCombine canonicalizes X >= 0 into X > -1.
748       // X >= 0   ->  Likely
749       isProb = true;
750       break;
751     default:
752       return false;
753     }
754   } else {
755     return false;
756   }
757 
758   unsigned TakenIdx = 0, NonTakenIdx = 1;
759 
760   if (!isProb)
761     std::swap(TakenIdx, NonTakenIdx);
762 
763   BranchProbability TakenProb(ZH_TAKEN_WEIGHT,
764                               ZH_TAKEN_WEIGHT + ZH_NONTAKEN_WEIGHT);
765   setEdgeProbability(BB, TakenIdx, TakenProb);
766   setEdgeProbability(BB, NonTakenIdx, TakenProb.getCompl());
767   return true;
768 }
769 
770 bool BranchProbabilityInfo::calcFloatingPointHeuristics(const BasicBlock *BB) {
771   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
772   if (!BI || !BI->isConditional())
773     return false;
774 
775   Value *Cond = BI->getCondition();
776   FCmpInst *FCmp = dyn_cast<FCmpInst>(Cond);
777   if (!FCmp)
778     return false;
779 
780   bool isProb;
781   if (FCmp->isEquality()) {
782     // f1 == f2 -> Unlikely
783     // f1 != f2 -> Likely
784     isProb = !FCmp->isTrueWhenEqual();
785   } else if (FCmp->getPredicate() == FCmpInst::FCMP_ORD) {
786     // !isnan -> Likely
787     isProb = true;
788   } else if (FCmp->getPredicate() == FCmpInst::FCMP_UNO) {
789     // isnan -> Unlikely
790     isProb = false;
791   } else {
792     return false;
793   }
794 
795   unsigned TakenIdx = 0, NonTakenIdx = 1;
796 
797   if (!isProb)
798     std::swap(TakenIdx, NonTakenIdx);
799 
800   BranchProbability TakenProb(FPH_TAKEN_WEIGHT,
801                               FPH_TAKEN_WEIGHT + FPH_NONTAKEN_WEIGHT);
802   setEdgeProbability(BB, TakenIdx, TakenProb);
803   setEdgeProbability(BB, NonTakenIdx, TakenProb.getCompl());
804   return true;
805 }
806 
807 bool BranchProbabilityInfo::calcInvokeHeuristics(const BasicBlock *BB) {
808   const InvokeInst *II = dyn_cast<InvokeInst>(BB->getTerminator());
809   if (!II)
810     return false;
811 
812   BranchProbability TakenProb(IH_TAKEN_WEIGHT,
813                               IH_TAKEN_WEIGHT + IH_NONTAKEN_WEIGHT);
814   setEdgeProbability(BB, 0 /*Index for Normal*/, TakenProb);
815   setEdgeProbability(BB, 1 /*Index for Unwind*/, TakenProb.getCompl());
816   return true;
817 }
818 
819 void BranchProbabilityInfo::releaseMemory() {
820   Probs.clear();
821 }
822 
823 void BranchProbabilityInfo::print(raw_ostream &OS) const {
824   OS << "---- Branch Probabilities ----\n";
825   // We print the probabilities from the last function the analysis ran over,
826   // or the function it is currently running over.
827   assert(LastF && "Cannot print prior to running over a function");
828   for (const auto &BI : *LastF) {
829     for (succ_const_iterator SI = succ_begin(&BI), SE = succ_end(&BI); SI != SE;
830          ++SI) {
831       printEdgeProbability(OS << "  ", &BI, *SI);
832     }
833   }
834 }
835 
836 bool BranchProbabilityInfo::
837 isEdgeHot(const BasicBlock *Src, const BasicBlock *Dst) const {
838   // Hot probability is at least 4/5 = 80%
839   // FIXME: Compare against a static "hot" BranchProbability.
840   return getEdgeProbability(Src, Dst) > BranchProbability(4, 5);
841 }
842 
843 const BasicBlock *
844 BranchProbabilityInfo::getHotSucc(const BasicBlock *BB) const {
845   auto MaxProb = BranchProbability::getZero();
846   const BasicBlock *MaxSucc = nullptr;
847 
848   for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) {
849     const BasicBlock *Succ = *I;
850     auto Prob = getEdgeProbability(BB, Succ);
851     if (Prob > MaxProb) {
852       MaxProb = Prob;
853       MaxSucc = Succ;
854     }
855   }
856 
857   // Hot probability is at least 4/5 = 80%
858   if (MaxProb > BranchProbability(4, 5))
859     return MaxSucc;
860 
861   return nullptr;
862 }
863 
864 /// Get the raw edge probability for the edge. If can't find it, return a
865 /// default probability 1/N where N is the number of successors. Here an edge is
866 /// specified using PredBlock and an
867 /// index to the successors.
868 BranchProbability
869 BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
870                                           unsigned IndexInSuccessors) const {
871   auto I = Probs.find(std::make_pair(Src, IndexInSuccessors));
872 
873   if (I != Probs.end())
874     return I->second;
875 
876   return {1,
877           static_cast<uint32_t>(std::distance(succ_begin(Src), succ_end(Src)))};
878 }
879 
880 BranchProbability
881 BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
882                                           succ_const_iterator Dst) const {
883   return getEdgeProbability(Src, Dst.getSuccessorIndex());
884 }
885 
886 /// Get the raw edge probability calculated for the block pair. This returns the
887 /// sum of all raw edge probabilities from Src to Dst.
888 BranchProbability
889 BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
890                                           const BasicBlock *Dst) const {
891   auto Prob = BranchProbability::getZero();
892   bool FoundProb = false;
893   for (succ_const_iterator I = succ_begin(Src), E = succ_end(Src); I != E; ++I)
894     if (*I == Dst) {
895       auto MapI = Probs.find(std::make_pair(Src, I.getSuccessorIndex()));
896       if (MapI != Probs.end()) {
897         FoundProb = true;
898         Prob += MapI->second;
899       }
900     }
901   uint32_t succ_num = std::distance(succ_begin(Src), succ_end(Src));
902   return FoundProb ? Prob : BranchProbability(1, succ_num);
903 }
904 
905 /// Set the edge probability for a given edge specified by PredBlock and an
906 /// index to the successors.
907 void BranchProbabilityInfo::setEdgeProbability(const BasicBlock *Src,
908                                                unsigned IndexInSuccessors,
909                                                BranchProbability Prob) {
910   Probs[std::make_pair(Src, IndexInSuccessors)] = Prob;
911   Handles.insert(BasicBlockCallbackVH(Src, this));
912   DEBUG(dbgs() << "set edge " << Src->getName() << " -> " << IndexInSuccessors
913                << " successor probability to " << Prob << "\n");
914 }
915 
916 raw_ostream &
917 BranchProbabilityInfo::printEdgeProbability(raw_ostream &OS,
918                                             const BasicBlock *Src,
919                                             const BasicBlock *Dst) const {
920   const BranchProbability Prob = getEdgeProbability(Src, Dst);
921   OS << "edge " << Src->getName() << " -> " << Dst->getName()
922      << " probability is " << Prob
923      << (isEdgeHot(Src, Dst) ? " [HOT edge]\n" : "\n");
924 
925   return OS;
926 }
927 
928 void BranchProbabilityInfo::eraseBlock(const BasicBlock *BB) {
929   for (auto I = Probs.begin(), E = Probs.end(); I != E; ++I) {
930     auto Key = I->first;
931     if (Key.first == BB)
932       Probs.erase(Key);
933   }
934 }
935 
936 void BranchProbabilityInfo::calculate(const Function &F, const LoopInfo &LI,
937                                       const TargetLibraryInfo *TLI) {
938   DEBUG(dbgs() << "---- Branch Probability Info : " << F.getName()
939                << " ----\n\n");
940   LastF = &F; // Store the last function we ran on for printing.
941   assert(PostDominatedByUnreachable.empty());
942   assert(PostDominatedByColdCall.empty());
943 
944   // Record SCC numbers of blocks in the CFG to identify irreducible loops.
945   // FIXME: We could only calculate this if the CFG is known to be irreducible
946   // (perhaps cache this info in LoopInfo if we can easily calculate it there?).
947   int SccNum = 0;
948   SccInfo SccI;
949   for (scc_iterator<const Function *> It = scc_begin(&F); !It.isAtEnd();
950        ++It, ++SccNum) {
951     // Ignore single-block SCCs since they either aren't loops or LoopInfo will
952     // catch them.
953     const std::vector<const BasicBlock *> &Scc = *It;
954     if (Scc.size() == 1)
955       continue;
956 
957     DEBUG(dbgs() << "BPI: SCC " << SccNum << ":");
958     for (auto *BB : Scc) {
959       DEBUG(dbgs() << " " << BB->getName());
960       SccI.SccNums[BB] = SccNum;
961     }
962     DEBUG(dbgs() << "\n");
963   }
964 
965   // Walk the basic blocks in post-order so that we can build up state about
966   // the successors of a block iteratively.
967   for (auto BB : post_order(&F.getEntryBlock())) {
968     DEBUG(dbgs() << "Computing probabilities for " << BB->getName() << "\n");
969     updatePostDominatedByUnreachable(BB);
970     updatePostDominatedByColdCall(BB);
971     // If there is no at least two successors, no sense to set probability.
972     if (BB->getTerminator()->getNumSuccessors() < 2)
973       continue;
974     if (calcMetadataWeights(BB))
975       continue;
976     if (calcUnreachableHeuristics(BB))
977       continue;
978     if (calcColdCallHeuristics(BB))
979       continue;
980     if (calcLoopBranchHeuristics(BB, LI, SccI))
981       continue;
982     if (calcPointerHeuristics(BB))
983       continue;
984     if (calcZeroHeuristics(BB, TLI))
985       continue;
986     if (calcFloatingPointHeuristics(BB))
987       continue;
988     calcInvokeHeuristics(BB);
989   }
990 
991   PostDominatedByUnreachable.clear();
992   PostDominatedByColdCall.clear();
993 
994   if (PrintBranchProb &&
995       (PrintBranchProbFuncName.empty() ||
996        F.getName().equals(PrintBranchProbFuncName))) {
997     print(dbgs());
998   }
999 }
1000 
1001 void BranchProbabilityInfoWrapperPass::getAnalysisUsage(
1002     AnalysisUsage &AU) const {
1003   AU.addRequired<LoopInfoWrapperPass>();
1004   AU.addRequired<TargetLibraryInfoWrapperPass>();
1005   AU.setPreservesAll();
1006 }
1007 
1008 bool BranchProbabilityInfoWrapperPass::runOnFunction(Function &F) {
1009   const LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
1010   const TargetLibraryInfo &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
1011   BPI.calculate(F, LI, &TLI);
1012   return false;
1013 }
1014 
1015 void BranchProbabilityInfoWrapperPass::releaseMemory() { BPI.releaseMemory(); }
1016 
1017 void BranchProbabilityInfoWrapperPass::print(raw_ostream &OS,
1018                                              const Module *) const {
1019   BPI.print(OS);
1020 }
1021 
1022 AnalysisKey BranchProbabilityAnalysis::Key;
1023 BranchProbabilityInfo
1024 BranchProbabilityAnalysis::run(Function &F, FunctionAnalysisManager &AM) {
1025   BranchProbabilityInfo BPI;
1026   BPI.calculate(F, AM.getResult<LoopAnalysis>(F), &AM.getResult<TargetLibraryAnalysis>(F));
1027   return BPI;
1028 }
1029 
1030 PreservedAnalyses
1031 BranchProbabilityPrinterPass::run(Function &F, FunctionAnalysisManager &AM) {
1032   OS << "Printing analysis results of BPI for function "
1033      << "'" << F.getName() << "':"
1034      << "\n";
1035   AM.getResult<BranchProbabilityAnalysis>(F).print(OS);
1036   return PreservedAnalyses::all();
1037 }
1038