1 //===-- AMDGPUReplaceLDSUseWithPointer.cpp --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass replaces all the uses of LDS within non-kernel functions by
10 // corresponding pointer counter-parts.
11 //
12 // The main motivation behind this pass is - to *avoid* subsequent LDS lowering
13 // pass from directly packing LDS (assume large LDS) into a struct type which
14 // would otherwise cause allocating huge memory for struct instance within every
15 // kernel.
16 //
17 // Brief sketch of the algorithm implemented in this pass is as below:
18 //
19 //   1. Collect all the LDS defined in the module which qualify for pointer
20 //      replacement, say it is, LDSGlobals set.
21 //
22 //   2. Collect all the reachable callees for each kernel defined in the module,
23 //      say it is, KernelToCallees map.
24 //
25 //   3. FOR (each global GV from LDSGlobals set) DO
26 //        LDSUsedNonKernels = Collect all non-kernel functions which use GV.
27 //        FOR (each kernel K in KernelToCallees map) DO
28 //           ReachableCallees = KernelToCallees[K]
29 //           ReachableAndLDSUsedCallees =
30 //              SetIntersect(LDSUsedNonKernels, ReachableCallees)
31 //           IF (ReachableAndLDSUsedCallees is not empty) THEN
32 //             Pointer = Create a pointer to point-to GV if not created.
33 //             Initialize Pointer to point-to GV within kernel K.
34 //           ENDIF
35 //        ENDFOR
36 //        Replace all uses of GV within non kernel functions by Pointer.
37 //      ENFOR
38 //
39 // LLVM IR example:
40 //
41 //    Input IR:
42 //
43 //    @lds = internal addrspace(3) global [4 x i32] undef, align 16
44 //
45 //    define internal void @f0() {
46 //    entry:
47 //      %gep = getelementptr inbounds [4 x i32], [4 x i32] addrspace(3)* @lds,
48 //             i32 0, i32 0
49 //      ret void
50 //    }
51 //
52 //    define protected amdgpu_kernel void @k0() {
53 //    entry:
54 //      call void @f0()
55 //      ret void
56 //    }
57 //
58 //    Output IR:
59 //
60 //    @lds = internal addrspace(3) global [4 x i32] undef, align 16
61 //    @lds.ptr = internal unnamed_addr addrspace(3) global i16 undef, align 2
62 //
63 //    define internal void @f0() {
64 //    entry:
65 //      %0 = load i16, i16 addrspace(3)* @lds.ptr, align 2
66 //      %1 = getelementptr i8, i8 addrspace(3)* null, i16 %0
67 //      %2 = bitcast i8 addrspace(3)* %1 to [4 x i32] addrspace(3)*
68 //      %gep = getelementptr inbounds [4 x i32], [4 x i32] addrspace(3)* %2,
69 //             i32 0, i32 0
70 //      ret void
71 //    }
72 //
73 //    define protected amdgpu_kernel void @k0() {
74 //    entry:
75 //      store i16 ptrtoint ([4 x i32] addrspace(3)* @lds to i16),
76 //            i16 addrspace(3)* @lds.ptr, align 2
77 //      call void @f0()
78 //      ret void
79 //    }
80 //
81 //===----------------------------------------------------------------------===//
82 
83 #include "AMDGPU.h"
84 #include "GCNSubtarget.h"
85 #include "Utils/AMDGPUBaseInfo.h"
86 #include "Utils/AMDGPULDSUtils.h"
87 #include "llvm/ADT/DenseMap.h"
88 #include "llvm/ADT/STLExtras.h"
89 #include "llvm/ADT/SetOperations.h"
90 #include "llvm/CodeGen/TargetPassConfig.h"
91 #include "llvm/IR/Constants.h"
92 #include "llvm/IR/DerivedTypes.h"
93 #include "llvm/IR/IRBuilder.h"
94 #include "llvm/IR/InlineAsm.h"
95 #include "llvm/IR/Instructions.h"
96 #include "llvm/IR/IntrinsicsAMDGPU.h"
97 #include "llvm/IR/ReplaceConstant.h"
98 #include "llvm/InitializePasses.h"
99 #include "llvm/Pass.h"
100 #include "llvm/Support/Debug.h"
101 #include "llvm/Target/TargetMachine.h"
102 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
103 #include "llvm/Transforms/Utils/ModuleUtils.h"
104 #include <algorithm>
105 #include <vector>
106 
107 #define DEBUG_TYPE "amdgpu-replace-lds-use-with-pointer"
108 
109 using namespace llvm;
110 
111 namespace {
112 
113 class ReplaceLDSUseImpl {
114   Module &M;
115   LLVMContext &Ctx;
116   const DataLayout &DL;
117   Constant *LDSMemBaseAddr;
118 
119   DenseMap<GlobalVariable *, GlobalVariable *> LDSToPointer;
120   DenseMap<GlobalVariable *, SmallPtrSet<Function *, 8>> LDSToNonKernels;
121   DenseMap<Function *, SmallPtrSet<Function *, 8>> KernelToCallees;
122   DenseMap<Function *, SmallPtrSet<GlobalVariable *, 8>> KernelToLDSPointers;
123   DenseMap<Function *, BasicBlock *> KernelToInitBB;
124   DenseMap<Function *, DenseMap<GlobalVariable *, Value *>>
125       FunctionToLDSToReplaceInst;
126 
127   // Collect LDS which requires their uses to be replaced by pointer.
128   std::vector<GlobalVariable *> collectLDSRequiringPointerReplace() {
129     // Collect LDS which requires module lowering.
130     std::vector<GlobalVariable *> LDSGlobals = AMDGPU::findVariablesToLower(M);
131 
132     // Remove LDS which don't qualify for replacement.
133     LDSGlobals.erase(std::remove_if(LDSGlobals.begin(), LDSGlobals.end(),
134                                     [&](GlobalVariable *GV) {
135                                       return shouldIgnorePointerReplacement(GV);
136                                     }),
137                      LDSGlobals.end());
138 
139     return LDSGlobals;
140   }
141 
142   // Returns true if uses of given LDS global within non-kernel functions should
143   // be keep as it is without pointer replacement.
144   bool shouldIgnorePointerReplacement(GlobalVariable *GV) {
145     // LDS whose size is very small and doesn`t exceed pointer size is not worth
146     // replacing.
147     if (DL.getTypeAllocSize(GV->getValueType()) <= 2)
148       return true;
149 
150     // LDS which is not used from non-kernel function scope or it is used from
151     // global scope does not qualify for replacement.
152     LDSToNonKernels[GV] = AMDGPU::collectNonKernelAccessorsOfLDS(GV);
153     return LDSToNonKernels[GV].empty();
154 
155     // FIXME: When GV is used within all (or within most of the kernels), then
156     // it does not make sense to create a pointer for it.
157   }
158 
159   // Insert new global LDS pointer which points to LDS.
160   GlobalVariable *createLDSPointer(GlobalVariable *GV) {
161     // LDS pointer which points to LDS is already created? return it.
162     auto PointerEntry = LDSToPointer.insert(std::make_pair(GV, nullptr));
163     if (!PointerEntry.second)
164       return PointerEntry.first->second;
165 
166     // We need to create new LDS pointer which points to LDS.
167     //
168     // Each CU owns at max 64K of LDS memory, so LDS address ranges from 0 to
169     // 2^16 - 1. Hence 16 bit pointer is enough to hold the LDS address.
170     auto *I16Ty = Type::getInt16Ty(Ctx);
171     GlobalVariable *LDSPointer = new GlobalVariable(
172         M, I16Ty, false, GlobalValue::InternalLinkage, UndefValue::get(I16Ty),
173         GV->getName() + Twine(".ptr"), nullptr, GlobalVariable::NotThreadLocal,
174         AMDGPUAS::LOCAL_ADDRESS);
175 
176     LDSPointer->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
177     LDSPointer->setAlignment(AMDGPU::getAlign(DL, LDSPointer));
178 
179     // Mark that an associated LDS pointer is created for LDS.
180     LDSToPointer[GV] = LDSPointer;
181 
182     return LDSPointer;
183   }
184 
185   // Split entry basic block in such a way that only lane 0 of each wave does
186   // the LDS pointer initialization, and return newly created basic block.
187   BasicBlock *activateLaneZero(Function *K) {
188     // If the entry basic block of kernel K is already splitted, then return
189     // newly created basic block.
190     auto BasicBlockEntry = KernelToInitBB.insert(std::make_pair(K, nullptr));
191     if (!BasicBlockEntry.second)
192       return BasicBlockEntry.first->second;
193 
194     // Split entry basic block of kernel K just after alloca.
195     //
196     // Find the split point just after alloca.
197     auto &EBB = K->getEntryBlock();
198     auto *EI = &(*(EBB.getFirstInsertionPt()));
199     BasicBlock::reverse_iterator RIT(EBB.getTerminator());
200     while (!isa<AllocaInst>(*RIT) && (&*RIT != EI))
201       ++RIT;
202     if (isa<AllocaInst>(*RIT))
203       --RIT;
204 
205     // Split entry basic block.
206     IRBuilder<> Builder(&*RIT);
207     Value *Mbcnt =
208         Builder.CreateIntrinsic(Intrinsic::amdgcn_mbcnt_lo, {},
209                                 {Builder.getInt32(-1), Builder.getInt32(0)});
210     Value *Cond = Builder.CreateICmpEQ(Mbcnt, Builder.getInt32(0));
211     Instruction *WB = cast<Instruction>(
212         Builder.CreateIntrinsic(Intrinsic::amdgcn_wave_barrier, {}, {}));
213 
214     BasicBlock *NBB = SplitBlockAndInsertIfThen(Cond, WB, false)->getParent();
215 
216     // Mark that the entry basic block of kernel K is splitted.
217     KernelToInitBB[K] = NBB;
218 
219     return NBB;
220   }
221 
222   // Within given kernel, initialize given LDS pointer to point to given LDS.
223   void initializeLDSPointer(Function *K, GlobalVariable *GV,
224                             GlobalVariable *LDSPointer) {
225     // If LDS pointer is already initialized within K, then nothing to do.
226     auto PointerEntry = KernelToLDSPointers.insert(
227         std::make_pair(K, SmallPtrSet<GlobalVariable *, 8>()));
228     if (!PointerEntry.second)
229       if (PointerEntry.first->second.contains(LDSPointer))
230         return;
231 
232     // Insert instructions at EI which initialize LDS pointer to point-to LDS
233     // within kernel K.
234     //
235     // That is, convert pointer type of GV to i16, and then store this converted
236     // i16 value within LDSPointer which is of type i16*.
237     auto *EI = &(*(activateLaneZero(K)->getFirstInsertionPt()));
238     IRBuilder<> Builder(EI);
239     Builder.CreateStore(Builder.CreatePtrToInt(GV, Type::getInt16Ty(Ctx)),
240                         LDSPointer);
241 
242     // Mark that LDS pointer is initialized within kernel K.
243     KernelToLDSPointers[K].insert(LDSPointer);
244   }
245 
246   // We have created an LDS pointer for LDS, and initialized it to point-to LDS
247   // within all relevent kernels. Now replace all the uses of LDS within
248   // non-kernel functions by LDS pointer.
249   void replaceLDSUseByPointer(GlobalVariable *GV, GlobalVariable *LDSPointer) {
250     SmallVector<User *, 8> LDSUsers(GV->users());
251     for (auto *U : LDSUsers) {
252       // When `U` is a constant expression, it is possible that same constant
253       // expression exists within multiple instructions, and within multiple
254       // non-kernel functions. Collect all those non-kernel functions and all
255       // those instructions within which `U` exist.
256       auto FunctionToInsts =
257           AMDGPU::getFunctionToInstsMap(U, false /*=CollectKernelInsts*/);
258 
259       for (auto FI = FunctionToInsts.begin(), FE = FunctionToInsts.end();
260            FI != FE; ++FI) {
261         Function *F = FI->first;
262         auto &Insts = FI->second;
263         for (auto *I : Insts) {
264           // If `U` is a constant expression, then we need to break the
265           // associated instruction into a set of separate instructions by
266           // converting constant expressions into instructions.
267           SmallPtrSet<Instruction *, 8> UserInsts;
268 
269           if (U == I) {
270             // `U` is an instruction, conversion from constant expression to
271             // set of instructions is *not* required.
272             UserInsts.insert(I);
273           } else {
274             // `U` is a constant expression, convert it into corresponding set
275             // of instructions.
276             auto *CE = cast<ConstantExpr>(U);
277             convertConstantExprsToInstructions(I, CE, &UserInsts);
278           }
279 
280           // Go through all the user instrutions, if LDS exist within them as an
281           // operand, then replace it by replace instruction.
282           for (auto *II : UserInsts) {
283             auto *ReplaceInst = getReplacementInst(F, GV, LDSPointer);
284             II->replaceUsesOfWith(GV, ReplaceInst);
285           }
286         }
287       }
288     }
289   }
290 
291   // Create a set of replacement instructions which together replace LDS within
292   // non-kernel function F by accessing LDS indirectly using LDS pointer.
293   Value *getReplacementInst(Function *F, GlobalVariable *GV,
294                             GlobalVariable *LDSPointer) {
295     // If the instruction which replaces LDS within F is already created, then
296     // return it.
297     auto LDSEntry = FunctionToLDSToReplaceInst.insert(
298         std::make_pair(F, DenseMap<GlobalVariable *, Value *>()));
299     if (!LDSEntry.second) {
300       auto ReplaceInstEntry =
301           LDSEntry.first->second.insert(std::make_pair(GV, nullptr));
302       if (!ReplaceInstEntry.second)
303         return ReplaceInstEntry.first->second;
304     }
305 
306     // Get the instruction insertion point within the beginning of the entry
307     // block of current non-kernel function.
308     auto *EI = &(*(F->getEntryBlock().getFirstInsertionPt()));
309     IRBuilder<> Builder(EI);
310 
311     // Insert required set of instructions which replace LDS within F.
312     auto *V = Builder.CreateBitCast(
313         Builder.CreateGEP(
314             Builder.getInt8Ty(), LDSMemBaseAddr,
315             Builder.CreateLoad(LDSPointer->getValueType(), LDSPointer)),
316         GV->getType());
317 
318     // Mark that the replacement instruction which replace LDS within F is
319     // created.
320     FunctionToLDSToReplaceInst[F][GV] = V;
321 
322     return V;
323   }
324 
325 public:
326   ReplaceLDSUseImpl(Module &M)
327       : M(M), Ctx(M.getContext()), DL(M.getDataLayout()) {
328     LDSMemBaseAddr = Constant::getIntegerValue(
329         PointerType::get(Type::getInt8Ty(M.getContext()),
330                          AMDGPUAS::LOCAL_ADDRESS),
331         APInt(32, 0));
332   }
333 
334   // Entry-point function which interface ReplaceLDSUseImpl with outside of the
335   // class.
336   bool replaceLDSUse();
337 
338 private:
339   // For a given LDS from collected LDS globals set, replace its non-kernel
340   // function scope uses by pointer.
341   bool replaceLDSUse(GlobalVariable *GV);
342 };
343 
344 // For given LDS from collected LDS globals set, replace its non-kernel function
345 // scope uses by pointer.
346 bool ReplaceLDSUseImpl::replaceLDSUse(GlobalVariable *GV) {
347   // Holds all those non-kernel functions within which LDS is being accessed.
348   SmallPtrSet<Function *, 8> &LDSAccessors = LDSToNonKernels[GV];
349 
350   // The LDS pointer which points to LDS and replaces all the uses of LDS.
351   GlobalVariable *LDSPointer = nullptr;
352 
353   // Traverse through each kernel K, check and if required, initialize the
354   // LDS pointer to point to LDS within K.
355   for (auto KI = KernelToCallees.begin(), KE = KernelToCallees.end(); KI != KE;
356        ++KI) {
357     Function *K = KI->first;
358     SmallPtrSet<Function *, 8> Callees = KI->second;
359 
360     // Compute reachable and LDS used callees for kernel K.
361     set_intersect(Callees, LDSAccessors);
362 
363     // None of the LDS accessing non-kernel functions are reachable from
364     // kernel K. Hence, no need to initialize LDS pointer within kernel K.
365     if (Callees.empty())
366       continue;
367 
368     // We have found reachable and LDS used callees for kernel K, and we need to
369     // initialize LDS pointer within kernel K, and we need to replace LDS use
370     // within those callees by LDS pointer.
371     //
372     // But, first check if LDS pointer is already created, if not create one.
373     LDSPointer = createLDSPointer(GV);
374 
375     // Initialize LDS pointer to point to LDS within kernel K.
376     initializeLDSPointer(K, GV, LDSPointer);
377   }
378 
379   // We have not found reachable and LDS used callees for any of the kernels,
380   // and hence we have not created LDS pointer.
381   if (!LDSPointer)
382     return false;
383 
384   // We have created an LDS pointer for LDS, and initialized it to point-to LDS
385   // within all relevent kernels. Now replace all the uses of LDS within
386   // non-kernel functions by LDS pointer.
387   replaceLDSUseByPointer(GV, LDSPointer);
388 
389   return true;
390 }
391 
392 // Entry-point function which interface ReplaceLDSUseImpl with outside of the
393 // class.
394 bool ReplaceLDSUseImpl::replaceLDSUse() {
395   // Collect LDS which requires their uses to be replaced by pointer.
396   std::vector<GlobalVariable *> LDSGlobals =
397       collectLDSRequiringPointerReplace();
398 
399   // No LDS to pointer-replace. Nothing to do.
400   if (LDSGlobals.empty())
401     return false;
402 
403   // Collect reachable callee set for each kernel defined in the module.
404   AMDGPU::collectReachableCallees(M, KernelToCallees);
405 
406   if (KernelToCallees.empty()) {
407     // Either module does not have any kernel definitions, or none of the kernel
408     // has a call to non-kernel functions, or we could not resolve any of the
409     // call sites to proper non-kernel functions, because of the situations like
410     // inline asm calls. Nothing to replace.
411     return false;
412   }
413 
414   // For every LDS from collected LDS globals set, replace its non-kernel
415   // function scope use by pointer.
416   bool Changed = false;
417   for (auto *GV : LDSGlobals)
418     Changed |= replaceLDSUse(GV);
419 
420   return Changed;
421 }
422 
423 class AMDGPUReplaceLDSUseWithPointer : public ModulePass {
424 public:
425   static char ID;
426 
427   AMDGPUReplaceLDSUseWithPointer() : ModulePass(ID) {
428     initializeAMDGPUReplaceLDSUseWithPointerPass(
429         *PassRegistry::getPassRegistry());
430   }
431 
432   bool runOnModule(Module &M) override;
433 
434   void getAnalysisUsage(AnalysisUsage &AU) const override {
435     AU.addRequired<TargetPassConfig>();
436   }
437 };
438 
439 } // namespace
440 
441 char AMDGPUReplaceLDSUseWithPointer::ID = 0;
442 char &llvm::AMDGPUReplaceLDSUseWithPointerID =
443     AMDGPUReplaceLDSUseWithPointer::ID;
444 
445 INITIALIZE_PASS_BEGIN(
446     AMDGPUReplaceLDSUseWithPointer, DEBUG_TYPE,
447     "Replace within non-kernel function use of LDS with pointer",
448     false /*only look at the cfg*/, false /*analysis pass*/)
449 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
450 INITIALIZE_PASS_END(
451     AMDGPUReplaceLDSUseWithPointer, DEBUG_TYPE,
452     "Replace within non-kernel function use of LDS with pointer",
453     false /*only look at the cfg*/, false /*analysis pass*/)
454 
455 bool AMDGPUReplaceLDSUseWithPointer::runOnModule(Module &M) {
456   ReplaceLDSUseImpl LDSUseReplacer{M};
457   return LDSUseReplacer.replaceLDSUse();
458 }
459 
460 ModulePass *llvm::createAMDGPUReplaceLDSUseWithPointerPass() {
461   return new AMDGPUReplaceLDSUseWithPointer();
462 }
463 
464 PreservedAnalyses
465 AMDGPUReplaceLDSUseWithPointerPass::run(Module &M, ModuleAnalysisManager &AM) {
466   ReplaceLDSUseImpl LDSUseReplacer{M};
467   LDSUseReplacer.replaceLDSUse();
468   return PreservedAnalyses::all();
469 }
470