18c4a3523SMatt Arsenault //===-- AMDGPULowerKernelArguments.cpp ------------------------------------------===// 28c4a3523SMatt Arsenault // 32946cd70SChandler Carruth // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 42946cd70SChandler Carruth // See https://llvm.org/LICENSE.txt for license information. 52946cd70SChandler Carruth // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 68c4a3523SMatt Arsenault // 78c4a3523SMatt Arsenault //===----------------------------------------------------------------------===// 88c4a3523SMatt Arsenault // 98c4a3523SMatt Arsenault /// \file This pass replaces accesses to kernel arguments with loads from 108c4a3523SMatt Arsenault /// offsets from the kernarg base pointer. 118c4a3523SMatt Arsenault // 128c4a3523SMatt Arsenault //===----------------------------------------------------------------------===// 138c4a3523SMatt Arsenault 148c4a3523SMatt Arsenault #include "AMDGPU.h" 158c4a3523SMatt Arsenault #include "AMDGPUSubtarget.h" 168c4a3523SMatt Arsenault #include "AMDGPUTargetMachine.h" 178c4a3523SMatt Arsenault #include "llvm/ADT/StringRef.h" 188c4a3523SMatt Arsenault #include "llvm/Analysis/Loads.h" 198c4a3523SMatt Arsenault #include "llvm/CodeGen/Passes.h" 208c4a3523SMatt Arsenault #include "llvm/CodeGen/TargetPassConfig.h" 218c4a3523SMatt Arsenault #include "llvm/IR/Attributes.h" 228c4a3523SMatt Arsenault #include "llvm/IR/BasicBlock.h" 238c4a3523SMatt Arsenault #include "llvm/IR/Constants.h" 248c4a3523SMatt Arsenault #include "llvm/IR/DerivedTypes.h" 258c4a3523SMatt Arsenault #include "llvm/IR/Function.h" 268c4a3523SMatt Arsenault #include "llvm/IR/IRBuilder.h" 278c4a3523SMatt Arsenault #include "llvm/IR/InstrTypes.h" 288c4a3523SMatt Arsenault #include "llvm/IR/Instruction.h" 298c4a3523SMatt Arsenault #include "llvm/IR/Instructions.h" 308c4a3523SMatt Arsenault #include "llvm/IR/LLVMContext.h" 318c4a3523SMatt Arsenault #include "llvm/IR/MDBuilder.h" 328c4a3523SMatt Arsenault #include "llvm/IR/Metadata.h" 338c4a3523SMatt Arsenault #include "llvm/IR/Operator.h" 348c4a3523SMatt Arsenault #include "llvm/IR/Type.h" 358c4a3523SMatt Arsenault #include "llvm/IR/Value.h" 368c4a3523SMatt Arsenault #include "llvm/Pass.h" 378c4a3523SMatt Arsenault #include "llvm/Support/Casting.h" 388c4a3523SMatt Arsenault 398c4a3523SMatt Arsenault #define DEBUG_TYPE "amdgpu-lower-kernel-arguments" 408c4a3523SMatt Arsenault 418c4a3523SMatt Arsenault using namespace llvm; 428c4a3523SMatt Arsenault 438c4a3523SMatt Arsenault namespace { 448c4a3523SMatt Arsenault 458c4a3523SMatt Arsenault class AMDGPULowerKernelArguments : public FunctionPass{ 468c4a3523SMatt Arsenault public: 478c4a3523SMatt Arsenault static char ID; 488c4a3523SMatt Arsenault 498c4a3523SMatt Arsenault AMDGPULowerKernelArguments() : FunctionPass(ID) {} 508c4a3523SMatt Arsenault 518c4a3523SMatt Arsenault bool runOnFunction(Function &F) override; 528c4a3523SMatt Arsenault 538c4a3523SMatt Arsenault void getAnalysisUsage(AnalysisUsage &AU) const override { 548c4a3523SMatt Arsenault AU.addRequired<TargetPassConfig>(); 558c4a3523SMatt Arsenault AU.setPreservesAll(); 568c4a3523SMatt Arsenault } 578c4a3523SMatt Arsenault }; 588c4a3523SMatt Arsenault 598c4a3523SMatt Arsenault } // end anonymous namespace 608c4a3523SMatt Arsenault 618c4a3523SMatt Arsenault bool AMDGPULowerKernelArguments::runOnFunction(Function &F) { 628c4a3523SMatt Arsenault CallingConv::ID CC = F.getCallingConv(); 638c4a3523SMatt Arsenault if (CC != CallingConv::AMDGPU_KERNEL || F.arg_empty()) 648c4a3523SMatt Arsenault return false; 658c4a3523SMatt Arsenault 668c4a3523SMatt Arsenault auto &TPC = getAnalysis<TargetPassConfig>(); 678c4a3523SMatt Arsenault 688c4a3523SMatt Arsenault const TargetMachine &TM = TPC.getTM<TargetMachine>(); 695bfbae5cSTom Stellard const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F); 708c4a3523SMatt Arsenault LLVMContext &Ctx = F.getParent()->getContext(); 718c4a3523SMatt Arsenault const DataLayout &DL = F.getParent()->getDataLayout(); 728c4a3523SMatt Arsenault BasicBlock &EntryBlock = *F.begin(); 738c4a3523SMatt Arsenault IRBuilder<> Builder(&*EntryBlock.begin()); 748c4a3523SMatt Arsenault 75*b65fa483SGuillaume Chatelet const Align KernArgBaseAlign(16); // FIXME: Increase if necessary 768c4a3523SMatt Arsenault const uint64_t BaseOffset = ST.getExplicitKernelArgOffset(F); 778c4a3523SMatt Arsenault 78*b65fa483SGuillaume Chatelet Align MaxAlign; 798c4a3523SMatt Arsenault // FIXME: Alignment is broken broken with explicit arg offset.; 804bec7d42SMatt Arsenault const uint64_t TotalKernArgSize = ST.getKernArgSegmentSize(F, MaxAlign); 81513e0c0eSMatt Arsenault if (TotalKernArgSize == 0) 82513e0c0eSMatt Arsenault return false; 838c4a3523SMatt Arsenault 848c4a3523SMatt Arsenault CallInst *KernArgSegment = 8557f5d0a8SNeil Henning Builder.CreateIntrinsic(Intrinsic::amdgcn_kernarg_segment_ptr, {}, {}, 8657f5d0a8SNeil Henning nullptr, F.getName() + ".kernarg.segment"); 878c4a3523SMatt Arsenault 888c4a3523SMatt Arsenault KernArgSegment->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull); 898c4a3523SMatt Arsenault KernArgSegment->addAttribute(AttributeList::ReturnIndex, 908c4a3523SMatt Arsenault Attribute::getWithDereferenceableBytes(Ctx, TotalKernArgSize)); 918c4a3523SMatt Arsenault 928c4a3523SMatt Arsenault unsigned AS = KernArgSegment->getType()->getPointerAddressSpace(); 93f5be3ad7SMatt Arsenault uint64_t ExplicitArgOffset = 0; 94f5be3ad7SMatt Arsenault 958c4a3523SMatt Arsenault for (Argument &Arg : F.args()) { 96f5be3ad7SMatt Arsenault Type *ArgTy = Arg.getType(); 97*b65fa483SGuillaume Chatelet unsigned ABITypeAlign = DL.getABITypeAlignment(ArgTy); 98f5be3ad7SMatt Arsenault unsigned Size = DL.getTypeSizeInBits(ArgTy); 99f5be3ad7SMatt Arsenault unsigned AllocSize = DL.getTypeAllocSize(ArgTy); 100f5be3ad7SMatt Arsenault 101*b65fa483SGuillaume Chatelet uint64_t EltOffset = alignTo(ExplicitArgOffset, ABITypeAlign) + BaseOffset; 102*b65fa483SGuillaume Chatelet ExplicitArgOffset = alignTo(ExplicitArgOffset, ABITypeAlign) + AllocSize; 103f5be3ad7SMatt Arsenault 1048c4a3523SMatt Arsenault if (Arg.use_empty()) 1058c4a3523SMatt Arsenault continue; 1068c4a3523SMatt Arsenault 1078c4a3523SMatt Arsenault if (PointerType *PT = dyn_cast<PointerType>(ArgTy)) { 1088c4a3523SMatt Arsenault // FIXME: Hack. We rely on AssertZext to be able to fold DS addressing 1098c4a3523SMatt Arsenault // modes on SI to know the high bits are 0 so pointer adds don't wrap. We 1108c4a3523SMatt Arsenault // can't represent this with range metadata because it's only allowed for 1118c4a3523SMatt Arsenault // integer types. 112cdd191d9SMatt Arsenault if ((PT->getAddressSpace() == AMDGPUAS::LOCAL_ADDRESS || 113cdd191d9SMatt Arsenault PT->getAddressSpace() == AMDGPUAS::REGION_ADDRESS) && 114e4c2e9b0SMatt Arsenault !ST.hasUsableDSOffset()) 1158c4a3523SMatt Arsenault continue; 1168c4a3523SMatt Arsenault 1178c4a3523SMatt Arsenault // FIXME: We can replace this with equivalent alias.scope/noalias 1188c4a3523SMatt Arsenault // metadata, but this appears to be a lot of work. 1198c4a3523SMatt Arsenault if (Arg.hasNoAliasAttr()) 1208c4a3523SMatt Arsenault continue; 1218c4a3523SMatt Arsenault } 1228c4a3523SMatt Arsenault 1238c4a3523SMatt Arsenault VectorType *VT = dyn_cast<VectorType>(ArgTy); 1248c4a3523SMatt Arsenault bool IsV3 = VT && VT->getNumElements() == 3; 125b5613ecfSMatt Arsenault bool DoShiftOpt = Size < 32 && !ArgTy->isAggregateType(); 126b5613ecfSMatt Arsenault 1278c4a3523SMatt Arsenault VectorType *V4Ty = nullptr; 1288c4a3523SMatt Arsenault 1298c4a3523SMatt Arsenault int64_t AlignDownOffset = alignDown(EltOffset, 4); 1308c4a3523SMatt Arsenault int64_t OffsetDiff = EltOffset - AlignDownOffset; 131*b65fa483SGuillaume Chatelet Align AdjustedAlign = commonAlignment( 132*b65fa483SGuillaume Chatelet KernArgBaseAlign, DoShiftOpt ? AlignDownOffset : EltOffset); 1338c4a3523SMatt Arsenault 1348c4a3523SMatt Arsenault Value *ArgPtr; 13514359ef1SJames Y Knight Type *AdjustedArgTy; 136b5613ecfSMatt Arsenault if (DoShiftOpt) { // FIXME: Handle aggregate types 1378c4a3523SMatt Arsenault // Since we don't have sub-dword scalar loads, avoid doing an extload by 1388c4a3523SMatt Arsenault // loading earlier than the argument address, and extracting the relevant 1398c4a3523SMatt Arsenault // bits. 1408c4a3523SMatt Arsenault // 1418c4a3523SMatt Arsenault // Additionally widen any sub-dword load to i32 even if suitably aligned, 1428c4a3523SMatt Arsenault // so that CSE between different argument loads works easily. 143f5be3ad7SMatt Arsenault ArgPtr = Builder.CreateConstInBoundsGEP1_64( 1447716075aSJames Y Knight Builder.getInt8Ty(), KernArgSegment, AlignDownOffset, 1458c4a3523SMatt Arsenault Arg.getName() + ".kernarg.offset.align.down"); 14614359ef1SJames Y Knight AdjustedArgTy = Builder.getInt32Ty(); 1478c4a3523SMatt Arsenault } else { 148f5be3ad7SMatt Arsenault ArgPtr = Builder.CreateConstInBoundsGEP1_64( 1497716075aSJames Y Knight Builder.getInt8Ty(), KernArgSegment, EltOffset, 1508c4a3523SMatt Arsenault Arg.getName() + ".kernarg.offset"); 15114359ef1SJames Y Knight AdjustedArgTy = ArgTy; 1528c4a3523SMatt Arsenault } 1538c4a3523SMatt Arsenault 1548c4a3523SMatt Arsenault if (IsV3 && Size >= 32) { 1558c4a3523SMatt Arsenault V4Ty = VectorType::get(VT->getVectorElementType(), 4); 1568c4a3523SMatt Arsenault // Use the hack that clang uses to avoid SelectionDAG ruining v3 loads 15714359ef1SJames Y Knight AdjustedArgTy = V4Ty; 1588c4a3523SMatt Arsenault } 1598c4a3523SMatt Arsenault 16014359ef1SJames Y Knight ArgPtr = Builder.CreateBitCast(ArgPtr, AdjustedArgTy->getPointerTo(AS), 16114359ef1SJames Y Knight ArgPtr->getName() + ".cast"); 16214359ef1SJames Y Knight LoadInst *Load = 163*b65fa483SGuillaume Chatelet Builder.CreateAlignedLoad(AdjustedArgTy, ArgPtr, AdjustedAlign.value()); 1648c4a3523SMatt Arsenault Load->setMetadata(LLVMContext::MD_invariant_load, MDNode::get(Ctx, {})); 1658c4a3523SMatt Arsenault 1668c4a3523SMatt Arsenault MDBuilder MDB(Ctx); 1678c4a3523SMatt Arsenault 1688c4a3523SMatt Arsenault if (isa<PointerType>(ArgTy)) { 1698c4a3523SMatt Arsenault if (Arg.hasNonNullAttr()) 1708c4a3523SMatt Arsenault Load->setMetadata(LLVMContext::MD_nonnull, MDNode::get(Ctx, {})); 1718c4a3523SMatt Arsenault 1728c4a3523SMatt Arsenault uint64_t DerefBytes = Arg.getDereferenceableBytes(); 1738c4a3523SMatt Arsenault if (DerefBytes != 0) { 1748c4a3523SMatt Arsenault Load->setMetadata( 1758c4a3523SMatt Arsenault LLVMContext::MD_dereferenceable, 1768c4a3523SMatt Arsenault MDNode::get(Ctx, 1778c4a3523SMatt Arsenault MDB.createConstant( 1788c4a3523SMatt Arsenault ConstantInt::get(Builder.getInt64Ty(), DerefBytes)))); 1798c4a3523SMatt Arsenault } 1808c4a3523SMatt Arsenault 1818c4a3523SMatt Arsenault uint64_t DerefOrNullBytes = Arg.getDereferenceableOrNullBytes(); 1828c4a3523SMatt Arsenault if (DerefOrNullBytes != 0) { 1838c4a3523SMatt Arsenault Load->setMetadata( 1848c4a3523SMatt Arsenault LLVMContext::MD_dereferenceable_or_null, 1858c4a3523SMatt Arsenault MDNode::get(Ctx, 1868c4a3523SMatt Arsenault MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(), 1878c4a3523SMatt Arsenault DerefOrNullBytes)))); 1888c4a3523SMatt Arsenault } 1898c4a3523SMatt Arsenault 1908c4a3523SMatt Arsenault unsigned ParamAlign = Arg.getParamAlignment(); 1918c4a3523SMatt Arsenault if (ParamAlign != 0) { 1928c4a3523SMatt Arsenault Load->setMetadata( 1938c4a3523SMatt Arsenault LLVMContext::MD_align, 1948c4a3523SMatt Arsenault MDNode::get(Ctx, 1958c4a3523SMatt Arsenault MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(), 1968c4a3523SMatt Arsenault ParamAlign)))); 1978c4a3523SMatt Arsenault } 1988c4a3523SMatt Arsenault } 1998c4a3523SMatt Arsenault 2008c4a3523SMatt Arsenault // TODO: Convert noalias arg to !noalias 2018c4a3523SMatt Arsenault 202b5613ecfSMatt Arsenault if (DoShiftOpt) { 2038c4a3523SMatt Arsenault Value *ExtractBits = OffsetDiff == 0 ? 2048c4a3523SMatt Arsenault Load : Builder.CreateLShr(Load, OffsetDiff * 8); 2058c4a3523SMatt Arsenault 2068c4a3523SMatt Arsenault IntegerType *ArgIntTy = Builder.getIntNTy(Size); 2078c4a3523SMatt Arsenault Value *Trunc = Builder.CreateTrunc(ExtractBits, ArgIntTy); 2088c4a3523SMatt Arsenault Value *NewVal = Builder.CreateBitCast(Trunc, ArgTy, 2098c4a3523SMatt Arsenault Arg.getName() + ".load"); 2108c4a3523SMatt Arsenault Arg.replaceAllUsesWith(NewVal); 2118c4a3523SMatt Arsenault } else if (IsV3) { 2128c4a3523SMatt Arsenault Value *Shuf = Builder.CreateShuffleVector(Load, UndefValue::get(V4Ty), 2138c4a3523SMatt Arsenault {0, 1, 2}, 2148c4a3523SMatt Arsenault Arg.getName() + ".load"); 2158c4a3523SMatt Arsenault Arg.replaceAllUsesWith(Shuf); 2168c4a3523SMatt Arsenault } else { 2178c4a3523SMatt Arsenault Load->setName(Arg.getName() + ".load"); 2188c4a3523SMatt Arsenault Arg.replaceAllUsesWith(Load); 2198c4a3523SMatt Arsenault } 2208c4a3523SMatt Arsenault } 2218c4a3523SMatt Arsenault 222f5be3ad7SMatt Arsenault KernArgSegment->addAttribute( 223f5be3ad7SMatt Arsenault AttributeList::ReturnIndex, 224f5be3ad7SMatt Arsenault Attribute::getWithAlignment(Ctx, std::max(KernArgBaseAlign, MaxAlign))); 225f5be3ad7SMatt Arsenault 2268c4a3523SMatt Arsenault return true; 2278c4a3523SMatt Arsenault } 2288c4a3523SMatt Arsenault 2298c4a3523SMatt Arsenault INITIALIZE_PASS_BEGIN(AMDGPULowerKernelArguments, DEBUG_TYPE, 2308c4a3523SMatt Arsenault "AMDGPU Lower Kernel Arguments", false, false) 2318c4a3523SMatt Arsenault INITIALIZE_PASS_END(AMDGPULowerKernelArguments, DEBUG_TYPE, "AMDGPU Lower Kernel Arguments", 2328c4a3523SMatt Arsenault false, false) 2338c4a3523SMatt Arsenault 2348c4a3523SMatt Arsenault char AMDGPULowerKernelArguments::ID = 0; 2358c4a3523SMatt Arsenault 2368c4a3523SMatt Arsenault FunctionPass *llvm::createAMDGPULowerKernelArgumentsPass() { 2378c4a3523SMatt Arsenault return new AMDGPULowerKernelArguments(); 2388c4a3523SMatt Arsenault } 239