1 //===--- MisExpect.cpp - Check the use of llvm.expect with PGO data -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This contains code to emit warnings for potentially incorrect usage of the
10 // llvm.expect intrinsic. This utility extracts the threshold values from
11 // metadata associated with the instrumented Branch or Switch instruction. The
12 // threshold values are then used to determine if a warning should be emmited.
13 //
14 // MisExpect's implementation relies on two assumptions about how branch weights
15 // are managed in LLVM.
16 //
17 // 1) Frontend profiling weights are always in place before llvm.expect is
18 // lowered in LowerExpectIntrinsic.cpp. Frontend based instrumentation therefore
19 // needs to extract the branch weights and then compare them to the weights
20 // being added by the llvm.expect intrinsic lowering.
21 //
22 // 2) Sampling and IR based profiles will *only* have branch weight metadata
23 // before profiling data is consulted if they are from a lowered llvm.expect
24 // intrinsic. These profiles thus always extract the expected weights and then
25 // compare them to the weights collected during profiling to determine if a
26 // diagnostic message is warranted.
27 //
28 //===----------------------------------------------------------------------===//
29
30 #include "llvm/Transforms/Utils/MisExpect.h"
31 #include "llvm/ADT/Twine.h"
32 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
33 #include "llvm/IR/Constants.h"
34 #include "llvm/IR/DiagnosticInfo.h"
35 #include "llvm/IR/Instruction.h"
36 #include "llvm/IR/Instructions.h"
37 #include "llvm/IR/LLVMContext.h"
38 #include "llvm/Support/BranchProbability.h"
39 #include "llvm/Support/CommandLine.h"
40 #include "llvm/Support/Debug.h"
41 #include "llvm/Support/FormatVariadic.h"
42 #include <cstdint>
43 #include <functional>
44 #include <numeric>
45
46 #define DEBUG_TYPE "misexpect"
47
48 using namespace llvm;
49 using namespace misexpect;
50
51 namespace llvm {
52
53 // Command line option to enable/disable the warning when profile data suggests
54 // a mismatch with the use of the llvm.expect intrinsic
55 static cl::opt<bool> PGOWarnMisExpect(
56 "pgo-warn-misexpect", cl::init(false), cl::Hidden,
57 cl::desc("Use this option to turn on/off "
58 "warnings about incorrect usage of llvm.expect intrinsics."));
59
60 static cl::opt<unsigned> MisExpectTolerance(
61 "misexpect-tolerance", cl::init(0),
62 cl::desc("Prevents emiting diagnostics when profile counts are "
63 "within N% of the threshold.."));
64
65 } // namespace llvm
66
67 namespace {
68
isMisExpectDiagEnabled(LLVMContext & Ctx)69 bool isMisExpectDiagEnabled(LLVMContext &Ctx) {
70 return PGOWarnMisExpect || Ctx.getMisExpectWarningRequested();
71 }
72
getMisExpectTolerance(LLVMContext & Ctx)73 uint64_t getMisExpectTolerance(LLVMContext &Ctx) {
74 return std::max(static_cast<uint64_t>(MisExpectTolerance),
75 Ctx.getDiagnosticsMisExpectTolerance());
76 }
77
getInstCondition(Instruction * I)78 Instruction *getInstCondition(Instruction *I) {
79 assert(I != nullptr && "MisExpect target Instruction cannot be nullptr");
80 Instruction *Ret = nullptr;
81 if (auto *B = dyn_cast<BranchInst>(I)) {
82 Ret = dyn_cast<Instruction>(B->getCondition());
83 }
84 // TODO: Find a way to resolve condition location for switches
85 // Using the condition of the switch seems to often resolve to an earlier
86 // point in the program, i.e. the calculation of the switch condition, rather
87 // than the switch's location in the source code. Thus, we should use the
88 // instruction to get source code locations rather than the condition to
89 // improve diagnostic output, such as the caret. If the same problem exists
90 // for branch instructions, then we should remove this function and directly
91 // use the instruction
92 //
93 else if (auto *S = dyn_cast<SwitchInst>(I)) {
94 Ret = dyn_cast<Instruction>(S->getCondition());
95 }
96 return Ret ? Ret : I;
97 }
98
emitMisexpectDiagnostic(Instruction * I,LLVMContext & Ctx,uint64_t ProfCount,uint64_t TotalCount)99 void emitMisexpectDiagnostic(Instruction *I, LLVMContext &Ctx,
100 uint64_t ProfCount, uint64_t TotalCount) {
101 double PercentageCorrect = (double)ProfCount / TotalCount;
102 auto PerString =
103 formatv("{0:P} ({1} / {2})", PercentageCorrect, ProfCount, TotalCount);
104 auto RemStr = formatv(
105 "Potential performance regression from use of the llvm.expect intrinsic: "
106 "Annotation was correct on {0} of profiled executions.",
107 PerString);
108 Twine Msg(PerString);
109 Instruction *Cond = getInstCondition(I);
110 if (isMisExpectDiagEnabled(Ctx))
111 Ctx.diagnose(DiagnosticInfoMisExpect(Cond, Msg));
112 OptimizationRemarkEmitter ORE(I->getParent()->getParent());
113 ORE.emit(OptimizationRemark(DEBUG_TYPE, "misexpect", Cond) << RemStr.str());
114 }
115
116 } // namespace
117
118 namespace llvm {
119 namespace misexpect {
120
121 // Helper function to extract branch weights into a vector
extractWeights(Instruction * I,LLVMContext & Ctx)122 Optional<SmallVector<uint32_t, 4>> extractWeights(Instruction *I,
123 LLVMContext &Ctx) {
124 assert(I && "MisExpect::extractWeights given invalid pointer");
125
126 auto *ProfileData = I->getMetadata(LLVMContext::MD_prof);
127 if (!ProfileData)
128 return None;
129
130 unsigned NOps = ProfileData->getNumOperands();
131 if (NOps < 3)
132 return None;
133
134 auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
135 if (!ProfDataName || !ProfDataName->getString().equals("branch_weights"))
136 return None;
137
138 SmallVector<uint32_t, 4> Weights(NOps - 1);
139 for (unsigned Idx = 1; Idx < NOps; Idx++) {
140 ConstantInt *Value =
141 mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
142 uint32_t V = Value->getZExtValue();
143 Weights[Idx - 1] = V;
144 }
145
146 return Weights;
147 }
148
149 // TODO: when clang allows c++17, use std::clamp instead
clamp(uint64_t value,uint32_t low,uint32_t hi)150 uint32_t clamp(uint64_t value, uint32_t low, uint32_t hi) {
151 if (value > hi)
152 return hi;
153 if (value < low)
154 return low;
155 return value;
156 }
157
verifyMisExpect(Instruction & I,ArrayRef<uint32_t> RealWeights,ArrayRef<uint32_t> ExpectedWeights)158 void verifyMisExpect(Instruction &I, ArrayRef<uint32_t> RealWeights,
159 ArrayRef<uint32_t> ExpectedWeights) {
160 // To determine if we emit a diagnostic, we need to compare the branch weights
161 // from the profile to those added by the llvm.expect intrinsic.
162 // So first, we extract the "likely" and "unlikely" weights from
163 // ExpectedWeights And determine the correct weight in the profile to compare
164 // against.
165 uint64_t LikelyBranchWeight = 0,
166 UnlikelyBranchWeight = std::numeric_limits<uint32_t>::max();
167 size_t MaxIndex = 0;
168 for (size_t Idx = 0, End = ExpectedWeights.size(); Idx < End; Idx++) {
169 uint32_t V = ExpectedWeights[Idx];
170 if (LikelyBranchWeight < V) {
171 LikelyBranchWeight = V;
172 MaxIndex = Idx;
173 }
174 if (UnlikelyBranchWeight > V) {
175 UnlikelyBranchWeight = V;
176 }
177 }
178
179 const uint64_t ProfiledWeight = RealWeights[MaxIndex];
180 const uint64_t RealWeightsTotal =
181 std::accumulate(RealWeights.begin(), RealWeights.end(), (uint64_t)0,
182 std::plus<uint64_t>());
183 const uint64_t NumUnlikelyTargets = RealWeights.size() - 1;
184
185 uint64_t TotalBranchWeight =
186 LikelyBranchWeight + (UnlikelyBranchWeight * NumUnlikelyTargets);
187
188 // FIXME: When we've addressed sample profiling, restore the assertion
189 //
190 // We cannot calculate branch probability if either of these invariants aren't
191 // met. However, MisExpect diagnostics should not prevent code from compiling,
192 // so we simply forgo emitting diagnostics here, and return early.
193 if ((TotalBranchWeight == 0) || (TotalBranchWeight <= LikelyBranchWeight))
194 return;
195
196 // To determine our threshold value we need to obtain the branch probability
197 // for the weights added by llvm.expect and use that proportion to calculate
198 // our threshold based on the collected profile data.
199 auto LikelyProbablilty = BranchProbability::getBranchProbability(
200 LikelyBranchWeight, TotalBranchWeight);
201
202 uint64_t ScaledThreshold = LikelyProbablilty.scale(RealWeightsTotal);
203
204 // clamp tolerance range to [0, 100)
205 auto Tolerance = getMisExpectTolerance(I.getContext());
206 Tolerance = clamp(Tolerance, 0, 99);
207
208 // Allow users to relax checking by N% i.e., if they use a 5% tolerance,
209 // then we check against 0.95*ScaledThreshold
210 if (Tolerance > 0)
211 ScaledThreshold *= (1.0 - Tolerance / 100.0);
212
213 // When the profile weight is below the threshold, we emit the diagnostic
214 if (ProfiledWeight < ScaledThreshold)
215 emitMisexpectDiagnostic(&I, I.getContext(), ProfiledWeight,
216 RealWeightsTotal);
217 }
218
checkBackendInstrumentation(Instruction & I,const ArrayRef<uint32_t> RealWeights)219 void checkBackendInstrumentation(Instruction &I,
220 const ArrayRef<uint32_t> RealWeights) {
221 auto ExpectedWeightsOpt = extractWeights(&I, I.getContext());
222 if (!ExpectedWeightsOpt)
223 return;
224 auto ExpectedWeights = ExpectedWeightsOpt.value();
225 verifyMisExpect(I, RealWeights, ExpectedWeights);
226 }
227
checkFrontendInstrumentation(Instruction & I,const ArrayRef<uint32_t> ExpectedWeights)228 void checkFrontendInstrumentation(Instruction &I,
229 const ArrayRef<uint32_t> ExpectedWeights) {
230 auto RealWeightsOpt = extractWeights(&I, I.getContext());
231 if (!RealWeightsOpt)
232 return;
233 auto RealWeights = RealWeightsOpt.value();
234 verifyMisExpect(I, RealWeights, ExpectedWeights);
235 }
236
checkExpectAnnotations(Instruction & I,const ArrayRef<uint32_t> ExistingWeights,bool IsFrontendInstr)237 void checkExpectAnnotations(Instruction &I,
238 const ArrayRef<uint32_t> ExistingWeights,
239 bool IsFrontendInstr) {
240 if (IsFrontendInstr) {
241 checkFrontendInstrumentation(I, ExistingWeights);
242 } else {
243 checkBackendInstrumentation(I, ExistingWeights);
244 }
245 }
246
247 } // namespace misexpect
248 } // namespace llvm
249 #undef DEBUG_TYPE
250