18c4a3523SMatt Arsenault //===-- AMDGPULowerKernelArguments.cpp ------------------------------------------===// 28c4a3523SMatt Arsenault // 38c4a3523SMatt Arsenault // The LLVM Compiler Infrastructure 48c4a3523SMatt Arsenault // 58c4a3523SMatt Arsenault // This file is distributed under the University of Illinois Open Source 68c4a3523SMatt Arsenault // License. See LICENSE.TXT for details. 78c4a3523SMatt Arsenault // 88c4a3523SMatt Arsenault //===----------------------------------------------------------------------===// 98c4a3523SMatt Arsenault // 108c4a3523SMatt Arsenault /// \file This pass replaces accesses to kernel arguments with loads from 118c4a3523SMatt Arsenault /// offsets from the kernarg base pointer. 128c4a3523SMatt Arsenault // 138c4a3523SMatt Arsenault //===----------------------------------------------------------------------===// 148c4a3523SMatt Arsenault 158c4a3523SMatt Arsenault #include "AMDGPU.h" 168c4a3523SMatt Arsenault #include "AMDGPUSubtarget.h" 178c4a3523SMatt Arsenault #include "AMDGPUTargetMachine.h" 188c4a3523SMatt Arsenault #include "llvm/ADT/StringRef.h" 198c4a3523SMatt Arsenault #include "llvm/Analysis/Loads.h" 208c4a3523SMatt Arsenault #include "llvm/CodeGen/Passes.h" 218c4a3523SMatt Arsenault #include "llvm/CodeGen/TargetPassConfig.h" 228c4a3523SMatt Arsenault #include "llvm/IR/Attributes.h" 238c4a3523SMatt Arsenault #include "llvm/IR/BasicBlock.h" 248c4a3523SMatt Arsenault #include "llvm/IR/Constants.h" 258c4a3523SMatt Arsenault #include "llvm/IR/DerivedTypes.h" 268c4a3523SMatt Arsenault #include "llvm/IR/Function.h" 278c4a3523SMatt Arsenault #include "llvm/IR/IRBuilder.h" 288c4a3523SMatt Arsenault #include "llvm/IR/InstrTypes.h" 298c4a3523SMatt Arsenault #include "llvm/IR/Instruction.h" 308c4a3523SMatt Arsenault #include "llvm/IR/Instructions.h" 318c4a3523SMatt Arsenault #include "llvm/IR/LLVMContext.h" 328c4a3523SMatt Arsenault #include "llvm/IR/MDBuilder.h" 338c4a3523SMatt Arsenault #include "llvm/IR/Metadata.h" 348c4a3523SMatt Arsenault #include "llvm/IR/Operator.h" 358c4a3523SMatt Arsenault #include "llvm/IR/Type.h" 368c4a3523SMatt Arsenault #include "llvm/IR/Value.h" 378c4a3523SMatt Arsenault #include "llvm/Pass.h" 388c4a3523SMatt Arsenault #include "llvm/Support/Casting.h" 398c4a3523SMatt Arsenault 408c4a3523SMatt Arsenault #define DEBUG_TYPE "amdgpu-lower-kernel-arguments" 418c4a3523SMatt Arsenault 428c4a3523SMatt Arsenault using namespace llvm; 438c4a3523SMatt Arsenault 448c4a3523SMatt Arsenault namespace { 458c4a3523SMatt Arsenault 468c4a3523SMatt Arsenault class AMDGPULowerKernelArguments : public FunctionPass{ 478c4a3523SMatt Arsenault public: 488c4a3523SMatt Arsenault static char ID; 498c4a3523SMatt Arsenault 508c4a3523SMatt Arsenault AMDGPULowerKernelArguments() : FunctionPass(ID) {} 518c4a3523SMatt Arsenault 528c4a3523SMatt Arsenault bool runOnFunction(Function &F) override; 538c4a3523SMatt Arsenault 548c4a3523SMatt Arsenault void getAnalysisUsage(AnalysisUsage &AU) const override { 558c4a3523SMatt Arsenault AU.addRequired<TargetPassConfig>(); 568c4a3523SMatt Arsenault AU.setPreservesAll(); 578c4a3523SMatt Arsenault } 588c4a3523SMatt Arsenault }; 598c4a3523SMatt Arsenault 608c4a3523SMatt Arsenault } // end anonymous namespace 618c4a3523SMatt Arsenault 628c4a3523SMatt Arsenault bool AMDGPULowerKernelArguments::runOnFunction(Function &F) { 638c4a3523SMatt Arsenault CallingConv::ID CC = F.getCallingConv(); 648c4a3523SMatt Arsenault if (CC != CallingConv::AMDGPU_KERNEL || F.arg_empty()) 658c4a3523SMatt Arsenault return false; 668c4a3523SMatt Arsenault 678c4a3523SMatt Arsenault auto &TPC = getAnalysis<TargetPassConfig>(); 688c4a3523SMatt Arsenault 698c4a3523SMatt Arsenault const TargetMachine &TM = TPC.getTM<TargetMachine>(); 705bfbae5cSTom Stellard const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F); 718c4a3523SMatt Arsenault LLVMContext &Ctx = F.getParent()->getContext(); 728c4a3523SMatt Arsenault const DataLayout &DL = F.getParent()->getDataLayout(); 738c4a3523SMatt Arsenault BasicBlock &EntryBlock = *F.begin(); 748c4a3523SMatt Arsenault IRBuilder<> Builder(&*EntryBlock.begin()); 758c4a3523SMatt Arsenault 76f5be3ad7SMatt Arsenault const unsigned KernArgBaseAlign = 16; // FIXME: Increase if necessary 778c4a3523SMatt Arsenault const uint64_t BaseOffset = ST.getExplicitKernelArgOffset(F); 788c4a3523SMatt Arsenault 794bec7d42SMatt Arsenault unsigned MaxAlign; 808c4a3523SMatt Arsenault // FIXME: Alignment is broken broken with explicit arg offset.; 814bec7d42SMatt Arsenault const uint64_t TotalKernArgSize = ST.getKernArgSegmentSize(F, MaxAlign); 82513e0c0eSMatt Arsenault if (TotalKernArgSize == 0) 83513e0c0eSMatt Arsenault return false; 848c4a3523SMatt Arsenault 858c4a3523SMatt Arsenault CallInst *KernArgSegment = 8657f5d0a8SNeil Henning Builder.CreateIntrinsic(Intrinsic::amdgcn_kernarg_segment_ptr, {}, {}, 8757f5d0a8SNeil Henning nullptr, F.getName() + ".kernarg.segment"); 888c4a3523SMatt Arsenault 898c4a3523SMatt Arsenault KernArgSegment->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull); 908c4a3523SMatt Arsenault KernArgSegment->addAttribute(AttributeList::ReturnIndex, 918c4a3523SMatt Arsenault Attribute::getWithDereferenceableBytes(Ctx, TotalKernArgSize)); 928c4a3523SMatt Arsenault 938c4a3523SMatt Arsenault unsigned AS = KernArgSegment->getType()->getPointerAddressSpace(); 94f5be3ad7SMatt Arsenault uint64_t ExplicitArgOffset = 0; 95f5be3ad7SMatt Arsenault 968c4a3523SMatt Arsenault for (Argument &Arg : F.args()) { 97f5be3ad7SMatt Arsenault Type *ArgTy = Arg.getType(); 98f5be3ad7SMatt Arsenault unsigned Align = DL.getABITypeAlignment(ArgTy); 99f5be3ad7SMatt Arsenault unsigned Size = DL.getTypeSizeInBits(ArgTy); 100f5be3ad7SMatt Arsenault unsigned AllocSize = DL.getTypeAllocSize(ArgTy); 101f5be3ad7SMatt Arsenault 102f5be3ad7SMatt Arsenault uint64_t EltOffset = alignTo(ExplicitArgOffset, Align) + BaseOffset; 103f5be3ad7SMatt Arsenault ExplicitArgOffset = alignTo(ExplicitArgOffset, Align) + AllocSize; 104f5be3ad7SMatt Arsenault 1058c4a3523SMatt Arsenault if (Arg.use_empty()) 1068c4a3523SMatt Arsenault continue; 1078c4a3523SMatt Arsenault 1088c4a3523SMatt Arsenault if (PointerType *PT = dyn_cast<PointerType>(ArgTy)) { 1098c4a3523SMatt Arsenault // FIXME: Hack. We rely on AssertZext to be able to fold DS addressing 1108c4a3523SMatt Arsenault // modes on SI to know the high bits are 0 so pointer adds don't wrap. We 1118c4a3523SMatt Arsenault // can't represent this with range metadata because it's only allowed for 1128c4a3523SMatt Arsenault // integer types. 1138c4a3523SMatt Arsenault if (PT->getAddressSpace() == AMDGPUAS::LOCAL_ADDRESS && 1148c4a3523SMatt Arsenault ST.getGeneration() == AMDGPUSubtarget::SOUTHERN_ISLANDS) 1158c4a3523SMatt Arsenault continue; 1168c4a3523SMatt Arsenault 1178c4a3523SMatt Arsenault // FIXME: We can replace this with equivalent alias.scope/noalias 1188c4a3523SMatt Arsenault // metadata, but this appears to be a lot of work. 1198c4a3523SMatt Arsenault if (Arg.hasNoAliasAttr()) 1208c4a3523SMatt Arsenault continue; 1218c4a3523SMatt Arsenault } 1228c4a3523SMatt Arsenault 1238c4a3523SMatt Arsenault VectorType *VT = dyn_cast<VectorType>(ArgTy); 1248c4a3523SMatt Arsenault bool IsV3 = VT && VT->getNumElements() == 3; 125*b5613ecfSMatt Arsenault bool DoShiftOpt = Size < 32 && !ArgTy->isAggregateType(); 126*b5613ecfSMatt Arsenault 1278c4a3523SMatt Arsenault VectorType *V4Ty = nullptr; 1288c4a3523SMatt Arsenault 1298c4a3523SMatt Arsenault int64_t AlignDownOffset = alignDown(EltOffset, 4); 1308c4a3523SMatt Arsenault int64_t OffsetDiff = EltOffset - AlignDownOffset; 131*b5613ecfSMatt Arsenault unsigned AdjustedAlign = MinAlign(DoShiftOpt ? AlignDownOffset : EltOffset, 132*b5613ecfSMatt Arsenault KernArgBaseAlign); 1338c4a3523SMatt Arsenault 1348c4a3523SMatt Arsenault Value *ArgPtr; 135*b5613ecfSMatt Arsenault if (DoShiftOpt) { // FIXME: Handle aggregate types 1368c4a3523SMatt Arsenault // Since we don't have sub-dword scalar loads, avoid doing an extload by 1378c4a3523SMatt Arsenault // loading earlier than the argument address, and extracting the relevant 1388c4a3523SMatt Arsenault // bits. 1398c4a3523SMatt Arsenault // 1408c4a3523SMatt Arsenault // Additionally widen any sub-dword load to i32 even if suitably aligned, 1418c4a3523SMatt Arsenault // so that CSE between different argument loads works easily. 1428c4a3523SMatt Arsenault 143f5be3ad7SMatt Arsenault ArgPtr = Builder.CreateConstInBoundsGEP1_64( 144f5be3ad7SMatt Arsenault KernArgSegment, 145f5be3ad7SMatt Arsenault AlignDownOffset, 1468c4a3523SMatt Arsenault Arg.getName() + ".kernarg.offset.align.down"); 147f5be3ad7SMatt Arsenault ArgPtr = Builder.CreateBitCast(ArgPtr, 148f5be3ad7SMatt Arsenault Builder.getInt32Ty()->getPointerTo(AS), 149f5be3ad7SMatt Arsenault ArgPtr->getName() + ".cast"); 1508c4a3523SMatt Arsenault } else { 151f5be3ad7SMatt Arsenault ArgPtr = Builder.CreateConstInBoundsGEP1_64( 152f5be3ad7SMatt Arsenault KernArgSegment, 153*b5613ecfSMatt Arsenault EltOffset, 1548c4a3523SMatt Arsenault Arg.getName() + ".kernarg.offset"); 155f5be3ad7SMatt Arsenault ArgPtr = Builder.CreateBitCast(ArgPtr, ArgTy->getPointerTo(AS), 156f5be3ad7SMatt Arsenault ArgPtr->getName() + ".cast"); 1578c4a3523SMatt Arsenault } 1588c4a3523SMatt Arsenault 1598c4a3523SMatt Arsenault if (IsV3 && Size >= 32) { 1608c4a3523SMatt Arsenault V4Ty = VectorType::get(VT->getVectorElementType(), 4); 1618c4a3523SMatt Arsenault // Use the hack that clang uses to avoid SelectionDAG ruining v3 loads 1628c4a3523SMatt Arsenault ArgPtr = Builder.CreateBitCast(ArgPtr, V4Ty->getPointerTo(AS)); 1638c4a3523SMatt Arsenault } 1648c4a3523SMatt Arsenault 1658c4a3523SMatt Arsenault LoadInst *Load = Builder.CreateAlignedLoad(ArgPtr, AdjustedAlign); 1668c4a3523SMatt Arsenault Load->setMetadata(LLVMContext::MD_invariant_load, MDNode::get(Ctx, {})); 1678c4a3523SMatt Arsenault 1688c4a3523SMatt Arsenault MDBuilder MDB(Ctx); 1698c4a3523SMatt Arsenault 1708c4a3523SMatt Arsenault if (isa<PointerType>(ArgTy)) { 1718c4a3523SMatt Arsenault if (Arg.hasNonNullAttr()) 1728c4a3523SMatt Arsenault Load->setMetadata(LLVMContext::MD_nonnull, MDNode::get(Ctx, {})); 1738c4a3523SMatt Arsenault 1748c4a3523SMatt Arsenault uint64_t DerefBytes = Arg.getDereferenceableBytes(); 1758c4a3523SMatt Arsenault if (DerefBytes != 0) { 1768c4a3523SMatt Arsenault Load->setMetadata( 1778c4a3523SMatt Arsenault LLVMContext::MD_dereferenceable, 1788c4a3523SMatt Arsenault MDNode::get(Ctx, 1798c4a3523SMatt Arsenault MDB.createConstant( 1808c4a3523SMatt Arsenault ConstantInt::get(Builder.getInt64Ty(), DerefBytes)))); 1818c4a3523SMatt Arsenault } 1828c4a3523SMatt Arsenault 1838c4a3523SMatt Arsenault uint64_t DerefOrNullBytes = Arg.getDereferenceableOrNullBytes(); 1848c4a3523SMatt Arsenault if (DerefOrNullBytes != 0) { 1858c4a3523SMatt Arsenault Load->setMetadata( 1868c4a3523SMatt Arsenault LLVMContext::MD_dereferenceable_or_null, 1878c4a3523SMatt Arsenault MDNode::get(Ctx, 1888c4a3523SMatt Arsenault MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(), 1898c4a3523SMatt Arsenault DerefOrNullBytes)))); 1908c4a3523SMatt Arsenault } 1918c4a3523SMatt Arsenault 1928c4a3523SMatt Arsenault unsigned ParamAlign = Arg.getParamAlignment(); 1938c4a3523SMatt Arsenault if (ParamAlign != 0) { 1948c4a3523SMatt Arsenault Load->setMetadata( 1958c4a3523SMatt Arsenault LLVMContext::MD_align, 1968c4a3523SMatt Arsenault MDNode::get(Ctx, 1978c4a3523SMatt Arsenault MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(), 1988c4a3523SMatt Arsenault ParamAlign)))); 1998c4a3523SMatt Arsenault } 2008c4a3523SMatt Arsenault } 2018c4a3523SMatt Arsenault 2028c4a3523SMatt Arsenault // TODO: Convert noalias arg to !noalias 2038c4a3523SMatt Arsenault 204*b5613ecfSMatt Arsenault if (DoShiftOpt) { 2058c4a3523SMatt Arsenault Value *ExtractBits = OffsetDiff == 0 ? 2068c4a3523SMatt Arsenault Load : Builder.CreateLShr(Load, OffsetDiff * 8); 2078c4a3523SMatt Arsenault 2088c4a3523SMatt Arsenault IntegerType *ArgIntTy = Builder.getIntNTy(Size); 2098c4a3523SMatt Arsenault Value *Trunc = Builder.CreateTrunc(ExtractBits, ArgIntTy); 2108c4a3523SMatt Arsenault Value *NewVal = Builder.CreateBitCast(Trunc, ArgTy, 2118c4a3523SMatt Arsenault Arg.getName() + ".load"); 2128c4a3523SMatt Arsenault Arg.replaceAllUsesWith(NewVal); 2138c4a3523SMatt Arsenault } else if (IsV3) { 2148c4a3523SMatt Arsenault Value *Shuf = Builder.CreateShuffleVector(Load, UndefValue::get(V4Ty), 2158c4a3523SMatt Arsenault {0, 1, 2}, 2168c4a3523SMatt Arsenault Arg.getName() + ".load"); 2178c4a3523SMatt Arsenault Arg.replaceAllUsesWith(Shuf); 2188c4a3523SMatt Arsenault } else { 2198c4a3523SMatt Arsenault Load->setName(Arg.getName() + ".load"); 2208c4a3523SMatt Arsenault Arg.replaceAllUsesWith(Load); 2218c4a3523SMatt Arsenault } 2228c4a3523SMatt Arsenault } 2238c4a3523SMatt Arsenault 224f5be3ad7SMatt Arsenault KernArgSegment->addAttribute( 225f5be3ad7SMatt Arsenault AttributeList::ReturnIndex, 226f5be3ad7SMatt Arsenault Attribute::getWithAlignment(Ctx, std::max(KernArgBaseAlign, MaxAlign))); 227f5be3ad7SMatt Arsenault 2288c4a3523SMatt Arsenault return true; 2298c4a3523SMatt Arsenault } 2308c4a3523SMatt Arsenault 2318c4a3523SMatt Arsenault INITIALIZE_PASS_BEGIN(AMDGPULowerKernelArguments, DEBUG_TYPE, 2328c4a3523SMatt Arsenault "AMDGPU Lower Kernel Arguments", false, false) 2338c4a3523SMatt Arsenault INITIALIZE_PASS_END(AMDGPULowerKernelArguments, DEBUG_TYPE, "AMDGPU Lower Kernel Arguments", 2348c4a3523SMatt Arsenault false, false) 2358c4a3523SMatt Arsenault 2368c4a3523SMatt Arsenault char AMDGPULowerKernelArguments::ID = 0; 2378c4a3523SMatt Arsenault 2388c4a3523SMatt Arsenault FunctionPass *llvm::createAMDGPULowerKernelArgumentsPass() { 2398c4a3523SMatt Arsenault return new AMDGPULowerKernelArguments(); 2408c4a3523SMatt Arsenault } 241