180fd5fa5Shsmahesha //===-- AMDGPUReplaceLDSUseWithPointer.cpp --------------------------------===// 280fd5fa5Shsmahesha // 380fd5fa5Shsmahesha // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 480fd5fa5Shsmahesha // See https://llvm.org/LICENSE.txt for license information. 580fd5fa5Shsmahesha // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 680fd5fa5Shsmahesha // 780fd5fa5Shsmahesha //===----------------------------------------------------------------------===// 880fd5fa5Shsmahesha // 980fd5fa5Shsmahesha // This pass replaces all the uses of LDS within non-kernel functions by 1080fd5fa5Shsmahesha // corresponding pointer counter-parts. 1180fd5fa5Shsmahesha // 1280fd5fa5Shsmahesha // The main motivation behind this pass is - to *avoid* subsequent LDS lowering 1380fd5fa5Shsmahesha // pass from directly packing LDS (assume large LDS) into a struct type which 1480fd5fa5Shsmahesha // would otherwise cause allocating huge memory for struct instance within every 1580fd5fa5Shsmahesha // kernel. 1680fd5fa5Shsmahesha // 1780fd5fa5Shsmahesha // Brief sketch of the algorithm implemented in this pass is as below: 1880fd5fa5Shsmahesha // 1980fd5fa5Shsmahesha // 1. Collect all the LDS defined in the module which qualify for pointer 2080fd5fa5Shsmahesha // replacement, say it is, LDSGlobals set. 2180fd5fa5Shsmahesha // 2280fd5fa5Shsmahesha // 2. Collect all the reachable callees for each kernel defined in the module, 2380fd5fa5Shsmahesha // say it is, KernelToCallees map. 2480fd5fa5Shsmahesha // 2580fd5fa5Shsmahesha // 3. FOR (each global GV from LDSGlobals set) DO 2680fd5fa5Shsmahesha // LDSUsedNonKernels = Collect all non-kernel functions which use GV. 2780fd5fa5Shsmahesha // FOR (each kernel K in KernelToCallees map) DO 2880fd5fa5Shsmahesha // ReachableCallees = KernelToCallees[K] 2980fd5fa5Shsmahesha // ReachableAndLDSUsedCallees = 3080fd5fa5Shsmahesha // SetIntersect(LDSUsedNonKernels, ReachableCallees) 3180fd5fa5Shsmahesha // IF (ReachableAndLDSUsedCallees is not empty) THEN 3280fd5fa5Shsmahesha // Pointer = Create a pointer to point-to GV if not created. 3380fd5fa5Shsmahesha // Initialize Pointer to point-to GV within kernel K. 3480fd5fa5Shsmahesha // ENDIF 3580fd5fa5Shsmahesha // ENDFOR 3680fd5fa5Shsmahesha // Replace all uses of GV within non kernel functions by Pointer. 3780fd5fa5Shsmahesha // ENFOR 3880fd5fa5Shsmahesha // 3980fd5fa5Shsmahesha // LLVM IR example: 4080fd5fa5Shsmahesha // 4180fd5fa5Shsmahesha // Input IR: 4280fd5fa5Shsmahesha // 4380fd5fa5Shsmahesha // @lds = internal addrspace(3) global [4 x i32] undef, align 16 4480fd5fa5Shsmahesha // 4580fd5fa5Shsmahesha // define internal void @f0() { 4680fd5fa5Shsmahesha // entry: 4780fd5fa5Shsmahesha // %gep = getelementptr inbounds [4 x i32], [4 x i32] addrspace(3)* @lds, 4880fd5fa5Shsmahesha // i32 0, i32 0 4980fd5fa5Shsmahesha // ret void 5080fd5fa5Shsmahesha // } 5180fd5fa5Shsmahesha // 5280fd5fa5Shsmahesha // define protected amdgpu_kernel void @k0() { 5380fd5fa5Shsmahesha // entry: 5480fd5fa5Shsmahesha // call void @f0() 5580fd5fa5Shsmahesha // ret void 5680fd5fa5Shsmahesha // } 5780fd5fa5Shsmahesha // 5880fd5fa5Shsmahesha // Output IR: 5980fd5fa5Shsmahesha // 6080fd5fa5Shsmahesha // @lds = internal addrspace(3) global [4 x i32] undef, align 16 6180fd5fa5Shsmahesha // @lds.ptr = internal unnamed_addr addrspace(3) global i16 undef, align 2 6280fd5fa5Shsmahesha // 6380fd5fa5Shsmahesha // define internal void @f0() { 6480fd5fa5Shsmahesha // entry: 6580fd5fa5Shsmahesha // %0 = load i16, i16 addrspace(3)* @lds.ptr, align 2 6680fd5fa5Shsmahesha // %1 = getelementptr i8, i8 addrspace(3)* null, i16 %0 6780fd5fa5Shsmahesha // %2 = bitcast i8 addrspace(3)* %1 to [4 x i32] addrspace(3)* 6880fd5fa5Shsmahesha // %gep = getelementptr inbounds [4 x i32], [4 x i32] addrspace(3)* %2, 6980fd5fa5Shsmahesha // i32 0, i32 0 7080fd5fa5Shsmahesha // ret void 7180fd5fa5Shsmahesha // } 7280fd5fa5Shsmahesha // 7380fd5fa5Shsmahesha // define protected amdgpu_kernel void @k0() { 7480fd5fa5Shsmahesha // entry: 7580fd5fa5Shsmahesha // store i16 ptrtoint ([4 x i32] addrspace(3)* @lds to i16), 7680fd5fa5Shsmahesha // i16 addrspace(3)* @lds.ptr, align 2 7780fd5fa5Shsmahesha // call void @f0() 7880fd5fa5Shsmahesha // ret void 7980fd5fa5Shsmahesha // } 8080fd5fa5Shsmahesha // 8180fd5fa5Shsmahesha //===----------------------------------------------------------------------===// 8280fd5fa5Shsmahesha 8380fd5fa5Shsmahesha #include "AMDGPU.h" 8480fd5fa5Shsmahesha #include "GCNSubtarget.h" 8580fd5fa5Shsmahesha #include "Utils/AMDGPUBaseInfo.h" 8680fd5fa5Shsmahesha #include "Utils/AMDGPULDSUtils.h" 8780fd5fa5Shsmahesha #include "llvm/ADT/DenseMap.h" 8880fd5fa5Shsmahesha #include "llvm/ADT/STLExtras.h" 8980fd5fa5Shsmahesha #include "llvm/ADT/SetOperations.h" 9080fd5fa5Shsmahesha #include "llvm/CodeGen/TargetPassConfig.h" 9180fd5fa5Shsmahesha #include "llvm/IR/Constants.h" 9280fd5fa5Shsmahesha #include "llvm/IR/DerivedTypes.h" 9380fd5fa5Shsmahesha #include "llvm/IR/IRBuilder.h" 9480fd5fa5Shsmahesha #include "llvm/IR/InlineAsm.h" 9580fd5fa5Shsmahesha #include "llvm/IR/Instructions.h" 9680fd5fa5Shsmahesha #include "llvm/IR/IntrinsicsAMDGPU.h" 9780fd5fa5Shsmahesha #include "llvm/IR/ReplaceConstant.h" 9880fd5fa5Shsmahesha #include "llvm/InitializePasses.h" 9980fd5fa5Shsmahesha #include "llvm/Pass.h" 10080fd5fa5Shsmahesha #include "llvm/Support/Debug.h" 10180fd5fa5Shsmahesha #include "llvm/Target/TargetMachine.h" 10280fd5fa5Shsmahesha #include "llvm/Transforms/Utils/BasicBlockUtils.h" 10380fd5fa5Shsmahesha #include "llvm/Transforms/Utils/ModuleUtils.h" 10480fd5fa5Shsmahesha #include <algorithm> 10580fd5fa5Shsmahesha #include <vector> 10680fd5fa5Shsmahesha 10780fd5fa5Shsmahesha #define DEBUG_TYPE "amdgpu-replace-lds-use-with-pointer" 10880fd5fa5Shsmahesha 10980fd5fa5Shsmahesha using namespace llvm; 11080fd5fa5Shsmahesha 11180fd5fa5Shsmahesha namespace { 11280fd5fa5Shsmahesha 11380fd5fa5Shsmahesha class ReplaceLDSUseImpl { 11480fd5fa5Shsmahesha Module &M; 11580fd5fa5Shsmahesha LLVMContext &Ctx; 11680fd5fa5Shsmahesha const DataLayout &DL; 11780fd5fa5Shsmahesha Constant *LDSMemBaseAddr; 11880fd5fa5Shsmahesha 11980fd5fa5Shsmahesha DenseMap<GlobalVariable *, GlobalVariable *> LDSToPointer; 12080fd5fa5Shsmahesha DenseMap<GlobalVariable *, SmallPtrSet<Function *, 8>> LDSToNonKernels; 12180fd5fa5Shsmahesha DenseMap<Function *, SmallPtrSet<Function *, 8>> KernelToCallees; 12280fd5fa5Shsmahesha DenseMap<Function *, SmallPtrSet<GlobalVariable *, 8>> KernelToLDSPointers; 12380fd5fa5Shsmahesha DenseMap<Function *, BasicBlock *> KernelToInitBB; 12480fd5fa5Shsmahesha DenseMap<Function *, DenseMap<GlobalVariable *, Value *>> 12580fd5fa5Shsmahesha FunctionToLDSToReplaceInst; 12680fd5fa5Shsmahesha 12780fd5fa5Shsmahesha // Collect LDS which requires their uses to be replaced by pointer. 12880fd5fa5Shsmahesha std::vector<GlobalVariable *> collectLDSRequiringPointerReplace() { 12980fd5fa5Shsmahesha // Collect LDS which requires module lowering. 13080fd5fa5Shsmahesha std::vector<GlobalVariable *> LDSGlobals = AMDGPU::findVariablesToLower(M); 13180fd5fa5Shsmahesha 13280fd5fa5Shsmahesha // Remove LDS which don't qualify for replacement. 13380fd5fa5Shsmahesha LDSGlobals.erase(std::remove_if(LDSGlobals.begin(), LDSGlobals.end(), 13480fd5fa5Shsmahesha [&](GlobalVariable *GV) { 13580fd5fa5Shsmahesha return shouldIgnorePointerReplacement(GV); 13680fd5fa5Shsmahesha }), 13780fd5fa5Shsmahesha LDSGlobals.end()); 13880fd5fa5Shsmahesha 13980fd5fa5Shsmahesha return LDSGlobals; 14080fd5fa5Shsmahesha } 14180fd5fa5Shsmahesha 14280fd5fa5Shsmahesha // Returns true if uses of given LDS global within non-kernel functions should 14380fd5fa5Shsmahesha // be keep as it is without pointer replacement. 14480fd5fa5Shsmahesha bool shouldIgnorePointerReplacement(GlobalVariable *GV) { 145*dc6e8dfdSJacob Lambert // LDS whose size is very small and doesn't exceed pointer size is not worth 14680fd5fa5Shsmahesha // replacing. 14780fd5fa5Shsmahesha if (DL.getTypeAllocSize(GV->getValueType()) <= 2) 14880fd5fa5Shsmahesha return true; 14980fd5fa5Shsmahesha 15080fd5fa5Shsmahesha // LDS which is not used from non-kernel function scope or it is used from 15180fd5fa5Shsmahesha // global scope does not qualify for replacement. 15280fd5fa5Shsmahesha LDSToNonKernels[GV] = AMDGPU::collectNonKernelAccessorsOfLDS(GV); 15380fd5fa5Shsmahesha return LDSToNonKernels[GV].empty(); 15480fd5fa5Shsmahesha 15580fd5fa5Shsmahesha // FIXME: When GV is used within all (or within most of the kernels), then 15680fd5fa5Shsmahesha // it does not make sense to create a pointer for it. 15780fd5fa5Shsmahesha } 15880fd5fa5Shsmahesha 15980fd5fa5Shsmahesha // Insert new global LDS pointer which points to LDS. 16080fd5fa5Shsmahesha GlobalVariable *createLDSPointer(GlobalVariable *GV) { 161*dc6e8dfdSJacob Lambert // LDS pointer which points to LDS is already created? Return it. 16280fd5fa5Shsmahesha auto PointerEntry = LDSToPointer.insert(std::make_pair(GV, nullptr)); 16380fd5fa5Shsmahesha if (!PointerEntry.second) 16480fd5fa5Shsmahesha return PointerEntry.first->second; 16580fd5fa5Shsmahesha 16680fd5fa5Shsmahesha // We need to create new LDS pointer which points to LDS. 16780fd5fa5Shsmahesha // 16880fd5fa5Shsmahesha // Each CU owns at max 64K of LDS memory, so LDS address ranges from 0 to 16980fd5fa5Shsmahesha // 2^16 - 1. Hence 16 bit pointer is enough to hold the LDS address. 17080fd5fa5Shsmahesha auto *I16Ty = Type::getInt16Ty(Ctx); 17180fd5fa5Shsmahesha GlobalVariable *LDSPointer = new GlobalVariable( 17280fd5fa5Shsmahesha M, I16Ty, false, GlobalValue::InternalLinkage, UndefValue::get(I16Ty), 17380fd5fa5Shsmahesha GV->getName() + Twine(".ptr"), nullptr, GlobalVariable::NotThreadLocal, 17480fd5fa5Shsmahesha AMDGPUAS::LOCAL_ADDRESS); 17580fd5fa5Shsmahesha 17680fd5fa5Shsmahesha LDSPointer->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); 17780fd5fa5Shsmahesha LDSPointer->setAlignment(AMDGPU::getAlign(DL, LDSPointer)); 17880fd5fa5Shsmahesha 17980fd5fa5Shsmahesha // Mark that an associated LDS pointer is created for LDS. 18080fd5fa5Shsmahesha LDSToPointer[GV] = LDSPointer; 18180fd5fa5Shsmahesha 18280fd5fa5Shsmahesha return LDSPointer; 18380fd5fa5Shsmahesha } 18480fd5fa5Shsmahesha 18580fd5fa5Shsmahesha // Split entry basic block in such a way that only lane 0 of each wave does 18680fd5fa5Shsmahesha // the LDS pointer initialization, and return newly created basic block. 18780fd5fa5Shsmahesha BasicBlock *activateLaneZero(Function *K) { 188*dc6e8dfdSJacob Lambert // If the entry basic block of kernel K is already split, then return 18980fd5fa5Shsmahesha // newly created basic block. 19080fd5fa5Shsmahesha auto BasicBlockEntry = KernelToInitBB.insert(std::make_pair(K, nullptr)); 19180fd5fa5Shsmahesha if (!BasicBlockEntry.second) 19280fd5fa5Shsmahesha return BasicBlockEntry.first->second; 19380fd5fa5Shsmahesha 1940c288140Shsmahesha // Split entry basic block of kernel K. 1950c288140Shsmahesha auto *EI = &(*(K->getEntryBlock().getFirstInsertionPt())); 1960c288140Shsmahesha IRBuilder<> Builder(EI); 19780fd5fa5Shsmahesha 19880fd5fa5Shsmahesha Value *Mbcnt = 19980fd5fa5Shsmahesha Builder.CreateIntrinsic(Intrinsic::amdgcn_mbcnt_lo, {}, 20080fd5fa5Shsmahesha {Builder.getInt32(-1), Builder.getInt32(0)}); 20180fd5fa5Shsmahesha Value *Cond = Builder.CreateICmpEQ(Mbcnt, Builder.getInt32(0)); 20280fd5fa5Shsmahesha Instruction *WB = cast<Instruction>( 20380fd5fa5Shsmahesha Builder.CreateIntrinsic(Intrinsic::amdgcn_wave_barrier, {}, {})); 20480fd5fa5Shsmahesha 20580fd5fa5Shsmahesha BasicBlock *NBB = SplitBlockAndInsertIfThen(Cond, WB, false)->getParent(); 20680fd5fa5Shsmahesha 207*dc6e8dfdSJacob Lambert // Mark that the entry basic block of kernel K is split. 20880fd5fa5Shsmahesha KernelToInitBB[K] = NBB; 20980fd5fa5Shsmahesha 21080fd5fa5Shsmahesha return NBB; 21180fd5fa5Shsmahesha } 21280fd5fa5Shsmahesha 21380fd5fa5Shsmahesha // Within given kernel, initialize given LDS pointer to point to given LDS. 21480fd5fa5Shsmahesha void initializeLDSPointer(Function *K, GlobalVariable *GV, 21580fd5fa5Shsmahesha GlobalVariable *LDSPointer) { 21680fd5fa5Shsmahesha // If LDS pointer is already initialized within K, then nothing to do. 21780fd5fa5Shsmahesha auto PointerEntry = KernelToLDSPointers.insert( 21880fd5fa5Shsmahesha std::make_pair(K, SmallPtrSet<GlobalVariable *, 8>())); 21980fd5fa5Shsmahesha if (!PointerEntry.second) 22080fd5fa5Shsmahesha if (PointerEntry.first->second.contains(LDSPointer)) 22180fd5fa5Shsmahesha return; 22280fd5fa5Shsmahesha 22380fd5fa5Shsmahesha // Insert instructions at EI which initialize LDS pointer to point-to LDS 22480fd5fa5Shsmahesha // within kernel K. 22580fd5fa5Shsmahesha // 22680fd5fa5Shsmahesha // That is, convert pointer type of GV to i16, and then store this converted 22780fd5fa5Shsmahesha // i16 value within LDSPointer which is of type i16*. 22880fd5fa5Shsmahesha auto *EI = &(*(activateLaneZero(K)->getFirstInsertionPt())); 22980fd5fa5Shsmahesha IRBuilder<> Builder(EI); 23080fd5fa5Shsmahesha Builder.CreateStore(Builder.CreatePtrToInt(GV, Type::getInt16Ty(Ctx)), 23180fd5fa5Shsmahesha LDSPointer); 23280fd5fa5Shsmahesha 23380fd5fa5Shsmahesha // Mark that LDS pointer is initialized within kernel K. 23480fd5fa5Shsmahesha KernelToLDSPointers[K].insert(LDSPointer); 23580fd5fa5Shsmahesha } 23680fd5fa5Shsmahesha 23780fd5fa5Shsmahesha // We have created an LDS pointer for LDS, and initialized it to point-to LDS 238*dc6e8dfdSJacob Lambert // within all relevant kernels. Now replace all the uses of LDS within 23980fd5fa5Shsmahesha // non-kernel functions by LDS pointer. 24080fd5fa5Shsmahesha void replaceLDSUseByPointer(GlobalVariable *GV, GlobalVariable *LDSPointer) { 24180fd5fa5Shsmahesha SmallVector<User *, 8> LDSUsers(GV->users()); 24280fd5fa5Shsmahesha for (auto *U : LDSUsers) { 24380fd5fa5Shsmahesha // When `U` is a constant expression, it is possible that same constant 24480fd5fa5Shsmahesha // expression exists within multiple instructions, and within multiple 24580fd5fa5Shsmahesha // non-kernel functions. Collect all those non-kernel functions and all 24680fd5fa5Shsmahesha // those instructions within which `U` exist. 24780fd5fa5Shsmahesha auto FunctionToInsts = 24880fd5fa5Shsmahesha AMDGPU::getFunctionToInstsMap(U, false /*=CollectKernelInsts*/); 24980fd5fa5Shsmahesha 25080fd5fa5Shsmahesha for (auto FI = FunctionToInsts.begin(), FE = FunctionToInsts.end(); 25180fd5fa5Shsmahesha FI != FE; ++FI) { 25280fd5fa5Shsmahesha Function *F = FI->first; 25380fd5fa5Shsmahesha auto &Insts = FI->second; 25480fd5fa5Shsmahesha for (auto *I : Insts) { 25580fd5fa5Shsmahesha // If `U` is a constant expression, then we need to break the 25680fd5fa5Shsmahesha // associated instruction into a set of separate instructions by 25780fd5fa5Shsmahesha // converting constant expressions into instructions. 25880fd5fa5Shsmahesha SmallPtrSet<Instruction *, 8> UserInsts; 25980fd5fa5Shsmahesha 26080fd5fa5Shsmahesha if (U == I) { 26180fd5fa5Shsmahesha // `U` is an instruction, conversion from constant expression to 26280fd5fa5Shsmahesha // set of instructions is *not* required. 26380fd5fa5Shsmahesha UserInsts.insert(I); 26480fd5fa5Shsmahesha } else { 26580fd5fa5Shsmahesha // `U` is a constant expression, convert it into corresponding set 26680fd5fa5Shsmahesha // of instructions. 26780fd5fa5Shsmahesha auto *CE = cast<ConstantExpr>(U); 26880fd5fa5Shsmahesha convertConstantExprsToInstructions(I, CE, &UserInsts); 26980fd5fa5Shsmahesha } 27080fd5fa5Shsmahesha 271*dc6e8dfdSJacob Lambert // Go through all the user instructions, if LDS exist within them as 272*dc6e8dfdSJacob Lambert // an operand, then replace it by replace instruction. 27380fd5fa5Shsmahesha for (auto *II : UserInsts) { 27480fd5fa5Shsmahesha auto *ReplaceInst = getReplacementInst(F, GV, LDSPointer); 27580fd5fa5Shsmahesha II->replaceUsesOfWith(GV, ReplaceInst); 27680fd5fa5Shsmahesha } 27780fd5fa5Shsmahesha } 27880fd5fa5Shsmahesha } 27980fd5fa5Shsmahesha } 28080fd5fa5Shsmahesha } 28180fd5fa5Shsmahesha 28280fd5fa5Shsmahesha // Create a set of replacement instructions which together replace LDS within 28380fd5fa5Shsmahesha // non-kernel function F by accessing LDS indirectly using LDS pointer. 28480fd5fa5Shsmahesha Value *getReplacementInst(Function *F, GlobalVariable *GV, 28580fd5fa5Shsmahesha GlobalVariable *LDSPointer) { 28680fd5fa5Shsmahesha // If the instruction which replaces LDS within F is already created, then 28780fd5fa5Shsmahesha // return it. 28880fd5fa5Shsmahesha auto LDSEntry = FunctionToLDSToReplaceInst.insert( 28980fd5fa5Shsmahesha std::make_pair(F, DenseMap<GlobalVariable *, Value *>())); 29080fd5fa5Shsmahesha if (!LDSEntry.second) { 29180fd5fa5Shsmahesha auto ReplaceInstEntry = 29280fd5fa5Shsmahesha LDSEntry.first->second.insert(std::make_pair(GV, nullptr)); 29380fd5fa5Shsmahesha if (!ReplaceInstEntry.second) 29480fd5fa5Shsmahesha return ReplaceInstEntry.first->second; 29580fd5fa5Shsmahesha } 29680fd5fa5Shsmahesha 29780fd5fa5Shsmahesha // Get the instruction insertion point within the beginning of the entry 29880fd5fa5Shsmahesha // block of current non-kernel function. 29980fd5fa5Shsmahesha auto *EI = &(*(F->getEntryBlock().getFirstInsertionPt())); 30080fd5fa5Shsmahesha IRBuilder<> Builder(EI); 30180fd5fa5Shsmahesha 30280fd5fa5Shsmahesha // Insert required set of instructions which replace LDS within F. 30380fd5fa5Shsmahesha auto *V = Builder.CreateBitCast( 30480fd5fa5Shsmahesha Builder.CreateGEP( 3052c68ecccSNikita Popov Builder.getInt8Ty(), LDSMemBaseAddr, 30680fd5fa5Shsmahesha Builder.CreateLoad(LDSPointer->getValueType(), LDSPointer)), 30780fd5fa5Shsmahesha GV->getType()); 30880fd5fa5Shsmahesha 30980fd5fa5Shsmahesha // Mark that the replacement instruction which replace LDS within F is 31080fd5fa5Shsmahesha // created. 31180fd5fa5Shsmahesha FunctionToLDSToReplaceInst[F][GV] = V; 31280fd5fa5Shsmahesha 31380fd5fa5Shsmahesha return V; 31480fd5fa5Shsmahesha } 31580fd5fa5Shsmahesha 31680fd5fa5Shsmahesha public: 31780fd5fa5Shsmahesha ReplaceLDSUseImpl(Module &M) 31880fd5fa5Shsmahesha : M(M), Ctx(M.getContext()), DL(M.getDataLayout()) { 31980fd5fa5Shsmahesha LDSMemBaseAddr = Constant::getIntegerValue( 32080fd5fa5Shsmahesha PointerType::get(Type::getInt8Ty(M.getContext()), 32180fd5fa5Shsmahesha AMDGPUAS::LOCAL_ADDRESS), 32280fd5fa5Shsmahesha APInt(32, 0)); 32380fd5fa5Shsmahesha } 32480fd5fa5Shsmahesha 32580fd5fa5Shsmahesha // Entry-point function which interface ReplaceLDSUseImpl with outside of the 32680fd5fa5Shsmahesha // class. 32780fd5fa5Shsmahesha bool replaceLDSUse(); 32880fd5fa5Shsmahesha 32980fd5fa5Shsmahesha private: 33080fd5fa5Shsmahesha // For a given LDS from collected LDS globals set, replace its non-kernel 33180fd5fa5Shsmahesha // function scope uses by pointer. 33280fd5fa5Shsmahesha bool replaceLDSUse(GlobalVariable *GV); 33380fd5fa5Shsmahesha }; 33480fd5fa5Shsmahesha 33580fd5fa5Shsmahesha // For given LDS from collected LDS globals set, replace its non-kernel function 33680fd5fa5Shsmahesha // scope uses by pointer. 33780fd5fa5Shsmahesha bool ReplaceLDSUseImpl::replaceLDSUse(GlobalVariable *GV) { 33880fd5fa5Shsmahesha // Holds all those non-kernel functions within which LDS is being accessed. 33980fd5fa5Shsmahesha SmallPtrSet<Function *, 8> &LDSAccessors = LDSToNonKernels[GV]; 34080fd5fa5Shsmahesha 34180fd5fa5Shsmahesha // The LDS pointer which points to LDS and replaces all the uses of LDS. 34280fd5fa5Shsmahesha GlobalVariable *LDSPointer = nullptr; 34380fd5fa5Shsmahesha 34480fd5fa5Shsmahesha // Traverse through each kernel K, check and if required, initialize the 34580fd5fa5Shsmahesha // LDS pointer to point to LDS within K. 34680fd5fa5Shsmahesha for (auto KI = KernelToCallees.begin(), KE = KernelToCallees.end(); KI != KE; 34780fd5fa5Shsmahesha ++KI) { 34880fd5fa5Shsmahesha Function *K = KI->first; 34980fd5fa5Shsmahesha SmallPtrSet<Function *, 8> Callees = KI->second; 35080fd5fa5Shsmahesha 35180fd5fa5Shsmahesha // Compute reachable and LDS used callees for kernel K. 35280fd5fa5Shsmahesha set_intersect(Callees, LDSAccessors); 35380fd5fa5Shsmahesha 35480fd5fa5Shsmahesha // None of the LDS accessing non-kernel functions are reachable from 35580fd5fa5Shsmahesha // kernel K. Hence, no need to initialize LDS pointer within kernel K. 35680fd5fa5Shsmahesha if (Callees.empty()) 35780fd5fa5Shsmahesha continue; 35880fd5fa5Shsmahesha 35980fd5fa5Shsmahesha // We have found reachable and LDS used callees for kernel K, and we need to 36080fd5fa5Shsmahesha // initialize LDS pointer within kernel K, and we need to replace LDS use 36180fd5fa5Shsmahesha // within those callees by LDS pointer. 36280fd5fa5Shsmahesha // 36380fd5fa5Shsmahesha // But, first check if LDS pointer is already created, if not create one. 36480fd5fa5Shsmahesha LDSPointer = createLDSPointer(GV); 36580fd5fa5Shsmahesha 36680fd5fa5Shsmahesha // Initialize LDS pointer to point to LDS within kernel K. 36780fd5fa5Shsmahesha initializeLDSPointer(K, GV, LDSPointer); 36880fd5fa5Shsmahesha } 36980fd5fa5Shsmahesha 37080fd5fa5Shsmahesha // We have not found reachable and LDS used callees for any of the kernels, 37180fd5fa5Shsmahesha // and hence we have not created LDS pointer. 37280fd5fa5Shsmahesha if (!LDSPointer) 37380fd5fa5Shsmahesha return false; 37480fd5fa5Shsmahesha 37580fd5fa5Shsmahesha // We have created an LDS pointer for LDS, and initialized it to point-to LDS 376*dc6e8dfdSJacob Lambert // within all relevant kernels. Now replace all the uses of LDS within 37780fd5fa5Shsmahesha // non-kernel functions by LDS pointer. 37880fd5fa5Shsmahesha replaceLDSUseByPointer(GV, LDSPointer); 37980fd5fa5Shsmahesha 38080fd5fa5Shsmahesha return true; 38180fd5fa5Shsmahesha } 38280fd5fa5Shsmahesha 38380fd5fa5Shsmahesha // Entry-point function which interface ReplaceLDSUseImpl with outside of the 38480fd5fa5Shsmahesha // class. 38580fd5fa5Shsmahesha bool ReplaceLDSUseImpl::replaceLDSUse() { 38680fd5fa5Shsmahesha // Collect LDS which requires their uses to be replaced by pointer. 38780fd5fa5Shsmahesha std::vector<GlobalVariable *> LDSGlobals = 38880fd5fa5Shsmahesha collectLDSRequiringPointerReplace(); 38980fd5fa5Shsmahesha 39080fd5fa5Shsmahesha // No LDS to pointer-replace. Nothing to do. 39180fd5fa5Shsmahesha if (LDSGlobals.empty()) 39280fd5fa5Shsmahesha return false; 39380fd5fa5Shsmahesha 39480fd5fa5Shsmahesha // Collect reachable callee set for each kernel defined in the module. 39580fd5fa5Shsmahesha AMDGPU::collectReachableCallees(M, KernelToCallees); 39680fd5fa5Shsmahesha 39780fd5fa5Shsmahesha if (KernelToCallees.empty()) { 39880fd5fa5Shsmahesha // Either module does not have any kernel definitions, or none of the kernel 39980fd5fa5Shsmahesha // has a call to non-kernel functions, or we could not resolve any of the 40080fd5fa5Shsmahesha // call sites to proper non-kernel functions, because of the situations like 40180fd5fa5Shsmahesha // inline asm calls. Nothing to replace. 40280fd5fa5Shsmahesha return false; 40380fd5fa5Shsmahesha } 40480fd5fa5Shsmahesha 40580fd5fa5Shsmahesha // For every LDS from collected LDS globals set, replace its non-kernel 40680fd5fa5Shsmahesha // function scope use by pointer. 40780fd5fa5Shsmahesha bool Changed = false; 40880fd5fa5Shsmahesha for (auto *GV : LDSGlobals) 40980fd5fa5Shsmahesha Changed |= replaceLDSUse(GV); 41080fd5fa5Shsmahesha 41180fd5fa5Shsmahesha return Changed; 41280fd5fa5Shsmahesha } 41380fd5fa5Shsmahesha 41480fd5fa5Shsmahesha class AMDGPUReplaceLDSUseWithPointer : public ModulePass { 41580fd5fa5Shsmahesha public: 41680fd5fa5Shsmahesha static char ID; 41780fd5fa5Shsmahesha 41880fd5fa5Shsmahesha AMDGPUReplaceLDSUseWithPointer() : ModulePass(ID) { 41980fd5fa5Shsmahesha initializeAMDGPUReplaceLDSUseWithPointerPass( 42080fd5fa5Shsmahesha *PassRegistry::getPassRegistry()); 42180fd5fa5Shsmahesha } 42280fd5fa5Shsmahesha 42380fd5fa5Shsmahesha bool runOnModule(Module &M) override; 42480fd5fa5Shsmahesha 42580fd5fa5Shsmahesha void getAnalysisUsage(AnalysisUsage &AU) const override { 42680fd5fa5Shsmahesha AU.addRequired<TargetPassConfig>(); 42780fd5fa5Shsmahesha } 42880fd5fa5Shsmahesha }; 42980fd5fa5Shsmahesha 43080fd5fa5Shsmahesha } // namespace 43180fd5fa5Shsmahesha 43280fd5fa5Shsmahesha char AMDGPUReplaceLDSUseWithPointer::ID = 0; 43380fd5fa5Shsmahesha char &llvm::AMDGPUReplaceLDSUseWithPointerID = 43480fd5fa5Shsmahesha AMDGPUReplaceLDSUseWithPointer::ID; 43580fd5fa5Shsmahesha 43680fd5fa5Shsmahesha INITIALIZE_PASS_BEGIN( 43780fd5fa5Shsmahesha AMDGPUReplaceLDSUseWithPointer, DEBUG_TYPE, 43880fd5fa5Shsmahesha "Replace within non-kernel function use of LDS with pointer", 43980fd5fa5Shsmahesha false /*only look at the cfg*/, false /*analysis pass*/) 44080fd5fa5Shsmahesha INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 44180fd5fa5Shsmahesha INITIALIZE_PASS_END( 44280fd5fa5Shsmahesha AMDGPUReplaceLDSUseWithPointer, DEBUG_TYPE, 44380fd5fa5Shsmahesha "Replace within non-kernel function use of LDS with pointer", 44480fd5fa5Shsmahesha false /*only look at the cfg*/, false /*analysis pass*/) 44580fd5fa5Shsmahesha 44680fd5fa5Shsmahesha bool AMDGPUReplaceLDSUseWithPointer::runOnModule(Module &M) { 44780fd5fa5Shsmahesha ReplaceLDSUseImpl LDSUseReplacer{M}; 44880fd5fa5Shsmahesha return LDSUseReplacer.replaceLDSUse(); 44980fd5fa5Shsmahesha } 45080fd5fa5Shsmahesha 45180fd5fa5Shsmahesha ModulePass *llvm::createAMDGPUReplaceLDSUseWithPointerPass() { 45280fd5fa5Shsmahesha return new AMDGPUReplaceLDSUseWithPointer(); 45380fd5fa5Shsmahesha } 45480fd5fa5Shsmahesha 45580fd5fa5Shsmahesha PreservedAnalyses 45680fd5fa5Shsmahesha AMDGPUReplaceLDSUseWithPointerPass::run(Module &M, ModuleAnalysisManager &AM) { 45780fd5fa5Shsmahesha ReplaceLDSUseImpl LDSUseReplacer{M}; 45880fd5fa5Shsmahesha LDSUseReplacer.replaceLDSUse(); 45980fd5fa5Shsmahesha return PreservedAnalyses::all(); 46080fd5fa5Shsmahesha } 461