1 //===- BoundsChecking.cpp - Instrumentation for run-time bounds checking --===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This file implements a pass that instruments the code to perform run-time 11 // bounds checking on loads, stores, and other memory intrinsics. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #define DEBUG_TYPE "bounds-checking" 16 #include "llvm/IRBuilder.h" 17 #include "llvm/Intrinsics.h" 18 #include "llvm/Pass.h" 19 #include "llvm/ADT/Statistic.h" 20 #include "llvm/Analysis/MemoryBuiltins.h" 21 #include "llvm/Support/CommandLine.h" 22 #include "llvm/Support/Debug.h" 23 #include "llvm/Support/InstIterator.h" 24 #include "llvm/Support/TargetFolder.h" 25 #include "llvm/Support/raw_ostream.h" 26 #include "llvm/Target/TargetData.h" 27 #include "llvm/Transforms/Instrumentation.h" 28 using namespace llvm; 29 30 static cl::opt<bool> SingleTrapBB("bounds-checking-single-trap", 31 cl::desc("Use one trap block per function")); 32 33 STATISTIC(ChecksAdded, "Bounds checks added"); 34 STATISTIC(ChecksSkipped, "Bounds checks skipped"); 35 STATISTIC(ChecksUnable, "Bounds checks unable to add"); 36 37 typedef IRBuilder<true, TargetFolder> BuilderTy; 38 39 namespace { 40 struct BoundsChecking : public FunctionPass { 41 static char ID; 42 43 BoundsChecking(unsigned _Penalty = 5) : FunctionPass(ID), Penalty(_Penalty){ 44 initializeBoundsCheckingPass(*PassRegistry::getPassRegistry()); 45 } 46 47 virtual bool runOnFunction(Function &F); 48 49 virtual void getAnalysisUsage(AnalysisUsage &AU) const { 50 AU.addRequired<TargetData>(); 51 } 52 53 private: 54 const TargetData *TD; 55 ObjectSizeOffsetEvaluator *ObjSizeEval; 56 BuilderTy *Builder; 57 Instruction *Inst; 58 BasicBlock *TrapBB; 59 unsigned Penalty; 60 61 BasicBlock *getTrapBB(); 62 void emitBranchToTrap(Value *Cmp = 0); 63 bool computeAllocSize(Value *Ptr, APInt &Offset, Value* &OffsetValue, 64 APInt &Size, Value* &SizeValue); 65 bool instrument(Value *Ptr, Value *Val); 66 }; 67 } 68 69 char BoundsChecking::ID = 0; 70 INITIALIZE_PASS(BoundsChecking, "bounds-checking", "Run-time bounds checking", 71 false, false) 72 73 74 /// getTrapBB - create a basic block that traps. All overflowing conditions 75 /// branch to this block. There's only one trap block per function. 76 BasicBlock *BoundsChecking::getTrapBB() { 77 if (TrapBB && SingleTrapBB) 78 return TrapBB; 79 80 Function *Fn = Inst->getParent()->getParent(); 81 BasicBlock::iterator PrevInsertPoint = Builder->GetInsertPoint(); 82 TrapBB = BasicBlock::Create(Fn->getContext(), "trap", Fn); 83 Builder->SetInsertPoint(TrapBB); 84 85 llvm::Value *F = Intrinsic::getDeclaration(Fn->getParent(), Intrinsic::trap); 86 CallInst *TrapCall = Builder->CreateCall(F); 87 TrapCall->setDoesNotReturn(); 88 TrapCall->setDoesNotThrow(); 89 TrapCall->setDebugLoc(Inst->getDebugLoc()); 90 Builder->CreateUnreachable(); 91 92 Builder->SetInsertPoint(PrevInsertPoint); 93 return TrapBB; 94 } 95 96 97 /// emitBranchToTrap - emit a branch instruction to a trap block. 98 /// If Cmp is non-null, perform a jump only if its value evaluates to true. 99 void BoundsChecking::emitBranchToTrap(Value *Cmp) { 100 // check if the comparison is always false 101 ConstantInt *C = dyn_cast_or_null<ConstantInt>(Cmp); 102 if (C) { 103 ++ChecksSkipped; 104 if (!C->getZExtValue()) 105 return; 106 else 107 Cmp = 0; // unconditional branch 108 } 109 110 Instruction *Inst = Builder->GetInsertPoint(); 111 BasicBlock *OldBB = Inst->getParent(); 112 BasicBlock *Cont = OldBB->splitBasicBlock(Inst); 113 OldBB->getTerminator()->eraseFromParent(); 114 115 if (Cmp) 116 BranchInst::Create(getTrapBB(), Cont, Cmp, OldBB); 117 else 118 BranchInst::Create(getTrapBB(), OldBB); 119 } 120 121 122 /// instrument - adds run-time bounds checks to memory accessing instructions. 123 /// Ptr is the pointer that will be read/written, and InstVal is either the 124 /// result from the load or the value being stored. It is used to determine the 125 /// size of memory block that is touched. 126 /// Returns true if any change was made to the IR, false otherwise. 127 bool BoundsChecking::instrument(Value *Ptr, Value *InstVal) { 128 uint64_t NeededSize = TD->getTypeStoreSize(InstVal->getType()); 129 DEBUG(dbgs() << "Instrument " << *Ptr << " for " << Twine(NeededSize) 130 << " bytes\n"); 131 132 SizeOffsetEvalType SizeOffset = ObjSizeEval->compute(Ptr); 133 134 if (!ObjSizeEval->bothKnown(SizeOffset)) { 135 ++ChecksUnable; 136 return false; 137 } 138 139 Value *Size = SizeOffset.first; 140 Value *Offset = SizeOffset.second; 141 ConstantInt *SizeCI = dyn_cast<ConstantInt>(Size); 142 143 IntegerType *IntTy = TD->getIntPtrType(Inst->getContext()); 144 Value *NeededSizeVal = ConstantInt::get(IntTy, NeededSize); 145 146 // three checks are required to ensure safety: 147 // . Offset >= 0 (since the offset is given from the base ptr) 148 // . Size >= Offset (unsigned) 149 // . Size - Offset >= NeededSize (unsigned) 150 // 151 // optimization: if Size >= 0 (signed), skip 1st check 152 // FIXME: add NSW/NUW here? -- we dont care if the subtraction overflows 153 Value *ObjSize = Builder->CreateSub(Size, Offset); 154 Value *Cmp2 = Builder->CreateICmpULT(Size, Offset); 155 Value *Cmp3 = Builder->CreateICmpULT(ObjSize, NeededSizeVal); 156 Value *Or = Builder->CreateOr(Cmp2, Cmp3); 157 if (!SizeCI || SizeCI->getValue().slt(0)) { 158 Value *Cmp1 = Builder->CreateICmpSLT(Offset, ConstantInt::get(IntTy, 0)); 159 Or = Builder->CreateOr(Cmp1, Or); 160 } 161 emitBranchToTrap(Or); 162 163 ++ChecksAdded; 164 return true; 165 } 166 167 bool BoundsChecking::runOnFunction(Function &F) { 168 TD = &getAnalysis<TargetData>(); 169 170 TrapBB = 0; 171 BuilderTy TheBuilder(F.getContext(), TargetFolder(TD)); 172 Builder = &TheBuilder; 173 ObjectSizeOffsetEvaluator TheObjSizeEval(TD, F.getContext()); 174 ObjSizeEval = &TheObjSizeEval; 175 176 // check HANDLE_MEMORY_INST in include/llvm/Instruction.def for memory 177 // touching instructions 178 std::vector<Instruction*> WorkList; 179 for (inst_iterator i = inst_begin(F), e = inst_end(F); i != e; ++i) { 180 Instruction *I = &*i; 181 if (isa<LoadInst>(I) || isa<StoreInst>(I) || isa<AtomicCmpXchgInst>(I) || 182 isa<AtomicRMWInst>(I)) 183 WorkList.push_back(I); 184 } 185 186 bool MadeChange = false; 187 for (std::vector<Instruction*>::iterator i = WorkList.begin(), 188 e = WorkList.end(); i != e; ++i) { 189 Inst = *i; 190 191 Builder->SetInsertPoint(Inst); 192 if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) { 193 MadeChange |= instrument(LI->getPointerOperand(), LI); 194 } else if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { 195 MadeChange |= instrument(SI->getPointerOperand(), SI->getValueOperand()); 196 } else if (AtomicCmpXchgInst *AI = dyn_cast<AtomicCmpXchgInst>(Inst)) { 197 MadeChange |= instrument(AI->getPointerOperand(),AI->getCompareOperand()); 198 } else if (AtomicRMWInst *AI = dyn_cast<AtomicRMWInst>(Inst)) { 199 MadeChange |= instrument(AI->getPointerOperand(), AI->getValOperand()); 200 } else { 201 llvm_unreachable("unknown Instruction type"); 202 } 203 } 204 return MadeChange; 205 } 206 207 FunctionPass *llvm::createBoundsCheckingPass(unsigned Penalty) { 208 return new BoundsChecking(Penalty); 209 } 210