1 //===-- AMDGPUPromoteAlloca.cpp - Promote Allocas -------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass eliminates allocas by either converting them into vectors or
11 // by migrating them to local address space.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "AMDGPU.h"
16 #include "AMDGPUSubtarget.h"
17 #include "llvm/Analysis/ValueTracking.h"
18 #include "llvm/IR/IRBuilder.h"
19 #include "llvm/IR/InstVisitor.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/raw_ostream.h"
22 
23 #define DEBUG_TYPE "amdgpu-promote-alloca"
24 
25 using namespace llvm;
26 
27 namespace {
28 
29 class AMDGPUPromoteAlloca : public FunctionPass,
30                        public InstVisitor<AMDGPUPromoteAlloca> {
31 
32   static char ID;
33   Module *Mod;
34   const AMDGPUSubtarget &ST;
35   int LocalMemAvailable;
36 
37 public:
38   AMDGPUPromoteAlloca(const AMDGPUSubtarget &st) : FunctionPass(ID), ST(st),
39                                                    LocalMemAvailable(0) { }
40   bool doInitialization(Module &M) override;
41   bool runOnFunction(Function &F) override;
42   const char *getPassName() const override { return "AMDGPU Promote Alloca"; }
43   void visitAlloca(AllocaInst &I);
44 };
45 
46 } // End anonymous namespace
47 
48 char AMDGPUPromoteAlloca::ID = 0;
49 
50 bool AMDGPUPromoteAlloca::doInitialization(Module &M) {
51   Mod = &M;
52   return false;
53 }
54 
55 bool AMDGPUPromoteAlloca::runOnFunction(Function &F) {
56 
57   FunctionType *FTy = F.getFunctionType();
58 
59   LocalMemAvailable = ST.getLocalMemorySize();
60 
61 
62   // If the function has any arguments in the local address space, then it's
63   // possible these arguments require the entire local memory space, so
64   // we cannot use local memory in the pass.
65   for (unsigned i = 0, e = FTy->getNumParams(); i != e; ++i) {
66     Type *ParamTy = FTy->getParamType(i);
67     if (ParamTy->isPointerTy() &&
68         ParamTy->getPointerAddressSpace() == AMDGPUAS::LOCAL_ADDRESS) {
69       LocalMemAvailable = 0;
70       DEBUG(dbgs() << "Function has local memory argument.  Promoting to "
71                       "local memory disabled.\n");
72       break;
73     }
74   }
75 
76   if (LocalMemAvailable > 0) {
77     // Check how much local memory is being used by global objects
78     for (Module::global_iterator I = Mod->global_begin(),
79                                  E = Mod->global_end(); I != E; ++I) {
80       GlobalVariable *GV = &*I;
81       if (GV->getType()->getAddressSpace() != AMDGPUAS::LOCAL_ADDRESS)
82         continue;
83       for (Value::use_iterator U = GV->use_begin(),
84                                UE = GV->use_end(); U != UE; ++U) {
85         Instruction *Use = dyn_cast<Instruction>(*U);
86         if (!Use)
87           continue;
88         if (Use->getParent()->getParent() == &F)
89           LocalMemAvailable -=
90               Mod->getDataLayout().getTypeAllocSize(GV->getValueType());
91       }
92     }
93   }
94 
95   LocalMemAvailable = std::max(0, LocalMemAvailable);
96   DEBUG(dbgs() << LocalMemAvailable << "bytes free in local memory.\n");
97 
98   visit(F);
99 
100   return false;
101 }
102 
103 static VectorType *arrayTypeToVecType(Type *ArrayTy) {
104   return VectorType::get(ArrayTy->getArrayElementType(),
105                          ArrayTy->getArrayNumElements());
106 }
107 
108 static Value *
109 calculateVectorIndex(Value *Ptr,
110                      const std::map<GetElementPtrInst *, Value *> &GEPIdx) {
111   if (isa<AllocaInst>(Ptr))
112     return Constant::getNullValue(Type::getInt32Ty(Ptr->getContext()));
113 
114   GetElementPtrInst *GEP = cast<GetElementPtrInst>(Ptr);
115 
116   auto I = GEPIdx.find(GEP);
117   return I == GEPIdx.end() ? nullptr : I->second;
118 }
119 
120 static Value* GEPToVectorIndex(GetElementPtrInst *GEP) {
121   // FIXME we only support simple cases
122   if (GEP->getNumOperands() != 3)
123     return NULL;
124 
125   ConstantInt *I0 = dyn_cast<ConstantInt>(GEP->getOperand(1));
126   if (!I0 || !I0->isZero())
127     return NULL;
128 
129   return GEP->getOperand(2);
130 }
131 
132 // Not an instruction handled below to turn into a vector.
133 //
134 // TODO: Check isTriviallyVectorizable for calls and handle other
135 // instructions.
136 static bool canVectorizeInst(Instruction *Inst, User *User) {
137   switch (Inst->getOpcode()) {
138   case Instruction::Load:
139   case Instruction::BitCast:
140   case Instruction::AddrSpaceCast:
141     return true;
142   case Instruction::Store: {
143     // Must be the stored pointer operand, not a stored value.
144     StoreInst *SI = cast<StoreInst>(Inst);
145     return SI->getPointerOperand() == User;
146   }
147   default:
148     return false;
149   }
150 }
151 
152 static bool tryPromoteAllocaToVector(AllocaInst *Alloca) {
153   Type *AllocaTy = Alloca->getAllocatedType();
154 
155   DEBUG(dbgs() << "Alloca Candidate for vectorization \n");
156 
157   // FIXME: There is no reason why we can't support larger arrays, we
158   // are just being conservative for now.
159   if (!AllocaTy->isArrayTy() ||
160       AllocaTy->getArrayElementType()->isVectorTy() ||
161       AllocaTy->getArrayNumElements() > 4) {
162 
163     DEBUG(dbgs() << "  Cannot convert type to vector");
164     return false;
165   }
166 
167   std::map<GetElementPtrInst*, Value*> GEPVectorIdx;
168   std::vector<Value*> WorkList;
169   for (User *AllocaUser : Alloca->users()) {
170     GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(AllocaUser);
171     if (!GEP) {
172       if (!canVectorizeInst(cast<Instruction>(AllocaUser), Alloca))
173         return false;
174 
175       WorkList.push_back(AllocaUser);
176       continue;
177     }
178 
179     Value *Index = GEPToVectorIndex(GEP);
180 
181     // If we can't compute a vector index from this GEP, then we can't
182     // promote this alloca to vector.
183     if (!Index) {
184       DEBUG(dbgs() << "  Cannot compute vector index for GEP " << *GEP << '\n');
185       return false;
186     }
187 
188     GEPVectorIdx[GEP] = Index;
189     for (User *GEPUser : AllocaUser->users()) {
190       if (!canVectorizeInst(cast<Instruction>(GEPUser), AllocaUser))
191         return false;
192 
193       WorkList.push_back(GEPUser);
194     }
195   }
196 
197   VectorType *VectorTy = arrayTypeToVecType(AllocaTy);
198 
199   DEBUG(dbgs() << "  Converting alloca to vector "
200         << *AllocaTy << " -> " << *VectorTy << '\n');
201 
202   for (std::vector<Value*>::iterator I = WorkList.begin(),
203                                      E = WorkList.end(); I != E; ++I) {
204     Instruction *Inst = cast<Instruction>(*I);
205     IRBuilder<> Builder(Inst);
206     switch (Inst->getOpcode()) {
207     case Instruction::Load: {
208       Value *Ptr = Inst->getOperand(0);
209       Value *Index = calculateVectorIndex(Ptr, GEPVectorIdx);
210       Value *BitCast = Builder.CreateBitCast(Alloca, VectorTy->getPointerTo(0));
211       Value *VecValue = Builder.CreateLoad(BitCast);
212       Value *ExtractElement = Builder.CreateExtractElement(VecValue, Index);
213       Inst->replaceAllUsesWith(ExtractElement);
214       Inst->eraseFromParent();
215       break;
216     }
217     case Instruction::Store: {
218       Value *Ptr = Inst->getOperand(1);
219       Value *Index = calculateVectorIndex(Ptr, GEPVectorIdx);
220       Value *BitCast = Builder.CreateBitCast(Alloca, VectorTy->getPointerTo(0));
221       Value *VecValue = Builder.CreateLoad(BitCast);
222       Value *NewVecValue = Builder.CreateInsertElement(VecValue,
223                                                        Inst->getOperand(0),
224                                                        Index);
225       Builder.CreateStore(NewVecValue, BitCast);
226       Inst->eraseFromParent();
227       break;
228     }
229     case Instruction::BitCast:
230     case Instruction::AddrSpaceCast:
231       break;
232 
233     default:
234       Inst->dump();
235       llvm_unreachable("Inconsistency in instructions promotable to vector");
236     }
237   }
238   return true;
239 }
240 
241 static bool collectUsesWithPtrTypes(Value *Val, std::vector<Value*> &WorkList) {
242   bool Success = true;
243   for (User *User : Val->users()) {
244     if(std::find(WorkList.begin(), WorkList.end(), User) != WorkList.end())
245       continue;
246     if (CallInst *CI = dyn_cast<CallInst>(User)) {
247       // TODO: We might be able to handle some cases where the callee is a
248       // constantexpr bitcast of a function.
249       if (!CI->getCalledFunction())
250         return false;
251 
252       WorkList.push_back(User);
253       continue;
254     }
255 
256     // FIXME: Correctly handle ptrtoint instructions.
257     Instruction *UseInst = dyn_cast<Instruction>(User);
258     if (UseInst && UseInst->getOpcode() == Instruction::PtrToInt)
259       return false;
260 
261     if (StoreInst *SI = dyn_cast_or_null<StoreInst>(UseInst)) {
262       // Reject if the stored value is not the pointer operand.
263       if (SI->getPointerOperand() != Val)
264         return false;
265     }
266 
267     if (!User->getType()->isPointerTy())
268       continue;
269 
270     WorkList.push_back(User);
271 
272     Success &= collectUsesWithPtrTypes(User, WorkList);
273   }
274   return Success;
275 }
276 
277 void AMDGPUPromoteAlloca::visitAlloca(AllocaInst &I) {
278   if (!I.isStaticAlloca())
279     return;
280 
281   IRBuilder<> Builder(&I);
282 
283   // First try to replace the alloca with a vector
284   Type *AllocaTy = I.getAllocatedType();
285 
286   DEBUG(dbgs() << "Trying to promote " << I << '\n');
287 
288   if (tryPromoteAllocaToVector(&I))
289     return;
290 
291   DEBUG(dbgs() << " alloca is not a candidate for vectorization.\n");
292 
293   // FIXME: This is the maximum work group size.  We should try to get
294   // value from the reqd_work_group_size function attribute if it is
295   // available.
296   unsigned WorkGroupSize = 256;
297   int AllocaSize =
298       WorkGroupSize * Mod->getDataLayout().getTypeAllocSize(AllocaTy);
299 
300   if (AllocaSize > LocalMemAvailable) {
301     DEBUG(dbgs() << " Not enough local memory to promote alloca.\n");
302     return;
303   }
304 
305   std::vector<Value*> WorkList;
306 
307   if (!collectUsesWithPtrTypes(&I, WorkList)) {
308     DEBUG(dbgs() << " Do not know how to convert all uses\n");
309     return;
310   }
311 
312   DEBUG(dbgs() << "Promoting alloca to local memory\n");
313   LocalMemAvailable -= AllocaSize;
314 
315   Type *GVTy = ArrayType::get(I.getAllocatedType(), 256);
316   GlobalVariable *GV = new GlobalVariable(
317       *Mod, GVTy, false, GlobalValue::ExternalLinkage, 0, I.getName(), 0,
318       GlobalVariable::NotThreadLocal, AMDGPUAS::LOCAL_ADDRESS);
319 
320   FunctionType *FTy = FunctionType::get(
321       Type::getInt32Ty(Mod->getContext()), false);
322   AttributeSet AttrSet;
323   AttrSet.addAttribute(Mod->getContext(), 0, Attribute::ReadNone);
324 
325   Value *ReadLocalSizeY = Mod->getOrInsertFunction(
326       "llvm.r600.read.local.size.y", FTy, AttrSet);
327   Value *ReadLocalSizeZ = Mod->getOrInsertFunction(
328       "llvm.r600.read.local.size.z", FTy, AttrSet);
329   Value *ReadTIDIGX = Mod->getOrInsertFunction(
330       "llvm.r600.read.tidig.x", FTy, AttrSet);
331   Value *ReadTIDIGY = Mod->getOrInsertFunction(
332       "llvm.r600.read.tidig.y", FTy, AttrSet);
333   Value *ReadTIDIGZ = Mod->getOrInsertFunction(
334       "llvm.r600.read.tidig.z", FTy, AttrSet);
335 
336   Value *TCntY = Builder.CreateCall(ReadLocalSizeY, {});
337   Value *TCntZ = Builder.CreateCall(ReadLocalSizeZ, {});
338   Value *TIdX = Builder.CreateCall(ReadTIDIGX, {});
339   Value *TIdY = Builder.CreateCall(ReadTIDIGY, {});
340   Value *TIdZ = Builder.CreateCall(ReadTIDIGZ, {});
341 
342   Value *Tmp0 = Builder.CreateMul(TCntY, TCntZ);
343   Tmp0 = Builder.CreateMul(Tmp0, TIdX);
344   Value *Tmp1 = Builder.CreateMul(TIdY, TCntZ);
345   Value *TID = Builder.CreateAdd(Tmp0, Tmp1);
346   TID = Builder.CreateAdd(TID, TIdZ);
347 
348   std::vector<Value*> Indices;
349   Indices.push_back(Constant::getNullValue(Type::getInt32Ty(Mod->getContext())));
350   Indices.push_back(TID);
351 
352   Value *Offset = Builder.CreateGEP(GVTy, GV, Indices);
353   I.mutateType(Offset->getType());
354   I.replaceAllUsesWith(Offset);
355   I.eraseFromParent();
356 
357   for (std::vector<Value*>::iterator i = WorkList.begin(),
358                                      e = WorkList.end(); i != e; ++i) {
359     Value *V = *i;
360     CallInst *Call = dyn_cast<CallInst>(V);
361     if (!Call) {
362       Type *EltTy = V->getType()->getPointerElementType();
363       PointerType *NewTy = PointerType::get(EltTy, AMDGPUAS::LOCAL_ADDRESS);
364 
365       // The operand's value should be corrected on its own.
366       if (isa<AddrSpaceCastInst>(V))
367         continue;
368 
369       // FIXME: It doesn't really make sense to try to do this for all
370       // instructions.
371       V->mutateType(NewTy);
372       continue;
373     }
374 
375     IntrinsicInst *Intr = dyn_cast<IntrinsicInst>(Call);
376     if (!Intr) {
377       std::vector<Type*> ArgTypes;
378       for (unsigned ArgIdx = 0, ArgEnd = Call->getNumArgOperands();
379                                 ArgIdx != ArgEnd; ++ArgIdx) {
380         ArgTypes.push_back(Call->getArgOperand(ArgIdx)->getType());
381       }
382       Function *F = Call->getCalledFunction();
383       FunctionType *NewType = FunctionType::get(Call->getType(), ArgTypes,
384                                                 F->isVarArg());
385       Constant *C = Mod->getOrInsertFunction((F->getName() + ".local").str(),
386                                              NewType, F->getAttributes());
387       Function *NewF = cast<Function>(C);
388       Call->setCalledFunction(NewF);
389       continue;
390     }
391 
392     Builder.SetInsertPoint(Intr);
393     switch (Intr->getIntrinsicID()) {
394     case Intrinsic::lifetime_start:
395     case Intrinsic::lifetime_end:
396       // These intrinsics are for address space 0 only
397       Intr->eraseFromParent();
398       continue;
399     case Intrinsic::memcpy: {
400       MemCpyInst *MemCpy = cast<MemCpyInst>(Intr);
401       Builder.CreateMemCpy(MemCpy->getRawDest(), MemCpy->getRawSource(),
402                            MemCpy->getLength(), MemCpy->getAlignment(),
403                            MemCpy->isVolatile());
404       Intr->eraseFromParent();
405       continue;
406     }
407     case Intrinsic::memset: {
408       MemSetInst *MemSet = cast<MemSetInst>(Intr);
409       Builder.CreateMemSet(MemSet->getRawDest(), MemSet->getValue(),
410                            MemSet->getLength(), MemSet->getAlignment(),
411                            MemSet->isVolatile());
412       Intr->eraseFromParent();
413       continue;
414     }
415     case Intrinsic::invariant_start:
416     case Intrinsic::invariant_end:
417     case Intrinsic::invariant_group_barrier:
418       Intr->eraseFromParent();
419       // FIXME: I think the invariant marker should still theoretically apply,
420       // but the intrinsics need to be changed to accept pointers with any
421       // address space.
422       continue;
423     default:
424       Intr->dump();
425       llvm_unreachable("Don't know how to promote alloca intrinsic use.");
426     }
427   }
428 }
429 
430 FunctionPass *llvm::createAMDGPUPromoteAlloca(const AMDGPUSubtarget &ST) {
431   return new AMDGPUPromoteAlloca(ST);
432 }
433