1 //===- bolt/Passes/ValidateInternalCalls.cpp ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the ValidateInternalCalls class.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "bolt/Passes/ValidateInternalCalls.h"
14 #include "bolt/Core/BinaryBasicBlock.h"
15 #include "bolt/Passes/DataflowInfoManager.h"
16 #include "bolt/Passes/FrameAnalysis.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/MC/MCInstPrinter.h"
19 #include <queue>
20
21 #define DEBUG_TYPE "bolt-internalcalls"
22
23 namespace llvm {
24 namespace bolt {
25
26 namespace {
27
28 // Helper used to extract the target basic block used in an internal call.
29 // Return nullptr if this is not an internal call target.
getInternalCallTarget(BinaryFunction & Function,const MCInst & Inst)30 BinaryBasicBlock *getInternalCallTarget(BinaryFunction &Function,
31 const MCInst &Inst) {
32 const BinaryContext &BC = Function.getBinaryContext();
33 if (!BC.MIB->isCall(Inst) || MCPlus::getNumPrimeOperands(Inst) != 1 ||
34 !Inst.getOperand(0).isExpr())
35 return nullptr;
36
37 return Function.getBasicBlockForLabel(BC.MIB->getTargetSymbol(Inst));
38 }
39
40 // A special StackPointerTracking that considers internal calls
41 class StackPointerTrackingForInternalCalls
42 : public StackPointerTrackingBase<StackPointerTrackingForInternalCalls> {
43 friend class DataflowAnalysis<StackPointerTrackingForInternalCalls,
44 std::pair<int, int>>;
45
46 Optional<unsigned> AnnotationIndex;
47
48 protected:
49 // We change the starting state to only consider the first block as an
50 // entry point, otherwise the analysis won't converge (there will be two valid
51 // stack offsets, one for an external call and another for an internal call).
getStartingStateAtBB(const BinaryBasicBlock & BB)52 std::pair<int, int> getStartingStateAtBB(const BinaryBasicBlock &BB) {
53 if (&BB == &*Func.begin())
54 return std::make_pair(-8, getEmpty());
55 return std::make_pair(getEmpty(), getEmpty());
56 }
57
58 // Here we decrement SP for internal calls too, in addition to the regular
59 // StackPointerTracking processing.
computeNext(const MCInst & Point,const std::pair<int,int> & Cur)60 std::pair<int, int> computeNext(const MCInst &Point,
61 const std::pair<int, int> &Cur) {
62 std::pair<int, int> Res = StackPointerTrackingBase<
63 StackPointerTrackingForInternalCalls>::computeNext(Point, Cur);
64 if (Res.first == StackPointerTracking::SUPERPOSITION ||
65 Res.first == StackPointerTracking::EMPTY)
66 return Res;
67
68 if (BC.MIB->isReturn(Point)) {
69 Res.first += 8;
70 return Res;
71 }
72
73 BinaryBasicBlock *Target = getInternalCallTarget(Func, Point);
74 if (!Target)
75 return Res;
76
77 Res.first -= 8;
78 return Res;
79 }
80
getAnnotationName() const81 StringRef getAnnotationName() const {
82 return StringRef("StackPointerTrackingForInternalCalls");
83 }
84
85 public:
StackPointerTrackingForInternalCalls(BinaryFunction & BF)86 StackPointerTrackingForInternalCalls(BinaryFunction &BF)
87 : StackPointerTrackingBase<StackPointerTrackingForInternalCalls>(BF) {}
88
run()89 void run() {
90 StackPointerTrackingBase<StackPointerTrackingForInternalCalls>::run();
91 }
92 };
93
94 } // end anonymous namespace
95
fixCFGForPIC(BinaryFunction & Function) const96 void ValidateInternalCalls::fixCFGForPIC(BinaryFunction &Function) const {
97 std::queue<BinaryBasicBlock *> Work;
98 for (BinaryBasicBlock &BB : Function)
99 Work.emplace(&BB);
100
101 while (!Work.empty()) {
102 BinaryBasicBlock &BB = *Work.front();
103 Work.pop();
104
105 // Search for the next internal call.
106 const BinaryBasicBlock::iterator InternalCall =
107 llvm::find_if(BB, [&](const MCInst &Inst) {
108 return getInternalCallTarget(Function, Inst) != nullptr;
109 });
110
111 // No internal call? Done with this block.
112 if (InternalCall == BB.end())
113 continue;
114
115 BinaryBasicBlock *Target = getInternalCallTarget(Function, *InternalCall);
116 InstructionListType MovedInsts = BB.splitInstructions(&*InternalCall);
117 if (!MovedInsts.empty()) {
118 // Split this block at the call instruction.
119 std::unique_ptr<BinaryBasicBlock> NewBB = Function.createBasicBlock();
120 NewBB->setOffset(0);
121 NewBB->addInstructions(MovedInsts.begin(), MovedInsts.end());
122 BB.moveAllSuccessorsTo(NewBB.get());
123
124 Work.emplace(NewBB.get());
125 std::vector<std::unique_ptr<BinaryBasicBlock>> NewBBs;
126 NewBBs.emplace_back(std::move(NewBB));
127 Function.insertBasicBlocks(&BB, std::move(NewBBs));
128 }
129 // Update successors
130 BB.removeAllSuccessors();
131 BB.addSuccessor(Target, BB.getExecutionCount(), 0ULL);
132 }
133 }
134
fixCFGForIC(BinaryFunction & Function) const135 bool ValidateInternalCalls::fixCFGForIC(BinaryFunction &Function) const {
136 const BinaryContext &BC = Function.getBinaryContext();
137 // Track SP value
138 StackPointerTrackingForInternalCalls SPTIC(Function);
139 SPTIC.run();
140
141 // Track instructions reaching a given point of the CFG to answer
142 // "There is a path from entry to point A that contains instruction B"
143 ReachingInsns<false> RI(Function);
144 RI.run();
145
146 // We use the InsnToBB map that DataflowInfoManager provides us
147 DataflowInfoManager Info(Function, nullptr, nullptr);
148
149 bool Updated = false;
150
151 auto processReturns = [&](BinaryBasicBlock &BB, MCInst &Return) {
152 // Check all reaching internal calls
153 for (auto I = RI.expr_begin(Return), E = RI.expr_end(); I != E; ++I) {
154 MCInst &ReachingInst = **I;
155 if (!getInternalCallTarget(Function, ReachingInst) ||
156 BC.MIB->hasAnnotation(ReachingInst, getProcessedICTag()))
157 continue;
158
159 // Stack pointer matching
160 int SPAtCall = SPTIC.getStateAt(ReachingInst)->first;
161 int SPAtRet = SPTIC.getStateAt(Return)->first;
162 if (SPAtCall != StackPointerTracking::SUPERPOSITION &&
163 SPAtRet != StackPointerTracking::SUPERPOSITION &&
164 SPAtCall != SPAtRet - 8)
165 continue;
166
167 Updated = true;
168
169 // Mark this call as processed, so we don't try to analyze it as a
170 // PIC-computation internal call.
171 BC.MIB->addAnnotation(ReachingInst, getProcessedICTag(), 0U);
172
173 // Connect this block with the returning block of the caller
174 BinaryBasicBlock *CallerBlock = Info.getInsnToBBMap()[&ReachingInst];
175 BinaryBasicBlock *ReturnDestBlock =
176 Function.getLayout().getBasicBlockAfter(CallerBlock);
177 BB.addSuccessor(ReturnDestBlock, BB.getExecutionCount(), 0);
178 }
179 };
180
181 // This will connect blocks terminated with RETs to their respective
182 // internal caller return block. A note here: this is overly conservative
183 // because in nested calls, or unrelated calls, it will create edges
184 // connecting RETs to potentially unrelated internal calls. This is safe
185 // and if this causes a problem to recover the stack offsets properly, we
186 // will fail later.
187 for (BinaryBasicBlock &BB : Function) {
188 for (MCInst &Inst : BB) {
189 if (!BC.MIB->isReturn(Inst))
190 continue;
191
192 processReturns(BB, Inst);
193 }
194 }
195 return Updated;
196 }
197
hasTailCallsInRange(BinaryFunction & Function) const198 bool ValidateInternalCalls::hasTailCallsInRange(
199 BinaryFunction &Function) const {
200 const BinaryContext &BC = Function.getBinaryContext();
201 for (BinaryBasicBlock &BB : Function)
202 for (MCInst &Inst : BB)
203 if (BC.MIB->isTailCall(Inst))
204 return true;
205 return false;
206 }
207
analyzeFunction(BinaryFunction & Function) const208 bool ValidateInternalCalls::analyzeFunction(BinaryFunction &Function) const {
209 fixCFGForPIC(Function);
210 while (fixCFGForIC(Function)) {
211 }
212
213 BinaryContext &BC = Function.getBinaryContext();
214 RegAnalysis RA = RegAnalysis(BC, nullptr, nullptr);
215 RA.setConservativeStrategy(RegAnalysis::ConservativeStrategy::CLOBBERS_NONE);
216 bool HasTailCalls = hasTailCallsInRange(Function);
217
218 for (BinaryBasicBlock &BB : Function) {
219 for (MCInst &Inst : BB) {
220 BinaryBasicBlock *Target = getInternalCallTarget(Function, Inst);
221 if (!Target || BC.MIB->hasAnnotation(Inst, getProcessedICTag()))
222 continue;
223
224 if (HasTailCalls) {
225 LLVM_DEBUG(dbgs() << Function
226 << " has tail calls and internal calls.\n");
227 return false;
228 }
229
230 FrameIndexEntry FIE;
231 int32_t SrcImm = 0;
232 MCPhysReg Reg = 0;
233 int64_t StackOffset = 0;
234 bool IsIndexed = false;
235 MCInst *TargetInst = ProgramPoint::getFirstPointAt(*Target).getInst();
236 if (!BC.MIB->isStackAccess(*TargetInst, FIE.IsLoad, FIE.IsStore,
237 FIE.IsStoreFromReg, Reg, SrcImm,
238 FIE.StackPtrReg, StackOffset, FIE.Size,
239 FIE.IsSimple, IsIndexed)) {
240 LLVM_DEBUG({
241 dbgs() << "Frame analysis failed - not simple: " << Function << "\n";
242 Function.dump();
243 });
244 return false;
245 }
246 if (!FIE.IsLoad || FIE.StackPtrReg != BC.MIB->getStackPointer() ||
247 StackOffset != 0) {
248 LLVM_DEBUG({
249 dbgs() << "Target instruction does not fetch return address - not "
250 "simple: "
251 << Function << "\n";
252 Function.dump();
253 });
254 return false;
255 }
256 // Now track how the return address is used by tracking uses of Reg
257 ReachingDefOrUse</*Def=*/false> RU =
258 ReachingDefOrUse<false>(RA, Function, Reg);
259 RU.run();
260
261 int64_t Offset = static_cast<int64_t>(Target->getInputOffset());
262 bool UseDetected = false;
263 for (auto I = RU.expr_begin(*RU.getStateBefore(*TargetInst)),
264 E = RU.expr_end();
265 I != E; ++I) {
266 MCInst &Use = **I;
267 BitVector UsedRegs = BitVector(BC.MRI->getNumRegs(), false);
268 BC.MIB->getTouchedRegs(Use, UsedRegs);
269 if (!UsedRegs[Reg])
270 continue;
271 UseDetected = true;
272 int64_t Output;
273 std::pair<MCPhysReg, int64_t> Input1 = std::make_pair(Reg, 0);
274 std::pair<MCPhysReg, int64_t> Input2 = std::make_pair(0, 0);
275 if (!BC.MIB->evaluateStackOffsetExpr(Use, Output, Input1, Input2)) {
276 LLVM_DEBUG(dbgs() << "Evaluate stack offset expr failed.\n");
277 return false;
278 }
279 if (Offset + Output < 0 ||
280 Offset + Output > static_cast<int64_t>(Function.getSize())) {
281 LLVM_DEBUG({
282 dbgs() << "Detected out-of-range PIC reference in " << Function
283 << "\nReturn address load: ";
284 BC.InstPrinter->printInst(TargetInst, 0, "", *BC.STI, dbgs());
285 dbgs() << "\nUse: ";
286 BC.InstPrinter->printInst(&Use, 0, "", *BC.STI, dbgs());
287 dbgs() << "\n";
288 Function.dump();
289 });
290 return false;
291 }
292 LLVM_DEBUG({
293 dbgs() << "Validated access: ";
294 BC.InstPrinter->printInst(&Use, 0, "", *BC.STI, dbgs());
295 dbgs() << "\n";
296 });
297 }
298 if (!UseDetected) {
299 LLVM_DEBUG(dbgs() << "No use detected.\n");
300 return false;
301 }
302 }
303 }
304 return true;
305 }
306
runOnFunctions(BinaryContext & BC)307 void ValidateInternalCalls::runOnFunctions(BinaryContext &BC) {
308 if (!BC.isX86())
309 return;
310
311 // Look for functions that need validation. This should be pretty rare.
312 std::set<BinaryFunction *> NeedsValidation;
313 for (auto &BFI : BC.getBinaryFunctions()) {
314 BinaryFunction &Function = BFI.second;
315 for (BinaryBasicBlock &BB : Function) {
316 for (MCInst &Inst : BB) {
317 if (getInternalCallTarget(Function, Inst)) {
318 NeedsValidation.insert(&Function);
319 Function.setSimple(false);
320 break;
321 }
322 }
323 }
324 }
325
326 // Skip validation for non-relocation mode
327 if (!BC.HasRelocations)
328 return;
329
330 // Since few functions need validation, we can work with our most expensive
331 // algorithms here. Fix the CFG treating internal calls as unconditional
332 // jumps. This optimistically assumes this call is a PIC trick to get the PC
333 // value, so it is not really a call, but a jump. If we find that it's not the
334 // case, we mark this function as non-simple and stop processing it.
335 std::set<BinaryFunction *> Invalid;
336 for (BinaryFunction *Function : NeedsValidation) {
337 LLVM_DEBUG(dbgs() << "Validating " << *Function << "\n");
338 if (!analyzeFunction(*Function))
339 Invalid.insert(Function);
340 clearAnnotations(*Function);
341 }
342
343 if (!Invalid.empty()) {
344 errs() << "BOLT-WARNING: will skip the following function(s) as unsupported"
345 " internal calls were detected:\n";
346 for (BinaryFunction *Function : Invalid) {
347 errs() << " " << *Function << "\n";
348 Function->setIgnored();
349 }
350 }
351 }
352
353 } // namespace bolt
354 } // namespace llvm
355