1 //===-- AMDGPUReplaceLDSUseWithPointer.cpp --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass replaces all the uses of LDS within non-kernel functions by 10 // corresponding pointer counter-parts. 11 // 12 // The main motivation behind this pass is - to *avoid* subsequent LDS lowering 13 // pass from directly packing LDS (assume large LDS) into a struct type which 14 // would otherwise cause allocating huge memory for struct instance within every 15 // kernel. 16 // 17 // Brief sketch of the algorithm implemented in this pass is as below: 18 // 19 // 1. Collect all the LDS defined in the module which qualify for pointer 20 // replacement, say it is, LDSGlobals set. 21 // 22 // 2. Collect all the reachable callees for each kernel defined in the module, 23 // say it is, KernelToCallees map. 24 // 25 // 3. FOR (each global GV from LDSGlobals set) DO 26 // LDSUsedNonKernels = Collect all non-kernel functions which use GV. 27 // FOR (each kernel K in KernelToCallees map) DO 28 // ReachableCallees = KernelToCallees[K] 29 // ReachableAndLDSUsedCallees = 30 // SetIntersect(LDSUsedNonKernels, ReachableCallees) 31 // IF (ReachableAndLDSUsedCallees is not empty) THEN 32 // Pointer = Create a pointer to point-to GV if not created. 33 // Initialize Pointer to point-to GV within kernel K. 34 // ENDIF 35 // ENDFOR 36 // Replace all uses of GV within non kernel functions by Pointer. 37 // ENFOR 38 // 39 // LLVM IR example: 40 // 41 // Input IR: 42 // 43 // @lds = internal addrspace(3) global [4 x i32] undef, align 16 44 // 45 // define internal void @f0() { 46 // entry: 47 // %gep = getelementptr inbounds [4 x i32], [4 x i32] addrspace(3)* @lds, 48 // i32 0, i32 0 49 // ret void 50 // } 51 // 52 // define protected amdgpu_kernel void @k0() { 53 // entry: 54 // call void @f0() 55 // ret void 56 // } 57 // 58 // Output IR: 59 // 60 // @lds = internal addrspace(3) global [4 x i32] undef, align 16 61 // @lds.ptr = internal unnamed_addr addrspace(3) global i16 undef, align 2 62 // 63 // define internal void @f0() { 64 // entry: 65 // %0 = load i16, i16 addrspace(3)* @lds.ptr, align 2 66 // %1 = getelementptr i8, i8 addrspace(3)* null, i16 %0 67 // %2 = bitcast i8 addrspace(3)* %1 to [4 x i32] addrspace(3)* 68 // %gep = getelementptr inbounds [4 x i32], [4 x i32] addrspace(3)* %2, 69 // i32 0, i32 0 70 // ret void 71 // } 72 // 73 // define protected amdgpu_kernel void @k0() { 74 // entry: 75 // store i16 ptrtoint ([4 x i32] addrspace(3)* @lds to i16), 76 // i16 addrspace(3)* @lds.ptr, align 2 77 // call void @f0() 78 // ret void 79 // } 80 // 81 //===----------------------------------------------------------------------===// 82 83 #include "AMDGPU.h" 84 #include "GCNSubtarget.h" 85 #include "Utils/AMDGPUBaseInfo.h" 86 #include "Utils/AMDGPULDSUtils.h" 87 #include "llvm/ADT/DenseMap.h" 88 #include "llvm/ADT/STLExtras.h" 89 #include "llvm/ADT/SetOperations.h" 90 #include "llvm/CodeGen/TargetPassConfig.h" 91 #include "llvm/IR/Constants.h" 92 #include "llvm/IR/DerivedTypes.h" 93 #include "llvm/IR/IRBuilder.h" 94 #include "llvm/IR/InlineAsm.h" 95 #include "llvm/IR/Instructions.h" 96 #include "llvm/IR/IntrinsicsAMDGPU.h" 97 #include "llvm/IR/ReplaceConstant.h" 98 #include "llvm/InitializePasses.h" 99 #include "llvm/Pass.h" 100 #include "llvm/Support/Debug.h" 101 #include "llvm/Target/TargetMachine.h" 102 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 103 #include "llvm/Transforms/Utils/ModuleUtils.h" 104 #include <algorithm> 105 #include <vector> 106 107 #define DEBUG_TYPE "amdgpu-replace-lds-use-with-pointer" 108 109 using namespace llvm; 110 111 namespace { 112 113 class ReplaceLDSUseImpl { 114 Module &M; 115 LLVMContext &Ctx; 116 const DataLayout &DL; 117 Constant *LDSMemBaseAddr; 118 119 DenseMap<GlobalVariable *, GlobalVariable *> LDSToPointer; 120 DenseMap<GlobalVariable *, SmallPtrSet<Function *, 8>> LDSToNonKernels; 121 DenseMap<Function *, SmallPtrSet<Function *, 8>> KernelToCallees; 122 DenseMap<Function *, SmallPtrSet<GlobalVariable *, 8>> KernelToLDSPointers; 123 DenseMap<Function *, BasicBlock *> KernelToInitBB; 124 DenseMap<Function *, DenseMap<GlobalVariable *, Value *>> 125 FunctionToLDSToReplaceInst; 126 127 // Collect LDS which requires their uses to be replaced by pointer. 128 std::vector<GlobalVariable *> collectLDSRequiringPointerReplace() { 129 // Collect LDS which requires module lowering. 130 std::vector<GlobalVariable *> LDSGlobals = AMDGPU::findVariablesToLower(M); 131 132 // Remove LDS which don't qualify for replacement. 133 llvm::erase_if(LDSGlobals, [&](GlobalVariable *GV) { 134 return shouldIgnorePointerReplacement(GV); 135 }); 136 137 return LDSGlobals; 138 } 139 140 // Returns true if uses of given LDS global within non-kernel functions should 141 // be keep as it is without pointer replacement. 142 bool shouldIgnorePointerReplacement(GlobalVariable *GV) { 143 // LDS whose size is very small and doesn't exceed pointer size is not worth 144 // replacing. 145 if (DL.getTypeAllocSize(GV->getValueType()) <= 2) 146 return true; 147 148 // LDS which is not used from non-kernel function scope or it is used from 149 // global scope does not qualify for replacement. 150 LDSToNonKernels[GV] = AMDGPU::collectNonKernelAccessorsOfLDS(GV); 151 return LDSToNonKernels[GV].empty(); 152 153 // FIXME: When GV is used within all (or within most of the kernels), then 154 // it does not make sense to create a pointer for it. 155 } 156 157 // Insert new global LDS pointer which points to LDS. 158 GlobalVariable *createLDSPointer(GlobalVariable *GV) { 159 // LDS pointer which points to LDS is already created? Return it. 160 auto PointerEntry = LDSToPointer.insert(std::make_pair(GV, nullptr)); 161 if (!PointerEntry.second) 162 return PointerEntry.first->second; 163 164 // We need to create new LDS pointer which points to LDS. 165 // 166 // Each CU owns at max 64K of LDS memory, so LDS address ranges from 0 to 167 // 2^16 - 1. Hence 16 bit pointer is enough to hold the LDS address. 168 auto *I16Ty = Type::getInt16Ty(Ctx); 169 GlobalVariable *LDSPointer = new GlobalVariable( 170 M, I16Ty, false, GlobalValue::InternalLinkage, UndefValue::get(I16Ty), 171 GV->getName() + Twine(".ptr"), nullptr, GlobalVariable::NotThreadLocal, 172 AMDGPUAS::LOCAL_ADDRESS); 173 174 LDSPointer->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); 175 LDSPointer->setAlignment(AMDGPU::getAlign(DL, LDSPointer)); 176 177 // Mark that an associated LDS pointer is created for LDS. 178 LDSToPointer[GV] = LDSPointer; 179 180 return LDSPointer; 181 } 182 183 // Split entry basic block in such a way that only lane 0 of each wave does 184 // the LDS pointer initialization, and return newly created basic block. 185 BasicBlock *activateLaneZero(Function *K) { 186 // If the entry basic block of kernel K is already split, then return 187 // newly created basic block. 188 auto BasicBlockEntry = KernelToInitBB.insert(std::make_pair(K, nullptr)); 189 if (!BasicBlockEntry.second) 190 return BasicBlockEntry.first->second; 191 192 // Split entry basic block of kernel K. 193 auto *EI = &(*(K->getEntryBlock().getFirstInsertionPt())); 194 IRBuilder<> Builder(EI); 195 196 Value *Mbcnt = 197 Builder.CreateIntrinsic(Intrinsic::amdgcn_mbcnt_lo, {}, 198 {Builder.getInt32(-1), Builder.getInt32(0)}); 199 Value *Cond = Builder.CreateICmpEQ(Mbcnt, Builder.getInt32(0)); 200 Instruction *WB = cast<Instruction>( 201 Builder.CreateIntrinsic(Intrinsic::amdgcn_wave_barrier, {}, {})); 202 203 BasicBlock *NBB = SplitBlockAndInsertIfThen(Cond, WB, false)->getParent(); 204 205 // Mark that the entry basic block of kernel K is split. 206 KernelToInitBB[K] = NBB; 207 208 return NBB; 209 } 210 211 // Within given kernel, initialize given LDS pointer to point to given LDS. 212 void initializeLDSPointer(Function *K, GlobalVariable *GV, 213 GlobalVariable *LDSPointer) { 214 // If LDS pointer is already initialized within K, then nothing to do. 215 auto PointerEntry = KernelToLDSPointers.insert( 216 std::make_pair(K, SmallPtrSet<GlobalVariable *, 8>())); 217 if (!PointerEntry.second) 218 if (PointerEntry.first->second.contains(LDSPointer)) 219 return; 220 221 // Insert instructions at EI which initialize LDS pointer to point-to LDS 222 // within kernel K. 223 // 224 // That is, convert pointer type of GV to i16, and then store this converted 225 // i16 value within LDSPointer which is of type i16*. 226 auto *EI = &(*(activateLaneZero(K)->getFirstInsertionPt())); 227 IRBuilder<> Builder(EI); 228 Builder.CreateStore(Builder.CreatePtrToInt(GV, Type::getInt16Ty(Ctx)), 229 LDSPointer); 230 231 // Mark that LDS pointer is initialized within kernel K. 232 KernelToLDSPointers[K].insert(LDSPointer); 233 } 234 235 // We have created an LDS pointer for LDS, and initialized it to point-to LDS 236 // within all relevant kernels. Now replace all the uses of LDS within 237 // non-kernel functions by LDS pointer. 238 void replaceLDSUseByPointer(GlobalVariable *GV, GlobalVariable *LDSPointer) { 239 SmallVector<User *, 8> LDSUsers(GV->users()); 240 for (auto *U : LDSUsers) { 241 // When `U` is a constant expression, it is possible that same constant 242 // expression exists within multiple instructions, and within multiple 243 // non-kernel functions. Collect all those non-kernel functions and all 244 // those instructions within which `U` exist. 245 auto FunctionToInsts = 246 AMDGPU::getFunctionToInstsMap(U, false /*=CollectKernelInsts*/); 247 248 for (auto FI = FunctionToInsts.begin(), FE = FunctionToInsts.end(); 249 FI != FE; ++FI) { 250 Function *F = FI->first; 251 auto &Insts = FI->second; 252 for (auto *I : Insts) { 253 // If `U` is a constant expression, then we need to break the 254 // associated instruction into a set of separate instructions by 255 // converting constant expressions into instructions. 256 SmallPtrSet<Instruction *, 8> UserInsts; 257 258 if (U == I) { 259 // `U` is an instruction, conversion from constant expression to 260 // set of instructions is *not* required. 261 UserInsts.insert(I); 262 } else { 263 // `U` is a constant expression, convert it into corresponding set 264 // of instructions. 265 auto *CE = cast<ConstantExpr>(U); 266 convertConstantExprsToInstructions(I, CE, &UserInsts); 267 } 268 269 // Go through all the user instructions, if LDS exist within them as 270 // an operand, then replace it by replace instruction. 271 for (auto *II : UserInsts) { 272 auto *ReplaceInst = getReplacementInst(F, GV, LDSPointer); 273 II->replaceUsesOfWith(GV, ReplaceInst); 274 } 275 } 276 } 277 } 278 } 279 280 // Create a set of replacement instructions which together replace LDS within 281 // non-kernel function F by accessing LDS indirectly using LDS pointer. 282 Value *getReplacementInst(Function *F, GlobalVariable *GV, 283 GlobalVariable *LDSPointer) { 284 // If the instruction which replaces LDS within F is already created, then 285 // return it. 286 auto LDSEntry = FunctionToLDSToReplaceInst.insert( 287 std::make_pair(F, DenseMap<GlobalVariable *, Value *>())); 288 if (!LDSEntry.second) { 289 auto ReplaceInstEntry = 290 LDSEntry.first->second.insert(std::make_pair(GV, nullptr)); 291 if (!ReplaceInstEntry.second) 292 return ReplaceInstEntry.first->second; 293 } 294 295 // Get the instruction insertion point within the beginning of the entry 296 // block of current non-kernel function. 297 auto *EI = &(*(F->getEntryBlock().getFirstInsertionPt())); 298 IRBuilder<> Builder(EI); 299 300 // Insert required set of instructions which replace LDS within F. 301 auto *V = Builder.CreateBitCast( 302 Builder.CreateGEP( 303 Builder.getInt8Ty(), LDSMemBaseAddr, 304 Builder.CreateLoad(LDSPointer->getValueType(), LDSPointer)), 305 GV->getType()); 306 307 // Mark that the replacement instruction which replace LDS within F is 308 // created. 309 FunctionToLDSToReplaceInst[F][GV] = V; 310 311 return V; 312 } 313 314 public: 315 ReplaceLDSUseImpl(Module &M) 316 : M(M), Ctx(M.getContext()), DL(M.getDataLayout()) { 317 LDSMemBaseAddr = Constant::getIntegerValue( 318 PointerType::get(Type::getInt8Ty(M.getContext()), 319 AMDGPUAS::LOCAL_ADDRESS), 320 APInt(32, 0)); 321 } 322 323 // Entry-point function which interface ReplaceLDSUseImpl with outside of the 324 // class. 325 bool replaceLDSUse(); 326 327 private: 328 // For a given LDS from collected LDS globals set, replace its non-kernel 329 // function scope uses by pointer. 330 bool replaceLDSUse(GlobalVariable *GV); 331 }; 332 333 // For given LDS from collected LDS globals set, replace its non-kernel function 334 // scope uses by pointer. 335 bool ReplaceLDSUseImpl::replaceLDSUse(GlobalVariable *GV) { 336 // Holds all those non-kernel functions within which LDS is being accessed. 337 SmallPtrSet<Function *, 8> &LDSAccessors = LDSToNonKernels[GV]; 338 339 // The LDS pointer which points to LDS and replaces all the uses of LDS. 340 GlobalVariable *LDSPointer = nullptr; 341 342 // Traverse through each kernel K, check and if required, initialize the 343 // LDS pointer to point to LDS within K. 344 for (auto KI = KernelToCallees.begin(), KE = KernelToCallees.end(); KI != KE; 345 ++KI) { 346 Function *K = KI->first; 347 SmallPtrSet<Function *, 8> Callees = KI->second; 348 349 // Compute reachable and LDS used callees for kernel K. 350 set_intersect(Callees, LDSAccessors); 351 352 // None of the LDS accessing non-kernel functions are reachable from 353 // kernel K. Hence, no need to initialize LDS pointer within kernel K. 354 if (Callees.empty()) 355 continue; 356 357 // We have found reachable and LDS used callees for kernel K, and we need to 358 // initialize LDS pointer within kernel K, and we need to replace LDS use 359 // within those callees by LDS pointer. 360 // 361 // But, first check if LDS pointer is already created, if not create one. 362 LDSPointer = createLDSPointer(GV); 363 364 // Initialize LDS pointer to point to LDS within kernel K. 365 initializeLDSPointer(K, GV, LDSPointer); 366 } 367 368 // We have not found reachable and LDS used callees for any of the kernels, 369 // and hence we have not created LDS pointer. 370 if (!LDSPointer) 371 return false; 372 373 // We have created an LDS pointer for LDS, and initialized it to point-to LDS 374 // within all relevant kernels. Now replace all the uses of LDS within 375 // non-kernel functions by LDS pointer. 376 replaceLDSUseByPointer(GV, LDSPointer); 377 378 return true; 379 } 380 381 // Entry-point function which interface ReplaceLDSUseImpl with outside of the 382 // class. 383 bool ReplaceLDSUseImpl::replaceLDSUse() { 384 // Collect LDS which requires their uses to be replaced by pointer. 385 std::vector<GlobalVariable *> LDSGlobals = 386 collectLDSRequiringPointerReplace(); 387 388 // No LDS to pointer-replace. Nothing to do. 389 if (LDSGlobals.empty()) 390 return false; 391 392 // Collect reachable callee set for each kernel defined in the module. 393 AMDGPU::collectReachableCallees(M, KernelToCallees); 394 395 if (KernelToCallees.empty()) { 396 // Either module does not have any kernel definitions, or none of the kernel 397 // has a call to non-kernel functions, or we could not resolve any of the 398 // call sites to proper non-kernel functions, because of the situations like 399 // inline asm calls. Nothing to replace. 400 return false; 401 } 402 403 // For every LDS from collected LDS globals set, replace its non-kernel 404 // function scope use by pointer. 405 bool Changed = false; 406 for (auto *GV : LDSGlobals) 407 Changed |= replaceLDSUse(GV); 408 409 return Changed; 410 } 411 412 class AMDGPUReplaceLDSUseWithPointer : public ModulePass { 413 public: 414 static char ID; 415 416 AMDGPUReplaceLDSUseWithPointer() : ModulePass(ID) { 417 initializeAMDGPUReplaceLDSUseWithPointerPass( 418 *PassRegistry::getPassRegistry()); 419 } 420 421 bool runOnModule(Module &M) override; 422 423 void getAnalysisUsage(AnalysisUsage &AU) const override { 424 AU.addRequired<TargetPassConfig>(); 425 } 426 }; 427 428 } // namespace 429 430 char AMDGPUReplaceLDSUseWithPointer::ID = 0; 431 char &llvm::AMDGPUReplaceLDSUseWithPointerID = 432 AMDGPUReplaceLDSUseWithPointer::ID; 433 434 INITIALIZE_PASS_BEGIN( 435 AMDGPUReplaceLDSUseWithPointer, DEBUG_TYPE, 436 "Replace within non-kernel function use of LDS with pointer", 437 false /*only look at the cfg*/, false /*analysis pass*/) 438 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 439 INITIALIZE_PASS_END( 440 AMDGPUReplaceLDSUseWithPointer, DEBUG_TYPE, 441 "Replace within non-kernel function use of LDS with pointer", 442 false /*only look at the cfg*/, false /*analysis pass*/) 443 444 bool AMDGPUReplaceLDSUseWithPointer::runOnModule(Module &M) { 445 ReplaceLDSUseImpl LDSUseReplacer{M}; 446 return LDSUseReplacer.replaceLDSUse(); 447 } 448 449 ModulePass *llvm::createAMDGPUReplaceLDSUseWithPointerPass() { 450 return new AMDGPUReplaceLDSUseWithPointer(); 451 } 452 453 PreservedAnalyses 454 AMDGPUReplaceLDSUseWithPointerPass::run(Module &M, ModuleAnalysisManager &AM) { 455 ReplaceLDSUseImpl LDSUseReplacer{M}; 456 LDSUseReplacer.replaceLDSUse(); 457 return PreservedAnalyses::all(); 458 } 459