18c4a3523SMatt Arsenault //===-- AMDGPULowerKernelArguments.cpp ------------------------------------------===//
28c4a3523SMatt Arsenault //
32946cd70SChandler Carruth // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42946cd70SChandler Carruth // See https://llvm.org/LICENSE.txt for license information.
52946cd70SChandler Carruth // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
68c4a3523SMatt Arsenault //
78c4a3523SMatt Arsenault //===----------------------------------------------------------------------===//
88c4a3523SMatt Arsenault //
98c4a3523SMatt Arsenault /// \file This pass replaces accesses to kernel arguments with loads from
108c4a3523SMatt Arsenault /// offsets from the kernarg base pointer.
118c4a3523SMatt Arsenault //
128c4a3523SMatt Arsenault //===----------------------------------------------------------------------===//
138c4a3523SMatt Arsenault 
148c4a3523SMatt Arsenault #include "AMDGPU.h"
158c4a3523SMatt Arsenault #include "AMDGPUSubtarget.h"
168c4a3523SMatt Arsenault #include "AMDGPUTargetMachine.h"
178c4a3523SMatt Arsenault #include "llvm/ADT/StringRef.h"
188c4a3523SMatt Arsenault #include "llvm/Analysis/Loads.h"
198c4a3523SMatt Arsenault #include "llvm/CodeGen/Passes.h"
208c4a3523SMatt Arsenault #include "llvm/CodeGen/TargetPassConfig.h"
218c4a3523SMatt Arsenault #include "llvm/IR/Attributes.h"
228c4a3523SMatt Arsenault #include "llvm/IR/BasicBlock.h"
238c4a3523SMatt Arsenault #include "llvm/IR/Constants.h"
248c4a3523SMatt Arsenault #include "llvm/IR/DerivedTypes.h"
258c4a3523SMatt Arsenault #include "llvm/IR/Function.h"
268c4a3523SMatt Arsenault #include "llvm/IR/IRBuilder.h"
278c4a3523SMatt Arsenault #include "llvm/IR/InstrTypes.h"
288c4a3523SMatt Arsenault #include "llvm/IR/Instruction.h"
298c4a3523SMatt Arsenault #include "llvm/IR/Instructions.h"
308c4a3523SMatt Arsenault #include "llvm/IR/LLVMContext.h"
318c4a3523SMatt Arsenault #include "llvm/IR/MDBuilder.h"
328c4a3523SMatt Arsenault #include "llvm/IR/Metadata.h"
338c4a3523SMatt Arsenault #include "llvm/IR/Operator.h"
348c4a3523SMatt Arsenault #include "llvm/IR/Type.h"
358c4a3523SMatt Arsenault #include "llvm/IR/Value.h"
368c4a3523SMatt Arsenault #include "llvm/Pass.h"
378c4a3523SMatt Arsenault #include "llvm/Support/Casting.h"
388c4a3523SMatt Arsenault 
398c4a3523SMatt Arsenault #define DEBUG_TYPE "amdgpu-lower-kernel-arguments"
408c4a3523SMatt Arsenault 
418c4a3523SMatt Arsenault using namespace llvm;
428c4a3523SMatt Arsenault 
438c4a3523SMatt Arsenault namespace {
448c4a3523SMatt Arsenault 
458c4a3523SMatt Arsenault class AMDGPULowerKernelArguments : public FunctionPass{
468c4a3523SMatt Arsenault public:
478c4a3523SMatt Arsenault   static char ID;
488c4a3523SMatt Arsenault 
498c4a3523SMatt Arsenault   AMDGPULowerKernelArguments() : FunctionPass(ID) {}
508c4a3523SMatt Arsenault 
518c4a3523SMatt Arsenault   bool runOnFunction(Function &F) override;
528c4a3523SMatt Arsenault 
538c4a3523SMatt Arsenault   void getAnalysisUsage(AnalysisUsage &AU) const override {
548c4a3523SMatt Arsenault     AU.addRequired<TargetPassConfig>();
558c4a3523SMatt Arsenault     AU.setPreservesAll();
568c4a3523SMatt Arsenault  }
578c4a3523SMatt Arsenault };
588c4a3523SMatt Arsenault 
598c4a3523SMatt Arsenault } // end anonymous namespace
608c4a3523SMatt Arsenault 
618c4a3523SMatt Arsenault bool AMDGPULowerKernelArguments::runOnFunction(Function &F) {
628c4a3523SMatt Arsenault   CallingConv::ID CC = F.getCallingConv();
638c4a3523SMatt Arsenault   if (CC != CallingConv::AMDGPU_KERNEL || F.arg_empty())
648c4a3523SMatt Arsenault     return false;
658c4a3523SMatt Arsenault 
668c4a3523SMatt Arsenault   auto &TPC = getAnalysis<TargetPassConfig>();
678c4a3523SMatt Arsenault 
688c4a3523SMatt Arsenault   const TargetMachine &TM = TPC.getTM<TargetMachine>();
695bfbae5cSTom Stellard   const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F);
708c4a3523SMatt Arsenault   LLVMContext &Ctx = F.getParent()->getContext();
718c4a3523SMatt Arsenault   const DataLayout &DL = F.getParent()->getDataLayout();
728c4a3523SMatt Arsenault   BasicBlock &EntryBlock = *F.begin();
738c4a3523SMatt Arsenault   IRBuilder<> Builder(&*EntryBlock.begin());
748c4a3523SMatt Arsenault 
75*b65fa483SGuillaume Chatelet   const Align KernArgBaseAlign(16); // FIXME: Increase if necessary
768c4a3523SMatt Arsenault   const uint64_t BaseOffset = ST.getExplicitKernelArgOffset(F);
778c4a3523SMatt Arsenault 
78*b65fa483SGuillaume Chatelet   Align MaxAlign;
798c4a3523SMatt Arsenault   // FIXME: Alignment is broken broken with explicit arg offset.;
804bec7d42SMatt Arsenault   const uint64_t TotalKernArgSize = ST.getKernArgSegmentSize(F, MaxAlign);
81513e0c0eSMatt Arsenault   if (TotalKernArgSize == 0)
82513e0c0eSMatt Arsenault     return false;
838c4a3523SMatt Arsenault 
848c4a3523SMatt Arsenault   CallInst *KernArgSegment =
8557f5d0a8SNeil Henning       Builder.CreateIntrinsic(Intrinsic::amdgcn_kernarg_segment_ptr, {}, {},
8657f5d0a8SNeil Henning                               nullptr, F.getName() + ".kernarg.segment");
878c4a3523SMatt Arsenault 
888c4a3523SMatt Arsenault   KernArgSegment->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull);
898c4a3523SMatt Arsenault   KernArgSegment->addAttribute(AttributeList::ReturnIndex,
908c4a3523SMatt Arsenault     Attribute::getWithDereferenceableBytes(Ctx, TotalKernArgSize));
918c4a3523SMatt Arsenault 
928c4a3523SMatt Arsenault   unsigned AS = KernArgSegment->getType()->getPointerAddressSpace();
93f5be3ad7SMatt Arsenault   uint64_t ExplicitArgOffset = 0;
94f5be3ad7SMatt Arsenault 
958c4a3523SMatt Arsenault   for (Argument &Arg : F.args()) {
96f5be3ad7SMatt Arsenault     Type *ArgTy = Arg.getType();
97*b65fa483SGuillaume Chatelet     unsigned ABITypeAlign = DL.getABITypeAlignment(ArgTy);
98f5be3ad7SMatt Arsenault     unsigned Size = DL.getTypeSizeInBits(ArgTy);
99f5be3ad7SMatt Arsenault     unsigned AllocSize = DL.getTypeAllocSize(ArgTy);
100f5be3ad7SMatt Arsenault 
101*b65fa483SGuillaume Chatelet     uint64_t EltOffset = alignTo(ExplicitArgOffset, ABITypeAlign) + BaseOffset;
102*b65fa483SGuillaume Chatelet     ExplicitArgOffset = alignTo(ExplicitArgOffset, ABITypeAlign) + AllocSize;
103f5be3ad7SMatt Arsenault 
1048c4a3523SMatt Arsenault     if (Arg.use_empty())
1058c4a3523SMatt Arsenault       continue;
1068c4a3523SMatt Arsenault 
1078c4a3523SMatt Arsenault     if (PointerType *PT = dyn_cast<PointerType>(ArgTy)) {
1088c4a3523SMatt Arsenault       // FIXME: Hack. We rely on AssertZext to be able to fold DS addressing
1098c4a3523SMatt Arsenault       // modes on SI to know the high bits are 0 so pointer adds don't wrap. We
1108c4a3523SMatt Arsenault       // can't represent this with range metadata because it's only allowed for
1118c4a3523SMatt Arsenault       // integer types.
112cdd191d9SMatt Arsenault       if ((PT->getAddressSpace() == AMDGPUAS::LOCAL_ADDRESS ||
113cdd191d9SMatt Arsenault            PT->getAddressSpace() == AMDGPUAS::REGION_ADDRESS) &&
114e4c2e9b0SMatt Arsenault           !ST.hasUsableDSOffset())
1158c4a3523SMatt Arsenault         continue;
1168c4a3523SMatt Arsenault 
1178c4a3523SMatt Arsenault       // FIXME: We can replace this with equivalent alias.scope/noalias
1188c4a3523SMatt Arsenault       // metadata, but this appears to be a lot of work.
1198c4a3523SMatt Arsenault       if (Arg.hasNoAliasAttr())
1208c4a3523SMatt Arsenault         continue;
1218c4a3523SMatt Arsenault     }
1228c4a3523SMatt Arsenault 
1238c4a3523SMatt Arsenault     VectorType *VT = dyn_cast<VectorType>(ArgTy);
1248c4a3523SMatt Arsenault     bool IsV3 = VT && VT->getNumElements() == 3;
125b5613ecfSMatt Arsenault     bool DoShiftOpt = Size < 32 && !ArgTy->isAggregateType();
126b5613ecfSMatt Arsenault 
1278c4a3523SMatt Arsenault     VectorType *V4Ty = nullptr;
1288c4a3523SMatt Arsenault 
1298c4a3523SMatt Arsenault     int64_t AlignDownOffset = alignDown(EltOffset, 4);
1308c4a3523SMatt Arsenault     int64_t OffsetDiff = EltOffset - AlignDownOffset;
131*b65fa483SGuillaume Chatelet     Align AdjustedAlign = commonAlignment(
132*b65fa483SGuillaume Chatelet         KernArgBaseAlign, DoShiftOpt ? AlignDownOffset : EltOffset);
1338c4a3523SMatt Arsenault 
1348c4a3523SMatt Arsenault     Value *ArgPtr;
13514359ef1SJames Y Knight     Type *AdjustedArgTy;
136b5613ecfSMatt Arsenault     if (DoShiftOpt) { // FIXME: Handle aggregate types
1378c4a3523SMatt Arsenault       // Since we don't have sub-dword scalar loads, avoid doing an extload by
1388c4a3523SMatt Arsenault       // loading earlier than the argument address, and extracting the relevant
1398c4a3523SMatt Arsenault       // bits.
1408c4a3523SMatt Arsenault       //
1418c4a3523SMatt Arsenault       // Additionally widen any sub-dword load to i32 even if suitably aligned,
1428c4a3523SMatt Arsenault       // so that CSE between different argument loads works easily.
143f5be3ad7SMatt Arsenault       ArgPtr = Builder.CreateConstInBoundsGEP1_64(
1447716075aSJames Y Knight           Builder.getInt8Ty(), KernArgSegment, AlignDownOffset,
1458c4a3523SMatt Arsenault           Arg.getName() + ".kernarg.offset.align.down");
14614359ef1SJames Y Knight       AdjustedArgTy = Builder.getInt32Ty();
1478c4a3523SMatt Arsenault     } else {
148f5be3ad7SMatt Arsenault       ArgPtr = Builder.CreateConstInBoundsGEP1_64(
1497716075aSJames Y Knight           Builder.getInt8Ty(), KernArgSegment, EltOffset,
1508c4a3523SMatt Arsenault           Arg.getName() + ".kernarg.offset");
15114359ef1SJames Y Knight       AdjustedArgTy = ArgTy;
1528c4a3523SMatt Arsenault     }
1538c4a3523SMatt Arsenault 
1548c4a3523SMatt Arsenault     if (IsV3 && Size >= 32) {
1558c4a3523SMatt Arsenault       V4Ty = VectorType::get(VT->getVectorElementType(), 4);
1568c4a3523SMatt Arsenault       // Use the hack that clang uses to avoid SelectionDAG ruining v3 loads
15714359ef1SJames Y Knight       AdjustedArgTy = V4Ty;
1588c4a3523SMatt Arsenault     }
1598c4a3523SMatt Arsenault 
16014359ef1SJames Y Knight     ArgPtr = Builder.CreateBitCast(ArgPtr, AdjustedArgTy->getPointerTo(AS),
16114359ef1SJames Y Knight                                    ArgPtr->getName() + ".cast");
16214359ef1SJames Y Knight     LoadInst *Load =
163*b65fa483SGuillaume Chatelet         Builder.CreateAlignedLoad(AdjustedArgTy, ArgPtr, AdjustedAlign.value());
1648c4a3523SMatt Arsenault     Load->setMetadata(LLVMContext::MD_invariant_load, MDNode::get(Ctx, {}));
1658c4a3523SMatt Arsenault 
1668c4a3523SMatt Arsenault     MDBuilder MDB(Ctx);
1678c4a3523SMatt Arsenault 
1688c4a3523SMatt Arsenault     if (isa<PointerType>(ArgTy)) {
1698c4a3523SMatt Arsenault       if (Arg.hasNonNullAttr())
1708c4a3523SMatt Arsenault         Load->setMetadata(LLVMContext::MD_nonnull, MDNode::get(Ctx, {}));
1718c4a3523SMatt Arsenault 
1728c4a3523SMatt Arsenault       uint64_t DerefBytes = Arg.getDereferenceableBytes();
1738c4a3523SMatt Arsenault       if (DerefBytes != 0) {
1748c4a3523SMatt Arsenault         Load->setMetadata(
1758c4a3523SMatt Arsenault           LLVMContext::MD_dereferenceable,
1768c4a3523SMatt Arsenault           MDNode::get(Ctx,
1778c4a3523SMatt Arsenault                       MDB.createConstant(
1788c4a3523SMatt Arsenault                         ConstantInt::get(Builder.getInt64Ty(), DerefBytes))));
1798c4a3523SMatt Arsenault       }
1808c4a3523SMatt Arsenault 
1818c4a3523SMatt Arsenault       uint64_t DerefOrNullBytes = Arg.getDereferenceableOrNullBytes();
1828c4a3523SMatt Arsenault       if (DerefOrNullBytes != 0) {
1838c4a3523SMatt Arsenault         Load->setMetadata(
1848c4a3523SMatt Arsenault           LLVMContext::MD_dereferenceable_or_null,
1858c4a3523SMatt Arsenault           MDNode::get(Ctx,
1868c4a3523SMatt Arsenault                       MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(),
1878c4a3523SMatt Arsenault                                                           DerefOrNullBytes))));
1888c4a3523SMatt Arsenault       }
1898c4a3523SMatt Arsenault 
1908c4a3523SMatt Arsenault       unsigned ParamAlign = Arg.getParamAlignment();
1918c4a3523SMatt Arsenault       if (ParamAlign != 0) {
1928c4a3523SMatt Arsenault         Load->setMetadata(
1938c4a3523SMatt Arsenault           LLVMContext::MD_align,
1948c4a3523SMatt Arsenault           MDNode::get(Ctx,
1958c4a3523SMatt Arsenault                       MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(),
1968c4a3523SMatt Arsenault                                                           ParamAlign))));
1978c4a3523SMatt Arsenault       }
1988c4a3523SMatt Arsenault     }
1998c4a3523SMatt Arsenault 
2008c4a3523SMatt Arsenault     // TODO: Convert noalias arg to !noalias
2018c4a3523SMatt Arsenault 
202b5613ecfSMatt Arsenault     if (DoShiftOpt) {
2038c4a3523SMatt Arsenault       Value *ExtractBits = OffsetDiff == 0 ?
2048c4a3523SMatt Arsenault         Load : Builder.CreateLShr(Load, OffsetDiff * 8);
2058c4a3523SMatt Arsenault 
2068c4a3523SMatt Arsenault       IntegerType *ArgIntTy = Builder.getIntNTy(Size);
2078c4a3523SMatt Arsenault       Value *Trunc = Builder.CreateTrunc(ExtractBits, ArgIntTy);
2088c4a3523SMatt Arsenault       Value *NewVal = Builder.CreateBitCast(Trunc, ArgTy,
2098c4a3523SMatt Arsenault                                             Arg.getName() + ".load");
2108c4a3523SMatt Arsenault       Arg.replaceAllUsesWith(NewVal);
2118c4a3523SMatt Arsenault     } else if (IsV3) {
2128c4a3523SMatt Arsenault       Value *Shuf = Builder.CreateShuffleVector(Load, UndefValue::get(V4Ty),
2138c4a3523SMatt Arsenault                                                 {0, 1, 2},
2148c4a3523SMatt Arsenault                                                 Arg.getName() + ".load");
2158c4a3523SMatt Arsenault       Arg.replaceAllUsesWith(Shuf);
2168c4a3523SMatt Arsenault     } else {
2178c4a3523SMatt Arsenault       Load->setName(Arg.getName() + ".load");
2188c4a3523SMatt Arsenault       Arg.replaceAllUsesWith(Load);
2198c4a3523SMatt Arsenault     }
2208c4a3523SMatt Arsenault   }
2218c4a3523SMatt Arsenault 
222f5be3ad7SMatt Arsenault   KernArgSegment->addAttribute(
223f5be3ad7SMatt Arsenault       AttributeList::ReturnIndex,
224f5be3ad7SMatt Arsenault       Attribute::getWithAlignment(Ctx, std::max(KernArgBaseAlign, MaxAlign)));
225f5be3ad7SMatt Arsenault 
2268c4a3523SMatt Arsenault   return true;
2278c4a3523SMatt Arsenault }
2288c4a3523SMatt Arsenault 
2298c4a3523SMatt Arsenault INITIALIZE_PASS_BEGIN(AMDGPULowerKernelArguments, DEBUG_TYPE,
2308c4a3523SMatt Arsenault                       "AMDGPU Lower Kernel Arguments", false, false)
2318c4a3523SMatt Arsenault INITIALIZE_PASS_END(AMDGPULowerKernelArguments, DEBUG_TYPE, "AMDGPU Lower Kernel Arguments",
2328c4a3523SMatt Arsenault                     false, false)
2338c4a3523SMatt Arsenault 
2348c4a3523SMatt Arsenault char AMDGPULowerKernelArguments::ID = 0;
2358c4a3523SMatt Arsenault 
2368c4a3523SMatt Arsenault FunctionPass *llvm::createAMDGPULowerKernelArgumentsPass() {
2378c4a3523SMatt Arsenault   return new AMDGPULowerKernelArguments();
2388c4a3523SMatt Arsenault }
239