1 //===-- ImplicitNullChecks.cpp - Fold null checks into memory accesses ----===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass turns explicit null checks of the form
11 //
12 //   test %r10, %r10
13 //   je throw_npe
14 //   movl (%r10), %esi
15 //   ...
16 //
17 // to
18 //
19 //   faulting_load_op("movl (%r10), %esi", throw_npe)
20 //   ...
21 //
22 // With the help of a runtime that understands the .fault_maps section,
23 // faulting_load_op branches to throw_npe if executing movl (%r10), %esi incurs
24 // a page fault.
25 //
26 //===----------------------------------------------------------------------===//
27 
28 #include "llvm/ADT/DenseSet.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/ADT/Statistic.h"
31 #include "llvm/CodeGen/Passes.h"
32 #include "llvm/CodeGen/MachineFunction.h"
33 #include "llvm/CodeGen/MachineMemOperand.h"
34 #include "llvm/CodeGen/MachineOperand.h"
35 #include "llvm/CodeGen/MachineFunctionPass.h"
36 #include "llvm/CodeGen/MachineInstrBuilder.h"
37 #include "llvm/CodeGen/MachineRegisterInfo.h"
38 #include "llvm/CodeGen/MachineModuleInfo.h"
39 #include "llvm/IR/BasicBlock.h"
40 #include "llvm/IR/Instruction.h"
41 #include "llvm/IR/LLVMContext.h"
42 #include "llvm/Support/CommandLine.h"
43 #include "llvm/Support/Debug.h"
44 #include "llvm/Target/TargetSubtargetInfo.h"
45 #include "llvm/Target/TargetInstrInfo.h"
46 
47 using namespace llvm;
48 
49 static cl::opt<int> PageSize("imp-null-check-page-size",
50                              cl::desc("The page size of the target in bytes"),
51                              cl::init(4096));
52 
53 #define DEBUG_TYPE "implicit-null-checks"
54 
55 STATISTIC(NumImplicitNullChecks,
56           "Number of explicit null checks made implicit");
57 
58 namespace {
59 
60 class ImplicitNullChecks : public MachineFunctionPass {
61   /// Represents one null check that can be made implicit.
62   class NullCheck {
63     // The memory operation the null check can be folded into.
64     MachineInstr *MemOperation;
65 
66     // The instruction actually doing the null check (Ptr != 0).
67     MachineInstr *CheckOperation;
68 
69     // The block the check resides in.
70     MachineBasicBlock *CheckBlock;
71 
72     // The block branched to if the pointer is non-null.
73     MachineBasicBlock *NotNullSucc;
74 
75     // The block branched to if the pointer is null.
76     MachineBasicBlock *NullSucc;
77 
78   public:
79     explicit NullCheck(MachineInstr *memOperation, MachineInstr *checkOperation,
80                        MachineBasicBlock *checkBlock,
81                        MachineBasicBlock *notNullSucc,
82                        MachineBasicBlock *nullSucc)
83         : MemOperation(memOperation), CheckOperation(checkOperation),
84           CheckBlock(checkBlock), NotNullSucc(notNullSucc), NullSucc(nullSucc) {
85     }
86 
87     MachineInstr *getMemOperation() const { return MemOperation; }
88 
89     MachineInstr *getCheckOperation() const { return CheckOperation; }
90 
91     MachineBasicBlock *getCheckBlock() const { return CheckBlock; }
92 
93     MachineBasicBlock *getNotNullSucc() const { return NotNullSucc; }
94 
95     MachineBasicBlock *getNullSucc() const { return NullSucc; }
96   };
97 
98   const TargetInstrInfo *TII = nullptr;
99   const TargetRegisterInfo *TRI = nullptr;
100   MachineModuleInfo *MMI = nullptr;
101 
102   bool analyzeBlockForNullChecks(MachineBasicBlock &MBB,
103                                  SmallVectorImpl<NullCheck> &NullCheckList);
104   MachineInstr *insertFaultingLoad(MachineInstr *LoadMI, MachineBasicBlock *MBB,
105                                    MachineBasicBlock *HandlerMBB);
106   void rewriteNullChecks(ArrayRef<NullCheck> NullCheckList);
107 
108 public:
109   static char ID;
110 
111   ImplicitNullChecks() : MachineFunctionPass(ID) {
112     initializeImplicitNullChecksPass(*PassRegistry::getPassRegistry());
113   }
114 
115   bool runOnMachineFunction(MachineFunction &MF) override;
116 
117   MachineFunctionProperties getRequiredProperties() const override {
118     return MachineFunctionProperties().set(
119         MachineFunctionProperties::Property::AllVRegsAllocated);
120   }
121 };
122 
123 /// \brief Detect re-ordering hazards and dependencies.
124 ///
125 /// This class keeps track of defs and uses, and can be queried if a given
126 /// machine instruction can be re-ordered from after the machine instructions
127 /// seen so far to before them.
128 class HazardDetector {
129   DenseSet<unsigned> RegDefs;
130   DenseSet<unsigned> RegUses;
131   const TargetRegisterInfo &TRI;
132   bool hasSeenClobber;
133 
134 public:
135   explicit HazardDetector(const TargetRegisterInfo &TRI) :
136     TRI(TRI), hasSeenClobber(false) {}
137 
138   /// \brief Make a note of \p MI for later queries to isSafeToHoist.
139   ///
140   /// May clobber this HazardDetector instance.  \see isClobbered.
141   void rememberInstruction(MachineInstr *MI);
142 
143   /// \brief Return true if it is safe to hoist \p MI from after all the
144   /// instructions seen so far (via rememberInstruction) to before it.
145   bool isSafeToHoist(MachineInstr *MI);
146 
147   /// \brief Return true if this instance of HazardDetector has been clobbered
148   /// (i.e. has no more useful information).
149   ///
150   /// A HazardDetecter is clobbered when it sees a construct it cannot
151   /// understand, and it would have to return a conservative answer for all
152   /// future queries.  Having a separate clobbered state lets the client code
153   /// bail early, without making queries about all of the future instructions
154   /// (which would have returned the most conservative answer anyway).
155   ///
156   /// Calling rememberInstruction or isSafeToHoist on a clobbered HazardDetector
157   /// is an error.
158   bool isClobbered() { return hasSeenClobber; }
159 };
160 }
161 
162 
163 void HazardDetector::rememberInstruction(MachineInstr *MI) {
164   assert(!isClobbered() &&
165          "Don't add instructions to a clobbered hazard detector");
166 
167   if (MI->mayStore() || MI->hasUnmodeledSideEffects()) {
168     hasSeenClobber = true;
169     return;
170   }
171 
172   for (auto *MMO : MI->memoperands()) {
173     // Right now we don't want to worry about LLVM's memory model.
174     if (!MMO->isUnordered()) {
175       hasSeenClobber = true;
176       return;
177     }
178   }
179 
180   for (auto &MO : MI->operands()) {
181     if (!MO.isReg() || !MO.getReg())
182       continue;
183 
184     if (MO.isDef())
185       RegDefs.insert(MO.getReg());
186     else
187       RegUses.insert(MO.getReg());
188   }
189 }
190 
191 bool HazardDetector::isSafeToHoist(MachineInstr *MI) {
192   assert(!isClobbered() && "isSafeToHoist cannot do anything useful!");
193 
194   // Right now we don't want to worry about LLVM's memory model.  This can be
195   // made more precise later.
196   for (auto *MMO : MI->memoperands())
197     if (!MMO->isUnordered())
198       return false;
199 
200   for (auto &MO : MI->operands()) {
201     if (MO.isReg() && MO.getReg()) {
202       for (unsigned Reg : RegDefs)
203         if (TRI.regsOverlap(Reg, MO.getReg()))
204           return false;  // We found a write-after-write or read-after-write
205 
206       if (MO.isDef())
207         for (unsigned Reg : RegUses)
208           if (TRI.regsOverlap(Reg, MO.getReg()))
209             return false;  // We found a write-after-read
210     }
211   }
212 
213   return true;
214 }
215 
216 bool ImplicitNullChecks::runOnMachineFunction(MachineFunction &MF) {
217   TII = MF.getSubtarget().getInstrInfo();
218   TRI = MF.getRegInfo().getTargetRegisterInfo();
219   MMI = &MF.getMMI();
220 
221   SmallVector<NullCheck, 16> NullCheckList;
222 
223   for (auto &MBB : MF)
224     analyzeBlockForNullChecks(MBB, NullCheckList);
225 
226   if (!NullCheckList.empty())
227     rewriteNullChecks(NullCheckList);
228 
229   return !NullCheckList.empty();
230 }
231 
232 /// Analyze MBB to check if its terminating branch can be turned into an
233 /// implicit null check.  If yes, append a description of the said null check to
234 /// NullCheckList and return true, else return false.
235 bool ImplicitNullChecks::analyzeBlockForNullChecks(
236     MachineBasicBlock &MBB, SmallVectorImpl<NullCheck> &NullCheckList) {
237   typedef TargetInstrInfo::MachineBranchPredicate MachineBranchPredicate;
238 
239   MDNode *BranchMD = nullptr;
240   if (auto *BB = MBB.getBasicBlock())
241     BranchMD = BB->getTerminator()->getMetadata(LLVMContext::MD_make_implicit);
242 
243   if (!BranchMD)
244     return false;
245 
246   MachineBranchPredicate MBP;
247 
248   if (TII->AnalyzeBranchPredicate(MBB, MBP, true))
249     return false;
250 
251   // Is the predicate comparing an integer to zero?
252   if (!(MBP.LHS.isReg() && MBP.RHS.isImm() && MBP.RHS.getImm() == 0 &&
253         (MBP.Predicate == MachineBranchPredicate::PRED_NE ||
254          MBP.Predicate == MachineBranchPredicate::PRED_EQ)))
255     return false;
256 
257   // If we cannot erase the test instruction itself, then making the null check
258   // implicit does not buy us much.
259   if (!MBP.SingleUseCondition)
260     return false;
261 
262   MachineBasicBlock *NotNullSucc, *NullSucc;
263 
264   if (MBP.Predicate == MachineBranchPredicate::PRED_NE) {
265     NotNullSucc = MBP.TrueDest;
266     NullSucc = MBP.FalseDest;
267   } else {
268     NotNullSucc = MBP.FalseDest;
269     NullSucc = MBP.TrueDest;
270   }
271 
272   // We handle the simplest case for now.  We can potentially do better by using
273   // the machine dominator tree.
274   if (NotNullSucc->pred_size() != 1)
275     return false;
276 
277   // Starting with a code fragment like:
278   //
279   //   test %RAX, %RAX
280   //   jne LblNotNull
281   //
282   //  LblNull:
283   //   callq throw_NullPointerException
284   //
285   //  LblNotNull:
286   //   Inst0
287   //   Inst1
288   //   ...
289   //   Def = Load (%RAX + <offset>)
290   //   ...
291   //
292   //
293   // we want to end up with
294   //
295   //   Def = FaultingLoad (%RAX + <offset>), LblNull
296   //   jmp LblNotNull ;; explicit or fallthrough
297   //
298   //  LblNotNull:
299   //   Inst0
300   //   Inst1
301   //   ...
302   //
303   //  LblNull:
304   //   callq throw_NullPointerException
305   //
306   //
307   // To see why this is legal, consider the two possibilities:
308   //
309   //  1. %RAX is null: since we constrain <offset> to be less than PageSize, the
310   //     load instruction dereferences the null page, causing a segmentation
311   //     fault.
312   //
313   //  2. %RAX is not null: in this case we know that the load cannot fault, as
314   //     otherwise the load would've faulted in the original program too and the
315   //     original program would've been undefined.
316   //
317   // This reasoning cannot be extended to justify hoisting through arbitrary
318   // control flow.  For instance, in the example below (in pseudo-C)
319   //
320   //    if (ptr == null) { throw_npe(); unreachable; }
321   //    if (some_cond) { return 42; }
322   //    v = ptr->field;  // LD
323   //    ...
324   //
325   // we cannot (without code duplication) use the load marked "LD" to null check
326   // ptr -- clause (2) above does not apply in this case.  In the above program
327   // the safety of ptr->field can be dependent on some_cond; and, for instance,
328   // ptr could be some non-null invalid reference that never gets loaded from
329   // because some_cond is always true.
330 
331   unsigned PointerReg = MBP.LHS.getReg();
332 
333   HazardDetector HD(*TRI);
334 
335   for (auto MII = NotNullSucc->begin(), MIE = NotNullSucc->end(); MII != MIE;
336        ++MII) {
337     MachineInstr *MI = &*MII;
338     unsigned BaseReg;
339     int64_t Offset;
340     if (TII->getMemOpBaseRegImmOfs(MI, BaseReg, Offset, TRI))
341       if (MI->mayLoad() && !MI->isPredicable() && BaseReg == PointerReg &&
342           Offset < PageSize && MI->getDesc().getNumDefs() <= 1 &&
343           HD.isSafeToHoist(MI)) {
344         NullCheckList.emplace_back(MI, MBP.ConditionDef, &MBB, NotNullSucc,
345                                    NullSucc);
346         return true;
347       }
348 
349     HD.rememberInstruction(MI);
350     if (HD.isClobbered())
351       return false;
352   }
353 
354   return false;
355 }
356 
357 /// Wrap a machine load instruction, LoadMI, into a FAULTING_LOAD_OP machine
358 /// instruction.  The FAULTING_LOAD_OP instruction does the same load as LoadMI
359 /// (defining the same register), and branches to HandlerMBB if the load
360 /// faults.  The FAULTING_LOAD_OP instruction is inserted at the end of MBB.
361 MachineInstr *
362 ImplicitNullChecks::insertFaultingLoad(MachineInstr *LoadMI,
363                                        MachineBasicBlock *MBB,
364                                        MachineBasicBlock *HandlerMBB) {
365   const unsigned NoRegister = 0; // Guaranteed to be the NoRegister value for
366                                  // all targets.
367 
368   DebugLoc DL;
369   unsigned NumDefs = LoadMI->getDesc().getNumDefs();
370   assert(NumDefs <= 1 && "other cases unhandled!");
371 
372   unsigned DefReg = NoRegister;
373   if (NumDefs != 0) {
374     DefReg = LoadMI->defs().begin()->getReg();
375     assert(std::distance(LoadMI->defs().begin(), LoadMI->defs().end()) == 1 &&
376            "expected exactly one def!");
377   }
378 
379   auto MIB = BuildMI(MBB, DL, TII->get(TargetOpcode::FAULTING_LOAD_OP), DefReg)
380                  .addMBB(HandlerMBB)
381                  .addImm(LoadMI->getOpcode());
382 
383   for (auto &MO : LoadMI->uses())
384     MIB.addOperand(MO);
385 
386   MIB.setMemRefs(LoadMI->memoperands_begin(), LoadMI->memoperands_end());
387 
388   return MIB;
389 }
390 
391 /// Rewrite the null checks in NullCheckList into implicit null checks.
392 void ImplicitNullChecks::rewriteNullChecks(
393     ArrayRef<ImplicitNullChecks::NullCheck> NullCheckList) {
394   DebugLoc DL;
395 
396   for (auto &NC : NullCheckList) {
397     // Remove the conditional branch dependent on the null check.
398     unsigned BranchesRemoved = TII->RemoveBranch(*NC.getCheckBlock());
399     (void)BranchesRemoved;
400     assert(BranchesRemoved > 0 && "expected at least one branch!");
401 
402     // Insert a faulting load where the conditional branch was originally.  We
403     // check earlier ensures that this bit of code motion is legal.  We do not
404     // touch the successors list for any basic block since we haven't changed
405     // control flow, we've just made it implicit.
406     MachineInstr *FaultingLoad = insertFaultingLoad(
407         NC.getMemOperation(), NC.getCheckBlock(), NC.getNullSucc());
408     // Now the values defined by MemOperation, if any, are live-in of
409     // the block of MemOperation.
410     // The original load operation may define implicit-defs alongside
411     // the loaded value.
412     MachineBasicBlock *MBB = NC.getMemOperation()->getParent();
413     for (const MachineOperand &MO : FaultingLoad->operands()) {
414       if (!MO.isReg() || !MO.isDef())
415         continue;
416       unsigned Reg = MO.getReg();
417       if (!Reg || MBB->isLiveIn(Reg))
418         continue;
419       MBB->addLiveIn(Reg);
420     }
421     NC.getMemOperation()->eraseFromParent();
422     NC.getCheckOperation()->eraseFromParent();
423 
424     // Insert an *unconditional* branch to not-null successor.
425     TII->InsertBranch(*NC.getCheckBlock(), NC.getNotNullSucc(), nullptr,
426                       /*Cond=*/None, DL);
427 
428     NumImplicitNullChecks++;
429   }
430 }
431 
432 char ImplicitNullChecks::ID = 0;
433 char &llvm::ImplicitNullChecksID = ImplicitNullChecks::ID;
434 INITIALIZE_PASS_BEGIN(ImplicitNullChecks, "implicit-null-checks",
435                       "Implicit null checks", false, false)
436 INITIALIZE_PASS_END(ImplicitNullChecks, "implicit-null-checks",
437                     "Implicit null checks", false, false)
438