18c4a3523SMatt Arsenault //===-- AMDGPULowerKernelArguments.cpp ------------------------------------------===// 28c4a3523SMatt Arsenault // 38c4a3523SMatt Arsenault // The LLVM Compiler Infrastructure 48c4a3523SMatt Arsenault // 58c4a3523SMatt Arsenault // This file is distributed under the University of Illinois Open Source 68c4a3523SMatt Arsenault // License. See LICENSE.TXT for details. 78c4a3523SMatt Arsenault // 88c4a3523SMatt Arsenault //===----------------------------------------------------------------------===// 98c4a3523SMatt Arsenault // 108c4a3523SMatt Arsenault /// \file This pass replaces accesses to kernel arguments with loads from 118c4a3523SMatt Arsenault /// offsets from the kernarg base pointer. 128c4a3523SMatt Arsenault // 138c4a3523SMatt Arsenault //===----------------------------------------------------------------------===// 148c4a3523SMatt Arsenault 158c4a3523SMatt Arsenault #include "AMDGPU.h" 168c4a3523SMatt Arsenault #include "AMDGPUSubtarget.h" 178c4a3523SMatt Arsenault #include "AMDGPUTargetMachine.h" 188c4a3523SMatt Arsenault #include "llvm/ADT/StringRef.h" 198c4a3523SMatt Arsenault #include "llvm/Analysis/DivergenceAnalysis.h" 208c4a3523SMatt Arsenault #include "llvm/Analysis/Loads.h" 218c4a3523SMatt Arsenault #include "llvm/CodeGen/Passes.h" 228c4a3523SMatt Arsenault #include "llvm/CodeGen/TargetPassConfig.h" 238c4a3523SMatt Arsenault #include "llvm/IR/Attributes.h" 248c4a3523SMatt Arsenault #include "llvm/IR/BasicBlock.h" 258c4a3523SMatt Arsenault #include "llvm/IR/Constants.h" 268c4a3523SMatt Arsenault #include "llvm/IR/DerivedTypes.h" 278c4a3523SMatt Arsenault #include "llvm/IR/Function.h" 288c4a3523SMatt Arsenault #include "llvm/IR/IRBuilder.h" 298c4a3523SMatt Arsenault #include "llvm/IR/InstrTypes.h" 308c4a3523SMatt Arsenault #include "llvm/IR/Instruction.h" 318c4a3523SMatt Arsenault #include "llvm/IR/Instructions.h" 328c4a3523SMatt Arsenault #include "llvm/IR/LLVMContext.h" 338c4a3523SMatt Arsenault #include "llvm/IR/MDBuilder.h" 348c4a3523SMatt Arsenault #include "llvm/IR/Metadata.h" 358c4a3523SMatt Arsenault #include "llvm/IR/Operator.h" 368c4a3523SMatt Arsenault #include "llvm/IR/Type.h" 378c4a3523SMatt Arsenault #include "llvm/IR/Value.h" 388c4a3523SMatt Arsenault #include "llvm/Pass.h" 398c4a3523SMatt Arsenault #include "llvm/Support/Casting.h" 408c4a3523SMatt Arsenault 418c4a3523SMatt Arsenault #define DEBUG_TYPE "amdgpu-lower-kernel-arguments" 428c4a3523SMatt Arsenault 438c4a3523SMatt Arsenault using namespace llvm; 448c4a3523SMatt Arsenault 458c4a3523SMatt Arsenault namespace { 468c4a3523SMatt Arsenault 478c4a3523SMatt Arsenault class AMDGPULowerKernelArguments : public FunctionPass{ 488c4a3523SMatt Arsenault public: 498c4a3523SMatt Arsenault static char ID; 508c4a3523SMatt Arsenault 518c4a3523SMatt Arsenault AMDGPULowerKernelArguments() : FunctionPass(ID) {} 528c4a3523SMatt Arsenault 538c4a3523SMatt Arsenault bool runOnFunction(Function &F) override; 548c4a3523SMatt Arsenault 558c4a3523SMatt Arsenault void getAnalysisUsage(AnalysisUsage &AU) const override { 568c4a3523SMatt Arsenault AU.addRequired<TargetPassConfig>(); 578c4a3523SMatt Arsenault AU.setPreservesAll(); 588c4a3523SMatt Arsenault } 598c4a3523SMatt Arsenault }; 608c4a3523SMatt Arsenault 618c4a3523SMatt Arsenault } // end anonymous namespace 628c4a3523SMatt Arsenault 638c4a3523SMatt Arsenault bool AMDGPULowerKernelArguments::runOnFunction(Function &F) { 648c4a3523SMatt Arsenault CallingConv::ID CC = F.getCallingConv(); 658c4a3523SMatt Arsenault if (CC != CallingConv::AMDGPU_KERNEL || F.arg_empty()) 668c4a3523SMatt Arsenault return false; 678c4a3523SMatt Arsenault 688c4a3523SMatt Arsenault auto &TPC = getAnalysis<TargetPassConfig>(); 698c4a3523SMatt Arsenault 708c4a3523SMatt Arsenault const TargetMachine &TM = TPC.getTM<TargetMachine>(); 71*5bfbae5cSTom Stellard const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F); 728c4a3523SMatt Arsenault LLVMContext &Ctx = F.getParent()->getContext(); 738c4a3523SMatt Arsenault const DataLayout &DL = F.getParent()->getDataLayout(); 748c4a3523SMatt Arsenault BasicBlock &EntryBlock = *F.begin(); 758c4a3523SMatt Arsenault IRBuilder<> Builder(&*EntryBlock.begin()); 768c4a3523SMatt Arsenault 77f5be3ad7SMatt Arsenault const unsigned KernArgBaseAlign = 16; // FIXME: Increase if necessary 788c4a3523SMatt Arsenault const uint64_t BaseOffset = ST.getExplicitKernelArgOffset(F); 798c4a3523SMatt Arsenault 808c4a3523SMatt Arsenault // FIXME: Alignment is broken broken with explicit arg offset.; 81f5be3ad7SMatt Arsenault const uint64_t TotalKernArgSize = ST.getKernArgSegmentSize(F); 82513e0c0eSMatt Arsenault if (TotalKernArgSize == 0) 83513e0c0eSMatt Arsenault return false; 848c4a3523SMatt Arsenault 858c4a3523SMatt Arsenault CallInst *KernArgSegment = 868c4a3523SMatt Arsenault Builder.CreateIntrinsic(Intrinsic::amdgcn_kernarg_segment_ptr, nullptr, 878c4a3523SMatt Arsenault F.getName() + ".kernarg.segment"); 888c4a3523SMatt Arsenault 898c4a3523SMatt Arsenault KernArgSegment->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull); 908c4a3523SMatt Arsenault KernArgSegment->addAttribute(AttributeList::ReturnIndex, 918c4a3523SMatt Arsenault Attribute::getWithDereferenceableBytes(Ctx, TotalKernArgSize)); 928c4a3523SMatt Arsenault 938c4a3523SMatt Arsenault unsigned AS = KernArgSegment->getType()->getPointerAddressSpace(); 94f5be3ad7SMatt Arsenault unsigned MaxAlign = 1; 95f5be3ad7SMatt Arsenault uint64_t ExplicitArgOffset = 0; 96f5be3ad7SMatt Arsenault 978c4a3523SMatt Arsenault for (Argument &Arg : F.args()) { 98f5be3ad7SMatt Arsenault Type *ArgTy = Arg.getType(); 99f5be3ad7SMatt Arsenault unsigned Align = DL.getABITypeAlignment(ArgTy); 100f5be3ad7SMatt Arsenault MaxAlign = std::max(Align, MaxAlign); 101f5be3ad7SMatt Arsenault unsigned Size = DL.getTypeSizeInBits(ArgTy); 102f5be3ad7SMatt Arsenault unsigned AllocSize = DL.getTypeAllocSize(ArgTy); 103f5be3ad7SMatt Arsenault 104f5be3ad7SMatt Arsenault 105f5be3ad7SMatt Arsenault // Clover seems to always pad i8/i16 to i32, but doesn't properly align 106f5be3ad7SMatt Arsenault // them? 107f5be3ad7SMatt Arsenault // Make sure the struct elements have correct size and alignment for ext 108f5be3ad7SMatt Arsenault // args. These seem to be padded up to 4-bytes but not correctly aligned. 109f5be3ad7SMatt Arsenault bool IsExtArg = AllocSize < 32 && (Arg.hasZExtAttr() || Arg.hasSExtAttr()) && 110f5be3ad7SMatt Arsenault !ST.isAmdHsaOS(); 111f5be3ad7SMatt Arsenault if (IsExtArg) 112f5be3ad7SMatt Arsenault AllocSize = 4; 113f5be3ad7SMatt Arsenault 114f5be3ad7SMatt Arsenault uint64_t EltOffset = alignTo(ExplicitArgOffset, Align) + BaseOffset; 115f5be3ad7SMatt Arsenault ExplicitArgOffset = alignTo(ExplicitArgOffset, Align) + AllocSize; 116f5be3ad7SMatt Arsenault 1178c4a3523SMatt Arsenault if (Arg.use_empty()) 1188c4a3523SMatt Arsenault continue; 1198c4a3523SMatt Arsenault 1208c4a3523SMatt Arsenault if (PointerType *PT = dyn_cast<PointerType>(ArgTy)) { 1218c4a3523SMatt Arsenault // FIXME: Hack. We rely on AssertZext to be able to fold DS addressing 1228c4a3523SMatt Arsenault // modes on SI to know the high bits are 0 so pointer adds don't wrap. We 1238c4a3523SMatt Arsenault // can't represent this with range metadata because it's only allowed for 1248c4a3523SMatt Arsenault // integer types. 1258c4a3523SMatt Arsenault if (PT->getAddressSpace() == AMDGPUAS::LOCAL_ADDRESS && 1268c4a3523SMatt Arsenault ST.getGeneration() == AMDGPUSubtarget::SOUTHERN_ISLANDS) 1278c4a3523SMatt Arsenault continue; 1288c4a3523SMatt Arsenault 1298c4a3523SMatt Arsenault // FIXME: We can replace this with equivalent alias.scope/noalias 1308c4a3523SMatt Arsenault // metadata, but this appears to be a lot of work. 1318c4a3523SMatt Arsenault if (Arg.hasNoAliasAttr()) 1328c4a3523SMatt Arsenault continue; 1338c4a3523SMatt Arsenault } 1348c4a3523SMatt Arsenault 1358c4a3523SMatt Arsenault VectorType *VT = dyn_cast<VectorType>(ArgTy); 1368c4a3523SMatt Arsenault bool IsV3 = VT && VT->getNumElements() == 3; 1378c4a3523SMatt Arsenault VectorType *V4Ty = nullptr; 1388c4a3523SMatt Arsenault 1398c4a3523SMatt Arsenault int64_t AlignDownOffset = alignDown(EltOffset, 4); 1408c4a3523SMatt Arsenault int64_t OffsetDiff = EltOffset - AlignDownOffset; 1418c4a3523SMatt Arsenault unsigned AdjustedAlign = MinAlign(KernArgBaseAlign, AlignDownOffset); 1428c4a3523SMatt Arsenault 1438c4a3523SMatt Arsenault Value *ArgPtr; 144513e0c0eSMatt Arsenault if (Size < 32 && !ArgTy->isAggregateType()) { // FIXME: Handle aggregate types 1458c4a3523SMatt Arsenault // Since we don't have sub-dword scalar loads, avoid doing an extload by 1468c4a3523SMatt Arsenault // loading earlier than the argument address, and extracting the relevant 1478c4a3523SMatt Arsenault // bits. 1488c4a3523SMatt Arsenault // 1498c4a3523SMatt Arsenault // Additionally widen any sub-dword load to i32 even if suitably aligned, 1508c4a3523SMatt Arsenault // so that CSE between different argument loads works easily. 1518c4a3523SMatt Arsenault 152f5be3ad7SMatt Arsenault ArgPtr = Builder.CreateConstInBoundsGEP1_64( 153f5be3ad7SMatt Arsenault KernArgSegment, 154f5be3ad7SMatt Arsenault AlignDownOffset, 1558c4a3523SMatt Arsenault Arg.getName() + ".kernarg.offset.align.down"); 156f5be3ad7SMatt Arsenault ArgPtr = Builder.CreateBitCast(ArgPtr, 157f5be3ad7SMatt Arsenault Builder.getInt32Ty()->getPointerTo(AS), 158f5be3ad7SMatt Arsenault ArgPtr->getName() + ".cast"); 1598c4a3523SMatt Arsenault } else { 160f5be3ad7SMatt Arsenault ArgPtr = Builder.CreateConstInBoundsGEP1_64( 161f5be3ad7SMatt Arsenault KernArgSegment, 162f5be3ad7SMatt Arsenault AlignDownOffset, 1638c4a3523SMatt Arsenault Arg.getName() + ".kernarg.offset"); 164f5be3ad7SMatt Arsenault ArgPtr = Builder.CreateBitCast(ArgPtr, ArgTy->getPointerTo(AS), 165f5be3ad7SMatt Arsenault ArgPtr->getName() + ".cast"); 1668c4a3523SMatt Arsenault } 1678c4a3523SMatt Arsenault 1688c4a3523SMatt Arsenault assert((!IsExtArg || !IsV3) && "incompatible situation"); 1698c4a3523SMatt Arsenault 1708c4a3523SMatt Arsenault if (IsV3 && Size >= 32) { 1718c4a3523SMatt Arsenault V4Ty = VectorType::get(VT->getVectorElementType(), 4); 1728c4a3523SMatt Arsenault // Use the hack that clang uses to avoid SelectionDAG ruining v3 loads 1738c4a3523SMatt Arsenault ArgPtr = Builder.CreateBitCast(ArgPtr, V4Ty->getPointerTo(AS)); 1748c4a3523SMatt Arsenault } 1758c4a3523SMatt Arsenault 1768c4a3523SMatt Arsenault LoadInst *Load = Builder.CreateAlignedLoad(ArgPtr, AdjustedAlign); 1778c4a3523SMatt Arsenault Load->setMetadata(LLVMContext::MD_invariant_load, MDNode::get(Ctx, {})); 1788c4a3523SMatt Arsenault 1798c4a3523SMatt Arsenault MDBuilder MDB(Ctx); 1808c4a3523SMatt Arsenault 1818c4a3523SMatt Arsenault if (isa<PointerType>(ArgTy)) { 1828c4a3523SMatt Arsenault if (Arg.hasNonNullAttr()) 1838c4a3523SMatt Arsenault Load->setMetadata(LLVMContext::MD_nonnull, MDNode::get(Ctx, {})); 1848c4a3523SMatt Arsenault 1858c4a3523SMatt Arsenault uint64_t DerefBytes = Arg.getDereferenceableBytes(); 1868c4a3523SMatt Arsenault if (DerefBytes != 0) { 1878c4a3523SMatt Arsenault Load->setMetadata( 1888c4a3523SMatt Arsenault LLVMContext::MD_dereferenceable, 1898c4a3523SMatt Arsenault MDNode::get(Ctx, 1908c4a3523SMatt Arsenault MDB.createConstant( 1918c4a3523SMatt Arsenault ConstantInt::get(Builder.getInt64Ty(), DerefBytes)))); 1928c4a3523SMatt Arsenault } 1938c4a3523SMatt Arsenault 1948c4a3523SMatt Arsenault uint64_t DerefOrNullBytes = Arg.getDereferenceableOrNullBytes(); 1958c4a3523SMatt Arsenault if (DerefOrNullBytes != 0) { 1968c4a3523SMatt Arsenault Load->setMetadata( 1978c4a3523SMatt Arsenault LLVMContext::MD_dereferenceable_or_null, 1988c4a3523SMatt Arsenault MDNode::get(Ctx, 1998c4a3523SMatt Arsenault MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(), 2008c4a3523SMatt Arsenault DerefOrNullBytes)))); 2018c4a3523SMatt Arsenault } 2028c4a3523SMatt Arsenault 2038c4a3523SMatt Arsenault unsigned ParamAlign = Arg.getParamAlignment(); 2048c4a3523SMatt Arsenault if (ParamAlign != 0) { 2058c4a3523SMatt Arsenault Load->setMetadata( 2068c4a3523SMatt Arsenault LLVMContext::MD_align, 2078c4a3523SMatt Arsenault MDNode::get(Ctx, 2088c4a3523SMatt Arsenault MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(), 2098c4a3523SMatt Arsenault ParamAlign)))); 2108c4a3523SMatt Arsenault } 2118c4a3523SMatt Arsenault } 2128c4a3523SMatt Arsenault 2138c4a3523SMatt Arsenault // TODO: Convert noalias arg to !noalias 2148c4a3523SMatt Arsenault 215513e0c0eSMatt Arsenault if (Size < 32 && !ArgTy->isAggregateType()) { 2168c4a3523SMatt Arsenault if (IsExtArg && OffsetDiff == 0) { 2178c4a3523SMatt Arsenault Type *I32Ty = Builder.getInt32Ty(); 2188c4a3523SMatt Arsenault bool IsSext = Arg.hasSExtAttr(); 2198c4a3523SMatt Arsenault Metadata *LowAndHigh[] = { 2208c4a3523SMatt Arsenault ConstantAsMetadata::get( 2218c4a3523SMatt Arsenault ConstantInt::get(I32Ty, IsSext ? minIntN(Size) : 0)), 2228c4a3523SMatt Arsenault ConstantAsMetadata::get( 2238c4a3523SMatt Arsenault ConstantInt::get(I32Ty, 2248c4a3523SMatt Arsenault IsSext ? maxIntN(Size) + 1 : maxUIntN(Size) + 1)) 2258c4a3523SMatt Arsenault }; 2268c4a3523SMatt Arsenault 2278c4a3523SMatt Arsenault Load->setMetadata(LLVMContext::MD_range, MDNode::get(Ctx, LowAndHigh)); 2288c4a3523SMatt Arsenault } 2298c4a3523SMatt Arsenault 2308c4a3523SMatt Arsenault Value *ExtractBits = OffsetDiff == 0 ? 2318c4a3523SMatt Arsenault Load : Builder.CreateLShr(Load, OffsetDiff * 8); 2328c4a3523SMatt Arsenault 2338c4a3523SMatt Arsenault IntegerType *ArgIntTy = Builder.getIntNTy(Size); 2348c4a3523SMatt Arsenault Value *Trunc = Builder.CreateTrunc(ExtractBits, ArgIntTy); 2358c4a3523SMatt Arsenault Value *NewVal = Builder.CreateBitCast(Trunc, ArgTy, 2368c4a3523SMatt Arsenault Arg.getName() + ".load"); 2378c4a3523SMatt Arsenault Arg.replaceAllUsesWith(NewVal); 2388c4a3523SMatt Arsenault } else if (IsV3) { 2398c4a3523SMatt Arsenault Value *Shuf = Builder.CreateShuffleVector(Load, UndefValue::get(V4Ty), 2408c4a3523SMatt Arsenault {0, 1, 2}, 2418c4a3523SMatt Arsenault Arg.getName() + ".load"); 2428c4a3523SMatt Arsenault Arg.replaceAllUsesWith(Shuf); 2438c4a3523SMatt Arsenault } else { 2448c4a3523SMatt Arsenault Load->setName(Arg.getName() + ".load"); 2458c4a3523SMatt Arsenault Arg.replaceAllUsesWith(Load); 2468c4a3523SMatt Arsenault } 2478c4a3523SMatt Arsenault } 2488c4a3523SMatt Arsenault 249f5be3ad7SMatt Arsenault KernArgSegment->addAttribute( 250f5be3ad7SMatt Arsenault AttributeList::ReturnIndex, 251f5be3ad7SMatt Arsenault Attribute::getWithAlignment(Ctx, std::max(KernArgBaseAlign, MaxAlign))); 252f5be3ad7SMatt Arsenault 2538c4a3523SMatt Arsenault return true; 2548c4a3523SMatt Arsenault } 2558c4a3523SMatt Arsenault 2568c4a3523SMatt Arsenault INITIALIZE_PASS_BEGIN(AMDGPULowerKernelArguments, DEBUG_TYPE, 2578c4a3523SMatt Arsenault "AMDGPU Lower Kernel Arguments", false, false) 2588c4a3523SMatt Arsenault INITIALIZE_PASS_END(AMDGPULowerKernelArguments, DEBUG_TYPE, "AMDGPU Lower Kernel Arguments", 2598c4a3523SMatt Arsenault false, false) 2608c4a3523SMatt Arsenault 2618c4a3523SMatt Arsenault char AMDGPULowerKernelArguments::ID = 0; 2628c4a3523SMatt Arsenault 2638c4a3523SMatt Arsenault FunctionPass *llvm::createAMDGPULowerKernelArgumentsPass() { 2648c4a3523SMatt Arsenault return new AMDGPULowerKernelArguments(); 2658c4a3523SMatt Arsenault } 266