1 //===- LowerExpectIntrinsic.cpp - Lower expect intrinsic ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass lowers the 'expect' intrinsic to LLVM metadata.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "llvm/Transforms/Scalar/LowerExpectIntrinsic.h"
14 #include "llvm/ADT/SmallVector.h"
15 #include "llvm/ADT/Statistic.h"
16 #include "llvm/ADT/iterator_range.h"
17 #include "llvm/IR/BasicBlock.h"
18 #include "llvm/IR/Constants.h"
19 #include "llvm/IR/Function.h"
20 #include "llvm/IR/Instructions.h"
21 #include "llvm/IR/Intrinsics.h"
22 #include "llvm/IR/LLVMContext.h"
23 #include "llvm/IR/MDBuilder.h"
24 #include "llvm/InitializePasses.h"
25 #include "llvm/Pass.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Transforms/Scalar.h"
28 
29 using namespace llvm;
30 
31 #define DEBUG_TYPE "lower-expect-intrinsic"
32 
33 STATISTIC(ExpectIntrinsicsHandled,
34           "Number of 'expect' intrinsic instructions handled");
35 
36 // These default values are chosen to represent an extremely skewed outcome for
37 // a condition, but they leave some room for interpretation by later passes.
38 //
39 // If the documentation for __builtin_expect() was made explicit that it should
40 // only be used in extreme cases, we could make this ratio higher. As it stands,
41 // programmers may be using __builtin_expect() / llvm.expect to annotate that a
42 // branch is likely or unlikely to be taken.
43 
44 // WARNING: these values are internal implementation detail of the pass.
45 // They should not be exposed to the outside of the pass, front-end codegen
46 // should emit @llvm.expect intrinsics instead of using these weights directly.
47 // Transforms should use TargetTransformInfo's getPredictableBranchThreshold().
48 static cl::opt<uint32_t> LikelyBranchWeight(
49     "likely-branch-weight", cl::Hidden, cl::init(2000),
50     cl::desc("Weight of the branch likely to be taken (default = 2000)"));
51 static cl::opt<uint32_t> UnlikelyBranchWeight(
52     "unlikely-branch-weight", cl::Hidden, cl::init(1),
53     cl::desc("Weight of the branch unlikely to be taken (default = 1)"));
54 
55 static std::tuple<uint32_t, uint32_t>
56 getBranchWeight(Intrinsic::ID IntrinsicID, CallInst *CI, int BranchCount) {
57   if (IntrinsicID == Intrinsic::expect) {
58     // __builtin_expect
59     return std::make_tuple(LikelyBranchWeight.getValue(),
60                            UnlikelyBranchWeight.getValue());
61   } else {
62     // __builtin_expect_with_probability
63     assert(CI->getNumOperands() >= 3 &&
64            "expect with probability must have 3 arguments");
65     auto *Confidence = cast<ConstantFP>(CI->getArgOperand(2));
66     double TrueProb = Confidence->getValueAPF().convertToDouble();
67     assert((TrueProb >= 0.0 && TrueProb <= 1.0) &&
68            "probability value must be in the range [0.0, 1.0]");
69     double FalseProb = (1.0 - TrueProb) / (BranchCount - 1);
70     uint32_t LikelyBW = ceil((TrueProb * (double)(INT32_MAX - 1)) + 1.0);
71     uint32_t UnlikelyBW = ceil((FalseProb * (double)(INT32_MAX - 1)) + 1.0);
72     return std::make_tuple(LikelyBW, UnlikelyBW);
73   }
74 }
75 
76 static bool handleSwitchExpect(SwitchInst &SI) {
77   CallInst *CI = dyn_cast<CallInst>(SI.getCondition());
78   if (!CI)
79     return false;
80 
81   Function *Fn = CI->getCalledFunction();
82   if (!Fn || (Fn->getIntrinsicID() != Intrinsic::expect &&
83               Fn->getIntrinsicID() != Intrinsic::expect_with_probability))
84     return false;
85 
86   Value *ArgValue = CI->getArgOperand(0);
87   ConstantInt *ExpectedValue = dyn_cast<ConstantInt>(CI->getArgOperand(1));
88   if (!ExpectedValue)
89     return false;
90 
91   SwitchInst::CaseHandle Case = *SI.findCaseValue(ExpectedValue);
92   unsigned n = SI.getNumCases(); // +1 for default case.
93   uint32_t LikelyBranchWeightVal, UnlikelyBranchWeightVal;
94   std::tie(LikelyBranchWeightVal, UnlikelyBranchWeightVal) =
95       getBranchWeight(Fn->getIntrinsicID(), CI, n + 1);
96 
97   SmallVector<uint32_t, 16> Weights(n + 1, UnlikelyBranchWeightVal);
98 
99   uint64_t Index = (Case == *SI.case_default()) ? 0 : Case.getCaseIndex() + 1;
100   Weights[Index] = LikelyBranchWeightVal;
101 
102   SI.setCondition(ArgValue);
103 
104   SI.setMetadata(LLVMContext::MD_prof,
105                  MDBuilder(CI->getContext()).createBranchWeights(Weights));
106 
107   return true;
108 }
109 
110 /// Handler for PHINodes that define the value argument to an
111 /// @llvm.expect call.
112 ///
113 /// If the operand of the phi has a constant value and it 'contradicts'
114 /// with the expected value of phi def, then the corresponding incoming
115 /// edge of the phi is unlikely to be taken. Using that information,
116 /// the branch probability info for the originating branch can be inferred.
117 static void handlePhiDef(CallInst *Expect) {
118   Value &Arg = *Expect->getArgOperand(0);
119   ConstantInt *ExpectedValue = dyn_cast<ConstantInt>(Expect->getArgOperand(1));
120   if (!ExpectedValue)
121     return;
122   const APInt &ExpectedPhiValue = ExpectedValue->getValue();
123 
124   // Walk up in backward a list of instructions that
125   // have 'copy' semantics by 'stripping' the copies
126   // until a PHI node or an instruction of unknown kind
127   // is reached. Negation via xor is also handled.
128   //
129   //       C = PHI(...);
130   //       B = C;
131   //       A = B;
132   //       D = __builtin_expect(A, 0);
133   //
134   Value *V = &Arg;
135   SmallVector<Instruction *, 4> Operations;
136   while (!isa<PHINode>(V)) {
137     if (ZExtInst *ZExt = dyn_cast<ZExtInst>(V)) {
138       V = ZExt->getOperand(0);
139       Operations.push_back(ZExt);
140       continue;
141     }
142 
143     if (SExtInst *SExt = dyn_cast<SExtInst>(V)) {
144       V = SExt->getOperand(0);
145       Operations.push_back(SExt);
146       continue;
147     }
148 
149     BinaryOperator *BinOp = dyn_cast<BinaryOperator>(V);
150     if (!BinOp || BinOp->getOpcode() != Instruction::Xor)
151       return;
152 
153     ConstantInt *CInt = dyn_cast<ConstantInt>(BinOp->getOperand(1));
154     if (!CInt)
155       return;
156 
157     V = BinOp->getOperand(0);
158     Operations.push_back(BinOp);
159   }
160 
161   // Executes the recorded operations on input 'Value'.
162   auto ApplyOperations = [&](const APInt &Value) {
163     APInt Result = Value;
164     for (auto Op : llvm::reverse(Operations)) {
165       switch (Op->getOpcode()) {
166       case Instruction::Xor:
167         Result ^= cast<ConstantInt>(Op->getOperand(1))->getValue();
168         break;
169       case Instruction::ZExt:
170         Result = Result.zext(Op->getType()->getIntegerBitWidth());
171         break;
172       case Instruction::SExt:
173         Result = Result.sext(Op->getType()->getIntegerBitWidth());
174         break;
175       default:
176         llvm_unreachable("Unexpected operation");
177       }
178     }
179     return Result;
180   };
181 
182   auto *PhiDef = cast<PHINode>(V);
183 
184   // Get the first dominating conditional branch of the operand
185   // i's incoming block.
186   auto GetDomConditional = [&](unsigned i) -> BranchInst * {
187     BasicBlock *BB = PhiDef->getIncomingBlock(i);
188     BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
189     if (BI && BI->isConditional())
190       return BI;
191     BB = BB->getSinglePredecessor();
192     if (!BB)
193       return nullptr;
194     BI = dyn_cast<BranchInst>(BB->getTerminator());
195     if (!BI || BI->isUnconditional())
196       return nullptr;
197     return BI;
198   };
199 
200   // Now walk through all Phi operands to find phi oprerands with values
201   // conflicting with the expected phi output value. Any such operand
202   // indicates the incoming edge to that operand is unlikely.
203   for (unsigned i = 0, e = PhiDef->getNumIncomingValues(); i != e; ++i) {
204 
205     Value *PhiOpnd = PhiDef->getIncomingValue(i);
206     ConstantInt *CI = dyn_cast<ConstantInt>(PhiOpnd);
207     if (!CI)
208       continue;
209 
210     // Not an interesting case when IsUnlikely is false -- we can not infer
211     // anything useful when the operand value matches the expected phi
212     // output.
213     if (ExpectedPhiValue == ApplyOperations(CI->getValue()))
214       continue;
215 
216     BranchInst *BI = GetDomConditional(i);
217     if (!BI)
218       continue;
219 
220     MDBuilder MDB(PhiDef->getContext());
221 
222     // There are two situations in which an operand of the PhiDef comes
223     // from a given successor of a branch instruction BI.
224     // 1) When the incoming block of the operand is the successor block;
225     // 2) When the incoming block is BI's enclosing block and the
226     // successor is the PhiDef's enclosing block.
227     //
228     // Returns true if the operand which comes from OpndIncomingBB
229     // comes from outgoing edge of BI that leads to Succ block.
230     auto *OpndIncomingBB = PhiDef->getIncomingBlock(i);
231     auto IsOpndComingFromSuccessor = [&](BasicBlock *Succ) {
232       if (OpndIncomingBB == Succ)
233         // If this successor is the incoming block for this
234         // Phi operand, then this successor does lead to the Phi.
235         return true;
236       if (OpndIncomingBB == BI->getParent() && Succ == PhiDef->getParent())
237         // Otherwise, if the edge is directly from the branch
238         // to the Phi, this successor is the one feeding this
239         // Phi operand.
240         return true;
241       return false;
242     };
243     uint32_t LikelyBranchWeightVal, UnlikelyBranchWeightVal;
244     std::tie(LikelyBranchWeightVal, UnlikelyBranchWeightVal) = getBranchWeight(
245         Expect->getCalledFunction()->getIntrinsicID(), Expect, 2);
246 
247     if (IsOpndComingFromSuccessor(BI->getSuccessor(1)))
248       BI->setMetadata(LLVMContext::MD_prof,
249                       MDB.createBranchWeights(LikelyBranchWeightVal,
250                                               UnlikelyBranchWeightVal));
251     else if (IsOpndComingFromSuccessor(BI->getSuccessor(0)))
252       BI->setMetadata(LLVMContext::MD_prof,
253                       MDB.createBranchWeights(UnlikelyBranchWeightVal,
254                                               LikelyBranchWeightVal));
255   }
256 }
257 
258 // Handle both BranchInst and SelectInst.
259 template <class BrSelInst> static bool handleBrSelExpect(BrSelInst &BSI) {
260 
261   // Handle non-optimized IR code like:
262   //   %expval = call i64 @llvm.expect.i64(i64 %conv1, i64 1)
263   //   %tobool = icmp ne i64 %expval, 0
264   //   br i1 %tobool, label %if.then, label %if.end
265   //
266   // Or the following simpler case:
267   //   %expval = call i1 @llvm.expect.i1(i1 %cmp, i1 1)
268   //   br i1 %expval, label %if.then, label %if.end
269 
270   CallInst *CI;
271 
272   ICmpInst *CmpI = dyn_cast<ICmpInst>(BSI.getCondition());
273   CmpInst::Predicate Predicate;
274   ConstantInt *CmpConstOperand = nullptr;
275   if (!CmpI) {
276     CI = dyn_cast<CallInst>(BSI.getCondition());
277     Predicate = CmpInst::ICMP_NE;
278   } else {
279     Predicate = CmpI->getPredicate();
280     if (Predicate != CmpInst::ICMP_NE && Predicate != CmpInst::ICMP_EQ)
281       return false;
282 
283     CmpConstOperand = dyn_cast<ConstantInt>(CmpI->getOperand(1));
284     if (!CmpConstOperand)
285       return false;
286     CI = dyn_cast<CallInst>(CmpI->getOperand(0));
287   }
288 
289   if (!CI)
290     return false;
291 
292   uint64_t ValueComparedTo = 0;
293   if (CmpConstOperand) {
294     if (CmpConstOperand->getBitWidth() > 64)
295       return false;
296     ValueComparedTo = CmpConstOperand->getZExtValue();
297   }
298 
299   Function *Fn = CI->getCalledFunction();
300   if (!Fn || (Fn->getIntrinsicID() != Intrinsic::expect &&
301               Fn->getIntrinsicID() != Intrinsic::expect_with_probability))
302     return false;
303 
304   Value *ArgValue = CI->getArgOperand(0);
305   ConstantInt *ExpectedValue = dyn_cast<ConstantInt>(CI->getArgOperand(1));
306   if (!ExpectedValue)
307     return false;
308 
309   MDBuilder MDB(CI->getContext());
310   MDNode *Node;
311 
312   uint32_t LikelyBranchWeightVal, UnlikelyBranchWeightVal;
313   std::tie(LikelyBranchWeightVal, UnlikelyBranchWeightVal) =
314       getBranchWeight(Fn->getIntrinsicID(), CI, 2);
315 
316   if ((ExpectedValue->getZExtValue() == ValueComparedTo) ==
317       (Predicate == CmpInst::ICMP_EQ)) {
318     Node =
319         MDB.createBranchWeights(LikelyBranchWeightVal, UnlikelyBranchWeightVal);
320   } else {
321     Node =
322         MDB.createBranchWeights(UnlikelyBranchWeightVal, LikelyBranchWeightVal);
323   }
324 
325   if (CmpI)
326     CmpI->setOperand(0, ArgValue);
327   else
328     BSI.setCondition(ArgValue);
329 
330   BSI.setMetadata(LLVMContext::MD_prof, Node);
331 
332   return true;
333 }
334 
335 static bool handleBranchExpect(BranchInst &BI) {
336   if (BI.isUnconditional())
337     return false;
338 
339   return handleBrSelExpect<BranchInst>(BI);
340 }
341 
342 static bool lowerExpectIntrinsic(Function &F) {
343   bool Changed = false;
344 
345   for (BasicBlock &BB : F) {
346     // Create "block_weights" metadata.
347     if (BranchInst *BI = dyn_cast<BranchInst>(BB.getTerminator())) {
348       if (handleBranchExpect(*BI))
349         ExpectIntrinsicsHandled++;
350     } else if (SwitchInst *SI = dyn_cast<SwitchInst>(BB.getTerminator())) {
351       if (handleSwitchExpect(*SI))
352         ExpectIntrinsicsHandled++;
353     }
354 
355     // Remove llvm.expect intrinsics. Iterate backwards in order
356     // to process select instructions before the intrinsic gets
357     // removed.
358     for (Instruction &Inst : llvm::make_early_inc_range(llvm::reverse(BB))) {
359       CallInst *CI = dyn_cast<CallInst>(&Inst);
360       if (!CI) {
361         if (SelectInst *SI = dyn_cast<SelectInst>(&Inst)) {
362           if (handleBrSelExpect(*SI))
363             ExpectIntrinsicsHandled++;
364         }
365         continue;
366       }
367 
368       Function *Fn = CI->getCalledFunction();
369       if (Fn && (Fn->getIntrinsicID() == Intrinsic::expect ||
370                  Fn->getIntrinsicID() == Intrinsic::expect_with_probability)) {
371         // Before erasing the llvm.expect, walk backward to find
372         // phi that define llvm.expect's first arg, and
373         // infer branch probability:
374         handlePhiDef(CI);
375         Value *Exp = CI->getArgOperand(0);
376         CI->replaceAllUsesWith(Exp);
377         CI->eraseFromParent();
378         Changed = true;
379       }
380     }
381   }
382 
383   return Changed;
384 }
385 
386 PreservedAnalyses LowerExpectIntrinsicPass::run(Function &F,
387                                                 FunctionAnalysisManager &) {
388   if (lowerExpectIntrinsic(F))
389     return PreservedAnalyses::none();
390 
391   return PreservedAnalyses::all();
392 }
393 
394 namespace {
395 /// Legacy pass for lowering expect intrinsics out of the IR.
396 ///
397 /// When this pass is run over a function it uses expect intrinsics which feed
398 /// branches and switches to provide branch weight metadata for those
399 /// terminators. It then removes the expect intrinsics from the IR so the rest
400 /// of the optimizer can ignore them.
401 class LowerExpectIntrinsic : public FunctionPass {
402 public:
403   static char ID;
404   LowerExpectIntrinsic() : FunctionPass(ID) {
405     initializeLowerExpectIntrinsicPass(*PassRegistry::getPassRegistry());
406   }
407 
408   bool runOnFunction(Function &F) override { return lowerExpectIntrinsic(F); }
409 };
410 }
411 
412 char LowerExpectIntrinsic::ID = 0;
413 INITIALIZE_PASS(LowerExpectIntrinsic, "lower-expect",
414                 "Lower 'expect' Intrinsics", false, false)
415 
416 FunctionPass *llvm::createLowerExpectIntrinsicPass() {
417   return new LowerExpectIntrinsic();
418 }
419