18c4a3523SMatt Arsenault //===-- AMDGPULowerKernelArguments.cpp ------------------------------------------===//
28c4a3523SMatt Arsenault //
3*2946cd70SChandler Carruth // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*2946cd70SChandler Carruth // See https://llvm.org/LICENSE.txt for license information.
5*2946cd70SChandler Carruth // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
68c4a3523SMatt Arsenault //
78c4a3523SMatt Arsenault //===----------------------------------------------------------------------===//
88c4a3523SMatt Arsenault //
98c4a3523SMatt Arsenault /// \file This pass replaces accesses to kernel arguments with loads from
108c4a3523SMatt Arsenault /// offsets from the kernarg base pointer.
118c4a3523SMatt Arsenault //
128c4a3523SMatt Arsenault //===----------------------------------------------------------------------===//
138c4a3523SMatt Arsenault 
148c4a3523SMatt Arsenault #include "AMDGPU.h"
158c4a3523SMatt Arsenault #include "AMDGPUSubtarget.h"
168c4a3523SMatt Arsenault #include "AMDGPUTargetMachine.h"
178c4a3523SMatt Arsenault #include "llvm/ADT/StringRef.h"
188c4a3523SMatt Arsenault #include "llvm/Analysis/Loads.h"
198c4a3523SMatt Arsenault #include "llvm/CodeGen/Passes.h"
208c4a3523SMatt Arsenault #include "llvm/CodeGen/TargetPassConfig.h"
218c4a3523SMatt Arsenault #include "llvm/IR/Attributes.h"
228c4a3523SMatt Arsenault #include "llvm/IR/BasicBlock.h"
238c4a3523SMatt Arsenault #include "llvm/IR/Constants.h"
248c4a3523SMatt Arsenault #include "llvm/IR/DerivedTypes.h"
258c4a3523SMatt Arsenault #include "llvm/IR/Function.h"
268c4a3523SMatt Arsenault #include "llvm/IR/IRBuilder.h"
278c4a3523SMatt Arsenault #include "llvm/IR/InstrTypes.h"
288c4a3523SMatt Arsenault #include "llvm/IR/Instruction.h"
298c4a3523SMatt Arsenault #include "llvm/IR/Instructions.h"
308c4a3523SMatt Arsenault #include "llvm/IR/LLVMContext.h"
318c4a3523SMatt Arsenault #include "llvm/IR/MDBuilder.h"
328c4a3523SMatt Arsenault #include "llvm/IR/Metadata.h"
338c4a3523SMatt Arsenault #include "llvm/IR/Operator.h"
348c4a3523SMatt Arsenault #include "llvm/IR/Type.h"
358c4a3523SMatt Arsenault #include "llvm/IR/Value.h"
368c4a3523SMatt Arsenault #include "llvm/Pass.h"
378c4a3523SMatt Arsenault #include "llvm/Support/Casting.h"
388c4a3523SMatt Arsenault 
398c4a3523SMatt Arsenault #define DEBUG_TYPE "amdgpu-lower-kernel-arguments"
408c4a3523SMatt Arsenault 
418c4a3523SMatt Arsenault using namespace llvm;
428c4a3523SMatt Arsenault 
438c4a3523SMatt Arsenault namespace {
448c4a3523SMatt Arsenault 
458c4a3523SMatt Arsenault class AMDGPULowerKernelArguments : public FunctionPass{
468c4a3523SMatt Arsenault public:
478c4a3523SMatt Arsenault   static char ID;
488c4a3523SMatt Arsenault 
498c4a3523SMatt Arsenault   AMDGPULowerKernelArguments() : FunctionPass(ID) {}
508c4a3523SMatt Arsenault 
518c4a3523SMatt Arsenault   bool runOnFunction(Function &F) override;
528c4a3523SMatt Arsenault 
538c4a3523SMatt Arsenault   void getAnalysisUsage(AnalysisUsage &AU) const override {
548c4a3523SMatt Arsenault     AU.addRequired<TargetPassConfig>();
558c4a3523SMatt Arsenault     AU.setPreservesAll();
568c4a3523SMatt Arsenault  }
578c4a3523SMatt Arsenault };
588c4a3523SMatt Arsenault 
598c4a3523SMatt Arsenault } // end anonymous namespace
608c4a3523SMatt Arsenault 
618c4a3523SMatt Arsenault bool AMDGPULowerKernelArguments::runOnFunction(Function &F) {
628c4a3523SMatt Arsenault   CallingConv::ID CC = F.getCallingConv();
638c4a3523SMatt Arsenault   if (CC != CallingConv::AMDGPU_KERNEL || F.arg_empty())
648c4a3523SMatt Arsenault     return false;
658c4a3523SMatt Arsenault 
668c4a3523SMatt Arsenault   auto &TPC = getAnalysis<TargetPassConfig>();
678c4a3523SMatt Arsenault 
688c4a3523SMatt Arsenault   const TargetMachine &TM = TPC.getTM<TargetMachine>();
695bfbae5cSTom Stellard   const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F);
708c4a3523SMatt Arsenault   LLVMContext &Ctx = F.getParent()->getContext();
718c4a3523SMatt Arsenault   const DataLayout &DL = F.getParent()->getDataLayout();
728c4a3523SMatt Arsenault   BasicBlock &EntryBlock = *F.begin();
738c4a3523SMatt Arsenault   IRBuilder<> Builder(&*EntryBlock.begin());
748c4a3523SMatt Arsenault 
75f5be3ad7SMatt Arsenault   const unsigned KernArgBaseAlign = 16; // FIXME: Increase if necessary
768c4a3523SMatt Arsenault   const uint64_t BaseOffset = ST.getExplicitKernelArgOffset(F);
778c4a3523SMatt Arsenault 
784bec7d42SMatt Arsenault   unsigned MaxAlign;
798c4a3523SMatt Arsenault   // FIXME: Alignment is broken broken with explicit arg offset.;
804bec7d42SMatt Arsenault   const uint64_t TotalKernArgSize = ST.getKernArgSegmentSize(F, MaxAlign);
81513e0c0eSMatt Arsenault   if (TotalKernArgSize == 0)
82513e0c0eSMatt Arsenault     return false;
838c4a3523SMatt Arsenault 
848c4a3523SMatt Arsenault   CallInst *KernArgSegment =
8557f5d0a8SNeil Henning       Builder.CreateIntrinsic(Intrinsic::amdgcn_kernarg_segment_ptr, {}, {},
8657f5d0a8SNeil Henning                               nullptr, F.getName() + ".kernarg.segment");
878c4a3523SMatt Arsenault 
888c4a3523SMatt Arsenault   KernArgSegment->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull);
898c4a3523SMatt Arsenault   KernArgSegment->addAttribute(AttributeList::ReturnIndex,
908c4a3523SMatt Arsenault     Attribute::getWithDereferenceableBytes(Ctx, TotalKernArgSize));
918c4a3523SMatt Arsenault 
928c4a3523SMatt Arsenault   unsigned AS = KernArgSegment->getType()->getPointerAddressSpace();
93f5be3ad7SMatt Arsenault   uint64_t ExplicitArgOffset = 0;
94f5be3ad7SMatt Arsenault 
958c4a3523SMatt Arsenault   for (Argument &Arg : F.args()) {
96f5be3ad7SMatt Arsenault     Type *ArgTy = Arg.getType();
97f5be3ad7SMatt Arsenault     unsigned Align = DL.getABITypeAlignment(ArgTy);
98f5be3ad7SMatt Arsenault     unsigned Size = DL.getTypeSizeInBits(ArgTy);
99f5be3ad7SMatt Arsenault     unsigned AllocSize = DL.getTypeAllocSize(ArgTy);
100f5be3ad7SMatt Arsenault 
101f5be3ad7SMatt Arsenault     uint64_t EltOffset = alignTo(ExplicitArgOffset, Align) + BaseOffset;
102f5be3ad7SMatt Arsenault     ExplicitArgOffset = alignTo(ExplicitArgOffset, Align) + AllocSize;
103f5be3ad7SMatt Arsenault 
1048c4a3523SMatt Arsenault     if (Arg.use_empty())
1058c4a3523SMatt Arsenault       continue;
1068c4a3523SMatt Arsenault 
1078c4a3523SMatt Arsenault     if (PointerType *PT = dyn_cast<PointerType>(ArgTy)) {
1088c4a3523SMatt Arsenault       // FIXME: Hack. We rely on AssertZext to be able to fold DS addressing
1098c4a3523SMatt Arsenault       // modes on SI to know the high bits are 0 so pointer adds don't wrap. We
1108c4a3523SMatt Arsenault       // can't represent this with range metadata because it's only allowed for
1118c4a3523SMatt Arsenault       // integer types.
1128c4a3523SMatt Arsenault       if (PT->getAddressSpace() == AMDGPUAS::LOCAL_ADDRESS &&
1138c4a3523SMatt Arsenault           ST.getGeneration() == AMDGPUSubtarget::SOUTHERN_ISLANDS)
1148c4a3523SMatt Arsenault         continue;
1158c4a3523SMatt Arsenault 
1168c4a3523SMatt Arsenault       // FIXME: We can replace this with equivalent alias.scope/noalias
1178c4a3523SMatt Arsenault       // metadata, but this appears to be a lot of work.
1188c4a3523SMatt Arsenault       if (Arg.hasNoAliasAttr())
1198c4a3523SMatt Arsenault         continue;
1208c4a3523SMatt Arsenault     }
1218c4a3523SMatt Arsenault 
1228c4a3523SMatt Arsenault     VectorType *VT = dyn_cast<VectorType>(ArgTy);
1238c4a3523SMatt Arsenault     bool IsV3 = VT && VT->getNumElements() == 3;
124b5613ecfSMatt Arsenault     bool DoShiftOpt = Size < 32 && !ArgTy->isAggregateType();
125b5613ecfSMatt Arsenault 
1268c4a3523SMatt Arsenault     VectorType *V4Ty = nullptr;
1278c4a3523SMatt Arsenault 
1288c4a3523SMatt Arsenault     int64_t AlignDownOffset = alignDown(EltOffset, 4);
1298c4a3523SMatt Arsenault     int64_t OffsetDiff = EltOffset - AlignDownOffset;
130b5613ecfSMatt Arsenault     unsigned AdjustedAlign = MinAlign(DoShiftOpt ? AlignDownOffset : EltOffset,
131b5613ecfSMatt Arsenault                                       KernArgBaseAlign);
1328c4a3523SMatt Arsenault 
1338c4a3523SMatt Arsenault     Value *ArgPtr;
134b5613ecfSMatt Arsenault     if (DoShiftOpt) { // FIXME: Handle aggregate types
1358c4a3523SMatt Arsenault       // Since we don't have sub-dword scalar loads, avoid doing an extload by
1368c4a3523SMatt Arsenault       // loading earlier than the argument address, and extracting the relevant
1378c4a3523SMatt Arsenault       // bits.
1388c4a3523SMatt Arsenault       //
1398c4a3523SMatt Arsenault       // Additionally widen any sub-dword load to i32 even if suitably aligned,
1408c4a3523SMatt Arsenault       // so that CSE between different argument loads works easily.
1418c4a3523SMatt Arsenault 
142f5be3ad7SMatt Arsenault       ArgPtr = Builder.CreateConstInBoundsGEP1_64(
143f5be3ad7SMatt Arsenault         KernArgSegment,
144f5be3ad7SMatt Arsenault         AlignDownOffset,
1458c4a3523SMatt Arsenault         Arg.getName() + ".kernarg.offset.align.down");
146f5be3ad7SMatt Arsenault       ArgPtr = Builder.CreateBitCast(ArgPtr,
147f5be3ad7SMatt Arsenault                                      Builder.getInt32Ty()->getPointerTo(AS),
148f5be3ad7SMatt Arsenault                                      ArgPtr->getName() + ".cast");
1498c4a3523SMatt Arsenault     } else {
150f5be3ad7SMatt Arsenault       ArgPtr = Builder.CreateConstInBoundsGEP1_64(
151f5be3ad7SMatt Arsenault         KernArgSegment,
152b5613ecfSMatt Arsenault         EltOffset,
1538c4a3523SMatt Arsenault         Arg.getName() + ".kernarg.offset");
154f5be3ad7SMatt Arsenault       ArgPtr = Builder.CreateBitCast(ArgPtr, ArgTy->getPointerTo(AS),
155f5be3ad7SMatt Arsenault                                      ArgPtr->getName() + ".cast");
1568c4a3523SMatt Arsenault     }
1578c4a3523SMatt Arsenault 
1588c4a3523SMatt Arsenault     if (IsV3 && Size >= 32) {
1598c4a3523SMatt Arsenault       V4Ty = VectorType::get(VT->getVectorElementType(), 4);
1608c4a3523SMatt Arsenault       // Use the hack that clang uses to avoid SelectionDAG ruining v3 loads
1618c4a3523SMatt Arsenault       ArgPtr = Builder.CreateBitCast(ArgPtr, V4Ty->getPointerTo(AS));
1628c4a3523SMatt Arsenault     }
1638c4a3523SMatt Arsenault 
1648c4a3523SMatt Arsenault     LoadInst *Load = Builder.CreateAlignedLoad(ArgPtr, AdjustedAlign);
1658c4a3523SMatt Arsenault     Load->setMetadata(LLVMContext::MD_invariant_load, MDNode::get(Ctx, {}));
1668c4a3523SMatt Arsenault 
1678c4a3523SMatt Arsenault     MDBuilder MDB(Ctx);
1688c4a3523SMatt Arsenault 
1698c4a3523SMatt Arsenault     if (isa<PointerType>(ArgTy)) {
1708c4a3523SMatt Arsenault       if (Arg.hasNonNullAttr())
1718c4a3523SMatt Arsenault         Load->setMetadata(LLVMContext::MD_nonnull, MDNode::get(Ctx, {}));
1728c4a3523SMatt Arsenault 
1738c4a3523SMatt Arsenault       uint64_t DerefBytes = Arg.getDereferenceableBytes();
1748c4a3523SMatt Arsenault       if (DerefBytes != 0) {
1758c4a3523SMatt Arsenault         Load->setMetadata(
1768c4a3523SMatt Arsenault           LLVMContext::MD_dereferenceable,
1778c4a3523SMatt Arsenault           MDNode::get(Ctx,
1788c4a3523SMatt Arsenault                       MDB.createConstant(
1798c4a3523SMatt Arsenault                         ConstantInt::get(Builder.getInt64Ty(), DerefBytes))));
1808c4a3523SMatt Arsenault       }
1818c4a3523SMatt Arsenault 
1828c4a3523SMatt Arsenault       uint64_t DerefOrNullBytes = Arg.getDereferenceableOrNullBytes();
1838c4a3523SMatt Arsenault       if (DerefOrNullBytes != 0) {
1848c4a3523SMatt Arsenault         Load->setMetadata(
1858c4a3523SMatt Arsenault           LLVMContext::MD_dereferenceable_or_null,
1868c4a3523SMatt Arsenault           MDNode::get(Ctx,
1878c4a3523SMatt Arsenault                       MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(),
1888c4a3523SMatt Arsenault                                                           DerefOrNullBytes))));
1898c4a3523SMatt Arsenault       }
1908c4a3523SMatt Arsenault 
1918c4a3523SMatt Arsenault       unsigned ParamAlign = Arg.getParamAlignment();
1928c4a3523SMatt Arsenault       if (ParamAlign != 0) {
1938c4a3523SMatt Arsenault         Load->setMetadata(
1948c4a3523SMatt Arsenault           LLVMContext::MD_align,
1958c4a3523SMatt Arsenault           MDNode::get(Ctx,
1968c4a3523SMatt Arsenault                       MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(),
1978c4a3523SMatt Arsenault                                                           ParamAlign))));
1988c4a3523SMatt Arsenault       }
1998c4a3523SMatt Arsenault     }
2008c4a3523SMatt Arsenault 
2018c4a3523SMatt Arsenault     // TODO: Convert noalias arg to !noalias
2028c4a3523SMatt Arsenault 
203b5613ecfSMatt Arsenault     if (DoShiftOpt) {
2048c4a3523SMatt Arsenault       Value *ExtractBits = OffsetDiff == 0 ?
2058c4a3523SMatt Arsenault         Load : Builder.CreateLShr(Load, OffsetDiff * 8);
2068c4a3523SMatt Arsenault 
2078c4a3523SMatt Arsenault       IntegerType *ArgIntTy = Builder.getIntNTy(Size);
2088c4a3523SMatt Arsenault       Value *Trunc = Builder.CreateTrunc(ExtractBits, ArgIntTy);
2098c4a3523SMatt Arsenault       Value *NewVal = Builder.CreateBitCast(Trunc, ArgTy,
2108c4a3523SMatt Arsenault                                             Arg.getName() + ".load");
2118c4a3523SMatt Arsenault       Arg.replaceAllUsesWith(NewVal);
2128c4a3523SMatt Arsenault     } else if (IsV3) {
2138c4a3523SMatt Arsenault       Value *Shuf = Builder.CreateShuffleVector(Load, UndefValue::get(V4Ty),
2148c4a3523SMatt Arsenault                                                 {0, 1, 2},
2158c4a3523SMatt Arsenault                                                 Arg.getName() + ".load");
2168c4a3523SMatt Arsenault       Arg.replaceAllUsesWith(Shuf);
2178c4a3523SMatt Arsenault     } else {
2188c4a3523SMatt Arsenault       Load->setName(Arg.getName() + ".load");
2198c4a3523SMatt Arsenault       Arg.replaceAllUsesWith(Load);
2208c4a3523SMatt Arsenault     }
2218c4a3523SMatt Arsenault   }
2228c4a3523SMatt Arsenault 
223f5be3ad7SMatt Arsenault   KernArgSegment->addAttribute(
224f5be3ad7SMatt Arsenault     AttributeList::ReturnIndex,
225f5be3ad7SMatt Arsenault     Attribute::getWithAlignment(Ctx, std::max(KernArgBaseAlign, MaxAlign)));
226f5be3ad7SMatt Arsenault 
2278c4a3523SMatt Arsenault   return true;
2288c4a3523SMatt Arsenault }
2298c4a3523SMatt Arsenault 
2308c4a3523SMatt Arsenault INITIALIZE_PASS_BEGIN(AMDGPULowerKernelArguments, DEBUG_TYPE,
2318c4a3523SMatt Arsenault                       "AMDGPU Lower Kernel Arguments", false, false)
2328c4a3523SMatt Arsenault INITIALIZE_PASS_END(AMDGPULowerKernelArguments, DEBUG_TYPE, "AMDGPU Lower Kernel Arguments",
2338c4a3523SMatt Arsenault                     false, false)
2348c4a3523SMatt Arsenault 
2358c4a3523SMatt Arsenault char AMDGPULowerKernelArguments::ID = 0;
2368c4a3523SMatt Arsenault 
2378c4a3523SMatt Arsenault FunctionPass *llvm::createAMDGPULowerKernelArgumentsPass() {
2388c4a3523SMatt Arsenault   return new AMDGPULowerKernelArguments();
2398c4a3523SMatt Arsenault }
240