1bac6cd5bSPaul Kirth //===--- MisExpect.cpp - Check the use of llvm.expect with PGO data -------===//
2bac6cd5bSPaul Kirth //
3bac6cd5bSPaul Kirth // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4bac6cd5bSPaul Kirth // See https://llvm.org/LICENSE.txt for license information.
5bac6cd5bSPaul Kirth // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6bac6cd5bSPaul Kirth //
7bac6cd5bSPaul Kirth //===----------------------------------------------------------------------===//
8bac6cd5bSPaul Kirth //
9bac6cd5bSPaul Kirth // This contains code to emit warnings for potentially incorrect usage of the
10bac6cd5bSPaul Kirth // llvm.expect intrinsic. This utility extracts the threshold values from
11bac6cd5bSPaul Kirth // metadata associated with the instrumented Branch or Switch instruction. The
12bac6cd5bSPaul Kirth // threshold values are then used to determine if a warning should be emmited.
13bac6cd5bSPaul Kirth //
14bac6cd5bSPaul Kirth // MisExpect's implementation relies on two assumptions about how branch weights
15bac6cd5bSPaul Kirth // are managed in LLVM.
16bac6cd5bSPaul Kirth //
17bac6cd5bSPaul Kirth // 1) Frontend profiling weights are always in place before llvm.expect is
18bac6cd5bSPaul Kirth // lowered in LowerExpectIntrinsic.cpp. Frontend based instrumentation therefore
19bac6cd5bSPaul Kirth // needs to extract the branch weights and then compare them to the weights
20bac6cd5bSPaul Kirth // being added by the llvm.expect intrinsic lowering.
21bac6cd5bSPaul Kirth //
22bac6cd5bSPaul Kirth // 2) Sampling and IR based profiles will *only* have branch weight metadata
23bac6cd5bSPaul Kirth // before profiling data is consulted if they are from a lowered llvm.expect
24bac6cd5bSPaul Kirth // intrinsic. These profiles thus always extract the expected weights and then
25bac6cd5bSPaul Kirth // compare them to the weights collected during profiling to determine if a
26bac6cd5bSPaul Kirth // diagnostic message is warranted.
27bac6cd5bSPaul Kirth //
28bac6cd5bSPaul Kirth //===----------------------------------------------------------------------===//
29bac6cd5bSPaul Kirth 
30bac6cd5bSPaul Kirth #include "llvm/Transforms/Utils/MisExpect.h"
31bac6cd5bSPaul Kirth #include "llvm/ADT/Twine.h"
32bac6cd5bSPaul Kirth #include "llvm/Analysis/OptimizationRemarkEmitter.h"
33bac6cd5bSPaul Kirth #include "llvm/IR/Constants.h"
34bac6cd5bSPaul Kirth #include "llvm/IR/DiagnosticInfo.h"
35bac6cd5bSPaul Kirth #include "llvm/IR/Instruction.h"
36bac6cd5bSPaul Kirth #include "llvm/IR/Instructions.h"
37bac6cd5bSPaul Kirth #include "llvm/IR/LLVMContext.h"
38bac6cd5bSPaul Kirth #include "llvm/Support/BranchProbability.h"
39bac6cd5bSPaul Kirth #include "llvm/Support/CommandLine.h"
40bac6cd5bSPaul Kirth #include "llvm/Support/Debug.h"
41bac6cd5bSPaul Kirth #include "llvm/Support/FormatVariadic.h"
42bac6cd5bSPaul Kirth #include <cstdint>
43bac6cd5bSPaul Kirth #include <functional>
44bac6cd5bSPaul Kirth #include <numeric>
45bac6cd5bSPaul Kirth 
46bac6cd5bSPaul Kirth #define DEBUG_TYPE "misexpect"
47bac6cd5bSPaul Kirth 
48bac6cd5bSPaul Kirth using namespace llvm;
49bac6cd5bSPaul Kirth using namespace misexpect;
50bac6cd5bSPaul Kirth 
51bac6cd5bSPaul Kirth namespace llvm {
52bac6cd5bSPaul Kirth 
53bac6cd5bSPaul Kirth // Command line option to enable/disable the warning when profile data suggests
54bac6cd5bSPaul Kirth // a mismatch with the use of the llvm.expect intrinsic
55bac6cd5bSPaul Kirth static cl::opt<bool> PGOWarnMisExpect(
56bac6cd5bSPaul Kirth     "pgo-warn-misexpect", cl::init(false), cl::Hidden,
57bac6cd5bSPaul Kirth     cl::desc("Use this option to turn on/off "
58bac6cd5bSPaul Kirth              "warnings about incorrect usage of llvm.expect intrinsics."));
59bac6cd5bSPaul Kirth 
60bac6cd5bSPaul Kirth static cl::opt<unsigned> MisExpectTolerance(
61bac6cd5bSPaul Kirth     "misexpect-tolerance", cl::init(0),
62bac6cd5bSPaul Kirth     cl::desc("Prevents emiting diagnostics when profile counts are "
63bac6cd5bSPaul Kirth              "within N% of the threshold.."));
64bac6cd5bSPaul Kirth 
65bac6cd5bSPaul Kirth } // namespace llvm
66bac6cd5bSPaul Kirth 
67bac6cd5bSPaul Kirth namespace {
68bac6cd5bSPaul Kirth 
isMisExpectDiagEnabled(LLVMContext & Ctx)69bac6cd5bSPaul Kirth bool isMisExpectDiagEnabled(LLVMContext &Ctx) {
70bac6cd5bSPaul Kirth   return PGOWarnMisExpect || Ctx.getMisExpectWarningRequested();
71bac6cd5bSPaul Kirth }
72bac6cd5bSPaul Kirth 
getMisExpectTolerance(LLVMContext & Ctx)73bac6cd5bSPaul Kirth uint64_t getMisExpectTolerance(LLVMContext &Ctx) {
74bac6cd5bSPaul Kirth   return std::max(static_cast<uint64_t>(MisExpectTolerance),
75bac6cd5bSPaul Kirth                   Ctx.getDiagnosticsMisExpectTolerance());
76bac6cd5bSPaul Kirth }
77bac6cd5bSPaul Kirth 
getInstCondition(Instruction * I)78bac6cd5bSPaul Kirth Instruction *getInstCondition(Instruction *I) {
79bac6cd5bSPaul Kirth   assert(I != nullptr && "MisExpect target Instruction cannot be nullptr");
80bac6cd5bSPaul Kirth   Instruction *Ret = nullptr;
81bac6cd5bSPaul Kirth   if (auto *B = dyn_cast<BranchInst>(I)) {
82bac6cd5bSPaul Kirth     Ret = dyn_cast<Instruction>(B->getCondition());
83bac6cd5bSPaul Kirth   }
84bac6cd5bSPaul Kirth   // TODO: Find a way to resolve condition location for switches
85bac6cd5bSPaul Kirth   // Using the condition of the switch seems to often resolve to an earlier
86bac6cd5bSPaul Kirth   // point in the program, i.e. the calculation of the switch condition, rather
87bac6cd5bSPaul Kirth   // than the switch's location in the source code. Thus, we should use the
88bac6cd5bSPaul Kirth   // instruction to get source code locations rather than the condition to
89bac6cd5bSPaul Kirth   // improve diagnostic output, such as the caret. If the same problem exists
90bac6cd5bSPaul Kirth   // for branch instructions, then we should remove this function and directly
91bac6cd5bSPaul Kirth   // use the instruction
92bac6cd5bSPaul Kirth   //
93bac6cd5bSPaul Kirth   else if (auto *S = dyn_cast<SwitchInst>(I)) {
94bac6cd5bSPaul Kirth     Ret = dyn_cast<Instruction>(S->getCondition());
95bac6cd5bSPaul Kirth   }
96bac6cd5bSPaul Kirth   return Ret ? Ret : I;
97bac6cd5bSPaul Kirth }
98bac6cd5bSPaul Kirth 
emitMisexpectDiagnostic(Instruction * I,LLVMContext & Ctx,uint64_t ProfCount,uint64_t TotalCount)99bac6cd5bSPaul Kirth void emitMisexpectDiagnostic(Instruction *I, LLVMContext &Ctx,
100bac6cd5bSPaul Kirth                              uint64_t ProfCount, uint64_t TotalCount) {
101bac6cd5bSPaul Kirth   double PercentageCorrect = (double)ProfCount / TotalCount;
102bac6cd5bSPaul Kirth   auto PerString =
103bac6cd5bSPaul Kirth       formatv("{0:P} ({1} / {2})", PercentageCorrect, ProfCount, TotalCount);
104bac6cd5bSPaul Kirth   auto RemStr = formatv(
105bac6cd5bSPaul Kirth       "Potential performance regression from use of the llvm.expect intrinsic: "
106bac6cd5bSPaul Kirth       "Annotation was correct on {0} of profiled executions.",
107bac6cd5bSPaul Kirth       PerString);
108bac6cd5bSPaul Kirth   Twine Msg(PerString);
109bac6cd5bSPaul Kirth   Instruction *Cond = getInstCondition(I);
110bac6cd5bSPaul Kirth   if (isMisExpectDiagEnabled(Ctx))
111bac6cd5bSPaul Kirth     Ctx.diagnose(DiagnosticInfoMisExpect(Cond, Msg));
112bac6cd5bSPaul Kirth   OptimizationRemarkEmitter ORE(I->getParent()->getParent());
113bac6cd5bSPaul Kirth   ORE.emit(OptimizationRemark(DEBUG_TYPE, "misexpect", Cond) << RemStr.str());
114bac6cd5bSPaul Kirth }
115bac6cd5bSPaul Kirth 
116bac6cd5bSPaul Kirth } // namespace
117bac6cd5bSPaul Kirth 
118bac6cd5bSPaul Kirth namespace llvm {
119bac6cd5bSPaul Kirth namespace misexpect {
120bac6cd5bSPaul Kirth 
121bac6cd5bSPaul Kirth // Helper function to extract branch weights into a vector
extractWeights(Instruction * I,LLVMContext & Ctx)122bac6cd5bSPaul Kirth Optional<SmallVector<uint32_t, 4>> extractWeights(Instruction *I,
123bac6cd5bSPaul Kirth                                                   LLVMContext &Ctx) {
124bac6cd5bSPaul Kirth   assert(I && "MisExpect::extractWeights given invalid pointer");
125bac6cd5bSPaul Kirth 
126bac6cd5bSPaul Kirth   auto *ProfileData = I->getMetadata(LLVMContext::MD_prof);
127bac6cd5bSPaul Kirth   if (!ProfileData)
128bac6cd5bSPaul Kirth     return None;
129bac6cd5bSPaul Kirth 
130bac6cd5bSPaul Kirth   unsigned NOps = ProfileData->getNumOperands();
131bac6cd5bSPaul Kirth   if (NOps < 3)
132bac6cd5bSPaul Kirth     return None;
133bac6cd5bSPaul Kirth 
134bac6cd5bSPaul Kirth   auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
135bac6cd5bSPaul Kirth   if (!ProfDataName || !ProfDataName->getString().equals("branch_weights"))
136bac6cd5bSPaul Kirth     return None;
137bac6cd5bSPaul Kirth 
138bac6cd5bSPaul Kirth   SmallVector<uint32_t, 4> Weights(NOps - 1);
139bac6cd5bSPaul Kirth   for (unsigned Idx = 1; Idx < NOps; Idx++) {
140bac6cd5bSPaul Kirth     ConstantInt *Value =
141bac6cd5bSPaul Kirth         mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
142bac6cd5bSPaul Kirth     uint32_t V = Value->getZExtValue();
143bac6cd5bSPaul Kirth     Weights[Idx - 1] = V;
144bac6cd5bSPaul Kirth   }
145bac6cd5bSPaul Kirth 
146bac6cd5bSPaul Kirth   return Weights;
147bac6cd5bSPaul Kirth }
148bac6cd5bSPaul Kirth 
149bac6cd5bSPaul Kirth // TODO: when clang allows c++17, use std::clamp instead
clamp(uint64_t value,uint32_t low,uint32_t hi)150bac6cd5bSPaul Kirth uint32_t clamp(uint64_t value, uint32_t low, uint32_t hi) {
151bac6cd5bSPaul Kirth   if (value > hi)
152bac6cd5bSPaul Kirth     return hi;
153bac6cd5bSPaul Kirth   if (value < low)
154bac6cd5bSPaul Kirth     return low;
155bac6cd5bSPaul Kirth   return value;
156bac6cd5bSPaul Kirth }
157bac6cd5bSPaul Kirth 
verifyMisExpect(Instruction & I,ArrayRef<uint32_t> RealWeights,ArrayRef<uint32_t> ExpectedWeights)158bac6cd5bSPaul Kirth void verifyMisExpect(Instruction &I, ArrayRef<uint32_t> RealWeights,
159bac6cd5bSPaul Kirth                      ArrayRef<uint32_t> ExpectedWeights) {
160bac6cd5bSPaul Kirth   // To determine if we emit a diagnostic, we need to compare the branch weights
161bac6cd5bSPaul Kirth   // from the profile to those added by the llvm.expect intrinsic.
162bac6cd5bSPaul Kirth   // So first, we extract the "likely" and "unlikely" weights from
163bac6cd5bSPaul Kirth   // ExpectedWeights And determine the correct weight in the profile to compare
164bac6cd5bSPaul Kirth   // against.
165bac6cd5bSPaul Kirth   uint64_t LikelyBranchWeight = 0,
166bac6cd5bSPaul Kirth            UnlikelyBranchWeight = std::numeric_limits<uint32_t>::max();
167bac6cd5bSPaul Kirth   size_t MaxIndex = 0;
168bac6cd5bSPaul Kirth   for (size_t Idx = 0, End = ExpectedWeights.size(); Idx < End; Idx++) {
169bac6cd5bSPaul Kirth     uint32_t V = ExpectedWeights[Idx];
170bac6cd5bSPaul Kirth     if (LikelyBranchWeight < V) {
171bac6cd5bSPaul Kirth       LikelyBranchWeight = V;
172bac6cd5bSPaul Kirth       MaxIndex = Idx;
173bac6cd5bSPaul Kirth     }
174bac6cd5bSPaul Kirth     if (UnlikelyBranchWeight > V) {
175bac6cd5bSPaul Kirth       UnlikelyBranchWeight = V;
176bac6cd5bSPaul Kirth     }
177bac6cd5bSPaul Kirth   }
178bac6cd5bSPaul Kirth 
179bac6cd5bSPaul Kirth   const uint64_t ProfiledWeight = RealWeights[MaxIndex];
180bac6cd5bSPaul Kirth   const uint64_t RealWeightsTotal =
181bac6cd5bSPaul Kirth       std::accumulate(RealWeights.begin(), RealWeights.end(), (uint64_t)0,
182bac6cd5bSPaul Kirth                       std::plus<uint64_t>());
183bac6cd5bSPaul Kirth   const uint64_t NumUnlikelyTargets = RealWeights.size() - 1;
184bac6cd5bSPaul Kirth 
185bac6cd5bSPaul Kirth   uint64_t TotalBranchWeight =
186bac6cd5bSPaul Kirth       LikelyBranchWeight + (UnlikelyBranchWeight * NumUnlikelyTargets);
187bac6cd5bSPaul Kirth 
1884683a2efSPaul Kirth   // FIXME: When we've addressed sample profiling, restore the assertion
1894683a2efSPaul Kirth   //
1904683a2efSPaul Kirth   // We cannot calculate branch probability if either of these invariants aren't
1914683a2efSPaul Kirth   // met. However, MisExpect diagnostics should not prevent code from compiling,
1924683a2efSPaul Kirth   // so we simply forgo emitting diagnostics here, and return early.
1934683a2efSPaul Kirth   if ((TotalBranchWeight == 0) || (TotalBranchWeight <= LikelyBranchWeight))
1944683a2efSPaul Kirth     return;
195bac6cd5bSPaul Kirth 
196bac6cd5bSPaul Kirth   // To determine our threshold value we need to obtain the branch probability
197bac6cd5bSPaul Kirth   // for the weights added by llvm.expect and use that proportion to calculate
198bac6cd5bSPaul Kirth   // our threshold based on the collected profile data.
199bac6cd5bSPaul Kirth   auto LikelyProbablilty = BranchProbability::getBranchProbability(
200bac6cd5bSPaul Kirth       LikelyBranchWeight, TotalBranchWeight);
201bac6cd5bSPaul Kirth 
202bac6cd5bSPaul Kirth   uint64_t ScaledThreshold = LikelyProbablilty.scale(RealWeightsTotal);
203bac6cd5bSPaul Kirth 
204bac6cd5bSPaul Kirth   // clamp tolerance range to [0, 100)
205bac6cd5bSPaul Kirth   auto Tolerance = getMisExpectTolerance(I.getContext());
206bac6cd5bSPaul Kirth   Tolerance = clamp(Tolerance, 0, 99);
207bac6cd5bSPaul Kirth 
208bac6cd5bSPaul Kirth   // Allow users to relax checking by N%  i.e., if they use a 5% tolerance,
209bac6cd5bSPaul Kirth   // then we check against 0.95*ScaledThreshold
210bac6cd5bSPaul Kirth   if (Tolerance > 0)
211bac6cd5bSPaul Kirth     ScaledThreshold *= (1.0 - Tolerance / 100.0);
212bac6cd5bSPaul Kirth 
213bac6cd5bSPaul Kirth   // When the profile weight is below the threshold, we emit the diagnostic
214bac6cd5bSPaul Kirth   if (ProfiledWeight < ScaledThreshold)
215bac6cd5bSPaul Kirth     emitMisexpectDiagnostic(&I, I.getContext(), ProfiledWeight,
216bac6cd5bSPaul Kirth                             RealWeightsTotal);
217bac6cd5bSPaul Kirth }
218bac6cd5bSPaul Kirth 
checkBackendInstrumentation(Instruction & I,const ArrayRef<uint32_t> RealWeights)219bac6cd5bSPaul Kirth void checkBackendInstrumentation(Instruction &I,
220bac6cd5bSPaul Kirth                                  const ArrayRef<uint32_t> RealWeights) {
221bac6cd5bSPaul Kirth   auto ExpectedWeightsOpt = extractWeights(&I, I.getContext());
222a7938c74SKazu Hirata   if (!ExpectedWeightsOpt)
223bac6cd5bSPaul Kirth     return;
224*611ffcf4SKazu Hirata   auto ExpectedWeights = ExpectedWeightsOpt.value();
225bac6cd5bSPaul Kirth   verifyMisExpect(I, RealWeights, ExpectedWeights);
226bac6cd5bSPaul Kirth }
227bac6cd5bSPaul Kirth 
checkFrontendInstrumentation(Instruction & I,const ArrayRef<uint32_t> ExpectedWeights)228bac6cd5bSPaul Kirth void checkFrontendInstrumentation(Instruction &I,
229bac6cd5bSPaul Kirth                                   const ArrayRef<uint32_t> ExpectedWeights) {
230bac6cd5bSPaul Kirth   auto RealWeightsOpt = extractWeights(&I, I.getContext());
231a7938c74SKazu Hirata   if (!RealWeightsOpt)
232bac6cd5bSPaul Kirth     return;
233*611ffcf4SKazu Hirata   auto RealWeights = RealWeightsOpt.value();
234bac6cd5bSPaul Kirth   verifyMisExpect(I, RealWeights, ExpectedWeights);
235bac6cd5bSPaul Kirth }
236bac6cd5bSPaul Kirth 
checkExpectAnnotations(Instruction & I,const ArrayRef<uint32_t> ExistingWeights,bool IsFrontendInstr)237bac6cd5bSPaul Kirth void checkExpectAnnotations(Instruction &I,
238bac6cd5bSPaul Kirth                             const ArrayRef<uint32_t> ExistingWeights,
239bac6cd5bSPaul Kirth                             bool IsFrontendInstr) {
240bac6cd5bSPaul Kirth   if (IsFrontendInstr) {
241bac6cd5bSPaul Kirth     checkFrontendInstrumentation(I, ExistingWeights);
242bac6cd5bSPaul Kirth   } else {
243bac6cd5bSPaul Kirth     checkBackendInstrumentation(I, ExistingWeights);
244bac6cd5bSPaul Kirth   }
245bac6cd5bSPaul Kirth }
246bac6cd5bSPaul Kirth 
247bac6cd5bSPaul Kirth } // namespace misexpect
248bac6cd5bSPaul Kirth } // namespace llvm
249bac6cd5bSPaul Kirth #undef DEBUG_TYPE
250