18c4a3523SMatt Arsenault //===-- AMDGPULowerKernelArguments.cpp ------------------------------------------===// 28c4a3523SMatt Arsenault // 32946cd70SChandler Carruth // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 42946cd70SChandler Carruth // See https://llvm.org/LICENSE.txt for license information. 52946cd70SChandler Carruth // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 68c4a3523SMatt Arsenault // 78c4a3523SMatt Arsenault //===----------------------------------------------------------------------===// 88c4a3523SMatt Arsenault // 98c4a3523SMatt Arsenault /// \file This pass replaces accesses to kernel arguments with loads from 108c4a3523SMatt Arsenault /// offsets from the kernarg base pointer. 118c4a3523SMatt Arsenault // 128c4a3523SMatt Arsenault //===----------------------------------------------------------------------===// 138c4a3523SMatt Arsenault 148c4a3523SMatt Arsenault #include "AMDGPU.h" 15560d7e04Sdfukalov #include "GCNSubtarget.h" 168c4a3523SMatt Arsenault #include "llvm/CodeGen/TargetPassConfig.h" 176a87e9b0Sdfukalov #include "llvm/IR/IntrinsicsAMDGPU.h" 1899142003SNikita Popov #include "llvm/IR/IRBuilder.h" 198c4a3523SMatt Arsenault #include "llvm/IR/MDBuilder.h" 206a87e9b0Sdfukalov #include "llvm/Target/TargetMachine.h" 218c4a3523SMatt Arsenault #define DEBUG_TYPE "amdgpu-lower-kernel-arguments" 228c4a3523SMatt Arsenault 238c4a3523SMatt Arsenault using namespace llvm; 248c4a3523SMatt Arsenault 258c4a3523SMatt Arsenault namespace { 268c4a3523SMatt Arsenault 278c4a3523SMatt Arsenault class AMDGPULowerKernelArguments : public FunctionPass{ 288c4a3523SMatt Arsenault public: 298c4a3523SMatt Arsenault static char ID; 308c4a3523SMatt Arsenault 318c4a3523SMatt Arsenault AMDGPULowerKernelArguments() : FunctionPass(ID) {} 328c4a3523SMatt Arsenault 338c4a3523SMatt Arsenault bool runOnFunction(Function &F) override; 348c4a3523SMatt Arsenault 358c4a3523SMatt Arsenault void getAnalysisUsage(AnalysisUsage &AU) const override { 368c4a3523SMatt Arsenault AU.addRequired<TargetPassConfig>(); 378c4a3523SMatt Arsenault AU.setPreservesAll(); 388c4a3523SMatt Arsenault } 398c4a3523SMatt Arsenault }; 408c4a3523SMatt Arsenault 418c4a3523SMatt Arsenault } // end anonymous namespace 428c4a3523SMatt Arsenault 43074c371aSMatt Arsenault // skip allocas 44074c371aSMatt Arsenault static BasicBlock::iterator getInsertPt(BasicBlock &BB) { 45074c371aSMatt Arsenault BasicBlock::iterator InsPt = BB.getFirstInsertionPt(); 46074c371aSMatt Arsenault for (BasicBlock::iterator E = BB.end(); InsPt != E; ++InsPt) { 47074c371aSMatt Arsenault AllocaInst *AI = dyn_cast<AllocaInst>(&*InsPt); 48074c371aSMatt Arsenault 49074c371aSMatt Arsenault // If this is a dynamic alloca, the value may depend on the loaded kernargs, 50074c371aSMatt Arsenault // so loads will need to be inserted before it. 51074c371aSMatt Arsenault if (!AI || !AI->isStaticAlloca()) 52074c371aSMatt Arsenault break; 53074c371aSMatt Arsenault } 54074c371aSMatt Arsenault 55074c371aSMatt Arsenault return InsPt; 56074c371aSMatt Arsenault } 57074c371aSMatt Arsenault 588c4a3523SMatt Arsenault bool AMDGPULowerKernelArguments::runOnFunction(Function &F) { 598c4a3523SMatt Arsenault CallingConv::ID CC = F.getCallingConv(); 608c4a3523SMatt Arsenault if (CC != CallingConv::AMDGPU_KERNEL || F.arg_empty()) 618c4a3523SMatt Arsenault return false; 628c4a3523SMatt Arsenault 638c4a3523SMatt Arsenault auto &TPC = getAnalysis<TargetPassConfig>(); 648c4a3523SMatt Arsenault 658c4a3523SMatt Arsenault const TargetMachine &TM = TPC.getTM<TargetMachine>(); 665bfbae5cSTom Stellard const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F); 678c4a3523SMatt Arsenault LLVMContext &Ctx = F.getParent()->getContext(); 688c4a3523SMatt Arsenault const DataLayout &DL = F.getParent()->getDataLayout(); 698c4a3523SMatt Arsenault BasicBlock &EntryBlock = *F.begin(); 70074c371aSMatt Arsenault IRBuilder<> Builder(&*getInsertPt(EntryBlock)); 718c4a3523SMatt Arsenault 72b65fa483SGuillaume Chatelet const Align KernArgBaseAlign(16); // FIXME: Increase if necessary 738c4a3523SMatt Arsenault const uint64_t BaseOffset = ST.getExplicitKernelArgOffset(F); 748c4a3523SMatt Arsenault 75b65fa483SGuillaume Chatelet Align MaxAlign; 768c4a3523SMatt Arsenault // FIXME: Alignment is broken broken with explicit arg offset.; 774bec7d42SMatt Arsenault const uint64_t TotalKernArgSize = ST.getKernArgSegmentSize(F, MaxAlign); 78513e0c0eSMatt Arsenault if (TotalKernArgSize == 0) 79513e0c0eSMatt Arsenault return false; 808c4a3523SMatt Arsenault 818c4a3523SMatt Arsenault CallInst *KernArgSegment = 8257f5d0a8SNeil Henning Builder.CreateIntrinsic(Intrinsic::amdgcn_kernarg_segment_ptr, {}, {}, 8357f5d0a8SNeil Henning nullptr, F.getName() + ".kernarg.segment"); 848c4a3523SMatt Arsenault 85*3f4d00bcSArthur Eubanks KernArgSegment->addRetAttr(Attribute::NonNull); 86*3f4d00bcSArthur Eubanks KernArgSegment->addRetAttr( 878c4a3523SMatt Arsenault Attribute::getWithDereferenceableBytes(Ctx, TotalKernArgSize)); 888c4a3523SMatt Arsenault 898c4a3523SMatt Arsenault unsigned AS = KernArgSegment->getType()->getPointerAddressSpace(); 90f5be3ad7SMatt Arsenault uint64_t ExplicitArgOffset = 0; 91f5be3ad7SMatt Arsenault 928c4a3523SMatt Arsenault for (Argument &Arg : F.args()) { 931168119cSMatt Arsenault const bool IsByRef = Arg.hasByRefAttr(); 941168119cSMatt Arsenault Type *ArgTy = IsByRef ? Arg.getParamByRefType() : Arg.getType(); 951168119cSMatt Arsenault MaybeAlign ABITypeAlign = IsByRef ? Arg.getParamAlign() : None; 961168119cSMatt Arsenault if (!ABITypeAlign) 971168119cSMatt Arsenault ABITypeAlign = DL.getABITypeAlign(ArgTy); 981168119cSMatt Arsenault 991168119cSMatt Arsenault uint64_t Size = DL.getTypeSizeInBits(ArgTy); 1001168119cSMatt Arsenault uint64_t AllocSize = DL.getTypeAllocSize(ArgTy); 101f5be3ad7SMatt Arsenault 102b65fa483SGuillaume Chatelet uint64_t EltOffset = alignTo(ExplicitArgOffset, ABITypeAlign) + BaseOffset; 103b65fa483SGuillaume Chatelet ExplicitArgOffset = alignTo(ExplicitArgOffset, ABITypeAlign) + AllocSize; 104f5be3ad7SMatt Arsenault 1058c4a3523SMatt Arsenault if (Arg.use_empty()) 1068c4a3523SMatt Arsenault continue; 1078c4a3523SMatt Arsenault 1081168119cSMatt Arsenault // If this is byval, the loads are already explicit in the function. We just 1091168119cSMatt Arsenault // need to rewrite the pointer values. 1101168119cSMatt Arsenault if (IsByRef) { 1111168119cSMatt Arsenault Value *ArgOffsetPtr = Builder.CreateConstInBoundsGEP1_64( 1121168119cSMatt Arsenault Builder.getInt8Ty(), KernArgSegment, EltOffset, 1131168119cSMatt Arsenault Arg.getName() + ".byval.kernarg.offset"); 1141168119cSMatt Arsenault 1151168119cSMatt Arsenault Value *CastOffsetPtr = Builder.CreatePointerBitCastOrAddrSpaceCast( 1161168119cSMatt Arsenault ArgOffsetPtr, Arg.getType()); 1171168119cSMatt Arsenault Arg.replaceAllUsesWith(CastOffsetPtr); 1181168119cSMatt Arsenault continue; 1191168119cSMatt Arsenault } 1201168119cSMatt Arsenault 1218c4a3523SMatt Arsenault if (PointerType *PT = dyn_cast<PointerType>(ArgTy)) { 1228c4a3523SMatt Arsenault // FIXME: Hack. We rely on AssertZext to be able to fold DS addressing 1238c4a3523SMatt Arsenault // modes on SI to know the high bits are 0 so pointer adds don't wrap. We 1248c4a3523SMatt Arsenault // can't represent this with range metadata because it's only allowed for 1258c4a3523SMatt Arsenault // integer types. 126cdd191d9SMatt Arsenault if ((PT->getAddressSpace() == AMDGPUAS::LOCAL_ADDRESS || 127cdd191d9SMatt Arsenault PT->getAddressSpace() == AMDGPUAS::REGION_ADDRESS) && 128e4c2e9b0SMatt Arsenault !ST.hasUsableDSOffset()) 1298c4a3523SMatt Arsenault continue; 1308c4a3523SMatt Arsenault 1318c4a3523SMatt Arsenault // FIXME: We can replace this with equivalent alias.scope/noalias 1328c4a3523SMatt Arsenault // metadata, but this appears to be a lot of work. 1338c4a3523SMatt Arsenault if (Arg.hasNoAliasAttr()) 1348c4a3523SMatt Arsenault continue; 1358c4a3523SMatt Arsenault } 1368c4a3523SMatt Arsenault 1373254a001SChristopher Tetreault auto *VT = dyn_cast<FixedVectorType>(ArgTy); 1388c4a3523SMatt Arsenault bool IsV3 = VT && VT->getNumElements() == 3; 139b5613ecfSMatt Arsenault bool DoShiftOpt = Size < 32 && !ArgTy->isAggregateType(); 140b5613ecfSMatt Arsenault 1418c4a3523SMatt Arsenault VectorType *V4Ty = nullptr; 1428c4a3523SMatt Arsenault 1438c4a3523SMatt Arsenault int64_t AlignDownOffset = alignDown(EltOffset, 4); 1448c4a3523SMatt Arsenault int64_t OffsetDiff = EltOffset - AlignDownOffset; 145b65fa483SGuillaume Chatelet Align AdjustedAlign = commonAlignment( 146b65fa483SGuillaume Chatelet KernArgBaseAlign, DoShiftOpt ? AlignDownOffset : EltOffset); 1478c4a3523SMatt Arsenault 1488c4a3523SMatt Arsenault Value *ArgPtr; 14914359ef1SJames Y Knight Type *AdjustedArgTy; 150b5613ecfSMatt Arsenault if (DoShiftOpt) { // FIXME: Handle aggregate types 1518c4a3523SMatt Arsenault // Since we don't have sub-dword scalar loads, avoid doing an extload by 1528c4a3523SMatt Arsenault // loading earlier than the argument address, and extracting the relevant 1538c4a3523SMatt Arsenault // bits. 1548c4a3523SMatt Arsenault // 1558c4a3523SMatt Arsenault // Additionally widen any sub-dword load to i32 even if suitably aligned, 1568c4a3523SMatt Arsenault // so that CSE between different argument loads works easily. 157f5be3ad7SMatt Arsenault ArgPtr = Builder.CreateConstInBoundsGEP1_64( 1587716075aSJames Y Knight Builder.getInt8Ty(), KernArgSegment, AlignDownOffset, 1598c4a3523SMatt Arsenault Arg.getName() + ".kernarg.offset.align.down"); 16014359ef1SJames Y Knight AdjustedArgTy = Builder.getInt32Ty(); 1618c4a3523SMatt Arsenault } else { 162f5be3ad7SMatt Arsenault ArgPtr = Builder.CreateConstInBoundsGEP1_64( 1637716075aSJames Y Knight Builder.getInt8Ty(), KernArgSegment, EltOffset, 1648c4a3523SMatt Arsenault Arg.getName() + ".kernarg.offset"); 16514359ef1SJames Y Knight AdjustedArgTy = ArgTy; 1668c4a3523SMatt Arsenault } 1678c4a3523SMatt Arsenault 1688c4a3523SMatt Arsenault if (IsV3 && Size >= 32) { 169aad93654SChristopher Tetreault V4Ty = FixedVectorType::get(VT->getElementType(), 4); 1708c4a3523SMatt Arsenault // Use the hack that clang uses to avoid SelectionDAG ruining v3 loads 17114359ef1SJames Y Knight AdjustedArgTy = V4Ty; 1728c4a3523SMatt Arsenault } 1738c4a3523SMatt Arsenault 17414359ef1SJames Y Knight ArgPtr = Builder.CreateBitCast(ArgPtr, AdjustedArgTy->getPointerTo(AS), 17514359ef1SJames Y Knight ArgPtr->getName() + ".cast"); 17614359ef1SJames Y Knight LoadInst *Load = 177279fa8e0SGuillaume Chatelet Builder.CreateAlignedLoad(AdjustedArgTy, ArgPtr, AdjustedAlign); 1788c4a3523SMatt Arsenault Load->setMetadata(LLVMContext::MD_invariant_load, MDNode::get(Ctx, {})); 1798c4a3523SMatt Arsenault 1808c4a3523SMatt Arsenault MDBuilder MDB(Ctx); 1818c4a3523SMatt Arsenault 1828c4a3523SMatt Arsenault if (isa<PointerType>(ArgTy)) { 1838c4a3523SMatt Arsenault if (Arg.hasNonNullAttr()) 1848c4a3523SMatt Arsenault Load->setMetadata(LLVMContext::MD_nonnull, MDNode::get(Ctx, {})); 1858c4a3523SMatt Arsenault 1868c4a3523SMatt Arsenault uint64_t DerefBytes = Arg.getDereferenceableBytes(); 1878c4a3523SMatt Arsenault if (DerefBytes != 0) { 1888c4a3523SMatt Arsenault Load->setMetadata( 1898c4a3523SMatt Arsenault LLVMContext::MD_dereferenceable, 1908c4a3523SMatt Arsenault MDNode::get(Ctx, 1918c4a3523SMatt Arsenault MDB.createConstant( 1928c4a3523SMatt Arsenault ConstantInt::get(Builder.getInt64Ty(), DerefBytes)))); 1938c4a3523SMatt Arsenault } 1948c4a3523SMatt Arsenault 1958c4a3523SMatt Arsenault uint64_t DerefOrNullBytes = Arg.getDereferenceableOrNullBytes(); 1968c4a3523SMatt Arsenault if (DerefOrNullBytes != 0) { 1978c4a3523SMatt Arsenault Load->setMetadata( 1988c4a3523SMatt Arsenault LLVMContext::MD_dereferenceable_or_null, 1998c4a3523SMatt Arsenault MDNode::get(Ctx, 2008c4a3523SMatt Arsenault MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(), 2018c4a3523SMatt Arsenault DerefOrNullBytes)))); 2028c4a3523SMatt Arsenault } 2038c4a3523SMatt Arsenault 2048c4a3523SMatt Arsenault unsigned ParamAlign = Arg.getParamAlignment(); 2058c4a3523SMatt Arsenault if (ParamAlign != 0) { 2068c4a3523SMatt Arsenault Load->setMetadata( 2078c4a3523SMatt Arsenault LLVMContext::MD_align, 2088c4a3523SMatt Arsenault MDNode::get(Ctx, 2098c4a3523SMatt Arsenault MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(), 2108c4a3523SMatt Arsenault ParamAlign)))); 2118c4a3523SMatt Arsenault } 2128c4a3523SMatt Arsenault } 2138c4a3523SMatt Arsenault 2148c4a3523SMatt Arsenault // TODO: Convert noalias arg to !noalias 2158c4a3523SMatt Arsenault 216b5613ecfSMatt Arsenault if (DoShiftOpt) { 2178c4a3523SMatt Arsenault Value *ExtractBits = OffsetDiff == 0 ? 2188c4a3523SMatt Arsenault Load : Builder.CreateLShr(Load, OffsetDiff * 8); 2198c4a3523SMatt Arsenault 2208c4a3523SMatt Arsenault IntegerType *ArgIntTy = Builder.getIntNTy(Size); 2218c4a3523SMatt Arsenault Value *Trunc = Builder.CreateTrunc(ExtractBits, ArgIntTy); 2228c4a3523SMatt Arsenault Value *NewVal = Builder.CreateBitCast(Trunc, ArgTy, 2238c4a3523SMatt Arsenault Arg.getName() + ".load"); 2248c4a3523SMatt Arsenault Arg.replaceAllUsesWith(NewVal); 2258c4a3523SMatt Arsenault } else if (IsV3) { 226420d046dSJuneyoung Lee Value *Shuf = Builder.CreateShuffleVector(Load, ArrayRef<int>{0, 1, 2}, 2278c4a3523SMatt Arsenault Arg.getName() + ".load"); 2288c4a3523SMatt Arsenault Arg.replaceAllUsesWith(Shuf); 2298c4a3523SMatt Arsenault } else { 2308c4a3523SMatt Arsenault Load->setName(Arg.getName() + ".load"); 2318c4a3523SMatt Arsenault Arg.replaceAllUsesWith(Load); 2328c4a3523SMatt Arsenault } 2338c4a3523SMatt Arsenault } 2348c4a3523SMatt Arsenault 235*3f4d00bcSArthur Eubanks KernArgSegment->addRetAttr( 236f5be3ad7SMatt Arsenault Attribute::getWithAlignment(Ctx, std::max(KernArgBaseAlign, MaxAlign))); 237f5be3ad7SMatt Arsenault 2388c4a3523SMatt Arsenault return true; 2398c4a3523SMatt Arsenault } 2408c4a3523SMatt Arsenault 2418c4a3523SMatt Arsenault INITIALIZE_PASS_BEGIN(AMDGPULowerKernelArguments, DEBUG_TYPE, 2428c4a3523SMatt Arsenault "AMDGPU Lower Kernel Arguments", false, false) 2438c4a3523SMatt Arsenault INITIALIZE_PASS_END(AMDGPULowerKernelArguments, DEBUG_TYPE, "AMDGPU Lower Kernel Arguments", 2448c4a3523SMatt Arsenault false, false) 2458c4a3523SMatt Arsenault 2468c4a3523SMatt Arsenault char AMDGPULowerKernelArguments::ID = 0; 2478c4a3523SMatt Arsenault 2488c4a3523SMatt Arsenault FunctionPass *llvm::createAMDGPULowerKernelArgumentsPass() { 2498c4a3523SMatt Arsenault return new AMDGPULowerKernelArguments(); 2508c4a3523SMatt Arsenault } 251