18c4a3523SMatt Arsenault //===-- AMDGPULowerKernelArguments.cpp ------------------------------------------===//
28c4a3523SMatt Arsenault //
32946cd70SChandler Carruth // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42946cd70SChandler Carruth // See https://llvm.org/LICENSE.txt for license information.
52946cd70SChandler Carruth // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
68c4a3523SMatt Arsenault //
78c4a3523SMatt Arsenault //===----------------------------------------------------------------------===//
88c4a3523SMatt Arsenault //
98c4a3523SMatt Arsenault /// \file This pass replaces accesses to kernel arguments with loads from
108c4a3523SMatt Arsenault /// offsets from the kernarg base pointer.
118c4a3523SMatt Arsenault //
128c4a3523SMatt Arsenault //===----------------------------------------------------------------------===//
138c4a3523SMatt Arsenault 
148c4a3523SMatt Arsenault #include "AMDGPU.h"
15560d7e04Sdfukalov #include "GCNSubtarget.h"
168c4a3523SMatt Arsenault #include "llvm/CodeGen/TargetPassConfig.h"
176a87e9b0Sdfukalov #include "llvm/IR/IntrinsicsAMDGPU.h"
1899142003SNikita Popov #include "llvm/IR/IRBuilder.h"
198c4a3523SMatt Arsenault #include "llvm/IR/MDBuilder.h"
206a87e9b0Sdfukalov #include "llvm/Target/TargetMachine.h"
218c4a3523SMatt Arsenault #define DEBUG_TYPE "amdgpu-lower-kernel-arguments"
228c4a3523SMatt Arsenault 
238c4a3523SMatt Arsenault using namespace llvm;
248c4a3523SMatt Arsenault 
258c4a3523SMatt Arsenault namespace {
268c4a3523SMatt Arsenault 
278c4a3523SMatt Arsenault class AMDGPULowerKernelArguments : public FunctionPass{
288c4a3523SMatt Arsenault public:
298c4a3523SMatt Arsenault   static char ID;
308c4a3523SMatt Arsenault 
AMDGPULowerKernelArguments()318c4a3523SMatt Arsenault   AMDGPULowerKernelArguments() : FunctionPass(ID) {}
328c4a3523SMatt Arsenault 
338c4a3523SMatt Arsenault   bool runOnFunction(Function &F) override;
348c4a3523SMatt Arsenault 
getAnalysisUsage(AnalysisUsage & AU) const358c4a3523SMatt Arsenault   void getAnalysisUsage(AnalysisUsage &AU) const override {
368c4a3523SMatt Arsenault     AU.addRequired<TargetPassConfig>();
378c4a3523SMatt Arsenault     AU.setPreservesAll();
388c4a3523SMatt Arsenault  }
398c4a3523SMatt Arsenault };
408c4a3523SMatt Arsenault 
418c4a3523SMatt Arsenault } // end anonymous namespace
428c4a3523SMatt Arsenault 
43074c371aSMatt Arsenault // skip allocas
getInsertPt(BasicBlock & BB)44074c371aSMatt Arsenault static BasicBlock::iterator getInsertPt(BasicBlock &BB) {
45074c371aSMatt Arsenault   BasicBlock::iterator InsPt = BB.getFirstInsertionPt();
46074c371aSMatt Arsenault   for (BasicBlock::iterator E = BB.end(); InsPt != E; ++InsPt) {
47074c371aSMatt Arsenault     AllocaInst *AI = dyn_cast<AllocaInst>(&*InsPt);
48074c371aSMatt Arsenault 
49074c371aSMatt Arsenault     // If this is a dynamic alloca, the value may depend on the loaded kernargs,
50074c371aSMatt Arsenault     // so loads will need to be inserted before it.
51074c371aSMatt Arsenault     if (!AI || !AI->isStaticAlloca())
52074c371aSMatt Arsenault       break;
53074c371aSMatt Arsenault   }
54074c371aSMatt Arsenault 
55074c371aSMatt Arsenault   return InsPt;
56074c371aSMatt Arsenault }
57074c371aSMatt Arsenault 
runOnFunction(Function & F)588c4a3523SMatt Arsenault bool AMDGPULowerKernelArguments::runOnFunction(Function &F) {
598c4a3523SMatt Arsenault   CallingConv::ID CC = F.getCallingConv();
608c4a3523SMatt Arsenault   if (CC != CallingConv::AMDGPU_KERNEL || F.arg_empty())
618c4a3523SMatt Arsenault     return false;
628c4a3523SMatt Arsenault 
638c4a3523SMatt Arsenault   auto &TPC = getAnalysis<TargetPassConfig>();
648c4a3523SMatt Arsenault 
658c4a3523SMatt Arsenault   const TargetMachine &TM = TPC.getTM<TargetMachine>();
665bfbae5cSTom Stellard   const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F);
678c4a3523SMatt Arsenault   LLVMContext &Ctx = F.getParent()->getContext();
688c4a3523SMatt Arsenault   const DataLayout &DL = F.getParent()->getDataLayout();
698c4a3523SMatt Arsenault   BasicBlock &EntryBlock = *F.begin();
70074c371aSMatt Arsenault   IRBuilder<> Builder(&*getInsertPt(EntryBlock));
718c4a3523SMatt Arsenault 
72b65fa483SGuillaume Chatelet   const Align KernArgBaseAlign(16); // FIXME: Increase if necessary
738c4a3523SMatt Arsenault   const uint64_t BaseOffset = ST.getExplicitKernelArgOffset(F);
748c4a3523SMatt Arsenault 
75b65fa483SGuillaume Chatelet   Align MaxAlign;
766527b2a4SSebastian Neubauer   // FIXME: Alignment is broken with explicit arg offset.;
774bec7d42SMatt Arsenault   const uint64_t TotalKernArgSize = ST.getKernArgSegmentSize(F, MaxAlign);
78513e0c0eSMatt Arsenault   if (TotalKernArgSize == 0)
79513e0c0eSMatt Arsenault     return false;
808c4a3523SMatt Arsenault 
818c4a3523SMatt Arsenault   CallInst *KernArgSegment =
8257f5d0a8SNeil Henning       Builder.CreateIntrinsic(Intrinsic::amdgcn_kernarg_segment_ptr, {}, {},
8357f5d0a8SNeil Henning                               nullptr, F.getName() + ".kernarg.segment");
848c4a3523SMatt Arsenault 
853f4d00bcSArthur Eubanks   KernArgSegment->addRetAttr(Attribute::NonNull);
863f4d00bcSArthur Eubanks   KernArgSegment->addRetAttr(
878c4a3523SMatt Arsenault       Attribute::getWithDereferenceableBytes(Ctx, TotalKernArgSize));
888c4a3523SMatt Arsenault 
898c4a3523SMatt Arsenault   unsigned AS = KernArgSegment->getType()->getPointerAddressSpace();
90f5be3ad7SMatt Arsenault   uint64_t ExplicitArgOffset = 0;
91f5be3ad7SMatt Arsenault 
928c4a3523SMatt Arsenault   for (Argument &Arg : F.args()) {
931168119cSMatt Arsenault     const bool IsByRef = Arg.hasByRefAttr();
941168119cSMatt Arsenault     Type *ArgTy = IsByRef ? Arg.getParamByRefType() : Arg.getType();
95*d154d0acSGuillaume Chatelet     MaybeAlign ParamAlign = IsByRef ? Arg.getParamAlign() : None;
96*d154d0acSGuillaume Chatelet     Align ABITypeAlign = DL.getValueOrABITypeAlignment(ParamAlign, ArgTy);
971168119cSMatt Arsenault 
981168119cSMatt Arsenault     uint64_t Size = DL.getTypeSizeInBits(ArgTy);
991168119cSMatt Arsenault     uint64_t AllocSize = DL.getTypeAllocSize(ArgTy);
100f5be3ad7SMatt Arsenault 
101b65fa483SGuillaume Chatelet     uint64_t EltOffset = alignTo(ExplicitArgOffset, ABITypeAlign) + BaseOffset;
102b65fa483SGuillaume Chatelet     ExplicitArgOffset = alignTo(ExplicitArgOffset, ABITypeAlign) + AllocSize;
103f5be3ad7SMatt Arsenault 
1048c4a3523SMatt Arsenault     if (Arg.use_empty())
1058c4a3523SMatt Arsenault       continue;
1068c4a3523SMatt Arsenault 
1071168119cSMatt Arsenault     // If this is byval, the loads are already explicit in the function. We just
1081168119cSMatt Arsenault     // need to rewrite the pointer values.
1091168119cSMatt Arsenault     if (IsByRef) {
1101168119cSMatt Arsenault       Value *ArgOffsetPtr = Builder.CreateConstInBoundsGEP1_64(
1111168119cSMatt Arsenault           Builder.getInt8Ty(), KernArgSegment, EltOffset,
1121168119cSMatt Arsenault           Arg.getName() + ".byval.kernarg.offset");
1131168119cSMatt Arsenault 
1141168119cSMatt Arsenault       Value *CastOffsetPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
1151168119cSMatt Arsenault           ArgOffsetPtr, Arg.getType());
1161168119cSMatt Arsenault       Arg.replaceAllUsesWith(CastOffsetPtr);
1171168119cSMatt Arsenault       continue;
1181168119cSMatt Arsenault     }
1191168119cSMatt Arsenault 
1208c4a3523SMatt Arsenault     if (PointerType *PT = dyn_cast<PointerType>(ArgTy)) {
1218c4a3523SMatt Arsenault       // FIXME: Hack. We rely on AssertZext to be able to fold DS addressing
1228c4a3523SMatt Arsenault       // modes on SI to know the high bits are 0 so pointer adds don't wrap. We
1238c4a3523SMatt Arsenault       // can't represent this with range metadata because it's only allowed for
1248c4a3523SMatt Arsenault       // integer types.
125cdd191d9SMatt Arsenault       if ((PT->getAddressSpace() == AMDGPUAS::LOCAL_ADDRESS ||
126cdd191d9SMatt Arsenault            PT->getAddressSpace() == AMDGPUAS::REGION_ADDRESS) &&
127e4c2e9b0SMatt Arsenault           !ST.hasUsableDSOffset())
1288c4a3523SMatt Arsenault         continue;
1298c4a3523SMatt Arsenault 
1308c4a3523SMatt Arsenault       // FIXME: We can replace this with equivalent alias.scope/noalias
1318c4a3523SMatt Arsenault       // metadata, but this appears to be a lot of work.
1328c4a3523SMatt Arsenault       if (Arg.hasNoAliasAttr())
1338c4a3523SMatt Arsenault         continue;
1348c4a3523SMatt Arsenault     }
1358c4a3523SMatt Arsenault 
1363254a001SChristopher Tetreault     auto *VT = dyn_cast<FixedVectorType>(ArgTy);
1378c4a3523SMatt Arsenault     bool IsV3 = VT && VT->getNumElements() == 3;
138b5613ecfSMatt Arsenault     bool DoShiftOpt = Size < 32 && !ArgTy->isAggregateType();
139b5613ecfSMatt Arsenault 
1408c4a3523SMatt Arsenault     VectorType *V4Ty = nullptr;
1418c4a3523SMatt Arsenault 
1428c4a3523SMatt Arsenault     int64_t AlignDownOffset = alignDown(EltOffset, 4);
1438c4a3523SMatt Arsenault     int64_t OffsetDiff = EltOffset - AlignDownOffset;
144b65fa483SGuillaume Chatelet     Align AdjustedAlign = commonAlignment(
145b65fa483SGuillaume Chatelet         KernArgBaseAlign, DoShiftOpt ? AlignDownOffset : EltOffset);
1468c4a3523SMatt Arsenault 
1478c4a3523SMatt Arsenault     Value *ArgPtr;
14814359ef1SJames Y Knight     Type *AdjustedArgTy;
149b5613ecfSMatt Arsenault     if (DoShiftOpt) { // FIXME: Handle aggregate types
1508c4a3523SMatt Arsenault       // Since we don't have sub-dword scalar loads, avoid doing an extload by
1518c4a3523SMatt Arsenault       // loading earlier than the argument address, and extracting the relevant
1528c4a3523SMatt Arsenault       // bits.
1538c4a3523SMatt Arsenault       //
1548c4a3523SMatt Arsenault       // Additionally widen any sub-dword load to i32 even if suitably aligned,
1558c4a3523SMatt Arsenault       // so that CSE between different argument loads works easily.
156f5be3ad7SMatt Arsenault       ArgPtr = Builder.CreateConstInBoundsGEP1_64(
1577716075aSJames Y Knight           Builder.getInt8Ty(), KernArgSegment, AlignDownOffset,
1588c4a3523SMatt Arsenault           Arg.getName() + ".kernarg.offset.align.down");
15914359ef1SJames Y Knight       AdjustedArgTy = Builder.getInt32Ty();
1608c4a3523SMatt Arsenault     } else {
161f5be3ad7SMatt Arsenault       ArgPtr = Builder.CreateConstInBoundsGEP1_64(
1627716075aSJames Y Knight           Builder.getInt8Ty(), KernArgSegment, EltOffset,
1638c4a3523SMatt Arsenault           Arg.getName() + ".kernarg.offset");
16414359ef1SJames Y Knight       AdjustedArgTy = ArgTy;
1658c4a3523SMatt Arsenault     }
1668c4a3523SMatt Arsenault 
1678c4a3523SMatt Arsenault     if (IsV3 && Size >= 32) {
168aad93654SChristopher Tetreault       V4Ty = FixedVectorType::get(VT->getElementType(), 4);
1698c4a3523SMatt Arsenault       // Use the hack that clang uses to avoid SelectionDAG ruining v3 loads
17014359ef1SJames Y Knight       AdjustedArgTy = V4Ty;
1718c4a3523SMatt Arsenault     }
1728c4a3523SMatt Arsenault 
17314359ef1SJames Y Knight     ArgPtr = Builder.CreateBitCast(ArgPtr, AdjustedArgTy->getPointerTo(AS),
17414359ef1SJames Y Knight                                    ArgPtr->getName() + ".cast");
17514359ef1SJames Y Knight     LoadInst *Load =
176279fa8e0SGuillaume Chatelet         Builder.CreateAlignedLoad(AdjustedArgTy, ArgPtr, AdjustedAlign);
1778c4a3523SMatt Arsenault     Load->setMetadata(LLVMContext::MD_invariant_load, MDNode::get(Ctx, {}));
1788c4a3523SMatt Arsenault 
1798c4a3523SMatt Arsenault     MDBuilder MDB(Ctx);
1808c4a3523SMatt Arsenault 
1818c4a3523SMatt Arsenault     if (isa<PointerType>(ArgTy)) {
1828c4a3523SMatt Arsenault       if (Arg.hasNonNullAttr())
1838c4a3523SMatt Arsenault         Load->setMetadata(LLVMContext::MD_nonnull, MDNode::get(Ctx, {}));
1848c4a3523SMatt Arsenault 
1858c4a3523SMatt Arsenault       uint64_t DerefBytes = Arg.getDereferenceableBytes();
1868c4a3523SMatt Arsenault       if (DerefBytes != 0) {
1878c4a3523SMatt Arsenault         Load->setMetadata(
1888c4a3523SMatt Arsenault           LLVMContext::MD_dereferenceable,
1898c4a3523SMatt Arsenault           MDNode::get(Ctx,
1908c4a3523SMatt Arsenault                       MDB.createConstant(
1918c4a3523SMatt Arsenault                         ConstantInt::get(Builder.getInt64Ty(), DerefBytes))));
1928c4a3523SMatt Arsenault       }
1938c4a3523SMatt Arsenault 
1948c4a3523SMatt Arsenault       uint64_t DerefOrNullBytes = Arg.getDereferenceableOrNullBytes();
1958c4a3523SMatt Arsenault       if (DerefOrNullBytes != 0) {
1968c4a3523SMatt Arsenault         Load->setMetadata(
1978c4a3523SMatt Arsenault           LLVMContext::MD_dereferenceable_or_null,
1988c4a3523SMatt Arsenault           MDNode::get(Ctx,
1998c4a3523SMatt Arsenault                       MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(),
2008c4a3523SMatt Arsenault                                                           DerefOrNullBytes))));
2018c4a3523SMatt Arsenault       }
2028c4a3523SMatt Arsenault 
2038c4a3523SMatt Arsenault       unsigned ParamAlign = Arg.getParamAlignment();
2048c4a3523SMatt Arsenault       if (ParamAlign != 0) {
2058c4a3523SMatt Arsenault         Load->setMetadata(
2068c4a3523SMatt Arsenault           LLVMContext::MD_align,
2078c4a3523SMatt Arsenault           MDNode::get(Ctx,
2088c4a3523SMatt Arsenault                       MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(),
2098c4a3523SMatt Arsenault                                                           ParamAlign))));
2108c4a3523SMatt Arsenault       }
2118c4a3523SMatt Arsenault     }
2128c4a3523SMatt Arsenault 
2138c4a3523SMatt Arsenault     // TODO: Convert noalias arg to !noalias
2148c4a3523SMatt Arsenault 
215b5613ecfSMatt Arsenault     if (DoShiftOpt) {
2168c4a3523SMatt Arsenault       Value *ExtractBits = OffsetDiff == 0 ?
2178c4a3523SMatt Arsenault         Load : Builder.CreateLShr(Load, OffsetDiff * 8);
2188c4a3523SMatt Arsenault 
2198c4a3523SMatt Arsenault       IntegerType *ArgIntTy = Builder.getIntNTy(Size);
2208c4a3523SMatt Arsenault       Value *Trunc = Builder.CreateTrunc(ExtractBits, ArgIntTy);
2218c4a3523SMatt Arsenault       Value *NewVal = Builder.CreateBitCast(Trunc, ArgTy,
2228c4a3523SMatt Arsenault                                             Arg.getName() + ".load");
2238c4a3523SMatt Arsenault       Arg.replaceAllUsesWith(NewVal);
2248c4a3523SMatt Arsenault     } else if (IsV3) {
225420d046dSJuneyoung Lee       Value *Shuf = Builder.CreateShuffleVector(Load, ArrayRef<int>{0, 1, 2},
2268c4a3523SMatt Arsenault                                                 Arg.getName() + ".load");
2278c4a3523SMatt Arsenault       Arg.replaceAllUsesWith(Shuf);
2288c4a3523SMatt Arsenault     } else {
2298c4a3523SMatt Arsenault       Load->setName(Arg.getName() + ".load");
2308c4a3523SMatt Arsenault       Arg.replaceAllUsesWith(Load);
2318c4a3523SMatt Arsenault     }
2328c4a3523SMatt Arsenault   }
2338c4a3523SMatt Arsenault 
2343f4d00bcSArthur Eubanks   KernArgSegment->addRetAttr(
235f5be3ad7SMatt Arsenault       Attribute::getWithAlignment(Ctx, std::max(KernArgBaseAlign, MaxAlign)));
236f5be3ad7SMatt Arsenault 
2378c4a3523SMatt Arsenault   return true;
2388c4a3523SMatt Arsenault }
2398c4a3523SMatt Arsenault 
2408c4a3523SMatt Arsenault INITIALIZE_PASS_BEGIN(AMDGPULowerKernelArguments, DEBUG_TYPE,
2418c4a3523SMatt Arsenault                       "AMDGPU Lower Kernel Arguments", false, false)
2428c4a3523SMatt Arsenault INITIALIZE_PASS_END(AMDGPULowerKernelArguments, DEBUG_TYPE, "AMDGPU Lower Kernel Arguments",
2438c4a3523SMatt Arsenault                     false, false)
2448c4a3523SMatt Arsenault 
2458c4a3523SMatt Arsenault char AMDGPULowerKernelArguments::ID = 0;
2468c4a3523SMatt Arsenault 
createAMDGPULowerKernelArgumentsPass()2478c4a3523SMatt Arsenault FunctionPass *llvm::createAMDGPULowerKernelArgumentsPass() {
2488c4a3523SMatt Arsenault   return new AMDGPULowerKernelArguments();
2498c4a3523SMatt Arsenault }
250