18c4a3523SMatt Arsenault //===-- AMDGPULowerKernelArguments.cpp ------------------------------------------===//
28c4a3523SMatt Arsenault //
38c4a3523SMatt Arsenault //                     The LLVM Compiler Infrastructure
48c4a3523SMatt Arsenault //
58c4a3523SMatt Arsenault // This file is distributed under the University of Illinois Open Source
68c4a3523SMatt Arsenault // License. See LICENSE.TXT for details.
78c4a3523SMatt Arsenault //
88c4a3523SMatt Arsenault //===----------------------------------------------------------------------===//
98c4a3523SMatt Arsenault //
108c4a3523SMatt Arsenault /// \file This pass replaces accesses to kernel arguments with loads from
118c4a3523SMatt Arsenault /// offsets from the kernarg base pointer.
128c4a3523SMatt Arsenault //
138c4a3523SMatt Arsenault //===----------------------------------------------------------------------===//
148c4a3523SMatt Arsenault 
158c4a3523SMatt Arsenault #include "AMDGPU.h"
168c4a3523SMatt Arsenault #include "AMDGPUSubtarget.h"
178c4a3523SMatt Arsenault #include "AMDGPUTargetMachine.h"
188c4a3523SMatt Arsenault #include "llvm/ADT/StringRef.h"
198c4a3523SMatt Arsenault #include "llvm/Analysis/Loads.h"
208c4a3523SMatt Arsenault #include "llvm/CodeGen/Passes.h"
218c4a3523SMatt Arsenault #include "llvm/CodeGen/TargetPassConfig.h"
228c4a3523SMatt Arsenault #include "llvm/IR/Attributes.h"
238c4a3523SMatt Arsenault #include "llvm/IR/BasicBlock.h"
248c4a3523SMatt Arsenault #include "llvm/IR/Constants.h"
258c4a3523SMatt Arsenault #include "llvm/IR/DerivedTypes.h"
268c4a3523SMatt Arsenault #include "llvm/IR/Function.h"
278c4a3523SMatt Arsenault #include "llvm/IR/IRBuilder.h"
288c4a3523SMatt Arsenault #include "llvm/IR/InstrTypes.h"
298c4a3523SMatt Arsenault #include "llvm/IR/Instruction.h"
308c4a3523SMatt Arsenault #include "llvm/IR/Instructions.h"
318c4a3523SMatt Arsenault #include "llvm/IR/LLVMContext.h"
328c4a3523SMatt Arsenault #include "llvm/IR/MDBuilder.h"
338c4a3523SMatt Arsenault #include "llvm/IR/Metadata.h"
348c4a3523SMatt Arsenault #include "llvm/IR/Operator.h"
358c4a3523SMatt Arsenault #include "llvm/IR/Type.h"
368c4a3523SMatt Arsenault #include "llvm/IR/Value.h"
378c4a3523SMatt Arsenault #include "llvm/Pass.h"
388c4a3523SMatt Arsenault #include "llvm/Support/Casting.h"
398c4a3523SMatt Arsenault 
408c4a3523SMatt Arsenault #define DEBUG_TYPE "amdgpu-lower-kernel-arguments"
418c4a3523SMatt Arsenault 
428c4a3523SMatt Arsenault using namespace llvm;
438c4a3523SMatt Arsenault 
448c4a3523SMatt Arsenault namespace {
458c4a3523SMatt Arsenault 
468c4a3523SMatt Arsenault class AMDGPULowerKernelArguments : public FunctionPass{
478c4a3523SMatt Arsenault public:
488c4a3523SMatt Arsenault   static char ID;
498c4a3523SMatt Arsenault 
508c4a3523SMatt Arsenault   AMDGPULowerKernelArguments() : FunctionPass(ID) {}
518c4a3523SMatt Arsenault 
528c4a3523SMatt Arsenault   bool runOnFunction(Function &F) override;
538c4a3523SMatt Arsenault 
548c4a3523SMatt Arsenault   void getAnalysisUsage(AnalysisUsage &AU) const override {
558c4a3523SMatt Arsenault     AU.addRequired<TargetPassConfig>();
568c4a3523SMatt Arsenault     AU.setPreservesAll();
578c4a3523SMatt Arsenault  }
588c4a3523SMatt Arsenault };
598c4a3523SMatt Arsenault 
608c4a3523SMatt Arsenault } // end anonymous namespace
618c4a3523SMatt Arsenault 
628c4a3523SMatt Arsenault bool AMDGPULowerKernelArguments::runOnFunction(Function &F) {
638c4a3523SMatt Arsenault   CallingConv::ID CC = F.getCallingConv();
648c4a3523SMatt Arsenault   if (CC != CallingConv::AMDGPU_KERNEL || F.arg_empty())
658c4a3523SMatt Arsenault     return false;
668c4a3523SMatt Arsenault 
678c4a3523SMatt Arsenault   auto &TPC = getAnalysis<TargetPassConfig>();
688c4a3523SMatt Arsenault 
698c4a3523SMatt Arsenault   const TargetMachine &TM = TPC.getTM<TargetMachine>();
705bfbae5cSTom Stellard   const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F);
718c4a3523SMatt Arsenault   LLVMContext &Ctx = F.getParent()->getContext();
728c4a3523SMatt Arsenault   const DataLayout &DL = F.getParent()->getDataLayout();
738c4a3523SMatt Arsenault   BasicBlock &EntryBlock = *F.begin();
748c4a3523SMatt Arsenault   IRBuilder<> Builder(&*EntryBlock.begin());
758c4a3523SMatt Arsenault 
76f5be3ad7SMatt Arsenault   const unsigned KernArgBaseAlign = 16; // FIXME: Increase if necessary
778c4a3523SMatt Arsenault   const uint64_t BaseOffset = ST.getExplicitKernelArgOffset(F);
788c4a3523SMatt Arsenault 
794bec7d42SMatt Arsenault   unsigned MaxAlign;
808c4a3523SMatt Arsenault   // FIXME: Alignment is broken broken with explicit arg offset.;
814bec7d42SMatt Arsenault   const uint64_t TotalKernArgSize = ST.getKernArgSegmentSize(F, MaxAlign);
82513e0c0eSMatt Arsenault   if (TotalKernArgSize == 0)
83513e0c0eSMatt Arsenault     return false;
848c4a3523SMatt Arsenault 
858c4a3523SMatt Arsenault   CallInst *KernArgSegment =
8657f5d0a8SNeil Henning       Builder.CreateIntrinsic(Intrinsic::amdgcn_kernarg_segment_ptr, {}, {},
8757f5d0a8SNeil Henning                               nullptr, F.getName() + ".kernarg.segment");
888c4a3523SMatt Arsenault 
898c4a3523SMatt Arsenault   KernArgSegment->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull);
908c4a3523SMatt Arsenault   KernArgSegment->addAttribute(AttributeList::ReturnIndex,
918c4a3523SMatt Arsenault     Attribute::getWithDereferenceableBytes(Ctx, TotalKernArgSize));
928c4a3523SMatt Arsenault 
938c4a3523SMatt Arsenault   unsigned AS = KernArgSegment->getType()->getPointerAddressSpace();
94f5be3ad7SMatt Arsenault   uint64_t ExplicitArgOffset = 0;
95f5be3ad7SMatt Arsenault 
968c4a3523SMatt Arsenault   for (Argument &Arg : F.args()) {
97f5be3ad7SMatt Arsenault     Type *ArgTy = Arg.getType();
98f5be3ad7SMatt Arsenault     unsigned Align = DL.getABITypeAlignment(ArgTy);
99f5be3ad7SMatt Arsenault     unsigned Size = DL.getTypeSizeInBits(ArgTy);
100f5be3ad7SMatt Arsenault     unsigned AllocSize = DL.getTypeAllocSize(ArgTy);
101f5be3ad7SMatt Arsenault 
102f5be3ad7SMatt Arsenault     uint64_t EltOffset = alignTo(ExplicitArgOffset, Align) + BaseOffset;
103f5be3ad7SMatt Arsenault     ExplicitArgOffset = alignTo(ExplicitArgOffset, Align) + AllocSize;
104f5be3ad7SMatt Arsenault 
1058c4a3523SMatt Arsenault     if (Arg.use_empty())
1068c4a3523SMatt Arsenault       continue;
1078c4a3523SMatt Arsenault 
1088c4a3523SMatt Arsenault     if (PointerType *PT = dyn_cast<PointerType>(ArgTy)) {
1098c4a3523SMatt Arsenault       // FIXME: Hack. We rely on AssertZext to be able to fold DS addressing
1108c4a3523SMatt Arsenault       // modes on SI to know the high bits are 0 so pointer adds don't wrap. We
1118c4a3523SMatt Arsenault       // can't represent this with range metadata because it's only allowed for
1128c4a3523SMatt Arsenault       // integer types.
1138c4a3523SMatt Arsenault       if (PT->getAddressSpace() == AMDGPUAS::LOCAL_ADDRESS &&
1148c4a3523SMatt Arsenault           ST.getGeneration() == AMDGPUSubtarget::SOUTHERN_ISLANDS)
1158c4a3523SMatt Arsenault         continue;
1168c4a3523SMatt Arsenault 
1178c4a3523SMatt Arsenault       // FIXME: We can replace this with equivalent alias.scope/noalias
1188c4a3523SMatt Arsenault       // metadata, but this appears to be a lot of work.
1198c4a3523SMatt Arsenault       if (Arg.hasNoAliasAttr())
1208c4a3523SMatt Arsenault         continue;
1218c4a3523SMatt Arsenault     }
1228c4a3523SMatt Arsenault 
1238c4a3523SMatt Arsenault     VectorType *VT = dyn_cast<VectorType>(ArgTy);
1248c4a3523SMatt Arsenault     bool IsV3 = VT && VT->getNumElements() == 3;
125*b5613ecfSMatt Arsenault     bool DoShiftOpt = Size < 32 && !ArgTy->isAggregateType();
126*b5613ecfSMatt Arsenault 
1278c4a3523SMatt Arsenault     VectorType *V4Ty = nullptr;
1288c4a3523SMatt Arsenault 
1298c4a3523SMatt Arsenault     int64_t AlignDownOffset = alignDown(EltOffset, 4);
1308c4a3523SMatt Arsenault     int64_t OffsetDiff = EltOffset - AlignDownOffset;
131*b5613ecfSMatt Arsenault     unsigned AdjustedAlign = MinAlign(DoShiftOpt ? AlignDownOffset : EltOffset,
132*b5613ecfSMatt Arsenault                                       KernArgBaseAlign);
1338c4a3523SMatt Arsenault 
1348c4a3523SMatt Arsenault     Value *ArgPtr;
135*b5613ecfSMatt Arsenault     if (DoShiftOpt) { // FIXME: Handle aggregate types
1368c4a3523SMatt Arsenault       // Since we don't have sub-dword scalar loads, avoid doing an extload by
1378c4a3523SMatt Arsenault       // loading earlier than the argument address, and extracting the relevant
1388c4a3523SMatt Arsenault       // bits.
1398c4a3523SMatt Arsenault       //
1408c4a3523SMatt Arsenault       // Additionally widen any sub-dword load to i32 even if suitably aligned,
1418c4a3523SMatt Arsenault       // so that CSE between different argument loads works easily.
1428c4a3523SMatt Arsenault 
143f5be3ad7SMatt Arsenault       ArgPtr = Builder.CreateConstInBoundsGEP1_64(
144f5be3ad7SMatt Arsenault         KernArgSegment,
145f5be3ad7SMatt Arsenault         AlignDownOffset,
1468c4a3523SMatt Arsenault         Arg.getName() + ".kernarg.offset.align.down");
147f5be3ad7SMatt Arsenault       ArgPtr = Builder.CreateBitCast(ArgPtr,
148f5be3ad7SMatt Arsenault                                      Builder.getInt32Ty()->getPointerTo(AS),
149f5be3ad7SMatt Arsenault                                      ArgPtr->getName() + ".cast");
1508c4a3523SMatt Arsenault     } else {
151f5be3ad7SMatt Arsenault       ArgPtr = Builder.CreateConstInBoundsGEP1_64(
152f5be3ad7SMatt Arsenault         KernArgSegment,
153*b5613ecfSMatt Arsenault         EltOffset,
1548c4a3523SMatt Arsenault         Arg.getName() + ".kernarg.offset");
155f5be3ad7SMatt Arsenault       ArgPtr = Builder.CreateBitCast(ArgPtr, ArgTy->getPointerTo(AS),
156f5be3ad7SMatt Arsenault                                      ArgPtr->getName() + ".cast");
1578c4a3523SMatt Arsenault     }
1588c4a3523SMatt Arsenault 
1598c4a3523SMatt Arsenault     if (IsV3 && Size >= 32) {
1608c4a3523SMatt Arsenault       V4Ty = VectorType::get(VT->getVectorElementType(), 4);
1618c4a3523SMatt Arsenault       // Use the hack that clang uses to avoid SelectionDAG ruining v3 loads
1628c4a3523SMatt Arsenault       ArgPtr = Builder.CreateBitCast(ArgPtr, V4Ty->getPointerTo(AS));
1638c4a3523SMatt Arsenault     }
1648c4a3523SMatt Arsenault 
1658c4a3523SMatt Arsenault     LoadInst *Load = Builder.CreateAlignedLoad(ArgPtr, AdjustedAlign);
1668c4a3523SMatt Arsenault     Load->setMetadata(LLVMContext::MD_invariant_load, MDNode::get(Ctx, {}));
1678c4a3523SMatt Arsenault 
1688c4a3523SMatt Arsenault     MDBuilder MDB(Ctx);
1698c4a3523SMatt Arsenault 
1708c4a3523SMatt Arsenault     if (isa<PointerType>(ArgTy)) {
1718c4a3523SMatt Arsenault       if (Arg.hasNonNullAttr())
1728c4a3523SMatt Arsenault         Load->setMetadata(LLVMContext::MD_nonnull, MDNode::get(Ctx, {}));
1738c4a3523SMatt Arsenault 
1748c4a3523SMatt Arsenault       uint64_t DerefBytes = Arg.getDereferenceableBytes();
1758c4a3523SMatt Arsenault       if (DerefBytes != 0) {
1768c4a3523SMatt Arsenault         Load->setMetadata(
1778c4a3523SMatt Arsenault           LLVMContext::MD_dereferenceable,
1788c4a3523SMatt Arsenault           MDNode::get(Ctx,
1798c4a3523SMatt Arsenault                       MDB.createConstant(
1808c4a3523SMatt Arsenault                         ConstantInt::get(Builder.getInt64Ty(), DerefBytes))));
1818c4a3523SMatt Arsenault       }
1828c4a3523SMatt Arsenault 
1838c4a3523SMatt Arsenault       uint64_t DerefOrNullBytes = Arg.getDereferenceableOrNullBytes();
1848c4a3523SMatt Arsenault       if (DerefOrNullBytes != 0) {
1858c4a3523SMatt Arsenault         Load->setMetadata(
1868c4a3523SMatt Arsenault           LLVMContext::MD_dereferenceable_or_null,
1878c4a3523SMatt Arsenault           MDNode::get(Ctx,
1888c4a3523SMatt Arsenault                       MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(),
1898c4a3523SMatt Arsenault                                                           DerefOrNullBytes))));
1908c4a3523SMatt Arsenault       }
1918c4a3523SMatt Arsenault 
1928c4a3523SMatt Arsenault       unsigned ParamAlign = Arg.getParamAlignment();
1938c4a3523SMatt Arsenault       if (ParamAlign != 0) {
1948c4a3523SMatt Arsenault         Load->setMetadata(
1958c4a3523SMatt Arsenault           LLVMContext::MD_align,
1968c4a3523SMatt Arsenault           MDNode::get(Ctx,
1978c4a3523SMatt Arsenault                       MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(),
1988c4a3523SMatt Arsenault                                                           ParamAlign))));
1998c4a3523SMatt Arsenault       }
2008c4a3523SMatt Arsenault     }
2018c4a3523SMatt Arsenault 
2028c4a3523SMatt Arsenault     // TODO: Convert noalias arg to !noalias
2038c4a3523SMatt Arsenault 
204*b5613ecfSMatt Arsenault     if (DoShiftOpt) {
2058c4a3523SMatt Arsenault       Value *ExtractBits = OffsetDiff == 0 ?
2068c4a3523SMatt Arsenault         Load : Builder.CreateLShr(Load, OffsetDiff * 8);
2078c4a3523SMatt Arsenault 
2088c4a3523SMatt Arsenault       IntegerType *ArgIntTy = Builder.getIntNTy(Size);
2098c4a3523SMatt Arsenault       Value *Trunc = Builder.CreateTrunc(ExtractBits, ArgIntTy);
2108c4a3523SMatt Arsenault       Value *NewVal = Builder.CreateBitCast(Trunc, ArgTy,
2118c4a3523SMatt Arsenault                                             Arg.getName() + ".load");
2128c4a3523SMatt Arsenault       Arg.replaceAllUsesWith(NewVal);
2138c4a3523SMatt Arsenault     } else if (IsV3) {
2148c4a3523SMatt Arsenault       Value *Shuf = Builder.CreateShuffleVector(Load, UndefValue::get(V4Ty),
2158c4a3523SMatt Arsenault                                                 {0, 1, 2},
2168c4a3523SMatt Arsenault                                                 Arg.getName() + ".load");
2178c4a3523SMatt Arsenault       Arg.replaceAllUsesWith(Shuf);
2188c4a3523SMatt Arsenault     } else {
2198c4a3523SMatt Arsenault       Load->setName(Arg.getName() + ".load");
2208c4a3523SMatt Arsenault       Arg.replaceAllUsesWith(Load);
2218c4a3523SMatt Arsenault     }
2228c4a3523SMatt Arsenault   }
2238c4a3523SMatt Arsenault 
224f5be3ad7SMatt Arsenault   KernArgSegment->addAttribute(
225f5be3ad7SMatt Arsenault     AttributeList::ReturnIndex,
226f5be3ad7SMatt Arsenault     Attribute::getWithAlignment(Ctx, std::max(KernArgBaseAlign, MaxAlign)));
227f5be3ad7SMatt Arsenault 
2288c4a3523SMatt Arsenault   return true;
2298c4a3523SMatt Arsenault }
2308c4a3523SMatt Arsenault 
2318c4a3523SMatt Arsenault INITIALIZE_PASS_BEGIN(AMDGPULowerKernelArguments, DEBUG_TYPE,
2328c4a3523SMatt Arsenault                       "AMDGPU Lower Kernel Arguments", false, false)
2338c4a3523SMatt Arsenault INITIALIZE_PASS_END(AMDGPULowerKernelArguments, DEBUG_TYPE, "AMDGPU Lower Kernel Arguments",
2348c4a3523SMatt Arsenault                     false, false)
2358c4a3523SMatt Arsenault 
2368c4a3523SMatt Arsenault char AMDGPULowerKernelArguments::ID = 0;
2378c4a3523SMatt Arsenault 
2388c4a3523SMatt Arsenault FunctionPass *llvm::createAMDGPULowerKernelArgumentsPass() {
2398c4a3523SMatt Arsenault   return new AMDGPULowerKernelArguments();
2408c4a3523SMatt Arsenault }
241