1 //===- LowerMemIntrinsics.cpp ----------------------------------*- C++ -*--===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 
10 #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
11 #include "llvm/IR/IRBuilder.h"
12 #include "llvm/IR/IntrinsicInst.h"
13 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
14 
15 using namespace llvm;
16 
17 void llvm::createMemCpyLoop(Instruction *InsertBefore,
18                             Value *SrcAddr, Value *DstAddr, Value *CopyLen,
19                             unsigned SrcAlign, unsigned DestAlign,
20                             bool SrcIsVolatile, bool DstIsVolatile) {
21   Type *TypeOfCopyLen = CopyLen->getType();
22 
23   BasicBlock *OrigBB = InsertBefore->getParent();
24   Function *F = OrigBB->getParent();
25   BasicBlock *NewBB =
26     InsertBefore->getParent()->splitBasicBlock(InsertBefore, "split");
27   BasicBlock *LoopBB = BasicBlock::Create(F->getContext(), "loadstoreloop",
28                                           F, NewBB);
29 
30   OrigBB->getTerminator()->setSuccessor(0, LoopBB);
31   IRBuilder<> Builder(OrigBB->getTerminator());
32 
33   // SrcAddr and DstAddr are expected to be pointer types,
34   // so no check is made here.
35   unsigned SrcAS = cast<PointerType>(SrcAddr->getType())->getAddressSpace();
36   unsigned DstAS = cast<PointerType>(DstAddr->getType())->getAddressSpace();
37 
38   // Cast pointers to (char *)
39   SrcAddr = Builder.CreateBitCast(SrcAddr, Builder.getInt8PtrTy(SrcAS));
40   DstAddr = Builder.CreateBitCast(DstAddr, Builder.getInt8PtrTy(DstAS));
41 
42   IRBuilder<> LoopBuilder(LoopBB);
43   PHINode *LoopIndex = LoopBuilder.CreatePHI(TypeOfCopyLen, 0);
44   LoopIndex->addIncoming(ConstantInt::get(TypeOfCopyLen, 0), OrigBB);
45 
46   // load from SrcAddr+LoopIndex
47   // TODO: we can leverage the align parameter of llvm.memcpy for more efficient
48   // word-sized loads and stores.
49   Value *Element =
50     LoopBuilder.CreateLoad(LoopBuilder.CreateInBoundsGEP(
51                              LoopBuilder.getInt8Ty(), SrcAddr, LoopIndex),
52                            SrcIsVolatile);
53   // store at DstAddr+LoopIndex
54   LoopBuilder.CreateStore(Element,
55                           LoopBuilder.CreateInBoundsGEP(LoopBuilder.getInt8Ty(),
56                                                         DstAddr, LoopIndex),
57                           DstIsVolatile);
58 
59   // The value for LoopIndex coming from backedge is (LoopIndex + 1)
60   Value *NewIndex =
61     LoopBuilder.CreateAdd(LoopIndex, ConstantInt::get(TypeOfCopyLen, 1));
62   LoopIndex->addIncoming(NewIndex, LoopBB);
63 
64   LoopBuilder.CreateCondBr(LoopBuilder.CreateICmpULT(NewIndex, CopyLen), LoopBB,
65                            NewBB);
66 }
67 
68 // Lower memmove to IR. memmove is required to correctly copy overlapping memory
69 // regions; therefore, it has to check the relative positions of the source and
70 // destination pointers and choose the copy direction accordingly.
71 //
72 // The code below is an IR rendition of this C function:
73 //
74 // void* memmove(void* dst, const void* src, size_t n) {
75 //   unsigned char* d = dst;
76 //   const unsigned char* s = src;
77 //   if (s < d) {
78 //     // copy backwards
79 //     while (n--) {
80 //       d[n] = s[n];
81 //     }
82 //   } else {
83 //     // copy forward
84 //     for (size_t i = 0; i < n; ++i) {
85 //       d[i] = s[i];
86 //     }
87 //   }
88 //   return dst;
89 // }
90 static void createMemMoveLoop(Instruction *InsertBefore,
91                               Value *SrcAddr, Value *DstAddr, Value *CopyLen,
92                               unsigned SrcAlign, unsigned DestAlign,
93                               bool SrcIsVolatile, bool DstIsVolatile) {
94   Type *TypeOfCopyLen = CopyLen->getType();
95   BasicBlock *OrigBB = InsertBefore->getParent();
96   Function *F = OrigBB->getParent();
97 
98   // Create the a comparison of src and dst, based on which we jump to either
99   // the forward-copy part of the function (if src >= dst) or the backwards-copy
100   // part (if src < dst).
101   // SplitBlockAndInsertIfThenElse conveniently creates the basic if-then-else
102   // structure. Its block terminators (unconditional branches) are replaced by
103   // the appropriate conditional branches when the loop is built.
104   ICmpInst *PtrCompare = new ICmpInst(InsertBefore, ICmpInst::ICMP_ULT,
105                                       SrcAddr, DstAddr, "compare_src_dst");
106   TerminatorInst *ThenTerm, *ElseTerm;
107   SplitBlockAndInsertIfThenElse(PtrCompare, InsertBefore, &ThenTerm,
108                                 &ElseTerm);
109 
110   // Each part of the function consists of two blocks:
111   //   copy_backwards:        used to skip the loop when n == 0
112   //   copy_backwards_loop:   the actual backwards loop BB
113   //   copy_forward:          used to skip the loop when n == 0
114   //   copy_forward_loop:     the actual forward loop BB
115   BasicBlock *CopyBackwardsBB = ThenTerm->getParent();
116   CopyBackwardsBB->setName("copy_backwards");
117   BasicBlock *CopyForwardBB = ElseTerm->getParent();
118   CopyForwardBB->setName("copy_forward");
119   BasicBlock *ExitBB = InsertBefore->getParent();
120   ExitBB->setName("memmove_done");
121 
122   // Initial comparison of n == 0 that lets us skip the loops altogether. Shared
123   // between both backwards and forward copy clauses.
124   ICmpInst *CompareN =
125       new ICmpInst(OrigBB->getTerminator(), ICmpInst::ICMP_EQ, CopyLen,
126                    ConstantInt::get(TypeOfCopyLen, 0), "compare_n_to_0");
127 
128   // Copying backwards.
129   BasicBlock *LoopBB =
130     BasicBlock::Create(F->getContext(), "copy_backwards_loop", F, CopyForwardBB);
131   IRBuilder<> LoopBuilder(LoopBB);
132   PHINode *LoopPhi = LoopBuilder.CreatePHI(TypeOfCopyLen, 0);
133   Value *IndexPtr = LoopBuilder.CreateSub(
134       LoopPhi, ConstantInt::get(TypeOfCopyLen, 1), "index_ptr");
135   Value *Element = LoopBuilder.CreateLoad(
136       LoopBuilder.CreateInBoundsGEP(SrcAddr, IndexPtr), "element");
137   LoopBuilder.CreateStore(Element,
138                           LoopBuilder.CreateInBoundsGEP(DstAddr, IndexPtr));
139   LoopBuilder.CreateCondBr(
140       LoopBuilder.CreateICmpEQ(IndexPtr, ConstantInt::get(TypeOfCopyLen, 0)),
141       ExitBB, LoopBB);
142   LoopPhi->addIncoming(IndexPtr, LoopBB);
143   LoopPhi->addIncoming(CopyLen, CopyBackwardsBB);
144   BranchInst::Create(ExitBB, LoopBB, CompareN, ThenTerm);
145   ThenTerm->eraseFromParent();
146 
147   // Copying forward.
148   BasicBlock *FwdLoopBB =
149     BasicBlock::Create(F->getContext(), "copy_forward_loop", F, ExitBB);
150   IRBuilder<> FwdLoopBuilder(FwdLoopBB);
151   PHINode *FwdCopyPhi = FwdLoopBuilder.CreatePHI(TypeOfCopyLen, 0, "index_ptr");
152   Value *FwdElement = FwdLoopBuilder.CreateLoad(
153       FwdLoopBuilder.CreateInBoundsGEP(SrcAddr, FwdCopyPhi), "element");
154   FwdLoopBuilder.CreateStore(
155       FwdElement, FwdLoopBuilder.CreateInBoundsGEP(DstAddr, FwdCopyPhi));
156   Value *FwdIndexPtr = FwdLoopBuilder.CreateAdd(
157       FwdCopyPhi, ConstantInt::get(TypeOfCopyLen, 1), "index_increment");
158   FwdLoopBuilder.CreateCondBr(FwdLoopBuilder.CreateICmpEQ(FwdIndexPtr, CopyLen),
159                               ExitBB, FwdLoopBB);
160   FwdCopyPhi->addIncoming(FwdIndexPtr, FwdLoopBB);
161   FwdCopyPhi->addIncoming(ConstantInt::get(TypeOfCopyLen, 0), CopyForwardBB);
162 
163   BranchInst::Create(ExitBB, FwdLoopBB, CompareN, ElseTerm);
164   ElseTerm->eraseFromParent();
165 }
166 
167 static void createMemSetLoop(Instruction *InsertBefore,
168                              Value *DstAddr, Value *CopyLen, Value *SetValue,
169                              unsigned Align, bool IsVolatile) {
170   BasicBlock *OrigBB = InsertBefore->getParent();
171   Function *F = OrigBB->getParent();
172   BasicBlock *NewBB =
173       OrigBB->splitBasicBlock(InsertBefore, "split");
174   BasicBlock *LoopBB
175     = BasicBlock::Create(F->getContext(), "loadstoreloop", F, NewBB);
176 
177   OrigBB->getTerminator()->setSuccessor(0, LoopBB);
178   IRBuilder<> Builder(OrigBB->getTerminator());
179 
180   // Cast pointer to the type of value getting stored
181   unsigned dstAS = cast<PointerType>(DstAddr->getType())->getAddressSpace();
182   DstAddr = Builder.CreateBitCast(DstAddr,
183                                   PointerType::get(SetValue->getType(), dstAS));
184 
185   IRBuilder<> LoopBuilder(LoopBB);
186   PHINode *LoopIndex = LoopBuilder.CreatePHI(CopyLen->getType(), 0);
187   LoopIndex->addIncoming(ConstantInt::get(CopyLen->getType(), 0), OrigBB);
188 
189   LoopBuilder.CreateStore(
190       SetValue,
191       LoopBuilder.CreateInBoundsGEP(SetValue->getType(), DstAddr, LoopIndex),
192       IsVolatile);
193 
194   Value *NewIndex =
195       LoopBuilder.CreateAdd(LoopIndex, ConstantInt::get(CopyLen->getType(), 1));
196   LoopIndex->addIncoming(NewIndex, LoopBB);
197 
198   LoopBuilder.CreateCondBr(LoopBuilder.CreateICmpULT(NewIndex, CopyLen), LoopBB,
199                            NewBB);
200 }
201 
202 void llvm::expandMemCpyAsLoop(MemCpyInst *Memcpy) {
203   createMemCpyLoop(/* InsertBefore */ Memcpy,
204                    /* SrcAddr */ Memcpy->getRawSource(),
205                    /* DstAddr */ Memcpy->getRawDest(),
206                    /* CopyLen */ Memcpy->getLength(),
207                    /* SrcAlign */ Memcpy->getAlignment(),
208                    /* DestAlign */ Memcpy->getAlignment(),
209                    /* SrcIsVolatile */ Memcpy->isVolatile(),
210                    /* DstIsVolatile */ Memcpy->isVolatile());
211 }
212 
213 void llvm::expandMemMoveAsLoop(MemMoveInst *Memmove) {
214   createMemMoveLoop(/* InsertBefore */ Memmove,
215                     /* SrcAddr */ Memmove->getRawSource(),
216                     /* DstAddr */ Memmove->getRawDest(),
217                     /* CopyLen */ Memmove->getLength(),
218                     /* SrcAlign */ Memmove->getAlignment(),
219                     /* DestAlign */ Memmove->getAlignment(),
220                     /* SrcIsVolatile */ Memmove->isVolatile(),
221                     /* DstIsVolatile */ Memmove->isVolatile());
222 }
223 
224 void llvm::expandMemSetAsLoop(MemSetInst *Memset) {
225   createMemSetLoop(/* InsertBefore */ Memset,
226                    /* DstAddr */ Memset->getRawDest(),
227                    /* CopyLen */ Memset->getLength(),
228                    /* SetValue */ Memset->getValue(),
229                    /* Alignment */ Memset->getAlignment(),
230                    Memset->isVolatile());
231 }
232