1 //===- bolt/Passes/ValidateInternalCalls.cpp ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the ValidateInternalCalls class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "bolt/Passes/ValidateInternalCalls.h"
14 #include "bolt/Passes/DataflowInfoManager.h"
15 #include "bolt/Passes/FrameAnalysis.h"
16 #include "llvm/MC/MCInstPrinter.h"
17 
18 #define DEBUG_TYPE "bolt-internalcalls"
19 
20 namespace llvm {
21 namespace bolt {
22 
23 namespace {
24 
25 // Helper used to extract the target basic block used in an internal call.
26 // Return nullptr if this is not an internal call target.
27 BinaryBasicBlock *getInternalCallTarget(BinaryFunction &Function,
28                                         const MCInst &Inst) {
29   const BinaryContext &BC = Function.getBinaryContext();
30   if (!BC.MIB->isCall(Inst) || MCPlus::getNumPrimeOperands(Inst) != 1 ||
31       !Inst.getOperand(0).isExpr())
32     return nullptr;
33 
34   return Function.getBasicBlockForLabel(BC.MIB->getTargetSymbol(Inst));
35 }
36 
37 // A special StackPointerTracking that considers internal calls
38 class StackPointerTrackingForInternalCalls
39     : public StackPointerTrackingBase<StackPointerTrackingForInternalCalls> {
40   friend class DataflowAnalysis<StackPointerTrackingForInternalCalls,
41                                 std::pair<int, int>>;
42 
43   Optional<unsigned> AnnotationIndex;
44 
45 protected:
46   // We change the starting state to only consider the first block as an
47   // entry point, otherwise the analysis won't converge (there will be two valid
48   // stack offsets, one for an external call and another for an internal call).
49   std::pair<int, int> getStartingStateAtBB(const BinaryBasicBlock &BB) {
50     if (&BB == &*Func.begin())
51       return std::make_pair(-8, getEmpty());
52     return std::make_pair(getEmpty(), getEmpty());
53   }
54 
55   // Here we decrement SP for internal calls too, in addition to the regular
56   // StackPointerTracking processing.
57   std::pair<int, int> computeNext(const MCInst &Point,
58                                   const std::pair<int, int> &Cur) {
59     std::pair<int, int> Res = StackPointerTrackingBase<
60         StackPointerTrackingForInternalCalls>::computeNext(Point, Cur);
61     if (Res.first == StackPointerTracking::SUPERPOSITION ||
62         Res.first == StackPointerTracking::EMPTY)
63       return Res;
64 
65     if (BC.MIB->isReturn(Point)) {
66       Res.first += 8;
67       return Res;
68     }
69 
70     BinaryBasicBlock *Target = getInternalCallTarget(Func, Point);
71     if (!Target)
72       return Res;
73 
74     Res.first -= 8;
75     return Res;
76   }
77 
78   StringRef getAnnotationName() const {
79     return StringRef("StackPointerTrackingForInternalCalls");
80   }
81 
82 public:
83   StackPointerTrackingForInternalCalls(BinaryFunction &BF)
84       : StackPointerTrackingBase<StackPointerTrackingForInternalCalls>(BF) {}
85 
86   void run() {
87     StackPointerTrackingBase<StackPointerTrackingForInternalCalls>::run();
88   }
89 };
90 
91 } // end anonymous namespace
92 
93 bool ValidateInternalCalls::fixCFGForPIC(BinaryFunction &Function) const {
94   const BinaryContext &BC = Function.getBinaryContext();
95   for (BinaryBasicBlock &BB : Function) {
96     for (auto II = BB.begin(); II != BB.end(); ++II) {
97       MCInst &Inst = *II;
98       BinaryBasicBlock *Target = getInternalCallTarget(Function, Inst);
99       if (!Target || BC.MIB->hasAnnotation(Inst, getProcessedICTag()))
100         continue;
101 
102       BC.MIB->addAnnotation(Inst, getProcessedICTag(), 0U);
103       InstructionListType MovedInsts = BB.splitInstructions(&Inst);
104       if (!MovedInsts.empty()) {
105         // Split this block at the call instruction. Create an unreachable
106         // block.
107         std::vector<std::unique_ptr<BinaryBasicBlock>> NewBBs;
108         NewBBs.emplace_back(Function.createBasicBlock());
109         NewBBs.back()->setOffset(0);
110         NewBBs.back()->addInstructions(MovedInsts.begin(), MovedInsts.end());
111         BB.moveAllSuccessorsTo(NewBBs.back().get());
112         Function.insertBasicBlocks(&BB, std::move(NewBBs));
113       }
114       // Update successors
115       BB.removeAllSuccessors();
116       BB.addSuccessor(Target, BB.getExecutionCount(), 0ULL);
117       return true;
118     }
119   }
120   return false;
121 }
122 
123 bool ValidateInternalCalls::fixCFGForIC(BinaryFunction &Function) const {
124   const BinaryContext &BC = Function.getBinaryContext();
125   // Track SP value
126   StackPointerTrackingForInternalCalls SPTIC(Function);
127   SPTIC.run();
128 
129   // Track instructions reaching a given point of the CFG to answer
130   // "There is a path from entry to point A that contains instruction B"
131   ReachingInsns<false> RI(Function);
132   RI.run();
133 
134   // We use the InsnToBB map that DataflowInfoManager provides us
135   DataflowInfoManager Info(Function, nullptr, nullptr);
136 
137   bool Updated = false;
138 
139   auto processReturns = [&](BinaryBasicBlock &BB, MCInst &Return) {
140     // Check all reaching internal calls
141     for (auto I = RI.expr_begin(Return), E = RI.expr_end(); I != E; ++I) {
142       MCInst &ReachingInst = **I;
143       if (!getInternalCallTarget(Function, ReachingInst) ||
144           BC.MIB->hasAnnotation(ReachingInst, getProcessedICTag()))
145         continue;
146 
147       // Stack pointer matching
148       int SPAtCall = SPTIC.getStateAt(ReachingInst)->first;
149       int SPAtRet = SPTIC.getStateAt(Return)->first;
150       if (SPAtCall != StackPointerTracking::SUPERPOSITION &&
151           SPAtRet != StackPointerTracking::SUPERPOSITION &&
152           SPAtCall != SPAtRet - 8)
153         continue;
154 
155       Updated = true;
156 
157       // Mark this call as processed, so we don't try to analyze it as a
158       // PIC-computation internal call.
159       BC.MIB->addAnnotation(ReachingInst, getProcessedICTag(), 0U);
160 
161       // Connect this block with the returning block of the caller
162       BinaryBasicBlock *CallerBlock = Info.getInsnToBBMap()[&ReachingInst];
163       BinaryBasicBlock *ReturnDestBlock =
164           Function.getLayout().getBasicBlockAfter(CallerBlock);
165       BB.addSuccessor(ReturnDestBlock, BB.getExecutionCount(), 0);
166     }
167   };
168 
169   // This will connect blocks terminated with RETs to their respective
170   // internal caller return block. A note here: this is overly conservative
171   // because in nested calls, or unrelated calls, it will create edges
172   // connecting RETs to potentially unrelated internal calls. This is safe
173   // and if this causes a problem to recover the stack offsets properly, we
174   // will fail later.
175   for (BinaryBasicBlock &BB : Function) {
176     for (MCInst &Inst : BB) {
177       if (!BC.MIB->isReturn(Inst))
178         continue;
179 
180       processReturns(BB, Inst);
181     }
182   }
183   return Updated;
184 }
185 
186 bool ValidateInternalCalls::hasTailCallsInRange(
187     BinaryFunction &Function) const {
188   const BinaryContext &BC = Function.getBinaryContext();
189   for (BinaryBasicBlock &BB : Function)
190     for (MCInst &Inst : BB)
191       if (BC.MIB->isTailCall(Inst))
192         return true;
193   return false;
194 }
195 
196 bool ValidateInternalCalls::analyzeFunction(BinaryFunction &Function) const {
197   while (fixCFGForPIC(Function)) {
198   }
199   clearAnnotations(Function);
200   while (fixCFGForIC(Function)) {
201   }
202 
203   BinaryContext &BC = Function.getBinaryContext();
204   RegAnalysis RA = RegAnalysis(BC, nullptr, nullptr);
205   RA.setConservativeStrategy(RegAnalysis::ConservativeStrategy::CLOBBERS_NONE);
206   bool HasTailCalls = hasTailCallsInRange(Function);
207 
208   for (BinaryBasicBlock &BB : Function) {
209     for (MCInst &Inst : BB) {
210       BinaryBasicBlock *Target = getInternalCallTarget(Function, Inst);
211       if (!Target || BC.MIB->hasAnnotation(Inst, getProcessedICTag()))
212         continue;
213 
214       if (HasTailCalls) {
215         LLVM_DEBUG(dbgs() << Function
216                           << " has tail calls and internal calls.\n");
217         return false;
218       }
219 
220       FrameIndexEntry FIE;
221       int32_t SrcImm = 0;
222       MCPhysReg Reg = 0;
223       int64_t StackOffset = 0;
224       bool IsIndexed = false;
225       MCInst *TargetInst = ProgramPoint::getFirstPointAt(*Target).getInst();
226       if (!BC.MIB->isStackAccess(*TargetInst, FIE.IsLoad, FIE.IsStore,
227                                  FIE.IsStoreFromReg, Reg, SrcImm,
228                                  FIE.StackPtrReg, StackOffset, FIE.Size,
229                                  FIE.IsSimple, IsIndexed)) {
230         LLVM_DEBUG({
231           dbgs() << "Frame analysis failed - not simple: " << Function << "\n";
232           Function.dump();
233         });
234         return false;
235       }
236       if (!FIE.IsLoad || FIE.StackPtrReg != BC.MIB->getStackPointer() ||
237           StackOffset != 0) {
238         LLVM_DEBUG({
239           dbgs() << "Target instruction does not fetch return address - not "
240                     "simple: "
241                  << Function << "\n";
242           Function.dump();
243         });
244         return false;
245       }
246       // Now track how the return address is used by tracking uses of Reg
247       ReachingDefOrUse</*Def=*/false> RU =
248           ReachingDefOrUse<false>(RA, Function, Reg);
249       RU.run();
250 
251       int64_t Offset = static_cast<int64_t>(Target->getInputOffset());
252       bool UseDetected = false;
253       for (auto I = RU.expr_begin(*RU.getStateBefore(*TargetInst)),
254                 E = RU.expr_end();
255            I != E; ++I) {
256         MCInst &Use = **I;
257         BitVector UsedRegs = BitVector(BC.MRI->getNumRegs(), false);
258         BC.MIB->getTouchedRegs(Use, UsedRegs);
259         if (!UsedRegs[Reg])
260           continue;
261         UseDetected = true;
262         int64_t Output;
263         std::pair<MCPhysReg, int64_t> Input1 = std::make_pair(Reg, 0);
264         std::pair<MCPhysReg, int64_t> Input2 = std::make_pair(0, 0);
265         if (!BC.MIB->evaluateStackOffsetExpr(Use, Output, Input1, Input2)) {
266           LLVM_DEBUG(dbgs() << "Evaluate stack offset expr failed.\n");
267           return false;
268         }
269         if (Offset + Output < 0 ||
270             Offset + Output > static_cast<int64_t>(Function.getSize())) {
271           LLVM_DEBUG({
272             dbgs() << "Detected out-of-range PIC reference in " << Function
273                    << "\nReturn address load: ";
274             BC.InstPrinter->printInst(TargetInst, 0, "", *BC.STI, dbgs());
275             dbgs() << "\nUse: ";
276             BC.InstPrinter->printInst(&Use, 0, "", *BC.STI, dbgs());
277             dbgs() << "\n";
278             Function.dump();
279           });
280           return false;
281         }
282         LLVM_DEBUG({
283           dbgs() << "Validated access: ";
284           BC.InstPrinter->printInst(&Use, 0, "", *BC.STI, dbgs());
285           dbgs() << "\n";
286         });
287       }
288       if (!UseDetected) {
289         LLVM_DEBUG(dbgs() << "No use detected.\n");
290         return false;
291       }
292     }
293   }
294   return true;
295 }
296 
297 void ValidateInternalCalls::runOnFunctions(BinaryContext &BC) {
298   if (!BC.isX86())
299     return;
300 
301   // Look for functions that need validation. This should be pretty rare.
302   std::set<BinaryFunction *> NeedsValidation;
303   for (auto &BFI : BC.getBinaryFunctions()) {
304     BinaryFunction &Function = BFI.second;
305     for (BinaryBasicBlock &BB : Function) {
306       for (MCInst &Inst : BB) {
307         if (getInternalCallTarget(Function, Inst)) {
308           NeedsValidation.insert(&Function);
309           Function.setSimple(false);
310           break;
311         }
312       }
313     }
314   }
315 
316   // Skip validation for non-relocation mode
317   if (!BC.HasRelocations)
318     return;
319 
320   // Since few functions need validation, we can work with our most expensive
321   // algorithms here. Fix the CFG treating internal calls as unconditional
322   // jumps. This optimistically assumes this call is a PIC trick to get the PC
323   // value, so it is not really a call, but a jump. If we find that it's not the
324   // case, we mark this function as non-simple and stop processing it.
325   std::set<BinaryFunction *> Invalid;
326   for (BinaryFunction *Function : NeedsValidation) {
327     LLVM_DEBUG(dbgs() << "Validating " << *Function << "\n");
328     if (!analyzeFunction(*Function))
329       Invalid.insert(Function);
330     clearAnnotations(*Function);
331   }
332 
333   if (!Invalid.empty()) {
334     errs() << "BOLT-WARNING: will skip the following function(s) as unsupported"
335               " internal calls were detected:\n";
336     for (BinaryFunction *Function : Invalid) {
337       errs() << "              " << *Function << "\n";
338       Function->setIgnored();
339     }
340   }
341 }
342 
343 } // namespace bolt
344 } // namespace llvm
345