1 //===-- AMDGPUReplaceLDSUseWithPointer.cpp --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass replaces all the uses of LDS within non-kernel functions by 10 // corresponding pointer counter-parts. 11 // 12 // The main motivation behind this pass is - to *avoid* subsequent LDS lowering 13 // pass from directly packing LDS (assume large LDS) into a struct type which 14 // would otherwise cause allocating huge memory for struct instance within every 15 // kernel. 16 // 17 // Brief sketch of the algorithm implemented in this pass is as below: 18 // 19 // 1. Collect all the LDS defined in the module which qualify for pointer 20 // replacement, say it is, LDSGlobals set. 21 // 22 // 2. Collect all the reachable callees for each kernel defined in the module, 23 // say it is, KernelToCallees map. 24 // 25 // 3. FOR (each global GV from LDSGlobals set) DO 26 // LDSUsedNonKernels = Collect all non-kernel functions which use GV. 27 // FOR (each kernel K in KernelToCallees map) DO 28 // ReachableCallees = KernelToCallees[K] 29 // ReachableAndLDSUsedCallees = 30 // SetIntersect(LDSUsedNonKernels, ReachableCallees) 31 // IF (ReachableAndLDSUsedCallees is not empty) THEN 32 // Pointer = Create a pointer to point-to GV if not created. 33 // Initialize Pointer to point-to GV within kernel K. 34 // ENDIF 35 // ENDFOR 36 // Replace all uses of GV within non kernel functions by Pointer. 37 // ENFOR 38 // 39 // LLVM IR example: 40 // 41 // Input IR: 42 // 43 // @lds = internal addrspace(3) global [4 x i32] undef, align 16 44 // 45 // define internal void @f0() { 46 // entry: 47 // %gep = getelementptr inbounds [4 x i32], [4 x i32] addrspace(3)* @lds, 48 // i32 0, i32 0 49 // ret void 50 // } 51 // 52 // define protected amdgpu_kernel void @k0() { 53 // entry: 54 // call void @f0() 55 // ret void 56 // } 57 // 58 // Output IR: 59 // 60 // @lds = internal addrspace(3) global [4 x i32] undef, align 16 61 // @lds.ptr = internal unnamed_addr addrspace(3) global i16 undef, align 2 62 // 63 // define internal void @f0() { 64 // entry: 65 // %0 = load i16, i16 addrspace(3)* @lds.ptr, align 2 66 // %1 = getelementptr i8, i8 addrspace(3)* null, i16 %0 67 // %2 = bitcast i8 addrspace(3)* %1 to [4 x i32] addrspace(3)* 68 // %gep = getelementptr inbounds [4 x i32], [4 x i32] addrspace(3)* %2, 69 // i32 0, i32 0 70 // ret void 71 // } 72 // 73 // define protected amdgpu_kernel void @k0() { 74 // entry: 75 // store i16 ptrtoint ([4 x i32] addrspace(3)* @lds to i16), 76 // i16 addrspace(3)* @lds.ptr, align 2 77 // call void @f0() 78 // ret void 79 // } 80 // 81 //===----------------------------------------------------------------------===// 82 83 #include "AMDGPU.h" 84 #include "GCNSubtarget.h" 85 #include "Utils/AMDGPUBaseInfo.h" 86 #include "Utils/AMDGPULDSUtils.h" 87 #include "llvm/ADT/DenseMap.h" 88 #include "llvm/ADT/STLExtras.h" 89 #include "llvm/ADT/SetOperations.h" 90 #include "llvm/CodeGen/TargetPassConfig.h" 91 #include "llvm/IR/Constants.h" 92 #include "llvm/IR/DerivedTypes.h" 93 #include "llvm/IR/IRBuilder.h" 94 #include "llvm/IR/InlineAsm.h" 95 #include "llvm/IR/Instructions.h" 96 #include "llvm/IR/IntrinsicsAMDGPU.h" 97 #include "llvm/IR/ReplaceConstant.h" 98 #include "llvm/InitializePasses.h" 99 #include "llvm/Pass.h" 100 #include "llvm/Support/Debug.h" 101 #include "llvm/Target/TargetMachine.h" 102 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 103 #include "llvm/Transforms/Utils/ModuleUtils.h" 104 #include <algorithm> 105 #include <vector> 106 107 #define DEBUG_TYPE "amdgpu-replace-lds-use-with-pointer" 108 109 using namespace llvm; 110 111 namespace { 112 113 class ReplaceLDSUseImpl { 114 Module &M; 115 LLVMContext &Ctx; 116 const DataLayout &DL; 117 Constant *LDSMemBaseAddr; 118 119 DenseMap<GlobalVariable *, GlobalVariable *> LDSToPointer; 120 DenseMap<GlobalVariable *, SmallPtrSet<Function *, 8>> LDSToNonKernels; 121 DenseMap<Function *, SmallPtrSet<Function *, 8>> KernelToCallees; 122 DenseMap<Function *, SmallPtrSet<GlobalVariable *, 8>> KernelToLDSPointers; 123 DenseMap<Function *, BasicBlock *> KernelToInitBB; 124 DenseMap<Function *, DenseMap<GlobalVariable *, Value *>> 125 FunctionToLDSToReplaceInst; 126 127 // Collect LDS which requires their uses to be replaced by pointer. 128 std::vector<GlobalVariable *> collectLDSRequiringPointerReplace() { 129 // Collect LDS which requires module lowering. 130 std::vector<GlobalVariable *> LDSGlobals = AMDGPU::findVariablesToLower(M); 131 132 // Remove LDS which don't qualify for replacement. 133 LDSGlobals.erase(std::remove_if(LDSGlobals.begin(), LDSGlobals.end(), 134 [&](GlobalVariable *GV) { 135 return shouldIgnorePointerReplacement(GV); 136 }), 137 LDSGlobals.end()); 138 139 return LDSGlobals; 140 } 141 142 // Returns true if uses of given LDS global within non-kernel functions should 143 // be keep as it is without pointer replacement. 144 bool shouldIgnorePointerReplacement(GlobalVariable *GV) { 145 // LDS whose size is very small and doesn`t exceed pointer size is not worth 146 // replacing. 147 if (DL.getTypeAllocSize(GV->getValueType()) <= 2) 148 return true; 149 150 // LDS which is not used from non-kernel function scope or it is used from 151 // global scope does not qualify for replacement. 152 LDSToNonKernels[GV] = AMDGPU::collectNonKernelAccessorsOfLDS(GV); 153 return LDSToNonKernels[GV].empty(); 154 155 // FIXME: When GV is used within all (or within most of the kernels), then 156 // it does not make sense to create a pointer for it. 157 } 158 159 // Insert new global LDS pointer which points to LDS. 160 GlobalVariable *createLDSPointer(GlobalVariable *GV) { 161 // LDS pointer which points to LDS is already created? return it. 162 auto PointerEntry = LDSToPointer.insert(std::make_pair(GV, nullptr)); 163 if (!PointerEntry.second) 164 return PointerEntry.first->second; 165 166 // We need to create new LDS pointer which points to LDS. 167 // 168 // Each CU owns at max 64K of LDS memory, so LDS address ranges from 0 to 169 // 2^16 - 1. Hence 16 bit pointer is enough to hold the LDS address. 170 auto *I16Ty = Type::getInt16Ty(Ctx); 171 GlobalVariable *LDSPointer = new GlobalVariable( 172 M, I16Ty, false, GlobalValue::InternalLinkage, UndefValue::get(I16Ty), 173 GV->getName() + Twine(".ptr"), nullptr, GlobalVariable::NotThreadLocal, 174 AMDGPUAS::LOCAL_ADDRESS); 175 176 LDSPointer->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); 177 LDSPointer->setAlignment(AMDGPU::getAlign(DL, LDSPointer)); 178 179 // Mark that an associated LDS pointer is created for LDS. 180 LDSToPointer[GV] = LDSPointer; 181 182 return LDSPointer; 183 } 184 185 // Split entry basic block in such a way that only lane 0 of each wave does 186 // the LDS pointer initialization, and return newly created basic block. 187 BasicBlock *activateLaneZero(Function *K) { 188 // If the entry basic block of kernel K is already splitted, then return 189 // newly created basic block. 190 auto BasicBlockEntry = KernelToInitBB.insert(std::make_pair(K, nullptr)); 191 if (!BasicBlockEntry.second) 192 return BasicBlockEntry.first->second; 193 194 // Split entry basic block of kernel K just after alloca. 195 // 196 // Find the split point just after alloca. 197 auto &EBB = K->getEntryBlock(); 198 auto *EI = &(*(EBB.getFirstInsertionPt())); 199 BasicBlock::reverse_iterator RIT(EBB.getTerminator()); 200 while (!isa<AllocaInst>(*RIT) && (&*RIT != EI)) 201 ++RIT; 202 if (isa<AllocaInst>(*RIT)) 203 --RIT; 204 205 // Split entry basic block. 206 IRBuilder<> Builder(&*RIT); 207 Value *Mbcnt = 208 Builder.CreateIntrinsic(Intrinsic::amdgcn_mbcnt_lo, {}, 209 {Builder.getInt32(-1), Builder.getInt32(0)}); 210 Value *Cond = Builder.CreateICmpEQ(Mbcnt, Builder.getInt32(0)); 211 Instruction *WB = cast<Instruction>( 212 Builder.CreateIntrinsic(Intrinsic::amdgcn_wave_barrier, {}, {})); 213 214 BasicBlock *NBB = SplitBlockAndInsertIfThen(Cond, WB, false)->getParent(); 215 216 // Mark that the entry basic block of kernel K is splitted. 217 KernelToInitBB[K] = NBB; 218 219 return NBB; 220 } 221 222 // Within given kernel, initialize given LDS pointer to point to given LDS. 223 void initializeLDSPointer(Function *K, GlobalVariable *GV, 224 GlobalVariable *LDSPointer) { 225 // If LDS pointer is already initialized within K, then nothing to do. 226 auto PointerEntry = KernelToLDSPointers.insert( 227 std::make_pair(K, SmallPtrSet<GlobalVariable *, 8>())); 228 if (!PointerEntry.second) 229 if (PointerEntry.first->second.contains(LDSPointer)) 230 return; 231 232 // Insert instructions at EI which initialize LDS pointer to point-to LDS 233 // within kernel K. 234 // 235 // That is, convert pointer type of GV to i16, and then store this converted 236 // i16 value within LDSPointer which is of type i16*. 237 auto *EI = &(*(activateLaneZero(K)->getFirstInsertionPt())); 238 IRBuilder<> Builder(EI); 239 Builder.CreateStore(Builder.CreatePtrToInt(GV, Type::getInt16Ty(Ctx)), 240 LDSPointer); 241 242 // Mark that LDS pointer is initialized within kernel K. 243 KernelToLDSPointers[K].insert(LDSPointer); 244 } 245 246 // We have created an LDS pointer for LDS, and initialized it to point-to LDS 247 // within all relevent kernels. Now replace all the uses of LDS within 248 // non-kernel functions by LDS pointer. 249 void replaceLDSUseByPointer(GlobalVariable *GV, GlobalVariable *LDSPointer) { 250 SmallVector<User *, 8> LDSUsers(GV->users()); 251 for (auto *U : LDSUsers) { 252 // When `U` is a constant expression, it is possible that same constant 253 // expression exists within multiple instructions, and within multiple 254 // non-kernel functions. Collect all those non-kernel functions and all 255 // those instructions within which `U` exist. 256 auto FunctionToInsts = 257 AMDGPU::getFunctionToInstsMap(U, false /*=CollectKernelInsts*/); 258 259 for (auto FI = FunctionToInsts.begin(), FE = FunctionToInsts.end(); 260 FI != FE; ++FI) { 261 Function *F = FI->first; 262 auto &Insts = FI->second; 263 for (auto *I : Insts) { 264 // If `U` is a constant expression, then we need to break the 265 // associated instruction into a set of separate instructions by 266 // converting constant expressions into instructions. 267 SmallPtrSet<Instruction *, 8> UserInsts; 268 269 if (U == I) { 270 // `U` is an instruction, conversion from constant expression to 271 // set of instructions is *not* required. 272 UserInsts.insert(I); 273 } else { 274 // `U` is a constant expression, convert it into corresponding set 275 // of instructions. 276 auto *CE = cast<ConstantExpr>(U); 277 convertConstantExprsToInstructions(I, CE, &UserInsts); 278 } 279 280 // Go through all the user instrutions, if LDS exist within them as an 281 // operand, then replace it by replace instruction. 282 for (auto *II : UserInsts) { 283 auto *ReplaceInst = getReplacementInst(F, GV, LDSPointer); 284 II->replaceUsesOfWith(GV, ReplaceInst); 285 } 286 } 287 } 288 } 289 } 290 291 // Create a set of replacement instructions which together replace LDS within 292 // non-kernel function F by accessing LDS indirectly using LDS pointer. 293 Value *getReplacementInst(Function *F, GlobalVariable *GV, 294 GlobalVariable *LDSPointer) { 295 // If the instruction which replaces LDS within F is already created, then 296 // return it. 297 auto LDSEntry = FunctionToLDSToReplaceInst.insert( 298 std::make_pair(F, DenseMap<GlobalVariable *, Value *>())); 299 if (!LDSEntry.second) { 300 auto ReplaceInstEntry = 301 LDSEntry.first->second.insert(std::make_pair(GV, nullptr)); 302 if (!ReplaceInstEntry.second) 303 return ReplaceInstEntry.first->second; 304 } 305 306 // Get the instruction insertion point within the beginning of the entry 307 // block of current non-kernel function. 308 auto *EI = &(*(F->getEntryBlock().getFirstInsertionPt())); 309 IRBuilder<> Builder(EI); 310 311 // Insert required set of instructions which replace LDS within F. 312 auto *V = Builder.CreateBitCast( 313 Builder.CreateGEP( 314 Builder.getInt8Ty(), LDSMemBaseAddr, 315 Builder.CreateLoad(LDSPointer->getValueType(), LDSPointer)), 316 GV->getType()); 317 318 // Mark that the replacement instruction which replace LDS within F is 319 // created. 320 FunctionToLDSToReplaceInst[F][GV] = V; 321 322 return V; 323 } 324 325 public: 326 ReplaceLDSUseImpl(Module &M) 327 : M(M), Ctx(M.getContext()), DL(M.getDataLayout()) { 328 LDSMemBaseAddr = Constant::getIntegerValue( 329 PointerType::get(Type::getInt8Ty(M.getContext()), 330 AMDGPUAS::LOCAL_ADDRESS), 331 APInt(32, 0)); 332 } 333 334 // Entry-point function which interface ReplaceLDSUseImpl with outside of the 335 // class. 336 bool replaceLDSUse(); 337 338 private: 339 // For a given LDS from collected LDS globals set, replace its non-kernel 340 // function scope uses by pointer. 341 bool replaceLDSUse(GlobalVariable *GV); 342 }; 343 344 // For given LDS from collected LDS globals set, replace its non-kernel function 345 // scope uses by pointer. 346 bool ReplaceLDSUseImpl::replaceLDSUse(GlobalVariable *GV) { 347 // Holds all those non-kernel functions within which LDS is being accessed. 348 SmallPtrSet<Function *, 8> &LDSAccessors = LDSToNonKernels[GV]; 349 350 // The LDS pointer which points to LDS and replaces all the uses of LDS. 351 GlobalVariable *LDSPointer = nullptr; 352 353 // Traverse through each kernel K, check and if required, initialize the 354 // LDS pointer to point to LDS within K. 355 for (auto KI = KernelToCallees.begin(), KE = KernelToCallees.end(); KI != KE; 356 ++KI) { 357 Function *K = KI->first; 358 SmallPtrSet<Function *, 8> Callees = KI->second; 359 360 // Compute reachable and LDS used callees for kernel K. 361 set_intersect(Callees, LDSAccessors); 362 363 // None of the LDS accessing non-kernel functions are reachable from 364 // kernel K. Hence, no need to initialize LDS pointer within kernel K. 365 if (Callees.empty()) 366 continue; 367 368 // We have found reachable and LDS used callees for kernel K, and we need to 369 // initialize LDS pointer within kernel K, and we need to replace LDS use 370 // within those callees by LDS pointer. 371 // 372 // But, first check if LDS pointer is already created, if not create one. 373 LDSPointer = createLDSPointer(GV); 374 375 // Initialize LDS pointer to point to LDS within kernel K. 376 initializeLDSPointer(K, GV, LDSPointer); 377 } 378 379 // We have not found reachable and LDS used callees for any of the kernels, 380 // and hence we have not created LDS pointer. 381 if (!LDSPointer) 382 return false; 383 384 // We have created an LDS pointer for LDS, and initialized it to point-to LDS 385 // within all relevent kernels. Now replace all the uses of LDS within 386 // non-kernel functions by LDS pointer. 387 replaceLDSUseByPointer(GV, LDSPointer); 388 389 return true; 390 } 391 392 // Entry-point function which interface ReplaceLDSUseImpl with outside of the 393 // class. 394 bool ReplaceLDSUseImpl::replaceLDSUse() { 395 // Collect LDS which requires their uses to be replaced by pointer. 396 std::vector<GlobalVariable *> LDSGlobals = 397 collectLDSRequiringPointerReplace(); 398 399 // No LDS to pointer-replace. Nothing to do. 400 if (LDSGlobals.empty()) 401 return false; 402 403 // Collect reachable callee set for each kernel defined in the module. 404 AMDGPU::collectReachableCallees(M, KernelToCallees); 405 406 if (KernelToCallees.empty()) { 407 // Either module does not have any kernel definitions, or none of the kernel 408 // has a call to non-kernel functions, or we could not resolve any of the 409 // call sites to proper non-kernel functions, because of the situations like 410 // inline asm calls. Nothing to replace. 411 return false; 412 } 413 414 // For every LDS from collected LDS globals set, replace its non-kernel 415 // function scope use by pointer. 416 bool Changed = false; 417 for (auto *GV : LDSGlobals) 418 Changed |= replaceLDSUse(GV); 419 420 return Changed; 421 } 422 423 class AMDGPUReplaceLDSUseWithPointer : public ModulePass { 424 public: 425 static char ID; 426 427 AMDGPUReplaceLDSUseWithPointer() : ModulePass(ID) { 428 initializeAMDGPUReplaceLDSUseWithPointerPass( 429 *PassRegistry::getPassRegistry()); 430 } 431 432 bool runOnModule(Module &M) override; 433 434 void getAnalysisUsage(AnalysisUsage &AU) const override { 435 AU.addRequired<TargetPassConfig>(); 436 } 437 }; 438 439 } // namespace 440 441 char AMDGPUReplaceLDSUseWithPointer::ID = 0; 442 char &llvm::AMDGPUReplaceLDSUseWithPointerID = 443 AMDGPUReplaceLDSUseWithPointer::ID; 444 445 INITIALIZE_PASS_BEGIN( 446 AMDGPUReplaceLDSUseWithPointer, DEBUG_TYPE, 447 "Replace within non-kernel function use of LDS with pointer", 448 false /*only look at the cfg*/, false /*analysis pass*/) 449 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 450 INITIALIZE_PASS_END( 451 AMDGPUReplaceLDSUseWithPointer, DEBUG_TYPE, 452 "Replace within non-kernel function use of LDS with pointer", 453 false /*only look at the cfg*/, false /*analysis pass*/) 454 455 bool AMDGPUReplaceLDSUseWithPointer::runOnModule(Module &M) { 456 ReplaceLDSUseImpl LDSUseReplacer{M}; 457 return LDSUseReplacer.replaceLDSUse(); 458 } 459 460 ModulePass *llvm::createAMDGPUReplaceLDSUseWithPointerPass() { 461 return new AMDGPUReplaceLDSUseWithPointer(); 462 } 463 464 PreservedAnalyses 465 AMDGPUReplaceLDSUseWithPointerPass::run(Module &M, ModuleAnalysisManager &AM) { 466 ReplaceLDSUseImpl LDSUseReplacer{M}; 467 LDSUseReplacer.replaceLDSUse(); 468 return PreservedAnalyses::all(); 469 } 470