14ba319b5SDimitry Andric //===-- AMDGPULowerKernelArguments.cpp ------------------------------------------===//
24ba319b5SDimitry Andric //
34ba319b5SDimitry Andric // The LLVM Compiler Infrastructure
44ba319b5SDimitry Andric //
54ba319b5SDimitry Andric // This file is distributed under the University of Illinois Open Source
64ba319b5SDimitry Andric // License. See LICENSE.TXT for details.
74ba319b5SDimitry Andric //
84ba319b5SDimitry Andric //===----------------------------------------------------------------------===//
94ba319b5SDimitry Andric //
104ba319b5SDimitry Andric /// \file This pass replaces accesses to kernel arguments with loads from
114ba319b5SDimitry Andric /// offsets from the kernarg base pointer.
124ba319b5SDimitry Andric //
134ba319b5SDimitry Andric //===----------------------------------------------------------------------===//
144ba319b5SDimitry Andric
154ba319b5SDimitry Andric #include "AMDGPU.h"
164ba319b5SDimitry Andric #include "AMDGPUSubtarget.h"
174ba319b5SDimitry Andric #include "AMDGPUTargetMachine.h"
184ba319b5SDimitry Andric #include "llvm/ADT/StringRef.h"
194ba319b5SDimitry Andric #include "llvm/Analysis/Loads.h"
204ba319b5SDimitry Andric #include "llvm/CodeGen/Passes.h"
214ba319b5SDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h"
224ba319b5SDimitry Andric #include "llvm/IR/Attributes.h"
234ba319b5SDimitry Andric #include "llvm/IR/BasicBlock.h"
244ba319b5SDimitry Andric #include "llvm/IR/Constants.h"
254ba319b5SDimitry Andric #include "llvm/IR/DerivedTypes.h"
264ba319b5SDimitry Andric #include "llvm/IR/Function.h"
274ba319b5SDimitry Andric #include "llvm/IR/IRBuilder.h"
284ba319b5SDimitry Andric #include "llvm/IR/InstrTypes.h"
294ba319b5SDimitry Andric #include "llvm/IR/Instruction.h"
304ba319b5SDimitry Andric #include "llvm/IR/Instructions.h"
314ba319b5SDimitry Andric #include "llvm/IR/LLVMContext.h"
324ba319b5SDimitry Andric #include "llvm/IR/MDBuilder.h"
334ba319b5SDimitry Andric #include "llvm/IR/Metadata.h"
344ba319b5SDimitry Andric #include "llvm/IR/Operator.h"
354ba319b5SDimitry Andric #include "llvm/IR/Type.h"
364ba319b5SDimitry Andric #include "llvm/IR/Value.h"
374ba319b5SDimitry Andric #include "llvm/Pass.h"
384ba319b5SDimitry Andric #include "llvm/Support/Casting.h"
394ba319b5SDimitry Andric
404ba319b5SDimitry Andric #define DEBUG_TYPE "amdgpu-lower-kernel-arguments"
414ba319b5SDimitry Andric
424ba319b5SDimitry Andric using namespace llvm;
434ba319b5SDimitry Andric
444ba319b5SDimitry Andric namespace {
454ba319b5SDimitry Andric
464ba319b5SDimitry Andric class AMDGPULowerKernelArguments : public FunctionPass{
474ba319b5SDimitry Andric public:
484ba319b5SDimitry Andric static char ID;
494ba319b5SDimitry Andric
AMDGPULowerKernelArguments()504ba319b5SDimitry Andric AMDGPULowerKernelArguments() : FunctionPass(ID) {}
514ba319b5SDimitry Andric
524ba319b5SDimitry Andric bool runOnFunction(Function &F) override;
534ba319b5SDimitry Andric
getAnalysisUsage(AnalysisUsage & AU) const544ba319b5SDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override {
554ba319b5SDimitry Andric AU.addRequired<TargetPassConfig>();
564ba319b5SDimitry Andric AU.setPreservesAll();
574ba319b5SDimitry Andric }
584ba319b5SDimitry Andric };
594ba319b5SDimitry Andric
604ba319b5SDimitry Andric } // end anonymous namespace
614ba319b5SDimitry Andric
runOnFunction(Function & F)624ba319b5SDimitry Andric bool AMDGPULowerKernelArguments::runOnFunction(Function &F) {
634ba319b5SDimitry Andric CallingConv::ID CC = F.getCallingConv();
644ba319b5SDimitry Andric if (CC != CallingConv::AMDGPU_KERNEL || F.arg_empty())
654ba319b5SDimitry Andric return false;
664ba319b5SDimitry Andric
674ba319b5SDimitry Andric auto &TPC = getAnalysis<TargetPassConfig>();
684ba319b5SDimitry Andric
694ba319b5SDimitry Andric const TargetMachine &TM = TPC.getTM<TargetMachine>();
704ba319b5SDimitry Andric const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F);
714ba319b5SDimitry Andric LLVMContext &Ctx = F.getParent()->getContext();
724ba319b5SDimitry Andric const DataLayout &DL = F.getParent()->getDataLayout();
734ba319b5SDimitry Andric BasicBlock &EntryBlock = *F.begin();
744ba319b5SDimitry Andric IRBuilder<> Builder(&*EntryBlock.begin());
754ba319b5SDimitry Andric
764ba319b5SDimitry Andric const unsigned KernArgBaseAlign = 16; // FIXME: Increase if necessary
774ba319b5SDimitry Andric const uint64_t BaseOffset = ST.getExplicitKernelArgOffset(F);
784ba319b5SDimitry Andric
794ba319b5SDimitry Andric unsigned MaxAlign;
804ba319b5SDimitry Andric // FIXME: Alignment is broken broken with explicit arg offset.;
814ba319b5SDimitry Andric const uint64_t TotalKernArgSize = ST.getKernArgSegmentSize(F, MaxAlign);
824ba319b5SDimitry Andric if (TotalKernArgSize == 0)
834ba319b5SDimitry Andric return false;
844ba319b5SDimitry Andric
854ba319b5SDimitry Andric CallInst *KernArgSegment =
86*b5893f02SDimitry Andric Builder.CreateIntrinsic(Intrinsic::amdgcn_kernarg_segment_ptr, {}, {},
87*b5893f02SDimitry Andric nullptr, F.getName() + ".kernarg.segment");
884ba319b5SDimitry Andric
894ba319b5SDimitry Andric KernArgSegment->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull);
904ba319b5SDimitry Andric KernArgSegment->addAttribute(AttributeList::ReturnIndex,
914ba319b5SDimitry Andric Attribute::getWithDereferenceableBytes(Ctx, TotalKernArgSize));
924ba319b5SDimitry Andric
934ba319b5SDimitry Andric unsigned AS = KernArgSegment->getType()->getPointerAddressSpace();
944ba319b5SDimitry Andric uint64_t ExplicitArgOffset = 0;
954ba319b5SDimitry Andric
964ba319b5SDimitry Andric for (Argument &Arg : F.args()) {
974ba319b5SDimitry Andric Type *ArgTy = Arg.getType();
984ba319b5SDimitry Andric unsigned Align = DL.getABITypeAlignment(ArgTy);
994ba319b5SDimitry Andric unsigned Size = DL.getTypeSizeInBits(ArgTy);
1004ba319b5SDimitry Andric unsigned AllocSize = DL.getTypeAllocSize(ArgTy);
1014ba319b5SDimitry Andric
1024ba319b5SDimitry Andric uint64_t EltOffset = alignTo(ExplicitArgOffset, Align) + BaseOffset;
1034ba319b5SDimitry Andric ExplicitArgOffset = alignTo(ExplicitArgOffset, Align) + AllocSize;
1044ba319b5SDimitry Andric
1054ba319b5SDimitry Andric if (Arg.use_empty())
1064ba319b5SDimitry Andric continue;
1074ba319b5SDimitry Andric
1084ba319b5SDimitry Andric if (PointerType *PT = dyn_cast<PointerType>(ArgTy)) {
1094ba319b5SDimitry Andric // FIXME: Hack. We rely on AssertZext to be able to fold DS addressing
1104ba319b5SDimitry Andric // modes on SI to know the high bits are 0 so pointer adds don't wrap. We
1114ba319b5SDimitry Andric // can't represent this with range metadata because it's only allowed for
1124ba319b5SDimitry Andric // integer types.
1134ba319b5SDimitry Andric if (PT->getAddressSpace() == AMDGPUAS::LOCAL_ADDRESS &&
1144ba319b5SDimitry Andric ST.getGeneration() == AMDGPUSubtarget::SOUTHERN_ISLANDS)
1154ba319b5SDimitry Andric continue;
1164ba319b5SDimitry Andric
1174ba319b5SDimitry Andric // FIXME: We can replace this with equivalent alias.scope/noalias
1184ba319b5SDimitry Andric // metadata, but this appears to be a lot of work.
1194ba319b5SDimitry Andric if (Arg.hasNoAliasAttr())
1204ba319b5SDimitry Andric continue;
1214ba319b5SDimitry Andric }
1224ba319b5SDimitry Andric
1234ba319b5SDimitry Andric VectorType *VT = dyn_cast<VectorType>(ArgTy);
1244ba319b5SDimitry Andric bool IsV3 = VT && VT->getNumElements() == 3;
125*b5893f02SDimitry Andric bool DoShiftOpt = Size < 32 && !ArgTy->isAggregateType();
126*b5893f02SDimitry Andric
1274ba319b5SDimitry Andric VectorType *V4Ty = nullptr;
1284ba319b5SDimitry Andric
1294ba319b5SDimitry Andric int64_t AlignDownOffset = alignDown(EltOffset, 4);
1304ba319b5SDimitry Andric int64_t OffsetDiff = EltOffset - AlignDownOffset;
131*b5893f02SDimitry Andric unsigned AdjustedAlign = MinAlign(DoShiftOpt ? AlignDownOffset : EltOffset,
132*b5893f02SDimitry Andric KernArgBaseAlign);
1334ba319b5SDimitry Andric
1344ba319b5SDimitry Andric Value *ArgPtr;
135*b5893f02SDimitry Andric if (DoShiftOpt) { // FIXME: Handle aggregate types
1364ba319b5SDimitry Andric // Since we don't have sub-dword scalar loads, avoid doing an extload by
1374ba319b5SDimitry Andric // loading earlier than the argument address, and extracting the relevant
1384ba319b5SDimitry Andric // bits.
1394ba319b5SDimitry Andric //
1404ba319b5SDimitry Andric // Additionally widen any sub-dword load to i32 even if suitably aligned,
1414ba319b5SDimitry Andric // so that CSE between different argument loads works easily.
1424ba319b5SDimitry Andric
1434ba319b5SDimitry Andric ArgPtr = Builder.CreateConstInBoundsGEP1_64(
1444ba319b5SDimitry Andric KernArgSegment,
1454ba319b5SDimitry Andric AlignDownOffset,
1464ba319b5SDimitry Andric Arg.getName() + ".kernarg.offset.align.down");
1474ba319b5SDimitry Andric ArgPtr = Builder.CreateBitCast(ArgPtr,
1484ba319b5SDimitry Andric Builder.getInt32Ty()->getPointerTo(AS),
1494ba319b5SDimitry Andric ArgPtr->getName() + ".cast");
1504ba319b5SDimitry Andric } else {
1514ba319b5SDimitry Andric ArgPtr = Builder.CreateConstInBoundsGEP1_64(
1524ba319b5SDimitry Andric KernArgSegment,
153*b5893f02SDimitry Andric EltOffset,
1544ba319b5SDimitry Andric Arg.getName() + ".kernarg.offset");
1554ba319b5SDimitry Andric ArgPtr = Builder.CreateBitCast(ArgPtr, ArgTy->getPointerTo(AS),
1564ba319b5SDimitry Andric ArgPtr->getName() + ".cast");
1574ba319b5SDimitry Andric }
1584ba319b5SDimitry Andric
1594ba319b5SDimitry Andric if (IsV3 && Size >= 32) {
1604ba319b5SDimitry Andric V4Ty = VectorType::get(VT->getVectorElementType(), 4);
1614ba319b5SDimitry Andric // Use the hack that clang uses to avoid SelectionDAG ruining v3 loads
1624ba319b5SDimitry Andric ArgPtr = Builder.CreateBitCast(ArgPtr, V4Ty->getPointerTo(AS));
1634ba319b5SDimitry Andric }
1644ba319b5SDimitry Andric
1654ba319b5SDimitry Andric LoadInst *Load = Builder.CreateAlignedLoad(ArgPtr, AdjustedAlign);
1664ba319b5SDimitry Andric Load->setMetadata(LLVMContext::MD_invariant_load, MDNode::get(Ctx, {}));
1674ba319b5SDimitry Andric
1684ba319b5SDimitry Andric MDBuilder MDB(Ctx);
1694ba319b5SDimitry Andric
1704ba319b5SDimitry Andric if (isa<PointerType>(ArgTy)) {
1714ba319b5SDimitry Andric if (Arg.hasNonNullAttr())
1724ba319b5SDimitry Andric Load->setMetadata(LLVMContext::MD_nonnull, MDNode::get(Ctx, {}));
1734ba319b5SDimitry Andric
1744ba319b5SDimitry Andric uint64_t DerefBytes = Arg.getDereferenceableBytes();
1754ba319b5SDimitry Andric if (DerefBytes != 0) {
1764ba319b5SDimitry Andric Load->setMetadata(
1774ba319b5SDimitry Andric LLVMContext::MD_dereferenceable,
1784ba319b5SDimitry Andric MDNode::get(Ctx,
1794ba319b5SDimitry Andric MDB.createConstant(
1804ba319b5SDimitry Andric ConstantInt::get(Builder.getInt64Ty(), DerefBytes))));
1814ba319b5SDimitry Andric }
1824ba319b5SDimitry Andric
1834ba319b5SDimitry Andric uint64_t DerefOrNullBytes = Arg.getDereferenceableOrNullBytes();
1844ba319b5SDimitry Andric if (DerefOrNullBytes != 0) {
1854ba319b5SDimitry Andric Load->setMetadata(
1864ba319b5SDimitry Andric LLVMContext::MD_dereferenceable_or_null,
1874ba319b5SDimitry Andric MDNode::get(Ctx,
1884ba319b5SDimitry Andric MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(),
1894ba319b5SDimitry Andric DerefOrNullBytes))));
1904ba319b5SDimitry Andric }
1914ba319b5SDimitry Andric
1924ba319b5SDimitry Andric unsigned ParamAlign = Arg.getParamAlignment();
1934ba319b5SDimitry Andric if (ParamAlign != 0) {
1944ba319b5SDimitry Andric Load->setMetadata(
1954ba319b5SDimitry Andric LLVMContext::MD_align,
1964ba319b5SDimitry Andric MDNode::get(Ctx,
1974ba319b5SDimitry Andric MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(),
1984ba319b5SDimitry Andric ParamAlign))));
1994ba319b5SDimitry Andric }
2004ba319b5SDimitry Andric }
2014ba319b5SDimitry Andric
2024ba319b5SDimitry Andric // TODO: Convert noalias arg to !noalias
2034ba319b5SDimitry Andric
204*b5893f02SDimitry Andric if (DoShiftOpt) {
2054ba319b5SDimitry Andric Value *ExtractBits = OffsetDiff == 0 ?
2064ba319b5SDimitry Andric Load : Builder.CreateLShr(Load, OffsetDiff * 8);
2074ba319b5SDimitry Andric
2084ba319b5SDimitry Andric IntegerType *ArgIntTy = Builder.getIntNTy(Size);
2094ba319b5SDimitry Andric Value *Trunc = Builder.CreateTrunc(ExtractBits, ArgIntTy);
2104ba319b5SDimitry Andric Value *NewVal = Builder.CreateBitCast(Trunc, ArgTy,
2114ba319b5SDimitry Andric Arg.getName() + ".load");
2124ba319b5SDimitry Andric Arg.replaceAllUsesWith(NewVal);
2134ba319b5SDimitry Andric } else if (IsV3) {
2144ba319b5SDimitry Andric Value *Shuf = Builder.CreateShuffleVector(Load, UndefValue::get(V4Ty),
2154ba319b5SDimitry Andric {0, 1, 2},
2164ba319b5SDimitry Andric Arg.getName() + ".load");
2174ba319b5SDimitry Andric Arg.replaceAllUsesWith(Shuf);
2184ba319b5SDimitry Andric } else {
2194ba319b5SDimitry Andric Load->setName(Arg.getName() + ".load");
2204ba319b5SDimitry Andric Arg.replaceAllUsesWith(Load);
2214ba319b5SDimitry Andric }
2224ba319b5SDimitry Andric }
2234ba319b5SDimitry Andric
2244ba319b5SDimitry Andric KernArgSegment->addAttribute(
2254ba319b5SDimitry Andric AttributeList::ReturnIndex,
2264ba319b5SDimitry Andric Attribute::getWithAlignment(Ctx, std::max(KernArgBaseAlign, MaxAlign)));
2274ba319b5SDimitry Andric
2284ba319b5SDimitry Andric return true;
2294ba319b5SDimitry Andric }
2304ba319b5SDimitry Andric
2314ba319b5SDimitry Andric INITIALIZE_PASS_BEGIN(AMDGPULowerKernelArguments, DEBUG_TYPE,
2324ba319b5SDimitry Andric "AMDGPU Lower Kernel Arguments", false, false)
2334ba319b5SDimitry Andric INITIALIZE_PASS_END(AMDGPULowerKernelArguments, DEBUG_TYPE, "AMDGPU Lower Kernel Arguments",
2344ba319b5SDimitry Andric false, false)
2354ba319b5SDimitry Andric
2364ba319b5SDimitry Andric char AMDGPULowerKernelArguments::ID = 0;
2374ba319b5SDimitry Andric
createAMDGPULowerKernelArgumentsPass()2384ba319b5SDimitry Andric FunctionPass *llvm::createAMDGPULowerKernelArgumentsPass() {
2394ba319b5SDimitry Andric return new AMDGPULowerKernelArguments();
2404ba319b5SDimitry Andric }
241