1 //===- bolt/Passes/RetpolineInsertion.cpp ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements RetpolineInsertion class, which replaces indirect
10 // branches (calls and jumps) with calls to retpolines to protect against branch
11 // target injection attacks.
12 // A unique retpoline is created for each register holding the address of the
13 // callee, if the callee address is in memory %r11 is used if available to
14 // hold the address of the callee before calling the retpoline, otherwise an
15 // address pattern specific retpoline is called where the callee address is
16 // loaded inside the retpoline.
17 // The user can determine when to assume %r11 available using r11-availability
18 // option, by default %r11 is assumed not available.
19 // Adding lfence instruction to the body of the speculate code is enabled by
20 // default and can be controlled by the user using retpoline-lfence option.
21 //
22 //===----------------------------------------------------------------------===//
23 
24 #include "bolt/Passes/RetpolineInsertion.h"
25 #include "llvm/Support/raw_ostream.h"
26 
27 #define DEBUG_TYPE "bolt-retpoline"
28 
29 using namespace llvm;
30 using namespace bolt;
31 namespace opts {
32 
33 extern cl::OptionCategory BoltCategory;
34 
35 llvm::cl::opt<bool> InsertRetpolines("insert-retpolines",
36                                      cl::desc("run retpoline insertion pass"),
37                                      cl::cat(BoltCategory));
38 
39 llvm::cl::opt<bool>
40 RetpolineLfence("retpoline-lfence",
41   cl::desc("determine if lfence instruction should exist in the retpoline"),
42   cl::init(true),
43   cl::ZeroOrMore,
44   cl::Hidden,
45   cl::cat(BoltCategory));
46 
47 cl::opt<RetpolineInsertion::AvailabilityOptions>
48 R11Availability("r11-availability",
49   cl::desc("determine the availablity of r11 before indirect branches"),
50   cl::init(RetpolineInsertion::AvailabilityOptions::NEVER),
51   cl::values(
52     clEnumValN(RetpolineInsertion::AvailabilityOptions::NEVER,
53       "never", "r11 not available"),
54     clEnumValN(RetpolineInsertion::AvailabilityOptions::ALWAYS,
55       "always", "r11 avaialable before calls and jumps"),
56     clEnumValN(RetpolineInsertion::AvailabilityOptions::ABI,
57       "abi", "r11 avaialable before calls but not before jumps")),
58   cl::ZeroOrMore,
59   cl::cat(BoltCategory));
60 
61 } // namespace opts
62 
63 namespace llvm {
64 namespace bolt {
65 
66 // Retpoline function structure:
67 // BB0: call BB2
68 // BB1: pause
69 //      lfence
70 //      jmp BB1
71 // BB2: mov %reg, (%rsp)
72 //      ret
73 // or
74 // BB2: push %r11
75 //      mov Address, %r11
76 //      mov %r11, 8(%rsp)
77 //      pop %r11
78 //      ret
79 BinaryFunction *createNewRetpoline(BinaryContext &BC,
80                                    const std::string &RetpolineTag,
81                                    const IndirectBranchInfo &BrInfo,
82                                    bool R11Available) {
83   auto &MIB = *BC.MIB;
84   MCContext &Ctx = *BC.Ctx.get();
85   LLVM_DEBUG(dbgs() << "BOLT-DEBUG: Creating a new retpoline function["
86                     << RetpolineTag << "]\n");
87 
88   BinaryFunction *NewRetpoline =
89       BC.createInjectedBinaryFunction(RetpolineTag, true);
90   std::vector<std::unique_ptr<BinaryBasicBlock>> NewBlocks(3);
91   for (int I = 0; I < 3; I++) {
92     MCSymbol *Symbol =
93         Ctx.createNamedTempSymbol(Twine(RetpolineTag + "_BB" + to_string(I)));
94     NewBlocks[I] = NewRetpoline->createBasicBlock(
95         BinaryBasicBlock::INVALID_OFFSET, Symbol);
96     NewBlocks[I].get()->setCFIState(0);
97   }
98 
99   BinaryBasicBlock &BB0 = *NewBlocks[0].get();
100   BinaryBasicBlock &BB1 = *NewBlocks[1].get();
101   BinaryBasicBlock &BB2 = *NewBlocks[2].get();
102 
103   BB0.addSuccessor(&BB2, 0, 0);
104   BB1.addSuccessor(&BB1, 0, 0);
105 
106   // Build BB0
107   MCInst DirectCall;
108   MIB.createDirectCall(DirectCall, BB2.getLabel(), &Ctx, /*IsTailCall*/ false);
109   BB0.addInstruction(DirectCall);
110 
111   // Build BB1
112   MCInst Pause;
113   MIB.createPause(Pause);
114   BB1.addInstruction(Pause);
115 
116   if (opts::RetpolineLfence) {
117     MCInst Lfence;
118     MIB.createLfence(Lfence);
119     BB1.addInstruction(Lfence);
120   }
121 
122   InstructionListType Seq;
123   MIB.createShortJmp(Seq, BB1.getLabel(), &Ctx);
124   BB1.addInstructions(Seq.begin(), Seq.end());
125 
126   // Build BB2
127   if (BrInfo.isMem()) {
128     if (R11Available) {
129       MCInst StoreToStack;
130       MIB.createSaveToStack(StoreToStack, MIB.getStackPointer(), 0,
131                             MIB.getX86R11(), 8);
132       BB2.addInstruction(StoreToStack);
133     } else {
134       MCInst PushR11;
135       MIB.createPushRegister(PushR11, MIB.getX86R11(), 8);
136       BB2.addInstruction(PushR11);
137 
138       MCInst LoadCalleeAddrs;
139       const IndirectBranchInfo::MemOpInfo &MemRef = BrInfo.Memory;
140       MIB.createLoad(LoadCalleeAddrs, MemRef.BaseRegNum, MemRef.ScaleValue,
141                      MemRef.IndexRegNum, MemRef.DispValue, MemRef.DispExpr,
142                      MemRef.SegRegNum, MIB.getX86R11(), 8);
143 
144       BB2.addInstruction(LoadCalleeAddrs);
145 
146       MCInst StoreToStack;
147       MIB.createSaveToStack(StoreToStack, MIB.getStackPointer(), 8,
148                             MIB.getX86R11(), 8);
149       BB2.addInstruction(StoreToStack);
150 
151       MCInst PopR11;
152       MIB.createPopRegister(PopR11, MIB.getX86R11(), 8);
153       BB2.addInstruction(PopR11);
154     }
155   } else if (BrInfo.isReg()) {
156     MCInst StoreToStack;
157     MIB.createSaveToStack(StoreToStack, MIB.getStackPointer(), 0,
158                           BrInfo.BranchReg, 8);
159     BB2.addInstruction(StoreToStack);
160   } else {
161     llvm_unreachable("not expected");
162   }
163 
164   // return
165   MCInst Return;
166   MIB.createReturn(Return);
167   BB2.addInstruction(Return);
168   NewRetpoline->insertBasicBlocks(nullptr, std::move(NewBlocks),
169                                   /* UpdateLayout */ true,
170                                   /* UpdateCFIState */ false);
171 
172   NewRetpoline->updateState(BinaryFunction::State::CFG_Finalized);
173   return NewRetpoline;
174 }
175 
176 std::string createRetpolineFunctionTag(BinaryContext &BC,
177                                        const IndirectBranchInfo &BrInfo,
178                                        bool R11Available) {
179   if (BrInfo.isReg())
180     return "__retpoline_r" + to_string(BrInfo.BranchReg) + "_";
181 
182   // Memory Branch
183   if (R11Available)
184     return "__retpoline_r11";
185 
186   std::string Tag = "__retpoline_mem_";
187 
188   const IndirectBranchInfo::MemOpInfo &MemRef = BrInfo.Memory;
189 
190   std::string DispExprStr;
191   if (MemRef.DispExpr) {
192     llvm::raw_string_ostream Ostream(DispExprStr);
193     MemRef.DispExpr->print(Ostream, BC.AsmInfo.get());
194     Ostream.flush();
195   }
196 
197   Tag += MemRef.BaseRegNum != BC.MIB->getNoRegister()
198              ? "r" + to_string(MemRef.BaseRegNum)
199              : "";
200 
201   Tag +=
202       MemRef.DispExpr ? "+" + DispExprStr : "+" + to_string(MemRef.DispValue);
203 
204   Tag += MemRef.IndexRegNum != BC.MIB->getNoRegister()
205              ? "+" + to_string(MemRef.ScaleValue) + "*" +
206                    to_string(MemRef.IndexRegNum)
207              : "";
208 
209   Tag += MemRef.SegRegNum != BC.MIB->getNoRegister()
210              ? "_seg_" + to_string(MemRef.SegRegNum)
211              : "";
212 
213   return Tag;
214 }
215 
216 BinaryFunction *RetpolineInsertion::getOrCreateRetpoline(
217     BinaryContext &BC, const IndirectBranchInfo &BrInfo, bool R11Available) {
218   const std::string RetpolineTag =
219       createRetpolineFunctionTag(BC, BrInfo, R11Available);
220 
221   if (CreatedRetpolines.count(RetpolineTag))
222     return CreatedRetpolines[RetpolineTag];
223 
224   return CreatedRetpolines[RetpolineTag] =
225              createNewRetpoline(BC, RetpolineTag, BrInfo, R11Available);
226 }
227 
228 void createBranchReplacement(BinaryContext &BC,
229                              const IndirectBranchInfo &BrInfo,
230                              bool R11Available,
231                              InstructionListType &Replacement,
232                              const MCSymbol *RetpolineSymbol) {
233   auto &MIB = *BC.MIB;
234   // Load the branch address in r11 if available
235   if (BrInfo.isMem() && R11Available) {
236     const IndirectBranchInfo::MemOpInfo &MemRef = BrInfo.Memory;
237     MCInst LoadCalleeAddrs;
238     MIB.createLoad(LoadCalleeAddrs, MemRef.BaseRegNum, MemRef.ScaleValue,
239                    MemRef.IndexRegNum, MemRef.DispValue, MemRef.DispExpr,
240                    MemRef.SegRegNum, MIB.getX86R11(), 8);
241     Replacement.push_back(LoadCalleeAddrs);
242   }
243 
244   // Call the retpoline
245   MCInst RetpolineCall;
246   MIB.createDirectCall(RetpolineCall, RetpolineSymbol, BC.Ctx.get(),
247                        BrInfo.isJump() || BrInfo.isTailCall());
248 
249   Replacement.push_back(RetpolineCall);
250 }
251 
252 IndirectBranchInfo::IndirectBranchInfo(MCInst &Inst, MCPlusBuilder &MIB) {
253   IsCall = MIB.isCall(Inst);
254   IsTailCall = MIB.isTailCall(Inst);
255 
256   if (MIB.isBranchOnMem(Inst)) {
257     IsMem = true;
258     if (!MIB.evaluateX86MemoryOperand(Inst, &Memory.BaseRegNum,
259                                       &Memory.ScaleValue,
260                                       &Memory.IndexRegNum, &Memory.DispValue,
261                                       &Memory.SegRegNum, &Memory.DispExpr))
262       llvm_unreachable("not expected");
263   } else if (MIB.isBranchOnReg(Inst)) {
264     assert(MCPlus::getNumPrimeOperands(Inst) == 1 && "expect 1 operand");
265     BranchReg = Inst.getOperand(0).getReg();
266   } else {
267     llvm_unreachable("unexpected instruction");
268   }
269 }
270 
271 void RetpolineInsertion::runOnFunctions(BinaryContext &BC) {
272   if (!opts::InsertRetpolines)
273     return;
274 
275   assert(BC.isX86() &&
276          "retpoline insertion not supported for target architecture");
277 
278   assert(BC.HasRelocations && "retpoline mode not supported in non-reloc");
279 
280   auto &MIB = *BC.MIB;
281   uint32_t RetpolinedBranches = 0;
282   for (auto &It : BC.getBinaryFunctions()) {
283     BinaryFunction &Function = It.second;
284     for (BinaryBasicBlock &BB : Function) {
285       for (auto It = BB.begin(); It != BB.end(); ++It) {
286         MCInst &Inst = *It;
287 
288         if (!MIB.isIndirectCall(Inst) && !MIB.isIndirectBranch(Inst))
289           continue;
290 
291         IndirectBranchInfo BrInfo(Inst, MIB);
292         bool R11Available = false;
293         BinaryFunction *TargetRetpoline;
294         InstructionListType Replacement;
295 
296         // Determine if r11 is available before this instruction
297         if (BrInfo.isMem()) {
298           if (MIB.hasAnnotation(Inst, "PLTCall"))
299             R11Available = true;
300           else if (opts::R11Availability == AvailabilityOptions::ALWAYS)
301             R11Available = true;
302           else if (opts::R11Availability == AvailabilityOptions::ABI)
303             R11Available = BrInfo.isCall();
304         }
305 
306         // If the instruction addressing pattern uses rsp and the retpoline
307         // loads the callee address then displacement needs to be updated
308         if (BrInfo.isMem() && !R11Available) {
309           IndirectBranchInfo::MemOpInfo &MemRef = BrInfo.Memory;
310           int Addend = (BrInfo.isJump() || BrInfo.isTailCall()) ? 8 : 16;
311           if (MemRef.BaseRegNum == MIB.getStackPointer())
312             MemRef.DispValue += Addend;
313           if (MemRef.IndexRegNum == MIB.getStackPointer())
314             MemRef.DispValue += Addend * MemRef.ScaleValue;
315         }
316 
317         TargetRetpoline = getOrCreateRetpoline(BC, BrInfo, R11Available);
318 
319         createBranchReplacement(BC, BrInfo, R11Available, Replacement,
320                                 TargetRetpoline->getSymbol());
321 
322         It = BB.replaceInstruction(It, Replacement.begin(), Replacement.end());
323         RetpolinedBranches++;
324       }
325     }
326   }
327   outs() << "BOLT-INFO: The number of created retpoline functions is : "
328          << CreatedRetpolines.size()
329          << "\nBOLT-INFO: The number of retpolined branches is : "
330          << RetpolinedBranches << "\n";
331 }
332 
333 } // namespace bolt
334 } // namespace llvm
335