1 //===- bolt/Passes/RetpolineInsertion.cpp ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements RetpolineInsertion class, which replaces indirect
10 // branches (calls and jumps) with calls to retpolines to protect against branch
11 // target injection attacks.
12 // A unique retpoline is created for each register holding the address of the
13 // callee, if the callee address is in memory %r11 is used if available to
14 // hold the address of the callee before calling the retpoline, otherwise an
15 // address pattern specific retpoline is called where the callee address is
16 // loaded inside the retpoline.
17 // The user can determine when to assume %r11 available using r11-availability
18 // option, by default %r11 is assumed not available.
19 // Adding lfence instruction to the body of the speculate code is enabled by
20 // default and can be controlled by the user using retpoline-lfence option.
21 //
22 //===----------------------------------------------------------------------===//
23 
24 #include "bolt/Passes/RetpolineInsertion.h"
25 #include "llvm/Support/raw_ostream.h"
26 
27 #define DEBUG_TYPE "bolt-retpoline"
28 
29 using namespace llvm;
30 using namespace bolt;
31 namespace opts {
32 
33 extern cl::OptionCategory BoltCategory;
34 
35 llvm::cl::opt<bool> InsertRetpolines("insert-retpolines",
36                                      cl::desc("run retpoline insertion pass"),
37                                      cl::cat(BoltCategory));
38 
39 llvm::cl::opt<bool>
40 RetpolineLfence("retpoline-lfence",
41   cl::desc("determine if lfence instruction should exist in the retpoline"),
42   cl::init(true),
43   cl::ZeroOrMore,
44   cl::Hidden,
45   cl::cat(BoltCategory));
46 
47 cl::opt<RetpolineInsertion::AvailabilityOptions>
48 R11Availability("r11-availability",
49   cl::desc("determine the availablity of r11 before indirect branches"),
50   cl::init(RetpolineInsertion::AvailabilityOptions::NEVER),
51   cl::values(
52     clEnumValN(RetpolineInsertion::AvailabilityOptions::NEVER,
53       "never", "r11 not available"),
54     clEnumValN(RetpolineInsertion::AvailabilityOptions::ALWAYS,
55       "always", "r11 avaialable before calls and jumps"),
56     clEnumValN(RetpolineInsertion::AvailabilityOptions::ABI,
57       "abi", "r11 avaialable before calls but not before jumps")),
58   cl::ZeroOrMore,
59   cl::cat(BoltCategory));
60 
61 } // namespace opts
62 
63 namespace llvm {
64 namespace bolt {
65 
66 // Retpoline function structure:
67 // BB0: call BB2
68 // BB1: pause
69 //      lfence
70 //      jmp BB1
71 // BB2: mov %reg, (%rsp)
72 //      ret
73 // or
74 // BB2: push %r11
75 //      mov Address, %r11
76 //      mov %r11, 8(%rsp)
77 //      pop %r11
78 //      ret
createNewRetpoline(BinaryContext & BC,const std::string & RetpolineTag,const IndirectBranchInfo & BrInfo,bool R11Available)79 BinaryFunction *createNewRetpoline(BinaryContext &BC,
80                                    const std::string &RetpolineTag,
81                                    const IndirectBranchInfo &BrInfo,
82                                    bool R11Available) {
83   auto &MIB = *BC.MIB;
84   MCContext &Ctx = *BC.Ctx.get();
85   LLVM_DEBUG(dbgs() << "BOLT-DEBUG: Creating a new retpoline function["
86                     << RetpolineTag << "]\n");
87 
88   BinaryFunction *NewRetpoline =
89       BC.createInjectedBinaryFunction(RetpolineTag, true);
90   std::vector<std::unique_ptr<BinaryBasicBlock>> NewBlocks(3);
91   for (int I = 0; I < 3; I++) {
92     MCSymbol *Symbol =
93         Ctx.createNamedTempSymbol(Twine(RetpolineTag + "_BB" + to_string(I)));
94     NewBlocks[I] = NewRetpoline->createBasicBlock(Symbol);
95     NewBlocks[I].get()->setCFIState(0);
96   }
97 
98   BinaryBasicBlock &BB0 = *NewBlocks[0].get();
99   BinaryBasicBlock &BB1 = *NewBlocks[1].get();
100   BinaryBasicBlock &BB2 = *NewBlocks[2].get();
101 
102   BB0.addSuccessor(&BB2, 0, 0);
103   BB1.addSuccessor(&BB1, 0, 0);
104 
105   // Build BB0
106   MCInst DirectCall;
107   MIB.createDirectCall(DirectCall, BB2.getLabel(), &Ctx, /*IsTailCall*/ false);
108   BB0.addInstruction(DirectCall);
109 
110   // Build BB1
111   MCInst Pause;
112   MIB.createPause(Pause);
113   BB1.addInstruction(Pause);
114 
115   if (opts::RetpolineLfence) {
116     MCInst Lfence;
117     MIB.createLfence(Lfence);
118     BB1.addInstruction(Lfence);
119   }
120 
121   InstructionListType Seq;
122   MIB.createShortJmp(Seq, BB1.getLabel(), &Ctx);
123   BB1.addInstructions(Seq.begin(), Seq.end());
124 
125   // Build BB2
126   if (BrInfo.isMem()) {
127     if (R11Available) {
128       MCInst StoreToStack;
129       MIB.createSaveToStack(StoreToStack, MIB.getStackPointer(), 0,
130                             MIB.getX86R11(), 8);
131       BB2.addInstruction(StoreToStack);
132     } else {
133       MCInst PushR11;
134       MIB.createPushRegister(PushR11, MIB.getX86R11(), 8);
135       BB2.addInstruction(PushR11);
136 
137       MCInst LoadCalleeAddrs;
138       const IndirectBranchInfo::MemOpInfo &MemRef = BrInfo.Memory;
139       MIB.createLoad(LoadCalleeAddrs, MemRef.BaseRegNum, MemRef.ScaleValue,
140                      MemRef.IndexRegNum, MemRef.DispValue, MemRef.DispExpr,
141                      MemRef.SegRegNum, MIB.getX86R11(), 8);
142 
143       BB2.addInstruction(LoadCalleeAddrs);
144 
145       MCInst StoreToStack;
146       MIB.createSaveToStack(StoreToStack, MIB.getStackPointer(), 8,
147                             MIB.getX86R11(), 8);
148       BB2.addInstruction(StoreToStack);
149 
150       MCInst PopR11;
151       MIB.createPopRegister(PopR11, MIB.getX86R11(), 8);
152       BB2.addInstruction(PopR11);
153     }
154   } else if (BrInfo.isReg()) {
155     MCInst StoreToStack;
156     MIB.createSaveToStack(StoreToStack, MIB.getStackPointer(), 0,
157                           BrInfo.BranchReg, 8);
158     BB2.addInstruction(StoreToStack);
159   } else {
160     llvm_unreachable("not expected");
161   }
162 
163   // return
164   MCInst Return;
165   MIB.createReturn(Return);
166   BB2.addInstruction(Return);
167   NewRetpoline->insertBasicBlocks(nullptr, std::move(NewBlocks),
168                                   /* UpdateLayout */ true,
169                                   /* UpdateCFIState */ false);
170 
171   NewRetpoline->updateState(BinaryFunction::State::CFG_Finalized);
172   return NewRetpoline;
173 }
174 
createRetpolineFunctionTag(BinaryContext & BC,const IndirectBranchInfo & BrInfo,bool R11Available)175 std::string createRetpolineFunctionTag(BinaryContext &BC,
176                                        const IndirectBranchInfo &BrInfo,
177                                        bool R11Available) {
178   if (BrInfo.isReg())
179     return "__retpoline_r" + to_string(BrInfo.BranchReg) + "_";
180 
181   // Memory Branch
182   if (R11Available)
183     return "__retpoline_r11";
184 
185   std::string Tag = "__retpoline_mem_";
186 
187   const IndirectBranchInfo::MemOpInfo &MemRef = BrInfo.Memory;
188 
189   std::string DispExprStr;
190   if (MemRef.DispExpr) {
191     llvm::raw_string_ostream Ostream(DispExprStr);
192     MemRef.DispExpr->print(Ostream, BC.AsmInfo.get());
193     Ostream.flush();
194   }
195 
196   Tag += MemRef.BaseRegNum != BC.MIB->getNoRegister()
197              ? "r" + to_string(MemRef.BaseRegNum)
198              : "";
199 
200   Tag +=
201       MemRef.DispExpr ? "+" + DispExprStr : "+" + to_string(MemRef.DispValue);
202 
203   Tag += MemRef.IndexRegNum != BC.MIB->getNoRegister()
204              ? "+" + to_string(MemRef.ScaleValue) + "*" +
205                    to_string(MemRef.IndexRegNum)
206              : "";
207 
208   Tag += MemRef.SegRegNum != BC.MIB->getNoRegister()
209              ? "_seg_" + to_string(MemRef.SegRegNum)
210              : "";
211 
212   return Tag;
213 }
214 
getOrCreateRetpoline(BinaryContext & BC,const IndirectBranchInfo & BrInfo,bool R11Available)215 BinaryFunction *RetpolineInsertion::getOrCreateRetpoline(
216     BinaryContext &BC, const IndirectBranchInfo &BrInfo, bool R11Available) {
217   const std::string RetpolineTag =
218       createRetpolineFunctionTag(BC, BrInfo, R11Available);
219 
220   if (CreatedRetpolines.count(RetpolineTag))
221     return CreatedRetpolines[RetpolineTag];
222 
223   return CreatedRetpolines[RetpolineTag] =
224              createNewRetpoline(BC, RetpolineTag, BrInfo, R11Available);
225 }
226 
createBranchReplacement(BinaryContext & BC,const IndirectBranchInfo & BrInfo,bool R11Available,InstructionListType & Replacement,const MCSymbol * RetpolineSymbol)227 void createBranchReplacement(BinaryContext &BC,
228                              const IndirectBranchInfo &BrInfo,
229                              bool R11Available,
230                              InstructionListType &Replacement,
231                              const MCSymbol *RetpolineSymbol) {
232   auto &MIB = *BC.MIB;
233   // Load the branch address in r11 if available
234   if (BrInfo.isMem() && R11Available) {
235     const IndirectBranchInfo::MemOpInfo &MemRef = BrInfo.Memory;
236     MCInst LoadCalleeAddrs;
237     MIB.createLoad(LoadCalleeAddrs, MemRef.BaseRegNum, MemRef.ScaleValue,
238                    MemRef.IndexRegNum, MemRef.DispValue, MemRef.DispExpr,
239                    MemRef.SegRegNum, MIB.getX86R11(), 8);
240     Replacement.push_back(LoadCalleeAddrs);
241   }
242 
243   // Call the retpoline
244   MCInst RetpolineCall;
245   MIB.createDirectCall(RetpolineCall, RetpolineSymbol, BC.Ctx.get(),
246                        BrInfo.isJump() || BrInfo.isTailCall());
247 
248   Replacement.push_back(RetpolineCall);
249 }
250 
IndirectBranchInfo(MCInst & Inst,MCPlusBuilder & MIB)251 IndirectBranchInfo::IndirectBranchInfo(MCInst &Inst, MCPlusBuilder &MIB) {
252   IsCall = MIB.isCall(Inst);
253   IsTailCall = MIB.isTailCall(Inst);
254 
255   if (MIB.isBranchOnMem(Inst)) {
256     IsMem = true;
257     if (!MIB.evaluateX86MemoryOperand(Inst, &Memory.BaseRegNum,
258                                       &Memory.ScaleValue,
259                                       &Memory.IndexRegNum, &Memory.DispValue,
260                                       &Memory.SegRegNum, &Memory.DispExpr))
261       llvm_unreachable("not expected");
262   } else if (MIB.isBranchOnReg(Inst)) {
263     assert(MCPlus::getNumPrimeOperands(Inst) == 1 && "expect 1 operand");
264     BranchReg = Inst.getOperand(0).getReg();
265   } else {
266     llvm_unreachable("unexpected instruction");
267   }
268 }
269 
runOnFunctions(BinaryContext & BC)270 void RetpolineInsertion::runOnFunctions(BinaryContext &BC) {
271   if (!opts::InsertRetpolines)
272     return;
273 
274   assert(BC.isX86() &&
275          "retpoline insertion not supported for target architecture");
276 
277   assert(BC.HasRelocations && "retpoline mode not supported in non-reloc");
278 
279   auto &MIB = *BC.MIB;
280   uint32_t RetpolinedBranches = 0;
281   for (auto &It : BC.getBinaryFunctions()) {
282     BinaryFunction &Function = It.second;
283     for (BinaryBasicBlock &BB : Function) {
284       for (auto It = BB.begin(); It != BB.end(); ++It) {
285         MCInst &Inst = *It;
286 
287         if (!MIB.isIndirectCall(Inst) && !MIB.isIndirectBranch(Inst))
288           continue;
289 
290         IndirectBranchInfo BrInfo(Inst, MIB);
291         bool R11Available = false;
292         BinaryFunction *TargetRetpoline;
293         InstructionListType Replacement;
294 
295         // Determine if r11 is available before this instruction
296         if (BrInfo.isMem()) {
297           if (MIB.hasAnnotation(Inst, "PLTCall"))
298             R11Available = true;
299           else if (opts::R11Availability == AvailabilityOptions::ALWAYS)
300             R11Available = true;
301           else if (opts::R11Availability == AvailabilityOptions::ABI)
302             R11Available = BrInfo.isCall();
303         }
304 
305         // If the instruction addressing pattern uses rsp and the retpoline
306         // loads the callee address then displacement needs to be updated
307         if (BrInfo.isMem() && !R11Available) {
308           IndirectBranchInfo::MemOpInfo &MemRef = BrInfo.Memory;
309           int Addend = (BrInfo.isJump() || BrInfo.isTailCall()) ? 8 : 16;
310           if (MemRef.BaseRegNum == MIB.getStackPointer())
311             MemRef.DispValue += Addend;
312           if (MemRef.IndexRegNum == MIB.getStackPointer())
313             MemRef.DispValue += Addend * MemRef.ScaleValue;
314         }
315 
316         TargetRetpoline = getOrCreateRetpoline(BC, BrInfo, R11Available);
317 
318         createBranchReplacement(BC, BrInfo, R11Available, Replacement,
319                                 TargetRetpoline->getSymbol());
320 
321         It = BB.replaceInstruction(It, Replacement.begin(), Replacement.end());
322         RetpolinedBranches++;
323       }
324     }
325   }
326   outs() << "BOLT-INFO: The number of created retpoline functions is : "
327          << CreatedRetpolines.size()
328          << "\nBOLT-INFO: The number of retpolined branches is : "
329          << RetpolinedBranches << "\n";
330 }
331 
332 } // namespace bolt
333 } // namespace llvm
334