18c4a3523SMatt Arsenault //===-- AMDGPULowerKernelArguments.cpp ------------------------------------------===//
28c4a3523SMatt Arsenault //
38c4a3523SMatt Arsenault //                     The LLVM Compiler Infrastructure
48c4a3523SMatt Arsenault //
58c4a3523SMatt Arsenault // This file is distributed under the University of Illinois Open Source
68c4a3523SMatt Arsenault // License. See LICENSE.TXT for details.
78c4a3523SMatt Arsenault //
88c4a3523SMatt Arsenault //===----------------------------------------------------------------------===//
98c4a3523SMatt Arsenault //
108c4a3523SMatt Arsenault /// \file This pass replaces accesses to kernel arguments with loads from
118c4a3523SMatt Arsenault /// offsets from the kernarg base pointer.
128c4a3523SMatt Arsenault //
138c4a3523SMatt Arsenault //===----------------------------------------------------------------------===//
148c4a3523SMatt Arsenault 
158c4a3523SMatt Arsenault #include "AMDGPU.h"
168c4a3523SMatt Arsenault #include "AMDGPUSubtarget.h"
178c4a3523SMatt Arsenault #include "AMDGPUTargetMachine.h"
188c4a3523SMatt Arsenault #include "llvm/ADT/StringRef.h"
198c4a3523SMatt Arsenault #include "llvm/Analysis/DivergenceAnalysis.h"
208c4a3523SMatt Arsenault #include "llvm/Analysis/Loads.h"
218c4a3523SMatt Arsenault #include "llvm/CodeGen/Passes.h"
228c4a3523SMatt Arsenault #include "llvm/CodeGen/TargetPassConfig.h"
238c4a3523SMatt Arsenault #include "llvm/IR/Attributes.h"
248c4a3523SMatt Arsenault #include "llvm/IR/BasicBlock.h"
258c4a3523SMatt Arsenault #include "llvm/IR/Constants.h"
268c4a3523SMatt Arsenault #include "llvm/IR/DerivedTypes.h"
278c4a3523SMatt Arsenault #include "llvm/IR/Function.h"
288c4a3523SMatt Arsenault #include "llvm/IR/IRBuilder.h"
298c4a3523SMatt Arsenault #include "llvm/IR/InstrTypes.h"
308c4a3523SMatt Arsenault #include "llvm/IR/Instruction.h"
318c4a3523SMatt Arsenault #include "llvm/IR/Instructions.h"
328c4a3523SMatt Arsenault #include "llvm/IR/LLVMContext.h"
338c4a3523SMatt Arsenault #include "llvm/IR/MDBuilder.h"
348c4a3523SMatt Arsenault #include "llvm/IR/Metadata.h"
358c4a3523SMatt Arsenault #include "llvm/IR/Operator.h"
368c4a3523SMatt Arsenault #include "llvm/IR/Type.h"
378c4a3523SMatt Arsenault #include "llvm/IR/Value.h"
388c4a3523SMatt Arsenault #include "llvm/Pass.h"
398c4a3523SMatt Arsenault #include "llvm/Support/Casting.h"
408c4a3523SMatt Arsenault 
418c4a3523SMatt Arsenault #define DEBUG_TYPE "amdgpu-lower-kernel-arguments"
428c4a3523SMatt Arsenault 
438c4a3523SMatt Arsenault using namespace llvm;
448c4a3523SMatt Arsenault 
458c4a3523SMatt Arsenault namespace {
468c4a3523SMatt Arsenault 
478c4a3523SMatt Arsenault class AMDGPULowerKernelArguments : public FunctionPass{
488c4a3523SMatt Arsenault public:
498c4a3523SMatt Arsenault   static char ID;
508c4a3523SMatt Arsenault 
518c4a3523SMatt Arsenault   AMDGPULowerKernelArguments() : FunctionPass(ID) {}
528c4a3523SMatt Arsenault 
538c4a3523SMatt Arsenault   bool runOnFunction(Function &F) override;
548c4a3523SMatt Arsenault 
558c4a3523SMatt Arsenault   void getAnalysisUsage(AnalysisUsage &AU) const override {
568c4a3523SMatt Arsenault     AU.addRequired<TargetPassConfig>();
578c4a3523SMatt Arsenault     AU.setPreservesAll();
588c4a3523SMatt Arsenault  }
598c4a3523SMatt Arsenault };
608c4a3523SMatt Arsenault 
618c4a3523SMatt Arsenault } // end anonymous namespace
628c4a3523SMatt Arsenault 
638c4a3523SMatt Arsenault bool AMDGPULowerKernelArguments::runOnFunction(Function &F) {
648c4a3523SMatt Arsenault   CallingConv::ID CC = F.getCallingConv();
658c4a3523SMatt Arsenault   if (CC != CallingConv::AMDGPU_KERNEL || F.arg_empty())
668c4a3523SMatt Arsenault     return false;
678c4a3523SMatt Arsenault 
688c4a3523SMatt Arsenault   auto &TPC = getAnalysis<TargetPassConfig>();
698c4a3523SMatt Arsenault 
708c4a3523SMatt Arsenault   const TargetMachine &TM = TPC.getTM<TargetMachine>();
71*5bfbae5cSTom Stellard   const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F);
728c4a3523SMatt Arsenault   LLVMContext &Ctx = F.getParent()->getContext();
738c4a3523SMatt Arsenault   const DataLayout &DL = F.getParent()->getDataLayout();
748c4a3523SMatt Arsenault   BasicBlock &EntryBlock = *F.begin();
758c4a3523SMatt Arsenault   IRBuilder<> Builder(&*EntryBlock.begin());
768c4a3523SMatt Arsenault 
77f5be3ad7SMatt Arsenault   const unsigned KernArgBaseAlign = 16; // FIXME: Increase if necessary
788c4a3523SMatt Arsenault   const uint64_t BaseOffset = ST.getExplicitKernelArgOffset(F);
798c4a3523SMatt Arsenault 
808c4a3523SMatt Arsenault   // FIXME: Alignment is broken broken with explicit arg offset.;
81f5be3ad7SMatt Arsenault   const uint64_t TotalKernArgSize = ST.getKernArgSegmentSize(F);
82513e0c0eSMatt Arsenault   if (TotalKernArgSize == 0)
83513e0c0eSMatt Arsenault     return false;
848c4a3523SMatt Arsenault 
858c4a3523SMatt Arsenault   CallInst *KernArgSegment =
868c4a3523SMatt Arsenault     Builder.CreateIntrinsic(Intrinsic::amdgcn_kernarg_segment_ptr, nullptr,
878c4a3523SMatt Arsenault                             F.getName() + ".kernarg.segment");
888c4a3523SMatt Arsenault 
898c4a3523SMatt Arsenault   KernArgSegment->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull);
908c4a3523SMatt Arsenault   KernArgSegment->addAttribute(AttributeList::ReturnIndex,
918c4a3523SMatt Arsenault     Attribute::getWithDereferenceableBytes(Ctx, TotalKernArgSize));
928c4a3523SMatt Arsenault 
938c4a3523SMatt Arsenault   unsigned AS = KernArgSegment->getType()->getPointerAddressSpace();
94f5be3ad7SMatt Arsenault   unsigned MaxAlign = 1;
95f5be3ad7SMatt Arsenault   uint64_t ExplicitArgOffset = 0;
96f5be3ad7SMatt Arsenault 
978c4a3523SMatt Arsenault   for (Argument &Arg : F.args()) {
98f5be3ad7SMatt Arsenault     Type *ArgTy = Arg.getType();
99f5be3ad7SMatt Arsenault     unsigned Align = DL.getABITypeAlignment(ArgTy);
100f5be3ad7SMatt Arsenault     MaxAlign = std::max(Align, MaxAlign);
101f5be3ad7SMatt Arsenault     unsigned Size = DL.getTypeSizeInBits(ArgTy);
102f5be3ad7SMatt Arsenault     unsigned AllocSize = DL.getTypeAllocSize(ArgTy);
103f5be3ad7SMatt Arsenault 
104f5be3ad7SMatt Arsenault 
105f5be3ad7SMatt Arsenault     // Clover seems to always pad i8/i16 to i32, but doesn't properly align
106f5be3ad7SMatt Arsenault     // them?
107f5be3ad7SMatt Arsenault     // Make sure the struct elements have correct size and alignment for ext
108f5be3ad7SMatt Arsenault     // args. These seem to be padded up to 4-bytes but not correctly aligned.
109f5be3ad7SMatt Arsenault     bool IsExtArg = AllocSize < 32 && (Arg.hasZExtAttr() || Arg.hasSExtAttr()) &&
110f5be3ad7SMatt Arsenault                     !ST.isAmdHsaOS();
111f5be3ad7SMatt Arsenault     if (IsExtArg)
112f5be3ad7SMatt Arsenault       AllocSize = 4;
113f5be3ad7SMatt Arsenault 
114f5be3ad7SMatt Arsenault     uint64_t EltOffset = alignTo(ExplicitArgOffset, Align) + BaseOffset;
115f5be3ad7SMatt Arsenault     ExplicitArgOffset = alignTo(ExplicitArgOffset, Align) + AllocSize;
116f5be3ad7SMatt Arsenault 
1178c4a3523SMatt Arsenault     if (Arg.use_empty())
1188c4a3523SMatt Arsenault       continue;
1198c4a3523SMatt Arsenault 
1208c4a3523SMatt Arsenault     if (PointerType *PT = dyn_cast<PointerType>(ArgTy)) {
1218c4a3523SMatt Arsenault       // FIXME: Hack. We rely on AssertZext to be able to fold DS addressing
1228c4a3523SMatt Arsenault       // modes on SI to know the high bits are 0 so pointer adds don't wrap. We
1238c4a3523SMatt Arsenault       // can't represent this with range metadata because it's only allowed for
1248c4a3523SMatt Arsenault       // integer types.
1258c4a3523SMatt Arsenault       if (PT->getAddressSpace() == AMDGPUAS::LOCAL_ADDRESS &&
1268c4a3523SMatt Arsenault           ST.getGeneration() == AMDGPUSubtarget::SOUTHERN_ISLANDS)
1278c4a3523SMatt Arsenault         continue;
1288c4a3523SMatt Arsenault 
1298c4a3523SMatt Arsenault       // FIXME: We can replace this with equivalent alias.scope/noalias
1308c4a3523SMatt Arsenault       // metadata, but this appears to be a lot of work.
1318c4a3523SMatt Arsenault       if (Arg.hasNoAliasAttr())
1328c4a3523SMatt Arsenault         continue;
1338c4a3523SMatt Arsenault     }
1348c4a3523SMatt Arsenault 
1358c4a3523SMatt Arsenault     VectorType *VT = dyn_cast<VectorType>(ArgTy);
1368c4a3523SMatt Arsenault     bool IsV3 = VT && VT->getNumElements() == 3;
1378c4a3523SMatt Arsenault     VectorType *V4Ty = nullptr;
1388c4a3523SMatt Arsenault 
1398c4a3523SMatt Arsenault     int64_t AlignDownOffset = alignDown(EltOffset, 4);
1408c4a3523SMatt Arsenault     int64_t OffsetDiff = EltOffset - AlignDownOffset;
1418c4a3523SMatt Arsenault     unsigned AdjustedAlign = MinAlign(KernArgBaseAlign, AlignDownOffset);
1428c4a3523SMatt Arsenault 
1438c4a3523SMatt Arsenault     Value *ArgPtr;
144513e0c0eSMatt Arsenault     if (Size < 32 && !ArgTy->isAggregateType()) { // FIXME: Handle aggregate types
1458c4a3523SMatt Arsenault       // Since we don't have sub-dword scalar loads, avoid doing an extload by
1468c4a3523SMatt Arsenault       // loading earlier than the argument address, and extracting the relevant
1478c4a3523SMatt Arsenault       // bits.
1488c4a3523SMatt Arsenault       //
1498c4a3523SMatt Arsenault       // Additionally widen any sub-dword load to i32 even if suitably aligned,
1508c4a3523SMatt Arsenault       // so that CSE between different argument loads works easily.
1518c4a3523SMatt Arsenault 
152f5be3ad7SMatt Arsenault       ArgPtr = Builder.CreateConstInBoundsGEP1_64(
153f5be3ad7SMatt Arsenault         KernArgSegment,
154f5be3ad7SMatt Arsenault         AlignDownOffset,
1558c4a3523SMatt Arsenault         Arg.getName() + ".kernarg.offset.align.down");
156f5be3ad7SMatt Arsenault       ArgPtr = Builder.CreateBitCast(ArgPtr,
157f5be3ad7SMatt Arsenault                                      Builder.getInt32Ty()->getPointerTo(AS),
158f5be3ad7SMatt Arsenault                                      ArgPtr->getName() + ".cast");
1598c4a3523SMatt Arsenault     } else {
160f5be3ad7SMatt Arsenault       ArgPtr = Builder.CreateConstInBoundsGEP1_64(
161f5be3ad7SMatt Arsenault         KernArgSegment,
162f5be3ad7SMatt Arsenault         AlignDownOffset,
1638c4a3523SMatt Arsenault         Arg.getName() + ".kernarg.offset");
164f5be3ad7SMatt Arsenault       ArgPtr = Builder.CreateBitCast(ArgPtr, ArgTy->getPointerTo(AS),
165f5be3ad7SMatt Arsenault                                      ArgPtr->getName() + ".cast");
1668c4a3523SMatt Arsenault     }
1678c4a3523SMatt Arsenault 
1688c4a3523SMatt Arsenault     assert((!IsExtArg || !IsV3) && "incompatible situation");
1698c4a3523SMatt Arsenault 
1708c4a3523SMatt Arsenault     if (IsV3 && Size >= 32) {
1718c4a3523SMatt Arsenault       V4Ty = VectorType::get(VT->getVectorElementType(), 4);
1728c4a3523SMatt Arsenault       // Use the hack that clang uses to avoid SelectionDAG ruining v3 loads
1738c4a3523SMatt Arsenault       ArgPtr = Builder.CreateBitCast(ArgPtr, V4Ty->getPointerTo(AS));
1748c4a3523SMatt Arsenault     }
1758c4a3523SMatt Arsenault 
1768c4a3523SMatt Arsenault     LoadInst *Load = Builder.CreateAlignedLoad(ArgPtr, AdjustedAlign);
1778c4a3523SMatt Arsenault     Load->setMetadata(LLVMContext::MD_invariant_load, MDNode::get(Ctx, {}));
1788c4a3523SMatt Arsenault 
1798c4a3523SMatt Arsenault     MDBuilder MDB(Ctx);
1808c4a3523SMatt Arsenault 
1818c4a3523SMatt Arsenault     if (isa<PointerType>(ArgTy)) {
1828c4a3523SMatt Arsenault       if (Arg.hasNonNullAttr())
1838c4a3523SMatt Arsenault         Load->setMetadata(LLVMContext::MD_nonnull, MDNode::get(Ctx, {}));
1848c4a3523SMatt Arsenault 
1858c4a3523SMatt Arsenault       uint64_t DerefBytes = Arg.getDereferenceableBytes();
1868c4a3523SMatt Arsenault       if (DerefBytes != 0) {
1878c4a3523SMatt Arsenault         Load->setMetadata(
1888c4a3523SMatt Arsenault           LLVMContext::MD_dereferenceable,
1898c4a3523SMatt Arsenault           MDNode::get(Ctx,
1908c4a3523SMatt Arsenault                       MDB.createConstant(
1918c4a3523SMatt Arsenault                         ConstantInt::get(Builder.getInt64Ty(), DerefBytes))));
1928c4a3523SMatt Arsenault       }
1938c4a3523SMatt Arsenault 
1948c4a3523SMatt Arsenault       uint64_t DerefOrNullBytes = Arg.getDereferenceableOrNullBytes();
1958c4a3523SMatt Arsenault       if (DerefOrNullBytes != 0) {
1968c4a3523SMatt Arsenault         Load->setMetadata(
1978c4a3523SMatt Arsenault           LLVMContext::MD_dereferenceable_or_null,
1988c4a3523SMatt Arsenault           MDNode::get(Ctx,
1998c4a3523SMatt Arsenault                       MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(),
2008c4a3523SMatt Arsenault                                                           DerefOrNullBytes))));
2018c4a3523SMatt Arsenault       }
2028c4a3523SMatt Arsenault 
2038c4a3523SMatt Arsenault       unsigned ParamAlign = Arg.getParamAlignment();
2048c4a3523SMatt Arsenault       if (ParamAlign != 0) {
2058c4a3523SMatt Arsenault         Load->setMetadata(
2068c4a3523SMatt Arsenault           LLVMContext::MD_align,
2078c4a3523SMatt Arsenault           MDNode::get(Ctx,
2088c4a3523SMatt Arsenault                       MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(),
2098c4a3523SMatt Arsenault                                                           ParamAlign))));
2108c4a3523SMatt Arsenault       }
2118c4a3523SMatt Arsenault     }
2128c4a3523SMatt Arsenault 
2138c4a3523SMatt Arsenault     // TODO: Convert noalias arg to !noalias
2148c4a3523SMatt Arsenault 
215513e0c0eSMatt Arsenault     if (Size < 32 && !ArgTy->isAggregateType()) {
2168c4a3523SMatt Arsenault       if (IsExtArg && OffsetDiff == 0) {
2178c4a3523SMatt Arsenault         Type *I32Ty = Builder.getInt32Ty();
2188c4a3523SMatt Arsenault         bool IsSext = Arg.hasSExtAttr();
2198c4a3523SMatt Arsenault         Metadata *LowAndHigh[] = {
2208c4a3523SMatt Arsenault           ConstantAsMetadata::get(
2218c4a3523SMatt Arsenault             ConstantInt::get(I32Ty, IsSext ? minIntN(Size) : 0)),
2228c4a3523SMatt Arsenault           ConstantAsMetadata::get(
2238c4a3523SMatt Arsenault             ConstantInt::get(I32Ty,
2248c4a3523SMatt Arsenault                              IsSext ? maxIntN(Size) + 1 : maxUIntN(Size) + 1))
2258c4a3523SMatt Arsenault         };
2268c4a3523SMatt Arsenault 
2278c4a3523SMatt Arsenault         Load->setMetadata(LLVMContext::MD_range, MDNode::get(Ctx, LowAndHigh));
2288c4a3523SMatt Arsenault       }
2298c4a3523SMatt Arsenault 
2308c4a3523SMatt Arsenault       Value *ExtractBits = OffsetDiff == 0 ?
2318c4a3523SMatt Arsenault         Load : Builder.CreateLShr(Load, OffsetDiff * 8);
2328c4a3523SMatt Arsenault 
2338c4a3523SMatt Arsenault       IntegerType *ArgIntTy = Builder.getIntNTy(Size);
2348c4a3523SMatt Arsenault       Value *Trunc = Builder.CreateTrunc(ExtractBits, ArgIntTy);
2358c4a3523SMatt Arsenault       Value *NewVal = Builder.CreateBitCast(Trunc, ArgTy,
2368c4a3523SMatt Arsenault                                             Arg.getName() + ".load");
2378c4a3523SMatt Arsenault       Arg.replaceAllUsesWith(NewVal);
2388c4a3523SMatt Arsenault     } else if (IsV3) {
2398c4a3523SMatt Arsenault       Value *Shuf = Builder.CreateShuffleVector(Load, UndefValue::get(V4Ty),
2408c4a3523SMatt Arsenault                                                 {0, 1, 2},
2418c4a3523SMatt Arsenault                                                 Arg.getName() + ".load");
2428c4a3523SMatt Arsenault       Arg.replaceAllUsesWith(Shuf);
2438c4a3523SMatt Arsenault     } else {
2448c4a3523SMatt Arsenault       Load->setName(Arg.getName() + ".load");
2458c4a3523SMatt Arsenault       Arg.replaceAllUsesWith(Load);
2468c4a3523SMatt Arsenault     }
2478c4a3523SMatt Arsenault   }
2488c4a3523SMatt Arsenault 
249f5be3ad7SMatt Arsenault   KernArgSegment->addAttribute(
250f5be3ad7SMatt Arsenault     AttributeList::ReturnIndex,
251f5be3ad7SMatt Arsenault     Attribute::getWithAlignment(Ctx, std::max(KernArgBaseAlign, MaxAlign)));
252f5be3ad7SMatt Arsenault 
2538c4a3523SMatt Arsenault   return true;
2548c4a3523SMatt Arsenault }
2558c4a3523SMatt Arsenault 
2568c4a3523SMatt Arsenault INITIALIZE_PASS_BEGIN(AMDGPULowerKernelArguments, DEBUG_TYPE,
2578c4a3523SMatt Arsenault                       "AMDGPU Lower Kernel Arguments", false, false)
2588c4a3523SMatt Arsenault INITIALIZE_PASS_END(AMDGPULowerKernelArguments, DEBUG_TYPE, "AMDGPU Lower Kernel Arguments",
2598c4a3523SMatt Arsenault                     false, false)
2608c4a3523SMatt Arsenault 
2618c4a3523SMatt Arsenault char AMDGPULowerKernelArguments::ID = 0;
2628c4a3523SMatt Arsenault 
2638c4a3523SMatt Arsenault FunctionPass *llvm::createAMDGPULowerKernelArgumentsPass() {
2648c4a3523SMatt Arsenault   return new AMDGPULowerKernelArguments();
2658c4a3523SMatt Arsenault }
266