1 //===- BranchProbabilityInfo.cpp - Branch Probability Analysis ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Loops should be simplified before this analysis.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "llvm/Analysis/BranchProbabilityInfo.h"
14 #include "llvm/ADT/PostOrderIterator.h"
15 #include "llvm/ADT/SCCIterator.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/Analysis/LoopInfo.h"
19 #include "llvm/Analysis/TargetLibraryInfo.h"
20 #include "llvm/IR/Attributes.h"
21 #include "llvm/IR/BasicBlock.h"
22 #include "llvm/IR/CFG.h"
23 #include "llvm/IR/Constants.h"
24 #include "llvm/IR/Dominators.h"
25 #include "llvm/IR/Function.h"
26 #include "llvm/IR/InstrTypes.h"
27 #include "llvm/IR/Instruction.h"
28 #include "llvm/IR/Instructions.h"
29 #include "llvm/IR/LLVMContext.h"
30 #include "llvm/IR/Metadata.h"
31 #include "llvm/IR/PassManager.h"
32 #include "llvm/IR/Type.h"
33 #include "llvm/IR/Value.h"
34 #include "llvm/InitializePasses.h"
35 #include "llvm/Pass.h"
36 #include "llvm/Support/BranchProbability.h"
37 #include "llvm/Support/Casting.h"
38 #include "llvm/Support/Debug.h"
39 #include "llvm/Support/raw_ostream.h"
40 #include <cassert>
41 #include <cstdint>
42 #include <iterator>
43 #include <utility>
44 
45 using namespace llvm;
46 
47 #define DEBUG_TYPE "branch-prob"
48 
49 static cl::opt<bool> PrintBranchProb(
50     "print-bpi", cl::init(false), cl::Hidden,
51     cl::desc("Print the branch probability info."));
52 
53 cl::opt<std::string> PrintBranchProbFuncName(
54     "print-bpi-func-name", cl::Hidden,
55     cl::desc("The option to specify the name of the function "
56              "whose branch probability info is printed."));
57 
58 INITIALIZE_PASS_BEGIN(BranchProbabilityInfoWrapperPass, "branch-prob",
59                       "Branch Probability Analysis", false, true)
60 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
61 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
62 INITIALIZE_PASS_END(BranchProbabilityInfoWrapperPass, "branch-prob",
63                     "Branch Probability Analysis", false, true)
64 
65 BranchProbabilityInfoWrapperPass::BranchProbabilityInfoWrapperPass()
66     : FunctionPass(ID) {
67   initializeBranchProbabilityInfoWrapperPassPass(
68       *PassRegistry::getPassRegistry());
69 }
70 
71 char BranchProbabilityInfoWrapperPass::ID = 0;
72 
73 // Weights are for internal use only. They are used by heuristics to help to
74 // estimate edges' probability. Example:
75 //
76 // Using "Loop Branch Heuristics" we predict weights of edges for the
77 // block BB2.
78 //         ...
79 //          |
80 //          V
81 //         BB1<-+
82 //          |   |
83 //          |   | (Weight = 124)
84 //          V   |
85 //         BB2--+
86 //          |
87 //          | (Weight = 4)
88 //          V
89 //         BB3
90 //
91 // Probability of the edge BB2->BB1 = 124 / (124 + 4) = 0.96875
92 // Probability of the edge BB2->BB3 = 4 / (124 + 4) = 0.03125
93 static const uint32_t LBH_TAKEN_WEIGHT = 124;
94 static const uint32_t LBH_NONTAKEN_WEIGHT = 4;
95 // Unlikely edges within a loop are half as likely as other edges
96 static const uint32_t LBH_UNLIKELY_WEIGHT = 62;
97 
98 /// Unreachable-terminating branch taken probability.
99 ///
100 /// This is the probability for a branch being taken to a block that terminates
101 /// (eventually) in unreachable. These are predicted as unlikely as possible.
102 /// All reachable probability will equally share the remaining part.
103 static const BranchProbability UR_TAKEN_PROB = BranchProbability::getRaw(1);
104 
105 /// Weight for a branch taken going into a cold block.
106 ///
107 /// This is the weight for a branch taken toward a block marked
108 /// cold.  A block is marked cold if it's postdominated by a
109 /// block containing a call to a cold function.  Cold functions
110 /// are those marked with attribute 'cold'.
111 static const uint32_t CC_TAKEN_WEIGHT = 4;
112 
113 /// Weight for a branch not-taken into a cold block.
114 ///
115 /// This is the weight for a branch not taken toward a block marked
116 /// cold.
117 static const uint32_t CC_NONTAKEN_WEIGHT = 64;
118 
119 static const uint32_t PH_TAKEN_WEIGHT = 20;
120 static const uint32_t PH_NONTAKEN_WEIGHT = 12;
121 
122 static const uint32_t ZH_TAKEN_WEIGHT = 20;
123 static const uint32_t ZH_NONTAKEN_WEIGHT = 12;
124 
125 static const uint32_t FPH_TAKEN_WEIGHT = 20;
126 static const uint32_t FPH_NONTAKEN_WEIGHT = 12;
127 
128 /// This is the probability for an ordered floating point comparison.
129 static const uint32_t FPH_ORD_WEIGHT = 1024 * 1024 - 1;
130 /// This is the probability for an unordered floating point comparison, it means
131 /// one or two of the operands are NaN. Usually it is used to test for an
132 /// exceptional case, so the result is unlikely.
133 static const uint32_t FPH_UNO_WEIGHT = 1;
134 
135 /// Invoke-terminating normal branch taken weight
136 ///
137 /// This is the weight for branching to the normal destination of an invoke
138 /// instruction. We expect this to happen most of the time. Set the weight to an
139 /// absurdly high value so that nested loops subsume it.
140 static const uint32_t IH_TAKEN_WEIGHT = 1024 * 1024 - 1;
141 
142 /// Invoke-terminating normal branch not-taken weight.
143 ///
144 /// This is the weight for branching to the unwind destination of an invoke
145 /// instruction. This is essentially never taken.
146 static const uint32_t IH_NONTAKEN_WEIGHT = 1;
147 
148 /// Add \p BB to PostDominatedByUnreachable set if applicable.
149 void
150 BranchProbabilityInfo::updatePostDominatedByUnreachable(const BasicBlock *BB) {
151   const Instruction *TI = BB->getTerminator();
152   if (TI->getNumSuccessors() == 0) {
153     if (isa<UnreachableInst>(TI) ||
154         // If this block is terminated by a call to
155         // @llvm.experimental.deoptimize then treat it like an unreachable since
156         // the @llvm.experimental.deoptimize call is expected to practically
157         // never execute.
158         BB->getTerminatingDeoptimizeCall())
159       PostDominatedByUnreachable.insert(BB);
160     return;
161   }
162 
163   // If the terminator is an InvokeInst, check only the normal destination block
164   // as the unwind edge of InvokeInst is also very unlikely taken.
165   if (auto *II = dyn_cast<InvokeInst>(TI)) {
166     if (PostDominatedByUnreachable.count(II->getNormalDest()))
167       PostDominatedByUnreachable.insert(BB);
168     return;
169   }
170 
171   for (auto *I : successors(BB))
172     // If any of successor is not post dominated then BB is also not.
173     if (!PostDominatedByUnreachable.count(I))
174       return;
175 
176   PostDominatedByUnreachable.insert(BB);
177 }
178 
179 /// Add \p BB to PostDominatedByColdCall set if applicable.
180 void
181 BranchProbabilityInfo::updatePostDominatedByColdCall(const BasicBlock *BB) {
182   assert(!PostDominatedByColdCall.count(BB));
183   const Instruction *TI = BB->getTerminator();
184   if (TI->getNumSuccessors() == 0)
185     return;
186 
187   // If all of successor are post dominated then BB is also done.
188   if (llvm::all_of(successors(BB), [&](const BasicBlock *SuccBB) {
189         return PostDominatedByColdCall.count(SuccBB);
190       })) {
191     PostDominatedByColdCall.insert(BB);
192     return;
193   }
194 
195   // If the terminator is an InvokeInst, check only the normal destination
196   // block as the unwind edge of InvokeInst is also very unlikely taken.
197   if (auto *II = dyn_cast<InvokeInst>(TI))
198     if (PostDominatedByColdCall.count(II->getNormalDest())) {
199       PostDominatedByColdCall.insert(BB);
200       return;
201     }
202 
203   // Otherwise, if the block itself contains a cold function, add it to the
204   // set of blocks post-dominated by a cold call.
205   for (auto &I : *BB)
206     if (const CallInst *CI = dyn_cast<CallInst>(&I))
207       if (CI->hasFnAttr(Attribute::Cold)) {
208         PostDominatedByColdCall.insert(BB);
209         return;
210       }
211 }
212 
213 /// Calculate edge weights for successors lead to unreachable.
214 ///
215 /// Predict that a successor which leads necessarily to an
216 /// unreachable-terminated block as extremely unlikely.
217 bool BranchProbabilityInfo::calcUnreachableHeuristics(const BasicBlock *BB) {
218   const Instruction *TI = BB->getTerminator();
219   (void) TI;
220   assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
221   assert(!isa<InvokeInst>(TI) &&
222          "Invokes should have already been handled by calcInvokeHeuristics");
223 
224   SmallVector<unsigned, 4> UnreachableEdges;
225   SmallVector<unsigned, 4> ReachableEdges;
226 
227   for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I)
228     if (PostDominatedByUnreachable.count(*I))
229       UnreachableEdges.push_back(I.getSuccessorIndex());
230     else
231       ReachableEdges.push_back(I.getSuccessorIndex());
232 
233   // Skip probabilities if all were reachable.
234   if (UnreachableEdges.empty())
235     return false;
236 
237   if (ReachableEdges.empty()) {
238     BranchProbability Prob(1, UnreachableEdges.size());
239     for (unsigned SuccIdx : UnreachableEdges)
240       setEdgeProbability(BB, SuccIdx, Prob);
241     return true;
242   }
243 
244   auto UnreachableProb = UR_TAKEN_PROB;
245   auto ReachableProb =
246       (BranchProbability::getOne() - UR_TAKEN_PROB * UnreachableEdges.size()) /
247       ReachableEdges.size();
248 
249   for (unsigned SuccIdx : UnreachableEdges)
250     setEdgeProbability(BB, SuccIdx, UnreachableProb);
251   for (unsigned SuccIdx : ReachableEdges)
252     setEdgeProbability(BB, SuccIdx, ReachableProb);
253 
254   return true;
255 }
256 
257 // Propagate existing explicit probabilities from either profile data or
258 // 'expect' intrinsic processing. Examine metadata against unreachable
259 // heuristic. The probability of the edge coming to unreachable block is
260 // set to min of metadata and unreachable heuristic.
261 bool BranchProbabilityInfo::calcMetadataWeights(const BasicBlock *BB) {
262   const Instruction *TI = BB->getTerminator();
263   assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
264   if (!(isa<BranchInst>(TI) || isa<SwitchInst>(TI) || isa<IndirectBrInst>(TI)))
265     return false;
266 
267   MDNode *WeightsNode = TI->getMetadata(LLVMContext::MD_prof);
268   if (!WeightsNode)
269     return false;
270 
271   // Check that the number of successors is manageable.
272   assert(TI->getNumSuccessors() < UINT32_MAX && "Too many successors");
273 
274   // Ensure there are weights for all of the successors. Note that the first
275   // operand to the metadata node is a name, not a weight.
276   if (WeightsNode->getNumOperands() != TI->getNumSuccessors() + 1)
277     return false;
278 
279   // Build up the final weights that will be used in a temporary buffer.
280   // Compute the sum of all weights to later decide whether they need to
281   // be scaled to fit in 32 bits.
282   uint64_t WeightSum = 0;
283   SmallVector<uint32_t, 2> Weights;
284   SmallVector<unsigned, 2> UnreachableIdxs;
285   SmallVector<unsigned, 2> ReachableIdxs;
286   Weights.reserve(TI->getNumSuccessors());
287   for (unsigned i = 1, e = WeightsNode->getNumOperands(); i != e; ++i) {
288     ConstantInt *Weight =
289         mdconst::dyn_extract<ConstantInt>(WeightsNode->getOperand(i));
290     if (!Weight)
291       return false;
292     assert(Weight->getValue().getActiveBits() <= 32 &&
293            "Too many bits for uint32_t");
294     Weights.push_back(Weight->getZExtValue());
295     WeightSum += Weights.back();
296     if (PostDominatedByUnreachable.count(TI->getSuccessor(i - 1)))
297       UnreachableIdxs.push_back(i - 1);
298     else
299       ReachableIdxs.push_back(i - 1);
300   }
301   assert(Weights.size() == TI->getNumSuccessors() && "Checked above");
302 
303   // If the sum of weights does not fit in 32 bits, scale every weight down
304   // accordingly.
305   uint64_t ScalingFactor =
306       (WeightSum > UINT32_MAX) ? WeightSum / UINT32_MAX + 1 : 1;
307 
308   if (ScalingFactor > 1) {
309     WeightSum = 0;
310     for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) {
311       Weights[i] /= ScalingFactor;
312       WeightSum += Weights[i];
313     }
314   }
315   assert(WeightSum <= UINT32_MAX &&
316          "Expected weights to scale down to 32 bits");
317 
318   if (WeightSum == 0 || ReachableIdxs.size() == 0) {
319     for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
320       Weights[i] = 1;
321     WeightSum = TI->getNumSuccessors();
322   }
323 
324   // Set the probability.
325   SmallVector<BranchProbability, 2> BP;
326   for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
327     BP.push_back({ Weights[i], static_cast<uint32_t>(WeightSum) });
328 
329   // Examine the metadata against unreachable heuristic.
330   // If the unreachable heuristic is more strong then we use it for this edge.
331   if (UnreachableIdxs.size() > 0 && ReachableIdxs.size() > 0) {
332     auto ToDistribute = BranchProbability::getZero();
333     auto UnreachableProb = UR_TAKEN_PROB;
334     for (auto i : UnreachableIdxs)
335       if (UnreachableProb < BP[i]) {
336         ToDistribute += BP[i] - UnreachableProb;
337         BP[i] = UnreachableProb;
338       }
339 
340     // If we modified the probability of some edges then we must distribute
341     // the difference between reachable blocks.
342     if (ToDistribute > BranchProbability::getZero()) {
343       BranchProbability PerEdge = ToDistribute / ReachableIdxs.size();
344       for (auto i : ReachableIdxs)
345         BP[i] += PerEdge;
346     }
347   }
348 
349   for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
350     setEdgeProbability(BB, i, BP[i]);
351 
352   return true;
353 }
354 
355 /// Calculate edge weights for edges leading to cold blocks.
356 ///
357 /// A cold block is one post-dominated by  a block with a call to a
358 /// cold function.  Those edges are unlikely to be taken, so we give
359 /// them relatively low weight.
360 ///
361 /// Return true if we could compute the weights for cold edges.
362 /// Return false, otherwise.
363 bool BranchProbabilityInfo::calcColdCallHeuristics(const BasicBlock *BB) {
364   const Instruction *TI = BB->getTerminator();
365   (void) TI;
366   assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
367   assert(!isa<InvokeInst>(TI) &&
368          "Invokes should have already been handled by calcInvokeHeuristics");
369 
370   // Determine which successors are post-dominated by a cold block.
371   SmallVector<unsigned, 4> ColdEdges;
372   SmallVector<unsigned, 4> NormalEdges;
373   for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I)
374     if (PostDominatedByColdCall.count(*I))
375       ColdEdges.push_back(I.getSuccessorIndex());
376     else
377       NormalEdges.push_back(I.getSuccessorIndex());
378 
379   // Skip probabilities if no cold edges.
380   if (ColdEdges.empty())
381     return false;
382 
383   if (NormalEdges.empty()) {
384     BranchProbability Prob(1, ColdEdges.size());
385     for (unsigned SuccIdx : ColdEdges)
386       setEdgeProbability(BB, SuccIdx, Prob);
387     return true;
388   }
389 
390   auto ColdProb = BranchProbability::getBranchProbability(
391       CC_TAKEN_WEIGHT,
392       (CC_TAKEN_WEIGHT + CC_NONTAKEN_WEIGHT) * uint64_t(ColdEdges.size()));
393   auto NormalProb = BranchProbability::getBranchProbability(
394       CC_NONTAKEN_WEIGHT,
395       (CC_TAKEN_WEIGHT + CC_NONTAKEN_WEIGHT) * uint64_t(NormalEdges.size()));
396 
397   for (unsigned SuccIdx : ColdEdges)
398     setEdgeProbability(BB, SuccIdx, ColdProb);
399   for (unsigned SuccIdx : NormalEdges)
400     setEdgeProbability(BB, SuccIdx, NormalProb);
401 
402   return true;
403 }
404 
405 // Calculate Edge Weights using "Pointer Heuristics". Predict a comparison
406 // between two pointer or pointer and NULL will fail.
407 bool BranchProbabilityInfo::calcPointerHeuristics(const BasicBlock *BB) {
408   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
409   if (!BI || !BI->isConditional())
410     return false;
411 
412   Value *Cond = BI->getCondition();
413   ICmpInst *CI = dyn_cast<ICmpInst>(Cond);
414   if (!CI || !CI->isEquality())
415     return false;
416 
417   Value *LHS = CI->getOperand(0);
418 
419   if (!LHS->getType()->isPointerTy())
420     return false;
421 
422   assert(CI->getOperand(1)->getType()->isPointerTy());
423 
424   // p != 0   ->   isProb = true
425   // p == 0   ->   isProb = false
426   // p != q   ->   isProb = true
427   // p == q   ->   isProb = false;
428   unsigned TakenIdx = 0, NonTakenIdx = 1;
429   bool isProb = CI->getPredicate() == ICmpInst::ICMP_NE;
430   if (!isProb)
431     std::swap(TakenIdx, NonTakenIdx);
432 
433   BranchProbability TakenProb(PH_TAKEN_WEIGHT,
434                               PH_TAKEN_WEIGHT + PH_NONTAKEN_WEIGHT);
435   setEdgeProbability(BB, TakenIdx, TakenProb);
436   setEdgeProbability(BB, NonTakenIdx, TakenProb.getCompl());
437   return true;
438 }
439 
440 static int getSCCNum(const BasicBlock *BB,
441                      const BranchProbabilityInfo::SccInfo &SccI) {
442   auto SccIt = SccI.SccNums.find(BB);
443   if (SccIt == SccI.SccNums.end())
444     return -1;
445   return SccIt->second;
446 }
447 
448 // Consider any block that is an entry point to the SCC as a header.
449 static bool isSCCHeader(const BasicBlock *BB, int SccNum,
450                         BranchProbabilityInfo::SccInfo &SccI) {
451   assert(getSCCNum(BB, SccI) == SccNum);
452 
453   // Lazily compute the set of headers for a given SCC and cache the results
454   // in the SccHeaderMap.
455   if (SccI.SccHeaders.size() <= static_cast<unsigned>(SccNum))
456     SccI.SccHeaders.resize(SccNum + 1);
457   auto &HeaderMap = SccI.SccHeaders[SccNum];
458   bool Inserted;
459   BranchProbabilityInfo::SccHeaderMap::iterator HeaderMapIt;
460   std::tie(HeaderMapIt, Inserted) = HeaderMap.insert(std::make_pair(BB, false));
461   if (Inserted) {
462     bool IsHeader = llvm::any_of(make_range(pred_begin(BB), pred_end(BB)),
463                                  [&](const BasicBlock *Pred) {
464                                    return getSCCNum(Pred, SccI) != SccNum;
465                                  });
466     HeaderMapIt->second = IsHeader;
467     return IsHeader;
468   } else
469     return HeaderMapIt->second;
470 }
471 
472 // Compute the unlikely successors to the block BB in the loop L, specifically
473 // those that are unlikely because this is a loop, and add them to the
474 // UnlikelyBlocks set.
475 static void
476 computeUnlikelySuccessors(const BasicBlock *BB, Loop *L,
477                           SmallPtrSetImpl<const BasicBlock*> &UnlikelyBlocks) {
478   // Sometimes in a loop we have a branch whose condition is made false by
479   // taking it. This is typically something like
480   //  int n = 0;
481   //  while (...) {
482   //    if (++n >= MAX) {
483   //      n = 0;
484   //    }
485   //  }
486   // In this sort of situation taking the branch means that at the very least it
487   // won't be taken again in the next iteration of the loop, so we should
488   // consider it less likely than a typical branch.
489   //
490   // We detect this by looking back through the graph of PHI nodes that sets the
491   // value that the condition depends on, and seeing if we can reach a successor
492   // block which can be determined to make the condition false.
493   //
494   // FIXME: We currently consider unlikely blocks to be half as likely as other
495   // blocks, but if we consider the example above the likelyhood is actually
496   // 1/MAX. We could therefore be more precise in how unlikely we consider
497   // blocks to be, but it would require more careful examination of the form
498   // of the comparison expression.
499   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
500   if (!BI || !BI->isConditional())
501     return;
502 
503   // Check if the branch is based on an instruction compared with a constant
504   CmpInst *CI = dyn_cast<CmpInst>(BI->getCondition());
505   if (!CI || !isa<Instruction>(CI->getOperand(0)) ||
506       !isa<Constant>(CI->getOperand(1)))
507     return;
508 
509   // Either the instruction must be a PHI, or a chain of operations involving
510   // constants that ends in a PHI which we can then collapse into a single value
511   // if the PHI value is known.
512   Instruction *CmpLHS = dyn_cast<Instruction>(CI->getOperand(0));
513   PHINode *CmpPHI = dyn_cast<PHINode>(CmpLHS);
514   Constant *CmpConst = dyn_cast<Constant>(CI->getOperand(1));
515   // Collect the instructions until we hit a PHI
516   SmallVector<BinaryOperator *, 1> InstChain;
517   while (!CmpPHI && CmpLHS && isa<BinaryOperator>(CmpLHS) &&
518          isa<Constant>(CmpLHS->getOperand(1))) {
519     // Stop if the chain extends outside of the loop
520     if (!L->contains(CmpLHS))
521       return;
522     InstChain.push_back(cast<BinaryOperator>(CmpLHS));
523     CmpLHS = dyn_cast<Instruction>(CmpLHS->getOperand(0));
524     if (CmpLHS)
525       CmpPHI = dyn_cast<PHINode>(CmpLHS);
526   }
527   if (!CmpPHI || !L->contains(CmpPHI))
528     return;
529 
530   // Trace the phi node to find all values that come from successors of BB
531   SmallPtrSet<PHINode*, 8> VisitedInsts;
532   SmallVector<PHINode*, 8> WorkList;
533   WorkList.push_back(CmpPHI);
534   VisitedInsts.insert(CmpPHI);
535   while (!WorkList.empty()) {
536     PHINode *P = WorkList.back();
537     WorkList.pop_back();
538     for (BasicBlock *B : P->blocks()) {
539       // Skip blocks that aren't part of the loop
540       if (!L->contains(B))
541         continue;
542       Value *V = P->getIncomingValueForBlock(B);
543       // If the source is a PHI add it to the work list if we haven't
544       // already visited it.
545       if (PHINode *PN = dyn_cast<PHINode>(V)) {
546         if (VisitedInsts.insert(PN).second)
547           WorkList.push_back(PN);
548         continue;
549       }
550       // If this incoming value is a constant and B is a successor of BB, then
551       // we can constant-evaluate the compare to see if it makes the branch be
552       // taken or not.
553       Constant *CmpLHSConst = dyn_cast<Constant>(V);
554       if (!CmpLHSConst ||
555           std::find(succ_begin(BB), succ_end(BB), B) == succ_end(BB))
556         continue;
557       // First collapse InstChain
558       for (Instruction *I : llvm::reverse(InstChain)) {
559         CmpLHSConst = ConstantExpr::get(I->getOpcode(), CmpLHSConst,
560                                         cast<Constant>(I->getOperand(1)), true);
561         if (!CmpLHSConst)
562           break;
563       }
564       if (!CmpLHSConst)
565         continue;
566       // Now constant-evaluate the compare
567       Constant *Result = ConstantExpr::getCompare(CI->getPredicate(),
568                                                   CmpLHSConst, CmpConst, true);
569       // If the result means we don't branch to the block then that block is
570       // unlikely.
571       if (Result &&
572           ((Result->isZeroValue() && B == BI->getSuccessor(0)) ||
573            (Result->isOneValue() && B == BI->getSuccessor(1))))
574         UnlikelyBlocks.insert(B);
575     }
576   }
577 }
578 
579 // Calculate Edge Weights using "Loop Branch Heuristics". Predict backedges
580 // as taken, exiting edges as not-taken.
581 bool BranchProbabilityInfo::calcLoopBranchHeuristics(const BasicBlock *BB,
582                                                      const LoopInfo &LI,
583                                                      SccInfo &SccI) {
584   int SccNum;
585   Loop *L = LI.getLoopFor(BB);
586   if (!L) {
587     SccNum = getSCCNum(BB, SccI);
588     if (SccNum < 0)
589       return false;
590   }
591 
592   SmallPtrSet<const BasicBlock*, 8> UnlikelyBlocks;
593   if (L)
594     computeUnlikelySuccessors(BB, L, UnlikelyBlocks);
595 
596   SmallVector<unsigned, 8> BackEdges;
597   SmallVector<unsigned, 8> ExitingEdges;
598   SmallVector<unsigned, 8> InEdges; // Edges from header to the loop.
599   SmallVector<unsigned, 8> UnlikelyEdges;
600 
601   for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) {
602     // Use LoopInfo if we have it, otherwise fall-back to SCC info to catch
603     // irreducible loops.
604     if (L) {
605       if (UnlikelyBlocks.count(*I) != 0)
606         UnlikelyEdges.push_back(I.getSuccessorIndex());
607       else if (!L->contains(*I))
608         ExitingEdges.push_back(I.getSuccessorIndex());
609       else if (L->getHeader() == *I)
610         BackEdges.push_back(I.getSuccessorIndex());
611       else
612         InEdges.push_back(I.getSuccessorIndex());
613     } else {
614       if (getSCCNum(*I, SccI) != SccNum)
615         ExitingEdges.push_back(I.getSuccessorIndex());
616       else if (isSCCHeader(*I, SccNum, SccI))
617         BackEdges.push_back(I.getSuccessorIndex());
618       else
619         InEdges.push_back(I.getSuccessorIndex());
620     }
621   }
622 
623   if (BackEdges.empty() && ExitingEdges.empty() && UnlikelyEdges.empty())
624     return false;
625 
626   // Collect the sum of probabilities of back-edges/in-edges/exiting-edges, and
627   // normalize them so that they sum up to one.
628   unsigned Denom = (BackEdges.empty() ? 0 : LBH_TAKEN_WEIGHT) +
629                    (InEdges.empty() ? 0 : LBH_TAKEN_WEIGHT) +
630                    (UnlikelyEdges.empty() ? 0 : LBH_UNLIKELY_WEIGHT) +
631                    (ExitingEdges.empty() ? 0 : LBH_NONTAKEN_WEIGHT);
632 
633   if (uint32_t numBackEdges = BackEdges.size()) {
634     BranchProbability TakenProb = BranchProbability(LBH_TAKEN_WEIGHT, Denom);
635     auto Prob = TakenProb / numBackEdges;
636     for (unsigned SuccIdx : BackEdges)
637       setEdgeProbability(BB, SuccIdx, Prob);
638   }
639 
640   if (uint32_t numInEdges = InEdges.size()) {
641     BranchProbability TakenProb = BranchProbability(LBH_TAKEN_WEIGHT, Denom);
642     auto Prob = TakenProb / numInEdges;
643     for (unsigned SuccIdx : InEdges)
644       setEdgeProbability(BB, SuccIdx, Prob);
645   }
646 
647   if (uint32_t numExitingEdges = ExitingEdges.size()) {
648     BranchProbability NotTakenProb = BranchProbability(LBH_NONTAKEN_WEIGHT,
649                                                        Denom);
650     auto Prob = NotTakenProb / numExitingEdges;
651     for (unsigned SuccIdx : ExitingEdges)
652       setEdgeProbability(BB, SuccIdx, Prob);
653   }
654 
655   if (uint32_t numUnlikelyEdges = UnlikelyEdges.size()) {
656     BranchProbability UnlikelyProb = BranchProbability(LBH_UNLIKELY_WEIGHT,
657                                                        Denom);
658     auto Prob = UnlikelyProb / numUnlikelyEdges;
659     for (unsigned SuccIdx : UnlikelyEdges)
660       setEdgeProbability(BB, SuccIdx, Prob);
661   }
662 
663   return true;
664 }
665 
666 bool BranchProbabilityInfo::calcZeroHeuristics(const BasicBlock *BB,
667                                                const TargetLibraryInfo *TLI) {
668   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
669   if (!BI || !BI->isConditional())
670     return false;
671 
672   Value *Cond = BI->getCondition();
673   ICmpInst *CI = dyn_cast<ICmpInst>(Cond);
674   if (!CI)
675     return false;
676 
677   auto GetConstantInt = [](Value *V) {
678     if (auto *I = dyn_cast<BitCastInst>(V))
679       return dyn_cast<ConstantInt>(I->getOperand(0));
680     return dyn_cast<ConstantInt>(V);
681   };
682 
683   Value *RHS = CI->getOperand(1);
684   ConstantInt *CV = GetConstantInt(RHS);
685   if (!CV)
686     return false;
687 
688   // If the LHS is the result of AND'ing a value with a single bit bitmask,
689   // we don't have information about probabilities.
690   if (Instruction *LHS = dyn_cast<Instruction>(CI->getOperand(0)))
691     if (LHS->getOpcode() == Instruction::And)
692       if (ConstantInt *AndRHS = dyn_cast<ConstantInt>(LHS->getOperand(1)))
693         if (AndRHS->getValue().isPowerOf2())
694           return false;
695 
696   // Check if the LHS is the return value of a library function
697   LibFunc Func = NumLibFuncs;
698   if (TLI)
699     if (CallInst *Call = dyn_cast<CallInst>(CI->getOperand(0)))
700       if (Function *CalledFn = Call->getCalledFunction())
701         TLI->getLibFunc(*CalledFn, Func);
702 
703   bool isProb;
704   if (Func == LibFunc_strcasecmp ||
705       Func == LibFunc_strcmp ||
706       Func == LibFunc_strncasecmp ||
707       Func == LibFunc_strncmp ||
708       Func == LibFunc_memcmp) {
709     // strcmp and similar functions return zero, negative, or positive, if the
710     // first string is equal, less, or greater than the second. We consider it
711     // likely that the strings are not equal, so a comparison with zero is
712     // probably false, but also a comparison with any other number is also
713     // probably false given that what exactly is returned for nonzero values is
714     // not specified. Any kind of comparison other than equality we know
715     // nothing about.
716     switch (CI->getPredicate()) {
717     case CmpInst::ICMP_EQ:
718       isProb = false;
719       break;
720     case CmpInst::ICMP_NE:
721       isProb = true;
722       break;
723     default:
724       return false;
725     }
726   } else if (CV->isZero()) {
727     switch (CI->getPredicate()) {
728     case CmpInst::ICMP_EQ:
729       // X == 0   ->  Unlikely
730       isProb = false;
731       break;
732     case CmpInst::ICMP_NE:
733       // X != 0   ->  Likely
734       isProb = true;
735       break;
736     case CmpInst::ICMP_SLT:
737       // X < 0   ->  Unlikely
738       isProb = false;
739       break;
740     case CmpInst::ICMP_SGT:
741       // X > 0   ->  Likely
742       isProb = true;
743       break;
744     default:
745       return false;
746     }
747   } else if (CV->isOne() && CI->getPredicate() == CmpInst::ICMP_SLT) {
748     // InstCombine canonicalizes X <= 0 into X < 1.
749     // X <= 0   ->  Unlikely
750     isProb = false;
751   } else if (CV->isMinusOne()) {
752     switch (CI->getPredicate()) {
753     case CmpInst::ICMP_EQ:
754       // X == -1  ->  Unlikely
755       isProb = false;
756       break;
757     case CmpInst::ICMP_NE:
758       // X != -1  ->  Likely
759       isProb = true;
760       break;
761     case CmpInst::ICMP_SGT:
762       // InstCombine canonicalizes X >= 0 into X > -1.
763       // X >= 0   ->  Likely
764       isProb = true;
765       break;
766     default:
767       return false;
768     }
769   } else {
770     return false;
771   }
772 
773   unsigned TakenIdx = 0, NonTakenIdx = 1;
774 
775   if (!isProb)
776     std::swap(TakenIdx, NonTakenIdx);
777 
778   BranchProbability TakenProb(ZH_TAKEN_WEIGHT,
779                               ZH_TAKEN_WEIGHT + ZH_NONTAKEN_WEIGHT);
780   setEdgeProbability(BB, TakenIdx, TakenProb);
781   setEdgeProbability(BB, NonTakenIdx, TakenProb.getCompl());
782   return true;
783 }
784 
785 bool BranchProbabilityInfo::calcFloatingPointHeuristics(const BasicBlock *BB) {
786   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
787   if (!BI || !BI->isConditional())
788     return false;
789 
790   Value *Cond = BI->getCondition();
791   FCmpInst *FCmp = dyn_cast<FCmpInst>(Cond);
792   if (!FCmp)
793     return false;
794 
795   uint32_t TakenWeight = FPH_TAKEN_WEIGHT;
796   uint32_t NontakenWeight = FPH_NONTAKEN_WEIGHT;
797   bool isProb;
798   if (FCmp->isEquality()) {
799     // f1 == f2 -> Unlikely
800     // f1 != f2 -> Likely
801     isProb = !FCmp->isTrueWhenEqual();
802   } else if (FCmp->getPredicate() == FCmpInst::FCMP_ORD) {
803     // !isnan -> Likely
804     isProb = true;
805     TakenWeight = FPH_ORD_WEIGHT;
806     NontakenWeight = FPH_UNO_WEIGHT;
807   } else if (FCmp->getPredicate() == FCmpInst::FCMP_UNO) {
808     // isnan -> Unlikely
809     isProb = false;
810     TakenWeight = FPH_ORD_WEIGHT;
811     NontakenWeight = FPH_UNO_WEIGHT;
812   } else {
813     return false;
814   }
815 
816   unsigned TakenIdx = 0, NonTakenIdx = 1;
817 
818   if (!isProb)
819     std::swap(TakenIdx, NonTakenIdx);
820 
821   BranchProbability TakenProb(TakenWeight, TakenWeight + NontakenWeight);
822   setEdgeProbability(BB, TakenIdx, TakenProb);
823   setEdgeProbability(BB, NonTakenIdx, TakenProb.getCompl());
824   return true;
825 }
826 
827 bool BranchProbabilityInfo::calcInvokeHeuristics(const BasicBlock *BB) {
828   const InvokeInst *II = dyn_cast<InvokeInst>(BB->getTerminator());
829   if (!II)
830     return false;
831 
832   BranchProbability TakenProb(IH_TAKEN_WEIGHT,
833                               IH_TAKEN_WEIGHT + IH_NONTAKEN_WEIGHT);
834   setEdgeProbability(BB, 0 /*Index for Normal*/, TakenProb);
835   setEdgeProbability(BB, 1 /*Index for Unwind*/, TakenProb.getCompl());
836   return true;
837 }
838 
839 void BranchProbabilityInfo::releaseMemory() {
840   Probs.clear();
841 }
842 
843 void BranchProbabilityInfo::print(raw_ostream &OS) const {
844   OS << "---- Branch Probabilities ----\n";
845   // We print the probabilities from the last function the analysis ran over,
846   // or the function it is currently running over.
847   assert(LastF && "Cannot print prior to running over a function");
848   for (const auto &BI : *LastF) {
849     for (succ_const_iterator SI = succ_begin(&BI), SE = succ_end(&BI); SI != SE;
850          ++SI) {
851       printEdgeProbability(OS << "  ", &BI, *SI);
852     }
853   }
854 }
855 
856 bool BranchProbabilityInfo::
857 isEdgeHot(const BasicBlock *Src, const BasicBlock *Dst) const {
858   // Hot probability is at least 4/5 = 80%
859   // FIXME: Compare against a static "hot" BranchProbability.
860   return getEdgeProbability(Src, Dst) > BranchProbability(4, 5);
861 }
862 
863 const BasicBlock *
864 BranchProbabilityInfo::getHotSucc(const BasicBlock *BB) const {
865   auto MaxProb = BranchProbability::getZero();
866   const BasicBlock *MaxSucc = nullptr;
867 
868   for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) {
869     const BasicBlock *Succ = *I;
870     auto Prob = getEdgeProbability(BB, Succ);
871     if (Prob > MaxProb) {
872       MaxProb = Prob;
873       MaxSucc = Succ;
874     }
875   }
876 
877   // Hot probability is at least 4/5 = 80%
878   if (MaxProb > BranchProbability(4, 5))
879     return MaxSucc;
880 
881   return nullptr;
882 }
883 
884 /// Get the raw edge probability for the edge. If can't find it, return a
885 /// default probability 1/N where N is the number of successors. Here an edge is
886 /// specified using PredBlock and an
887 /// index to the successors.
888 BranchProbability
889 BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
890                                           unsigned IndexInSuccessors) const {
891   auto I = Probs.find(std::make_pair(Src, IndexInSuccessors));
892 
893   if (I != Probs.end())
894     return I->second;
895 
896   return {1, static_cast<uint32_t>(succ_size(Src))};
897 }
898 
899 BranchProbability
900 BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
901                                           succ_const_iterator Dst) const {
902   return getEdgeProbability(Src, Dst.getSuccessorIndex());
903 }
904 
905 /// Get the raw edge probability calculated for the block pair. This returns the
906 /// sum of all raw edge probabilities from Src to Dst.
907 BranchProbability
908 BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
909                                           const BasicBlock *Dst) const {
910   auto Prob = BranchProbability::getZero();
911   bool FoundProb = false;
912   for (succ_const_iterator I = succ_begin(Src), E = succ_end(Src); I != E; ++I)
913     if (*I == Dst) {
914       auto MapI = Probs.find(std::make_pair(Src, I.getSuccessorIndex()));
915       if (MapI != Probs.end()) {
916         FoundProb = true;
917         Prob += MapI->second;
918       }
919     }
920   uint32_t succ_num = std::distance(succ_begin(Src), succ_end(Src));
921   return FoundProb ? Prob : BranchProbability(1, succ_num);
922 }
923 
924 /// Set the edge probability for a given edge specified by PredBlock and an
925 /// index to the successors.
926 void BranchProbabilityInfo::setEdgeProbability(const BasicBlock *Src,
927                                                unsigned IndexInSuccessors,
928                                                BranchProbability Prob) {
929   Probs[std::make_pair(Src, IndexInSuccessors)] = Prob;
930   Handles.insert(BasicBlockCallbackVH(Src, this));
931   LLVM_DEBUG(dbgs() << "set edge " << Src->getName() << " -> "
932                     << IndexInSuccessors << " successor probability to " << Prob
933                     << "\n");
934 }
935 
936 raw_ostream &
937 BranchProbabilityInfo::printEdgeProbability(raw_ostream &OS,
938                                             const BasicBlock *Src,
939                                             const BasicBlock *Dst) const {
940   const BranchProbability Prob = getEdgeProbability(Src, Dst);
941   OS << "edge " << Src->getName() << " -> " << Dst->getName()
942      << " probability is " << Prob
943      << (isEdgeHot(Src, Dst) ? " [HOT edge]\n" : "\n");
944 
945   return OS;
946 }
947 
948 void BranchProbabilityInfo::eraseBlock(const BasicBlock *BB) {
949   for (auto I = Probs.begin(), E = Probs.end(); I != E; ++I) {
950     auto Key = I->first;
951     if (Key.first == BB)
952       Probs.erase(Key);
953   }
954 }
955 
956 void BranchProbabilityInfo::calculate(const Function &F, const LoopInfo &LI,
957                                       const TargetLibraryInfo *TLI) {
958   LLVM_DEBUG(dbgs() << "---- Branch Probability Info : " << F.getName()
959                     << " ----\n\n");
960   LastF = &F; // Store the last function we ran on for printing.
961   assert(PostDominatedByUnreachable.empty());
962   assert(PostDominatedByColdCall.empty());
963 
964   // Record SCC numbers of blocks in the CFG to identify irreducible loops.
965   // FIXME: We could only calculate this if the CFG is known to be irreducible
966   // (perhaps cache this info in LoopInfo if we can easily calculate it there?).
967   int SccNum = 0;
968   SccInfo SccI;
969   for (scc_iterator<const Function *> It = scc_begin(&F); !It.isAtEnd();
970        ++It, ++SccNum) {
971     // Ignore single-block SCCs since they either aren't loops or LoopInfo will
972     // catch them.
973     const std::vector<const BasicBlock *> &Scc = *It;
974     if (Scc.size() == 1)
975       continue;
976 
977     LLVM_DEBUG(dbgs() << "BPI: SCC " << SccNum << ":");
978     for (auto *BB : Scc) {
979       LLVM_DEBUG(dbgs() << " " << BB->getName());
980       SccI.SccNums[BB] = SccNum;
981     }
982     LLVM_DEBUG(dbgs() << "\n");
983   }
984 
985   // Walk the basic blocks in post-order so that we can build up state about
986   // the successors of a block iteratively.
987   for (auto BB : post_order(&F.getEntryBlock())) {
988     LLVM_DEBUG(dbgs() << "Computing probabilities for " << BB->getName()
989                       << "\n");
990     updatePostDominatedByUnreachable(BB);
991     updatePostDominatedByColdCall(BB);
992     // If there is no at least two successors, no sense to set probability.
993     if (BB->getTerminator()->getNumSuccessors() < 2)
994       continue;
995     if (calcMetadataWeights(BB))
996       continue;
997     if (calcInvokeHeuristics(BB))
998       continue;
999     if (calcUnreachableHeuristics(BB))
1000       continue;
1001     if (calcColdCallHeuristics(BB))
1002       continue;
1003     if (calcLoopBranchHeuristics(BB, LI, SccI))
1004       continue;
1005     if (calcPointerHeuristics(BB))
1006       continue;
1007     if (calcZeroHeuristics(BB, TLI))
1008       continue;
1009     if (calcFloatingPointHeuristics(BB))
1010       continue;
1011   }
1012 
1013   PostDominatedByUnreachable.clear();
1014   PostDominatedByColdCall.clear();
1015 
1016   if (PrintBranchProb &&
1017       (PrintBranchProbFuncName.empty() ||
1018        F.getName().equals(PrintBranchProbFuncName))) {
1019     print(dbgs());
1020   }
1021 }
1022 
1023 void BranchProbabilityInfoWrapperPass::getAnalysisUsage(
1024     AnalysisUsage &AU) const {
1025   // We require DT so it's available when LI is available. The LI updating code
1026   // asserts that DT is also present so if we don't make sure that we have DT
1027   // here, that assert will trigger.
1028   AU.addRequired<DominatorTreeWrapperPass>();
1029   AU.addRequired<LoopInfoWrapperPass>();
1030   AU.addRequired<TargetLibraryInfoWrapperPass>();
1031   AU.setPreservesAll();
1032 }
1033 
1034 bool BranchProbabilityInfoWrapperPass::runOnFunction(Function &F) {
1035   const LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
1036   const TargetLibraryInfo &TLI =
1037       getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
1038   BPI.calculate(F, LI, &TLI);
1039   return false;
1040 }
1041 
1042 void BranchProbabilityInfoWrapperPass::releaseMemory() { BPI.releaseMemory(); }
1043 
1044 void BranchProbabilityInfoWrapperPass::print(raw_ostream &OS,
1045                                              const Module *) const {
1046   BPI.print(OS);
1047 }
1048 
1049 AnalysisKey BranchProbabilityAnalysis::Key;
1050 BranchProbabilityInfo
1051 BranchProbabilityAnalysis::run(Function &F, FunctionAnalysisManager &AM) {
1052   BranchProbabilityInfo BPI;
1053   BPI.calculate(F, AM.getResult<LoopAnalysis>(F), &AM.getResult<TargetLibraryAnalysis>(F));
1054   return BPI;
1055 }
1056 
1057 PreservedAnalyses
1058 BranchProbabilityPrinterPass::run(Function &F, FunctionAnalysisManager &AM) {
1059   OS << "Printing analysis results of BPI for function "
1060      << "'" << F.getName() << "':"
1061      << "\n";
1062   AM.getResult<BranchProbabilityAnalysis>(F).print(OS);
1063   return PreservedAnalyses::all();
1064 }
1065