1 //===-- AMDGPULowerKernelArguments.cpp ------------------------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 /// \file This pass replaces accesses to kernel arguments with loads from
11 /// offsets from the kernarg base pointer.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "AMDGPU.h"
16 #include "AMDGPUSubtarget.h"
17 #include "AMDGPUTargetMachine.h"
18 #include "llvm/ADT/StringRef.h"
19 #include "llvm/Analysis/DivergenceAnalysis.h"
20 #include "llvm/Analysis/Loads.h"
21 #include "llvm/CodeGen/Passes.h"
22 #include "llvm/CodeGen/TargetPassConfig.h"
23 #include "llvm/IR/Attributes.h"
24 #include "llvm/IR/BasicBlock.h"
25 #include "llvm/IR/Constants.h"
26 #include "llvm/IR/DerivedTypes.h"
27 #include "llvm/IR/Function.h"
28 #include "llvm/IR/IRBuilder.h"
29 #include "llvm/IR/InstrTypes.h"
30 #include "llvm/IR/Instruction.h"
31 #include "llvm/IR/Instructions.h"
32 #include "llvm/IR/LLVMContext.h"
33 #include "llvm/IR/MDBuilder.h"
34 #include "llvm/IR/Metadata.h"
35 #include "llvm/IR/Operator.h"
36 #include "llvm/IR/Type.h"
37 #include "llvm/IR/Value.h"
38 #include "llvm/Pass.h"
39 #include "llvm/Support/Casting.h"
40 
41 #define DEBUG_TYPE "amdgpu-lower-kernel-arguments"
42 
43 using namespace llvm;
44 
45 namespace {
46 
47 class AMDGPULowerKernelArguments : public FunctionPass{
48 public:
49   static char ID;
50 
51   AMDGPULowerKernelArguments() : FunctionPass(ID) {}
52 
53   bool runOnFunction(Function &F) override;
54 
55   void getAnalysisUsage(AnalysisUsage &AU) const override {
56     AU.addRequired<TargetPassConfig>();
57     AU.setPreservesAll();
58  }
59 };
60 
61 } // end anonymous namespace
62 
63 bool AMDGPULowerKernelArguments::runOnFunction(Function &F) {
64   CallingConv::ID CC = F.getCallingConv();
65   if (CC != CallingConv::AMDGPU_KERNEL || F.arg_empty())
66     return false;
67 
68   auto &TPC = getAnalysis<TargetPassConfig>();
69 
70   const TargetMachine &TM = TPC.getTM<TargetMachine>();
71   const SISubtarget &ST = TM.getSubtarget<SISubtarget>(F);
72   LLVMContext &Ctx = F.getParent()->getContext();
73   const DataLayout &DL = F.getParent()->getDataLayout();
74   BasicBlock &EntryBlock = *F.begin();
75   IRBuilder<> Builder(&*EntryBlock.begin());
76 
77   SmallVector<Type *, 16> ArgTypes;
78   for (Argument &Arg : F.args()) {
79     Type *ArgTy = Arg.getType();
80     unsigned Size = DL.getTypeStoreSizeInBits(ArgTy);
81     bool IsExtArg = Size < 32 && (Arg.hasZExtAttr() || Arg.hasSExtAttr()) &&
82                     !ST.isAmdHsaOS();
83 
84     // Clover seems to always pad i8/i16 to i32, but doesn't properly align
85     // them?
86     // Make sure the struct elements have correct size and alignment for ext
87     // args. These seem to be padded up to 4-bytes but not correctly aligned.
88     ArgTypes.push_back(
89       IsExtArg ? ArrayType::get(ArgTy, 32 / Size) : Arg.getType());
90   }
91 
92   StructType *ArgStructTy = StructType::create(Ctx, ArgTypes, F.getName());
93   const StructLayout *Layout = DL.getStructLayout(ArgStructTy);
94 
95   // Minimum alignment for kern segment is 16.
96   unsigned KernArgBaseAlign = std::max(16u, DL.getABITypeAlignment(ArgStructTy));
97   const uint64_t BaseOffset = ST.getExplicitKernelArgOffset(F);
98 
99   // FIXME: Alignment is broken broken with explicit arg offset.;
100   const uint64_t TotalKernArgSize = BaseOffset +
101     ST.getKernArgSegmentSize(F, DL.getTypeAllocSize(ArgStructTy));
102   if (TotalKernArgSize == 0)
103     return false;
104 
105   CallInst *KernArgSegment =
106     Builder.CreateIntrinsic(Intrinsic::amdgcn_kernarg_segment_ptr, nullptr,
107                             F.getName() + ".kernarg.segment");
108 
109   KernArgSegment->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull);
110   KernArgSegment->addAttribute(AttributeList::ReturnIndex,
111     Attribute::getWithDereferenceableBytes(Ctx, TotalKernArgSize));
112   KernArgSegment->addAttribute(AttributeList::ReturnIndex,
113     Attribute::getWithAlignment(Ctx, KernArgBaseAlign));
114 
115   Value *KernArgBase = KernArgSegment;
116   if (BaseOffset != 0) {
117     KernArgBase = Builder.CreateConstInBoundsGEP1_64(KernArgBase, BaseOffset);
118     KernArgBaseAlign = MinAlign(KernArgBaseAlign, BaseOffset);
119   }
120 
121   unsigned AS = KernArgSegment->getType()->getPointerAddressSpace();
122   Value *CastStruct = Builder.CreateBitCast(KernArgBase,
123                                             ArgStructTy->getPointerTo(AS));
124   for (Argument &Arg : F.args()) {
125     if (Arg.use_empty())
126       continue;
127 
128     Type *ArgTy = Arg.getType();
129     if (PointerType *PT = dyn_cast<PointerType>(ArgTy)) {
130       // FIXME: Hack. We rely on AssertZext to be able to fold DS addressing
131       // modes on SI to know the high bits are 0 so pointer adds don't wrap. We
132       // can't represent this with range metadata because it's only allowed for
133       // integer types.
134       if (PT->getAddressSpace() == AMDGPUAS::LOCAL_ADDRESS &&
135           ST.getGeneration() == AMDGPUSubtarget::SOUTHERN_ISLANDS)
136         continue;
137 
138       // FIXME: We can replace this with equivalent alias.scope/noalias
139       // metadata, but this appears to be a lot of work.
140       if (Arg.hasNoAliasAttr())
141         continue;
142     }
143 
144     VectorType *VT = dyn_cast<VectorType>(ArgTy);
145     bool IsV3 = VT && VT->getNumElements() == 3;
146     VectorType *V4Ty = nullptr;
147 
148     unsigned Size = DL.getTypeSizeInBits(ArgTy);
149     bool IsExtArg = Size < 32 && (Arg.hasZExtAttr() || Arg.hasSExtAttr()) &&
150                     !ST.isAmdHsaOS();
151     int64_t EltOffset = Layout->getElementOffset(Arg.getArgNo());
152     int64_t AlignDownOffset = alignDown(EltOffset, 4);
153     int64_t OffsetDiff = EltOffset - AlignDownOffset;
154     unsigned AdjustedAlign = MinAlign(KernArgBaseAlign, AlignDownOffset);
155 
156     Value *ArgPtr;
157     if (Size < 32 && !ArgTy->isAggregateType()) { // FIXME: Handle aggregate types
158       // Since we don't have sub-dword scalar loads, avoid doing an extload by
159       // loading earlier than the argument address, and extracting the relevant
160       // bits.
161       //
162       // Additionally widen any sub-dword load to i32 even if suitably aligned,
163       // so that CSE between different argument loads works easily.
164 
165       ArgPtr = Builder.CreateConstGEP1_64(KernArgBase, AlignDownOffset);
166       ArgPtr = Builder.CreateBitCast(
167         ArgPtr,
168         Builder.getInt32Ty()->getPointerTo(AS),
169         Arg.getName() + ".kernarg.offset.align.down");
170     } else {
171       ArgPtr = Builder.CreateStructGEP(CastStruct, Arg.getArgNo(),
172                                        Arg.getName() + ".kernarg.offset");
173     }
174 
175     assert((!IsExtArg || !IsV3) && "incompatible situation");
176 
177 
178     if (IsV3 && Size >= 32) {
179       V4Ty = VectorType::get(VT->getVectorElementType(), 4);
180       // Use the hack that clang uses to avoid SelectionDAG ruining v3 loads
181       ArgPtr = Builder.CreateBitCast(ArgPtr, V4Ty->getPointerTo(AS));
182     }
183 
184     LoadInst *Load = Builder.CreateAlignedLoad(ArgPtr, AdjustedAlign);
185     Load->setMetadata(LLVMContext::MD_invariant_load, MDNode::get(Ctx, {}));
186 
187     MDBuilder MDB(Ctx);
188 
189     if (isa<PointerType>(ArgTy)) {
190       if (Arg.hasNonNullAttr())
191         Load->setMetadata(LLVMContext::MD_nonnull, MDNode::get(Ctx, {}));
192 
193       uint64_t DerefBytes = Arg.getDereferenceableBytes();
194       if (DerefBytes != 0) {
195         Load->setMetadata(
196           LLVMContext::MD_dereferenceable,
197           MDNode::get(Ctx,
198                       MDB.createConstant(
199                         ConstantInt::get(Builder.getInt64Ty(), DerefBytes))));
200       }
201 
202       uint64_t DerefOrNullBytes = Arg.getDereferenceableOrNullBytes();
203       if (DerefOrNullBytes != 0) {
204         Load->setMetadata(
205           LLVMContext::MD_dereferenceable_or_null,
206           MDNode::get(Ctx,
207                       MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(),
208                                                           DerefOrNullBytes))));
209       }
210 
211       unsigned ParamAlign = Arg.getParamAlignment();
212       if (ParamAlign != 0) {
213         Load->setMetadata(
214           LLVMContext::MD_align,
215           MDNode::get(Ctx,
216                       MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(),
217                                                           ParamAlign))));
218       }
219     }
220 
221     // TODO: Convert noalias arg to !noalias
222 
223     if (Size < 32 && !ArgTy->isAggregateType()) {
224       if (IsExtArg && OffsetDiff == 0) {
225         Type *I32Ty = Builder.getInt32Ty();
226         bool IsSext = Arg.hasSExtAttr();
227         Metadata *LowAndHigh[] = {
228           ConstantAsMetadata::get(
229             ConstantInt::get(I32Ty, IsSext ? minIntN(Size) : 0)),
230           ConstantAsMetadata::get(
231             ConstantInt::get(I32Ty,
232                              IsSext ? maxIntN(Size) + 1 : maxUIntN(Size) + 1))
233         };
234 
235         Load->setMetadata(LLVMContext::MD_range, MDNode::get(Ctx, LowAndHigh));
236       }
237 
238       Value *ExtractBits = OffsetDiff == 0 ?
239         Load : Builder.CreateLShr(Load, OffsetDiff * 8);
240 
241       IntegerType *ArgIntTy = Builder.getIntNTy(Size);
242       Value *Trunc = Builder.CreateTrunc(ExtractBits, ArgIntTy);
243       Value *NewVal = Builder.CreateBitCast(Trunc, ArgTy,
244                                             Arg.getName() + ".load");
245       Arg.replaceAllUsesWith(NewVal);
246     } else if (IsV3) {
247       Value *Shuf = Builder.CreateShuffleVector(Load, UndefValue::get(V4Ty),
248                                                 {0, 1, 2},
249                                                 Arg.getName() + ".load");
250       Arg.replaceAllUsesWith(Shuf);
251     } else {
252       Load->setName(Arg.getName() + ".load");
253       Arg.replaceAllUsesWith(Load);
254     }
255   }
256 
257   return true;
258 }
259 
260 INITIALIZE_PASS_BEGIN(AMDGPULowerKernelArguments, DEBUG_TYPE,
261                       "AMDGPU Lower Kernel Arguments", false, false)
262 INITIALIZE_PASS_END(AMDGPULowerKernelArguments, DEBUG_TYPE, "AMDGPU Lower Kernel Arguments",
263                     false, false)
264 
265 char AMDGPULowerKernelArguments::ID = 0;
266 
267 FunctionPass *llvm::createAMDGPULowerKernelArgumentsPass() {
268   return new AMDGPULowerKernelArguments();
269 }
270