1 //===- bolt/Passes/ValidateInternalCalls.cpp ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the ValidateInternalCalls class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "bolt/Passes/ValidateInternalCalls.h"
14 #include "bolt/Passes/DataflowInfoManager.h"
15 #include "bolt/Passes/FrameAnalysis.h"
16 #include "llvm/MC/MCInstPrinter.h"
17 
18 #define DEBUG_TYPE "bolt-internalcalls"
19 
20 namespace llvm {
21 namespace bolt {
22 
23 namespace {
24 
25 // Helper used to extract the target basic block used in an internal call.
26 // Return nullptr if this is not an internal call target.
27 BinaryBasicBlock *getInternalCallTarget(BinaryFunction &Function,
28                                         const MCInst &Inst) {
29   const BinaryContext &BC = Function.getBinaryContext();
30   if (!BC.MIB->isCall(Inst) || MCPlus::getNumPrimeOperands(Inst) != 1 ||
31       !Inst.getOperand(0).isExpr())
32     return nullptr;
33 
34   return Function.getBasicBlockForLabel(BC.MIB->getTargetSymbol(Inst));
35 }
36 
37 // A special StackPointerTracking that considers internal calls
38 class StackPointerTrackingForInternalCalls
39     : public StackPointerTrackingBase<StackPointerTrackingForInternalCalls> {
40   friend class DataflowAnalysis<StackPointerTrackingForInternalCalls,
41                                 std::pair<int, int>>;
42 
43   Optional<unsigned> AnnotationIndex;
44 
45 protected:
46   // We change the starting state to only consider the first block as an
47   // entry point, otherwise the analysis won't converge (there will be two valid
48   // stack offsets, one for an external call and another for an internal call).
49   std::pair<int, int> getStartingStateAtBB(const BinaryBasicBlock &BB) {
50     if (&BB == &*Func.begin())
51       return std::make_pair(-8, getEmpty());
52     return std::make_pair(getEmpty(), getEmpty());
53   }
54 
55   // Here we decrement SP for internal calls too, in addition to the regular
56   // StackPointerTracking processing.
57   std::pair<int, int> computeNext(const MCInst &Point,
58                                   const std::pair<int, int> &Cur) {
59     std::pair<int, int> Res = StackPointerTrackingBase<
60         StackPointerTrackingForInternalCalls>::computeNext(Point, Cur);
61     if (Res.first == StackPointerTracking::SUPERPOSITION ||
62         Res.first == StackPointerTracking::EMPTY)
63       return Res;
64 
65     if (BC.MIB->isReturn(Point)) {
66       Res.first += 8;
67       return Res;
68     }
69 
70     BinaryBasicBlock *Target = getInternalCallTarget(Func, Point);
71     if (!Target)
72       return Res;
73 
74     Res.first -= 8;
75     return Res;
76   }
77 
78   StringRef getAnnotationName() const {
79     return StringRef("StackPointerTrackingForInternalCalls");
80   }
81 
82 public:
83   StackPointerTrackingForInternalCalls(BinaryFunction &BF)
84       : StackPointerTrackingBase<StackPointerTrackingForInternalCalls>(BF) {}
85 
86   void run() {
87     StackPointerTrackingBase<StackPointerTrackingForInternalCalls>::run();
88   }
89 };
90 
91 } // end anonymous namespace
92 
93 bool ValidateInternalCalls::fixCFGForPIC(BinaryFunction &Function) const {
94   const BinaryContext &BC = Function.getBinaryContext();
95   for (BinaryBasicBlock &BB : Function) {
96     for (auto II = BB.begin(); II != BB.end(); ++II) {
97       MCInst &Inst = *II;
98       BinaryBasicBlock *Target = getInternalCallTarget(Function, Inst);
99       if (!Target || BC.MIB->hasAnnotation(Inst, getProcessedICTag()))
100         continue;
101 
102       BC.MIB->addAnnotation(Inst, getProcessedICTag(), 0U);
103       InstructionListType MovedInsts = BB.splitInstructions(&Inst);
104       if (!MovedInsts.empty()) {
105         // Split this block at the call instruction. Create an unreachable
106         // block.
107         std::vector<std::unique_ptr<BinaryBasicBlock>> NewBBs;
108         NewBBs.emplace_back(Function.createBasicBlock(0));
109         NewBBs.back()->addInstructions(MovedInsts.begin(), MovedInsts.end());
110         BB.moveAllSuccessorsTo(NewBBs.back().get());
111         Function.insertBasicBlocks(&BB, std::move(NewBBs));
112       }
113       // Update successors
114       BB.removeAllSuccessors();
115       BB.addSuccessor(Target, BB.getExecutionCount(), 0ULL);
116       return true;
117     }
118   }
119   return false;
120 }
121 
122 bool ValidateInternalCalls::fixCFGForIC(BinaryFunction &Function) const {
123   const BinaryContext &BC = Function.getBinaryContext();
124   // Track SP value
125   StackPointerTrackingForInternalCalls SPTIC(Function);
126   SPTIC.run();
127 
128   // Track instructions reaching a given point of the CFG to answer
129   // "There is a path from entry to point A that contains instruction B"
130   ReachingInsns<false> RI(Function);
131   RI.run();
132 
133   // We use the InsnToBB map that DataflowInfoManager provides us
134   DataflowInfoManager Info(Function, nullptr, nullptr);
135 
136   bool Updated = false;
137 
138   auto processReturns = [&](BinaryBasicBlock &BB, MCInst &Return) {
139     // Check all reaching internal calls
140     for (auto I = RI.expr_begin(Return), E = RI.expr_end(); I != E; ++I) {
141       MCInst &ReachingInst = **I;
142       if (!getInternalCallTarget(Function, ReachingInst) ||
143           BC.MIB->hasAnnotation(ReachingInst, getProcessedICTag()))
144         continue;
145 
146       // Stack pointer matching
147       int SPAtCall = SPTIC.getStateAt(ReachingInst)->first;
148       int SPAtRet = SPTIC.getStateAt(Return)->first;
149       if (SPAtCall != StackPointerTracking::SUPERPOSITION &&
150           SPAtRet != StackPointerTracking::SUPERPOSITION &&
151           SPAtCall != SPAtRet - 8)
152         continue;
153 
154       Updated = true;
155 
156       // Mark this call as processed, so we don't try to analyze it as a
157       // PIC-computation internal call.
158       BC.MIB->addAnnotation(ReachingInst, getProcessedICTag(), 0U);
159 
160       // Connect this block with the returning block of the caller
161       BinaryBasicBlock *CallerBlock = Info.getInsnToBBMap()[&ReachingInst];
162       BinaryBasicBlock *ReturnDestBlock =
163           Function.getBasicBlockAfter(CallerBlock);
164       BB.addSuccessor(ReturnDestBlock, BB.getExecutionCount(), 0);
165     }
166   };
167 
168   // This will connect blocks terminated with RETs to their respective
169   // internal caller return block. A note here: this is overly conservative
170   // because in nested calls, or unrelated calls, it will create edges
171   // connecting RETs to potentially unrelated internal calls. This is safe
172   // and if this causes a problem to recover the stack offsets properly, we
173   // will fail later.
174   for (BinaryBasicBlock &BB : Function) {
175     for (MCInst &Inst : BB) {
176       if (!BC.MIB->isReturn(Inst))
177         continue;
178 
179       processReturns(BB, Inst);
180     }
181   }
182   return Updated;
183 }
184 
185 bool ValidateInternalCalls::hasTailCallsInRange(
186     BinaryFunction &Function) const {
187   const BinaryContext &BC = Function.getBinaryContext();
188   for (BinaryBasicBlock &BB : Function)
189     for (MCInst &Inst : BB)
190       if (BC.MIB->isTailCall(Inst))
191         return true;
192   return false;
193 }
194 
195 bool ValidateInternalCalls::analyzeFunction(BinaryFunction &Function) const {
196   while (fixCFGForPIC(Function)) {
197   }
198   clearAnnotations(Function);
199   while (fixCFGForIC(Function)) {
200   }
201 
202   BinaryContext &BC = Function.getBinaryContext();
203   RegAnalysis RA = RegAnalysis(BC, nullptr, nullptr);
204   RA.setConservativeStrategy(RegAnalysis::ConservativeStrategy::CLOBBERS_NONE);
205   bool HasTailCalls = hasTailCallsInRange(Function);
206 
207   for (BinaryBasicBlock &BB : Function) {
208     for (MCInst &Inst : BB) {
209       BinaryBasicBlock *Target = getInternalCallTarget(Function, Inst);
210       if (!Target || BC.MIB->hasAnnotation(Inst, getProcessedICTag()))
211         continue;
212 
213       if (HasTailCalls) {
214         LLVM_DEBUG(dbgs() << Function
215                           << " has tail calls and internal calls.\n");
216         return false;
217       }
218 
219       FrameIndexEntry FIE;
220       int32_t SrcImm = 0;
221       MCPhysReg Reg = 0;
222       int64_t StackOffset = 0;
223       bool IsIndexed = false;
224       MCInst *TargetInst = ProgramPoint::getFirstPointAt(*Target).getInst();
225       if (!BC.MIB->isStackAccess(*TargetInst, FIE.IsLoad, FIE.IsStore,
226                                  FIE.IsStoreFromReg, Reg, SrcImm,
227                                  FIE.StackPtrReg, StackOffset, FIE.Size,
228                                  FIE.IsSimple, IsIndexed)) {
229         LLVM_DEBUG({
230           dbgs() << "Frame analysis failed - not simple: " << Function << "\n";
231           Function.dump();
232         });
233         return false;
234       }
235       if (!FIE.IsLoad || FIE.StackPtrReg != BC.MIB->getStackPointer() ||
236           StackOffset != 0) {
237         LLVM_DEBUG({
238           dbgs() << "Target instruction does not fetch return address - not "
239                     "simple: "
240                  << Function << "\n";
241           Function.dump();
242         });
243         return false;
244       }
245       // Now track how the return address is used by tracking uses of Reg
246       ReachingDefOrUse</*Def=*/false> RU =
247           ReachingDefOrUse<false>(RA, Function, Reg);
248       RU.run();
249 
250       int64_t Offset = static_cast<int64_t>(Target->getInputOffset());
251       bool UseDetected = false;
252       for (auto I = RU.expr_begin(*RU.getStateBefore(*TargetInst)),
253                 E = RU.expr_end();
254            I != E; ++I) {
255         MCInst &Use = **I;
256         BitVector UsedRegs = BitVector(BC.MRI->getNumRegs(), false);
257         BC.MIB->getTouchedRegs(Use, UsedRegs);
258         if (!UsedRegs[Reg])
259           continue;
260         UseDetected = true;
261         int64_t Output;
262         std::pair<MCPhysReg, int64_t> Input1 = std::make_pair(Reg, 0);
263         std::pair<MCPhysReg, int64_t> Input2 = std::make_pair(0, 0);
264         if (!BC.MIB->evaluateStackOffsetExpr(Use, Output, Input1, Input2)) {
265           LLVM_DEBUG(dbgs() << "Evaluate stack offset expr failed.\n");
266           return false;
267         }
268         if (Offset + Output < 0 ||
269             Offset + Output > static_cast<int64_t>(Function.getSize())) {
270           LLVM_DEBUG({
271             dbgs() << "Detected out-of-range PIC reference in " << Function
272                    << "\nReturn address load: ";
273             BC.InstPrinter->printInst(TargetInst, 0, "", *BC.STI, dbgs());
274             dbgs() << "\nUse: ";
275             BC.InstPrinter->printInst(&Use, 0, "", *BC.STI, dbgs());
276             dbgs() << "\n";
277             Function.dump();
278           });
279           return false;
280         }
281         LLVM_DEBUG({
282           dbgs() << "Validated access: ";
283           BC.InstPrinter->printInst(&Use, 0, "", *BC.STI, dbgs());
284           dbgs() << "\n";
285         });
286       }
287       if (!UseDetected) {
288         LLVM_DEBUG(dbgs() << "No use detected.\n");
289         return false;
290       }
291     }
292   }
293   return true;
294 }
295 
296 void ValidateInternalCalls::runOnFunctions(BinaryContext &BC) {
297   if (!BC.isX86())
298     return;
299 
300   // Look for functions that need validation. This should be pretty rare.
301   std::set<BinaryFunction *> NeedsValidation;
302   for (auto &BFI : BC.getBinaryFunctions()) {
303     BinaryFunction &Function = BFI.second;
304     for (BinaryBasicBlock &BB : Function) {
305       for (MCInst &Inst : BB) {
306         if (getInternalCallTarget(Function, Inst)) {
307           NeedsValidation.insert(&Function);
308           Function.setSimple(false);
309           break;
310         }
311       }
312     }
313   }
314 
315   // Skip validation for non-relocation mode
316   if (!BC.HasRelocations)
317     return;
318 
319   // Since few functions need validation, we can work with our most expensive
320   // algorithms here. Fix the CFG treating internal calls as unconditional
321   // jumps. This optimistically assumes this call is a PIC trick to get the PC
322   // value, so it is not really a call, but a jump. If we find that it's not the
323   // case, we mark this function as non-simple and stop processing it.
324   std::set<BinaryFunction *> Invalid;
325   for (BinaryFunction *Function : NeedsValidation) {
326     LLVM_DEBUG(dbgs() << "Validating " << *Function << "\n");
327     if (!analyzeFunction(*Function))
328       Invalid.insert(Function);
329     clearAnnotations(*Function);
330   }
331 
332   if (!Invalid.empty()) {
333     errs() << "BOLT-WARNING: will skip the following function(s) as unsupported"
334               " internal calls were detected:\n";
335     for (BinaryFunction *Function : Invalid) {
336       errs() << "              " << *Function << "\n";
337       Function->setIgnored();
338     }
339   }
340 }
341 
342 } // namespace bolt
343 } // namespace llvm
344