1 //===- bolt/Passes/ValidateInternalCalls.cpp ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the ValidateInternalCalls class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "bolt/Passes/ValidateInternalCalls.h"
14 #include "bolt/Core/BinaryBasicBlock.h"
15 #include "bolt/Passes/DataflowInfoManager.h"
16 #include "bolt/Passes/FrameAnalysis.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/MC/MCInstPrinter.h"
19 #include <queue>
20 
21 #define DEBUG_TYPE "bolt-internalcalls"
22 
23 namespace llvm {
24 namespace bolt {
25 
26 namespace {
27 
28 // Helper used to extract the target basic block used in an internal call.
29 // Return nullptr if this is not an internal call target.
getInternalCallTarget(BinaryFunction & Function,const MCInst & Inst)30 BinaryBasicBlock *getInternalCallTarget(BinaryFunction &Function,
31                                         const MCInst &Inst) {
32   const BinaryContext &BC = Function.getBinaryContext();
33   if (!BC.MIB->isCall(Inst) || MCPlus::getNumPrimeOperands(Inst) != 1 ||
34       !Inst.getOperand(0).isExpr())
35     return nullptr;
36 
37   return Function.getBasicBlockForLabel(BC.MIB->getTargetSymbol(Inst));
38 }
39 
40 // A special StackPointerTracking that considers internal calls
41 class StackPointerTrackingForInternalCalls
42     : public StackPointerTrackingBase<StackPointerTrackingForInternalCalls> {
43   friend class DataflowAnalysis<StackPointerTrackingForInternalCalls,
44                                 std::pair<int, int>>;
45 
46   Optional<unsigned> AnnotationIndex;
47 
48 protected:
49   // We change the starting state to only consider the first block as an
50   // entry point, otherwise the analysis won't converge (there will be two valid
51   // stack offsets, one for an external call and another for an internal call).
getStartingStateAtBB(const BinaryBasicBlock & BB)52   std::pair<int, int> getStartingStateAtBB(const BinaryBasicBlock &BB) {
53     if (&BB == &*Func.begin())
54       return std::make_pair(-8, getEmpty());
55     return std::make_pair(getEmpty(), getEmpty());
56   }
57 
58   // Here we decrement SP for internal calls too, in addition to the regular
59   // StackPointerTracking processing.
computeNext(const MCInst & Point,const std::pair<int,int> & Cur)60   std::pair<int, int> computeNext(const MCInst &Point,
61                                   const std::pair<int, int> &Cur) {
62     std::pair<int, int> Res = StackPointerTrackingBase<
63         StackPointerTrackingForInternalCalls>::computeNext(Point, Cur);
64     if (Res.first == StackPointerTracking::SUPERPOSITION ||
65         Res.first == StackPointerTracking::EMPTY)
66       return Res;
67 
68     if (BC.MIB->isReturn(Point)) {
69       Res.first += 8;
70       return Res;
71     }
72 
73     BinaryBasicBlock *Target = getInternalCallTarget(Func, Point);
74     if (!Target)
75       return Res;
76 
77     Res.first -= 8;
78     return Res;
79   }
80 
getAnnotationName() const81   StringRef getAnnotationName() const {
82     return StringRef("StackPointerTrackingForInternalCalls");
83   }
84 
85 public:
StackPointerTrackingForInternalCalls(BinaryFunction & BF)86   StackPointerTrackingForInternalCalls(BinaryFunction &BF)
87       : StackPointerTrackingBase<StackPointerTrackingForInternalCalls>(BF) {}
88 
run()89   void run() {
90     StackPointerTrackingBase<StackPointerTrackingForInternalCalls>::run();
91   }
92 };
93 
94 } // end anonymous namespace
95 
fixCFGForPIC(BinaryFunction & Function) const96 void ValidateInternalCalls::fixCFGForPIC(BinaryFunction &Function) const {
97   std::queue<BinaryBasicBlock *> Work;
98   for (BinaryBasicBlock &BB : Function)
99     Work.emplace(&BB);
100 
101   while (!Work.empty()) {
102     BinaryBasicBlock &BB = *Work.front();
103     Work.pop();
104 
105     // Search for the next internal call.
106     const BinaryBasicBlock::iterator InternalCall =
107         llvm::find_if(BB, [&](const MCInst &Inst) {
108           return getInternalCallTarget(Function, Inst) != nullptr;
109         });
110 
111     // No internal call? Done with this block.
112     if (InternalCall == BB.end())
113       continue;
114 
115     BinaryBasicBlock *Target = getInternalCallTarget(Function, *InternalCall);
116     InstructionListType MovedInsts = BB.splitInstructions(&*InternalCall);
117     if (!MovedInsts.empty()) {
118       // Split this block at the call instruction.
119       std::unique_ptr<BinaryBasicBlock> NewBB = Function.createBasicBlock();
120       NewBB->setOffset(0);
121       NewBB->addInstructions(MovedInsts.begin(), MovedInsts.end());
122       BB.moveAllSuccessorsTo(NewBB.get());
123 
124       Work.emplace(NewBB.get());
125       std::vector<std::unique_ptr<BinaryBasicBlock>> NewBBs;
126       NewBBs.emplace_back(std::move(NewBB));
127       Function.insertBasicBlocks(&BB, std::move(NewBBs));
128     }
129     // Update successors
130     BB.removeAllSuccessors();
131     BB.addSuccessor(Target, BB.getExecutionCount(), 0ULL);
132   }
133 }
134 
fixCFGForIC(BinaryFunction & Function) const135 bool ValidateInternalCalls::fixCFGForIC(BinaryFunction &Function) const {
136   const BinaryContext &BC = Function.getBinaryContext();
137   // Track SP value
138   StackPointerTrackingForInternalCalls SPTIC(Function);
139   SPTIC.run();
140 
141   // Track instructions reaching a given point of the CFG to answer
142   // "There is a path from entry to point A that contains instruction B"
143   ReachingInsns<false> RI(Function);
144   RI.run();
145 
146   // We use the InsnToBB map that DataflowInfoManager provides us
147   DataflowInfoManager Info(Function, nullptr, nullptr);
148 
149   bool Updated = false;
150 
151   auto processReturns = [&](BinaryBasicBlock &BB, MCInst &Return) {
152     // Check all reaching internal calls
153     for (auto I = RI.expr_begin(Return), E = RI.expr_end(); I != E; ++I) {
154       MCInst &ReachingInst = **I;
155       if (!getInternalCallTarget(Function, ReachingInst) ||
156           BC.MIB->hasAnnotation(ReachingInst, getProcessedICTag()))
157         continue;
158 
159       // Stack pointer matching
160       int SPAtCall = SPTIC.getStateAt(ReachingInst)->first;
161       int SPAtRet = SPTIC.getStateAt(Return)->first;
162       if (SPAtCall != StackPointerTracking::SUPERPOSITION &&
163           SPAtRet != StackPointerTracking::SUPERPOSITION &&
164           SPAtCall != SPAtRet - 8)
165         continue;
166 
167       Updated = true;
168 
169       // Mark this call as processed, so we don't try to analyze it as a
170       // PIC-computation internal call.
171       BC.MIB->addAnnotation(ReachingInst, getProcessedICTag(), 0U);
172 
173       // Connect this block with the returning block of the caller
174       BinaryBasicBlock *CallerBlock = Info.getInsnToBBMap()[&ReachingInst];
175       BinaryBasicBlock *ReturnDestBlock =
176           Function.getLayout().getBasicBlockAfter(CallerBlock);
177       BB.addSuccessor(ReturnDestBlock, BB.getExecutionCount(), 0);
178     }
179   };
180 
181   // This will connect blocks terminated with RETs to their respective
182   // internal caller return block. A note here: this is overly conservative
183   // because in nested calls, or unrelated calls, it will create edges
184   // connecting RETs to potentially unrelated internal calls. This is safe
185   // and if this causes a problem to recover the stack offsets properly, we
186   // will fail later.
187   for (BinaryBasicBlock &BB : Function) {
188     for (MCInst &Inst : BB) {
189       if (!BC.MIB->isReturn(Inst))
190         continue;
191 
192       processReturns(BB, Inst);
193     }
194   }
195   return Updated;
196 }
197 
hasTailCallsInRange(BinaryFunction & Function) const198 bool ValidateInternalCalls::hasTailCallsInRange(
199     BinaryFunction &Function) const {
200   const BinaryContext &BC = Function.getBinaryContext();
201   for (BinaryBasicBlock &BB : Function)
202     for (MCInst &Inst : BB)
203       if (BC.MIB->isTailCall(Inst))
204         return true;
205   return false;
206 }
207 
analyzeFunction(BinaryFunction & Function) const208 bool ValidateInternalCalls::analyzeFunction(BinaryFunction &Function) const {
209   fixCFGForPIC(Function);
210   while (fixCFGForIC(Function)) {
211   }
212 
213   BinaryContext &BC = Function.getBinaryContext();
214   RegAnalysis RA = RegAnalysis(BC, nullptr, nullptr);
215   RA.setConservativeStrategy(RegAnalysis::ConservativeStrategy::CLOBBERS_NONE);
216   bool HasTailCalls = hasTailCallsInRange(Function);
217 
218   for (BinaryBasicBlock &BB : Function) {
219     for (MCInst &Inst : BB) {
220       BinaryBasicBlock *Target = getInternalCallTarget(Function, Inst);
221       if (!Target || BC.MIB->hasAnnotation(Inst, getProcessedICTag()))
222         continue;
223 
224       if (HasTailCalls) {
225         LLVM_DEBUG(dbgs() << Function
226                           << " has tail calls and internal calls.\n");
227         return false;
228       }
229 
230       FrameIndexEntry FIE;
231       int32_t SrcImm = 0;
232       MCPhysReg Reg = 0;
233       int64_t StackOffset = 0;
234       bool IsIndexed = false;
235       MCInst *TargetInst = ProgramPoint::getFirstPointAt(*Target).getInst();
236       if (!BC.MIB->isStackAccess(*TargetInst, FIE.IsLoad, FIE.IsStore,
237                                  FIE.IsStoreFromReg, Reg, SrcImm,
238                                  FIE.StackPtrReg, StackOffset, FIE.Size,
239                                  FIE.IsSimple, IsIndexed)) {
240         LLVM_DEBUG({
241           dbgs() << "Frame analysis failed - not simple: " << Function << "\n";
242           Function.dump();
243         });
244         return false;
245       }
246       if (!FIE.IsLoad || FIE.StackPtrReg != BC.MIB->getStackPointer() ||
247           StackOffset != 0) {
248         LLVM_DEBUG({
249           dbgs() << "Target instruction does not fetch return address - not "
250                     "simple: "
251                  << Function << "\n";
252           Function.dump();
253         });
254         return false;
255       }
256       // Now track how the return address is used by tracking uses of Reg
257       ReachingDefOrUse</*Def=*/false> RU =
258           ReachingDefOrUse<false>(RA, Function, Reg);
259       RU.run();
260 
261       int64_t Offset = static_cast<int64_t>(Target->getInputOffset());
262       bool UseDetected = false;
263       for (auto I = RU.expr_begin(*RU.getStateBefore(*TargetInst)),
264                 E = RU.expr_end();
265            I != E; ++I) {
266         MCInst &Use = **I;
267         BitVector UsedRegs = BitVector(BC.MRI->getNumRegs(), false);
268         BC.MIB->getTouchedRegs(Use, UsedRegs);
269         if (!UsedRegs[Reg])
270           continue;
271         UseDetected = true;
272         int64_t Output;
273         std::pair<MCPhysReg, int64_t> Input1 = std::make_pair(Reg, 0);
274         std::pair<MCPhysReg, int64_t> Input2 = std::make_pair(0, 0);
275         if (!BC.MIB->evaluateStackOffsetExpr(Use, Output, Input1, Input2)) {
276           LLVM_DEBUG(dbgs() << "Evaluate stack offset expr failed.\n");
277           return false;
278         }
279         if (Offset + Output < 0 ||
280             Offset + Output > static_cast<int64_t>(Function.getSize())) {
281           LLVM_DEBUG({
282             dbgs() << "Detected out-of-range PIC reference in " << Function
283                    << "\nReturn address load: ";
284             BC.InstPrinter->printInst(TargetInst, 0, "", *BC.STI, dbgs());
285             dbgs() << "\nUse: ";
286             BC.InstPrinter->printInst(&Use, 0, "", *BC.STI, dbgs());
287             dbgs() << "\n";
288             Function.dump();
289           });
290           return false;
291         }
292         LLVM_DEBUG({
293           dbgs() << "Validated access: ";
294           BC.InstPrinter->printInst(&Use, 0, "", *BC.STI, dbgs());
295           dbgs() << "\n";
296         });
297       }
298       if (!UseDetected) {
299         LLVM_DEBUG(dbgs() << "No use detected.\n");
300         return false;
301       }
302     }
303   }
304   return true;
305 }
306 
runOnFunctions(BinaryContext & BC)307 void ValidateInternalCalls::runOnFunctions(BinaryContext &BC) {
308   if (!BC.isX86())
309     return;
310 
311   // Look for functions that need validation. This should be pretty rare.
312   std::set<BinaryFunction *> NeedsValidation;
313   for (auto &BFI : BC.getBinaryFunctions()) {
314     BinaryFunction &Function = BFI.second;
315     for (BinaryBasicBlock &BB : Function) {
316       for (MCInst &Inst : BB) {
317         if (getInternalCallTarget(Function, Inst)) {
318           NeedsValidation.insert(&Function);
319           Function.setSimple(false);
320           break;
321         }
322       }
323     }
324   }
325 
326   // Skip validation for non-relocation mode
327   if (!BC.HasRelocations)
328     return;
329 
330   // Since few functions need validation, we can work with our most expensive
331   // algorithms here. Fix the CFG treating internal calls as unconditional
332   // jumps. This optimistically assumes this call is a PIC trick to get the PC
333   // value, so it is not really a call, but a jump. If we find that it's not the
334   // case, we mark this function as non-simple and stop processing it.
335   std::set<BinaryFunction *> Invalid;
336   for (BinaryFunction *Function : NeedsValidation) {
337     LLVM_DEBUG(dbgs() << "Validating " << *Function << "\n");
338     if (!analyzeFunction(*Function))
339       Invalid.insert(Function);
340     clearAnnotations(*Function);
341   }
342 
343   if (!Invalid.empty()) {
344     errs() << "BOLT-WARNING: will skip the following function(s) as unsupported"
345               " internal calls were detected:\n";
346     for (BinaryFunction *Function : Invalid) {
347       errs() << "              " << *Function << "\n";
348       Function->setIgnored();
349     }
350   }
351 }
352 
353 } // namespace bolt
354 } // namespace llvm
355