1 //===- bolt/Passes/RetpolineInsertion.cpp ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements RetpolineInsertion class, which replaces indirect
10 // branches (calls and jumps) with calls to retpolines to protect against branch
11 // target injection attacks.
12 // A unique retpoline is created for each register holding the address of the
13 // callee, if the callee address is in memory %r11 is used if available to
14 // hold the address of the callee before calling the retpoline, otherwise an
15 // address pattern specific retpoline is called where the callee address is
16 // loaded inside the retpoline.
17 // The user can determine when to assume %r11 available using r11-availability
18 // option, by default %r11 is assumed not available.
19 // Adding lfence instruction to the body of the speculate code is enabled by
20 // default and can be controlled by the user using retpoline-lfence option.
21 //
22 //===----------------------------------------------------------------------===//
23 
24 #include "bolt/Passes/RetpolineInsertion.h"
25 #include "llvm/Support/raw_ostream.h"
26 
27 #define DEBUG_TYPE "bolt-retpoline"
28 
29 using namespace llvm;
30 using namespace bolt;
31 namespace opts {
32 
33 extern cl::OptionCategory BoltCategory;
34 
35 llvm::cl::opt<bool>
36 InsertRetpolines("insert-retpolines",
37   cl::desc("run retpoline insertion pass"),
38   cl::init(false),
39   cl::ZeroOrMore,
40   cl::cat(BoltCategory));
41 
42 llvm::cl::opt<bool>
43 RetpolineLfence("retpoline-lfence",
44   cl::desc("determine if lfence instruction should exist in the retpoline"),
45   cl::init(true),
46   cl::ZeroOrMore,
47   cl::Hidden,
48   cl::cat(BoltCategory));
49 
50 cl::opt<RetpolineInsertion::AvailabilityOptions>
51 R11Availability("r11-availability",
52   cl::desc("determine the availablity of r11 before indirect branches"),
53   cl::init(RetpolineInsertion::AvailabilityOptions::NEVER),
54   cl::values(
55     clEnumValN(RetpolineInsertion::AvailabilityOptions::NEVER,
56       "never", "r11 not available"),
57     clEnumValN(RetpolineInsertion::AvailabilityOptions::ALWAYS,
58       "always", "r11 avaialable before calls and jumps"),
59     clEnumValN(RetpolineInsertion::AvailabilityOptions::ABI,
60       "abi", "r11 avaialable before calls but not before jumps")),
61   cl::ZeroOrMore,
62   cl::cat(BoltCategory));
63 
64 } // namespace opts
65 
66 namespace llvm {
67 namespace bolt {
68 
69 // Retpoline function structure:
70 // BB0: call BB2
71 // BB1: pause
72 //      lfence
73 //      jmp BB1
74 // BB2: mov %reg, (%rsp)
75 //      ret
76 // or
77 // BB2: push %r11
78 //      mov Address, %r11
79 //      mov %r11, 8(%rsp)
80 //      pop %r11
81 //      ret
82 BinaryFunction *createNewRetpoline(BinaryContext &BC,
83                                    const std::string &RetpolineTag,
84                                    const IndirectBranchInfo &BrInfo,
85                                    bool R11Available) {
86   auto &MIB = *BC.MIB;
87   MCContext &Ctx = *BC.Ctx.get();
88   LLVM_DEBUG(dbgs() << "BOLT-DEBUG: Creating a new retpoline function["
89                     << RetpolineTag << "]\n");
90 
91   BinaryFunction *NewRetpoline =
92       BC.createInjectedBinaryFunction(RetpolineTag, true);
93   std::vector<std::unique_ptr<BinaryBasicBlock>> NewBlocks(3);
94   for (int I = 0; I < 3; I++) {
95     MCSymbol *Symbol =
96         Ctx.createNamedTempSymbol(Twine(RetpolineTag + "_BB" + to_string(I)));
97     NewBlocks[I] = NewRetpoline->createBasicBlock(
98         BinaryBasicBlock::INVALID_OFFSET, Symbol);
99     NewBlocks[I].get()->setCFIState(0);
100   }
101 
102   BinaryBasicBlock &BB0 = *NewBlocks[0].get();
103   BinaryBasicBlock &BB1 = *NewBlocks[1].get();
104   BinaryBasicBlock &BB2 = *NewBlocks[2].get();
105 
106   BB0.addSuccessor(&BB2, 0, 0);
107   BB1.addSuccessor(&BB1, 0, 0);
108 
109   // Build BB0
110   MCInst DirectCall;
111   MIB.createDirectCall(DirectCall, BB2.getLabel(), &Ctx, /*IsTailCall*/ false);
112   BB0.addInstruction(DirectCall);
113 
114   // Build BB1
115   MCInst Pause;
116   MIB.createPause(Pause);
117   BB1.addInstruction(Pause);
118 
119   if (opts::RetpolineLfence) {
120     MCInst Lfence;
121     MIB.createLfence(Lfence);
122     BB1.addInstruction(Lfence);
123   }
124 
125   InstructionListType Seq;
126   MIB.createShortJmp(Seq, BB1.getLabel(), &Ctx);
127   BB1.addInstructions(Seq.begin(), Seq.end());
128 
129   // Build BB2
130   if (BrInfo.isMem()) {
131     if (R11Available) {
132       MCInst StoreToStack;
133       MIB.createSaveToStack(StoreToStack, MIB.getStackPointer(), 0,
134                             MIB.getX86R11(), 8);
135       BB2.addInstruction(StoreToStack);
136     } else {
137       MCInst PushR11;
138       MIB.createPushRegister(PushR11, MIB.getX86R11(), 8);
139       BB2.addInstruction(PushR11);
140 
141       MCInst LoadCalleeAddrs;
142       const IndirectBranchInfo::MemOpInfo &MemRef = BrInfo.Memory;
143       MIB.createLoad(LoadCalleeAddrs, MemRef.BaseRegNum, MemRef.ScaleValue,
144                      MemRef.IndexRegNum, MemRef.DispValue, MemRef.DispExpr,
145                      MemRef.SegRegNum, MIB.getX86R11(), 8);
146 
147       BB2.addInstruction(LoadCalleeAddrs);
148 
149       MCInst StoreToStack;
150       MIB.createSaveToStack(StoreToStack, MIB.getStackPointer(), 8,
151                             MIB.getX86R11(), 8);
152       BB2.addInstruction(StoreToStack);
153 
154       MCInst PopR11;
155       MIB.createPopRegister(PopR11, MIB.getX86R11(), 8);
156       BB2.addInstruction(PopR11);
157     }
158   } else if (BrInfo.isReg()) {
159     MCInst StoreToStack;
160     MIB.createSaveToStack(StoreToStack, MIB.getStackPointer(), 0,
161                           BrInfo.BranchReg, 8);
162     BB2.addInstruction(StoreToStack);
163   } else {
164     llvm_unreachable("not expected");
165   }
166 
167   // return
168   MCInst Return;
169   MIB.createReturn(Return);
170   BB2.addInstruction(Return);
171   NewRetpoline->insertBasicBlocks(nullptr, move(NewBlocks),
172                                   /* UpdateLayout */ true,
173                                   /* UpdateCFIState */ false);
174 
175   NewRetpoline->updateState(BinaryFunction::State::CFG_Finalized);
176   return NewRetpoline;
177 }
178 
179 std::string createRetpolineFunctionTag(BinaryContext &BC,
180                                        const IndirectBranchInfo &BrInfo,
181                                        bool R11Available) {
182   if (BrInfo.isReg())
183     return "__retpoline_r" + to_string(BrInfo.BranchReg) + "_";
184 
185   // Memory Branch
186   if (R11Available)
187     return "__retpoline_r11";
188 
189   std::string Tag = "__retpoline_mem_";
190 
191   const IndirectBranchInfo::MemOpInfo &MemRef = BrInfo.Memory;
192 
193   std::string DispExprStr;
194   if (MemRef.DispExpr) {
195     llvm::raw_string_ostream Ostream(DispExprStr);
196     MemRef.DispExpr->print(Ostream, BC.AsmInfo.get());
197     Ostream.flush();
198   }
199 
200   Tag += MemRef.BaseRegNum != BC.MIB->getNoRegister()
201              ? "r" + to_string(MemRef.BaseRegNum)
202              : "";
203 
204   Tag +=
205       MemRef.DispExpr ? "+" + DispExprStr : "+" + to_string(MemRef.DispValue);
206 
207   Tag += MemRef.IndexRegNum != BC.MIB->getNoRegister()
208              ? "+" + to_string(MemRef.ScaleValue) + "*" +
209                    to_string(MemRef.IndexRegNum)
210              : "";
211 
212   Tag += MemRef.SegRegNum != BC.MIB->getNoRegister()
213              ? "_seg_" + to_string(MemRef.SegRegNum)
214              : "";
215 
216   return Tag;
217 }
218 
219 BinaryFunction *RetpolineInsertion::getOrCreateRetpoline(
220     BinaryContext &BC, const IndirectBranchInfo &BrInfo, bool R11Available) {
221   const std::string RetpolineTag =
222       createRetpolineFunctionTag(BC, BrInfo, R11Available);
223 
224   if (CreatedRetpolines.count(RetpolineTag))
225     return CreatedRetpolines[RetpolineTag];
226 
227   return CreatedRetpolines[RetpolineTag] =
228              createNewRetpoline(BC, RetpolineTag, BrInfo, R11Available);
229 }
230 
231 void createBranchReplacement(BinaryContext &BC,
232                              const IndirectBranchInfo &BrInfo,
233                              bool R11Available,
234                              InstructionListType &Replacement,
235                              const MCSymbol *RetpolineSymbol) {
236   auto &MIB = *BC.MIB;
237   // Load the branch address in r11 if available
238   if (BrInfo.isMem() && R11Available) {
239     const IndirectBranchInfo::MemOpInfo &MemRef = BrInfo.Memory;
240     MCInst LoadCalleeAddrs;
241     MIB.createLoad(LoadCalleeAddrs, MemRef.BaseRegNum, MemRef.ScaleValue,
242                    MemRef.IndexRegNum, MemRef.DispValue, MemRef.DispExpr,
243                    MemRef.SegRegNum, MIB.getX86R11(), 8);
244     Replacement.push_back(LoadCalleeAddrs);
245   }
246 
247   // Call the retpoline
248   MCInst RetpolineCall;
249   MIB.createDirectCall(RetpolineCall, RetpolineSymbol, BC.Ctx.get(),
250                        BrInfo.isJump() || BrInfo.isTailCall());
251 
252   Replacement.push_back(RetpolineCall);
253 }
254 
255 IndirectBranchInfo::IndirectBranchInfo(MCInst &Inst, MCPlusBuilder &MIB) {
256   IsCall = MIB.isCall(Inst);
257   IsTailCall = MIB.isTailCall(Inst);
258 
259   if (MIB.isBranchOnMem(Inst)) {
260     IsMem = true;
261     if (!MIB.evaluateX86MemoryOperand(Inst, &Memory.BaseRegNum,
262                                       &Memory.ScaleValue,
263                                       &Memory.IndexRegNum, &Memory.DispValue,
264                                       &Memory.SegRegNum, &Memory.DispExpr))
265       llvm_unreachable("not expected");
266   } else if (MIB.isBranchOnReg(Inst)) {
267     assert(MCPlus::getNumPrimeOperands(Inst) == 1 && "expect 1 operand");
268     BranchReg = Inst.getOperand(0).getReg();
269   } else {
270     llvm_unreachable("unexpected instruction");
271   }
272 }
273 
274 void RetpolineInsertion::runOnFunctions(BinaryContext &BC) {
275   if (!opts::InsertRetpolines)
276     return;
277 
278   assert(BC.isX86() &&
279          "retpoline insertion not supported for target architecture");
280 
281   assert(BC.HasRelocations && "retpoline mode not supported in non-reloc");
282 
283   auto &MIB = *BC.MIB;
284   uint32_t RetpolinedBranches = 0;
285   for (auto &It : BC.getBinaryFunctions()) {
286     BinaryFunction &Function = It.second;
287     for (BinaryBasicBlock &BB : Function) {
288       for (auto It = BB.begin(); It != BB.end(); ++It) {
289         MCInst &Inst = *It;
290 
291         if (!MIB.isIndirectCall(Inst) && !MIB.isIndirectBranch(Inst))
292           continue;
293 
294         IndirectBranchInfo BrInfo(Inst, MIB);
295         bool R11Available = false;
296         BinaryFunction *TargetRetpoline;
297         InstructionListType Replacement;
298 
299         // Determine if r11 is available before this instruction
300         if (BrInfo.isMem()) {
301           if (MIB.hasAnnotation(Inst, "PLTCall"))
302             R11Available = true;
303           else if (opts::R11Availability == AvailabilityOptions::ALWAYS)
304             R11Available = true;
305           else if (opts::R11Availability == AvailabilityOptions::ABI)
306             R11Available = BrInfo.isCall();
307         }
308 
309         // If the instruction addressing pattern uses rsp and the retpoline
310         // loads the callee address then displacement needs to be updated
311         if (BrInfo.isMem() && !R11Available) {
312           IndirectBranchInfo::MemOpInfo &MemRef = BrInfo.Memory;
313           int Addend = (BrInfo.isJump() || BrInfo.isTailCall()) ? 8 : 16;
314           if (MemRef.BaseRegNum == MIB.getStackPointer())
315             MemRef.DispValue += Addend;
316           if (MemRef.IndexRegNum == MIB.getStackPointer())
317             MemRef.DispValue += Addend * MemRef.ScaleValue;
318         }
319 
320         TargetRetpoline = getOrCreateRetpoline(BC, BrInfo, R11Available);
321 
322         createBranchReplacement(BC, BrInfo, R11Available, Replacement,
323                                 TargetRetpoline->getSymbol());
324 
325         It = BB.replaceInstruction(It, Replacement.begin(), Replacement.end());
326         RetpolinedBranches++;
327       }
328     }
329   }
330   outs() << "BOLT-INFO: The number of created retpoline functions is : "
331          << CreatedRetpolines.size()
332          << "\nBOLT-INFO: The number of retpolined branches is : "
333          << RetpolinedBranches << "\n";
334 }
335 
336 } // namespace bolt
337 } // namespace llvm
338