1*9cf995beSStanislav Mekhanoshin //===-- AMDGPUPromoteKernelArguments.cpp ----------------------------------===//
2*9cf995beSStanislav Mekhanoshin //
3*9cf995beSStanislav Mekhanoshin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*9cf995beSStanislav Mekhanoshin // See https://llvm.org/LICENSE.txt for license information.
5*9cf995beSStanislav Mekhanoshin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*9cf995beSStanislav Mekhanoshin //
7*9cf995beSStanislav Mekhanoshin //===----------------------------------------------------------------------===//
8*9cf995beSStanislav Mekhanoshin //
9*9cf995beSStanislav Mekhanoshin /// \file This pass recursively promotes generic pointer arguments of a kernel
10*9cf995beSStanislav Mekhanoshin /// into the global address space.
11*9cf995beSStanislav Mekhanoshin ///
12*9cf995beSStanislav Mekhanoshin /// The pass walks kernel's pointer arguments, then loads from them. If a loaded
13*9cf995beSStanislav Mekhanoshin /// value is a pointer and loaded pointer is unmodified in the kernel before the
14*9cf995beSStanislav Mekhanoshin /// load, then promote loaded pointer to global. Then recursively continue.
15*9cf995beSStanislav Mekhanoshin //
16*9cf995beSStanislav Mekhanoshin //===----------------------------------------------------------------------===//
17*9cf995beSStanislav Mekhanoshin 
18*9cf995beSStanislav Mekhanoshin #include "AMDGPU.h"
19*9cf995beSStanislav Mekhanoshin #include "llvm/ADT/SmallVector.h"
20*9cf995beSStanislav Mekhanoshin #include "llvm/Analysis/MemorySSA.h"
21*9cf995beSStanislav Mekhanoshin #include "llvm/IR/IRBuilder.h"
22*9cf995beSStanislav Mekhanoshin #include "llvm/InitializePasses.h"
23*9cf995beSStanislav Mekhanoshin 
24*9cf995beSStanislav Mekhanoshin #define DEBUG_TYPE "amdgpu-promote-kernel-arguments"
25*9cf995beSStanislav Mekhanoshin 
26*9cf995beSStanislav Mekhanoshin using namespace llvm;
27*9cf995beSStanislav Mekhanoshin 
28*9cf995beSStanislav Mekhanoshin namespace {
29*9cf995beSStanislav Mekhanoshin 
30*9cf995beSStanislav Mekhanoshin class AMDGPUPromoteKernelArguments : public FunctionPass {
31*9cf995beSStanislav Mekhanoshin   MemorySSA *MSSA;
32*9cf995beSStanislav Mekhanoshin 
33*9cf995beSStanislav Mekhanoshin   Instruction *ArgCastInsertPt;
34*9cf995beSStanislav Mekhanoshin 
35*9cf995beSStanislav Mekhanoshin   SmallVector<Value *> Ptrs;
36*9cf995beSStanislav Mekhanoshin 
37*9cf995beSStanislav Mekhanoshin   void enqueueUsers(Value *Ptr);
38*9cf995beSStanislav Mekhanoshin 
39*9cf995beSStanislav Mekhanoshin   bool promotePointer(Value *Ptr);
40*9cf995beSStanislav Mekhanoshin 
41*9cf995beSStanislav Mekhanoshin public:
42*9cf995beSStanislav Mekhanoshin   static char ID;
43*9cf995beSStanislav Mekhanoshin 
44*9cf995beSStanislav Mekhanoshin   AMDGPUPromoteKernelArguments() : FunctionPass(ID) {}
45*9cf995beSStanislav Mekhanoshin 
46*9cf995beSStanislav Mekhanoshin   bool run(Function &F, MemorySSA &MSSA);
47*9cf995beSStanislav Mekhanoshin 
48*9cf995beSStanislav Mekhanoshin   bool runOnFunction(Function &F) override;
49*9cf995beSStanislav Mekhanoshin 
50*9cf995beSStanislav Mekhanoshin   void getAnalysisUsage(AnalysisUsage &AU) const override {
51*9cf995beSStanislav Mekhanoshin     AU.addRequired<MemorySSAWrapperPass>();
52*9cf995beSStanislav Mekhanoshin     AU.setPreservesAll();
53*9cf995beSStanislav Mekhanoshin   }
54*9cf995beSStanislav Mekhanoshin };
55*9cf995beSStanislav Mekhanoshin 
56*9cf995beSStanislav Mekhanoshin } // end anonymous namespace
57*9cf995beSStanislav Mekhanoshin 
58*9cf995beSStanislav Mekhanoshin void AMDGPUPromoteKernelArguments::enqueueUsers(Value *Ptr) {
59*9cf995beSStanislav Mekhanoshin   SmallVector<User *> PtrUsers(Ptr->users());
60*9cf995beSStanislav Mekhanoshin 
61*9cf995beSStanislav Mekhanoshin   while (!PtrUsers.empty()) {
62*9cf995beSStanislav Mekhanoshin     Instruction *U = dyn_cast<Instruction>(PtrUsers.pop_back_val());
63*9cf995beSStanislav Mekhanoshin     if (!U)
64*9cf995beSStanislav Mekhanoshin       continue;
65*9cf995beSStanislav Mekhanoshin 
66*9cf995beSStanislav Mekhanoshin     switch (U->getOpcode()) {
67*9cf995beSStanislav Mekhanoshin     default:
68*9cf995beSStanislav Mekhanoshin       break;
69*9cf995beSStanislav Mekhanoshin     case Instruction::Load: {
70*9cf995beSStanislav Mekhanoshin       LoadInst *LD = cast<LoadInst>(U);
71*9cf995beSStanislav Mekhanoshin       PointerType *PT = dyn_cast<PointerType>(LD->getType());
72*9cf995beSStanislav Mekhanoshin       if (!PT ||
73*9cf995beSStanislav Mekhanoshin           (PT->getAddressSpace() != AMDGPUAS::FLAT_ADDRESS &&
74*9cf995beSStanislav Mekhanoshin            PT->getAddressSpace() != AMDGPUAS::GLOBAL_ADDRESS &&
75*9cf995beSStanislav Mekhanoshin            PT->getAddressSpace() != AMDGPUAS::CONSTANT_ADDRESS) ||
76*9cf995beSStanislav Mekhanoshin           LD->getPointerOperand()->stripInBoundsOffsets() != Ptr)
77*9cf995beSStanislav Mekhanoshin         break;
78*9cf995beSStanislav Mekhanoshin       const MemoryAccess *MA = MSSA->getWalker()->getClobberingMemoryAccess(LD);
79*9cf995beSStanislav Mekhanoshin       // TODO: This load poprobably can be promoted to constant address space.
80*9cf995beSStanislav Mekhanoshin       if (MSSA->isLiveOnEntryDef(MA))
81*9cf995beSStanislav Mekhanoshin         Ptrs.push_back(LD);
82*9cf995beSStanislav Mekhanoshin       break;
83*9cf995beSStanislav Mekhanoshin     }
84*9cf995beSStanislav Mekhanoshin     case Instruction::GetElementPtr:
85*9cf995beSStanislav Mekhanoshin     case Instruction::AddrSpaceCast:
86*9cf995beSStanislav Mekhanoshin     case Instruction::BitCast:
87*9cf995beSStanislav Mekhanoshin       if (U->getOperand(0)->stripInBoundsOffsets() == Ptr)
88*9cf995beSStanislav Mekhanoshin         PtrUsers.append(U->user_begin(), U->user_end());
89*9cf995beSStanislav Mekhanoshin       break;
90*9cf995beSStanislav Mekhanoshin     }
91*9cf995beSStanislav Mekhanoshin   }
92*9cf995beSStanislav Mekhanoshin }
93*9cf995beSStanislav Mekhanoshin 
94*9cf995beSStanislav Mekhanoshin bool AMDGPUPromoteKernelArguments::promotePointer(Value *Ptr) {
95*9cf995beSStanislav Mekhanoshin   enqueueUsers(Ptr);
96*9cf995beSStanislav Mekhanoshin 
97*9cf995beSStanislav Mekhanoshin   PointerType *PT = cast<PointerType>(Ptr->getType());
98*9cf995beSStanislav Mekhanoshin   if (PT->getAddressSpace() != AMDGPUAS::FLAT_ADDRESS)
99*9cf995beSStanislav Mekhanoshin     return false;
100*9cf995beSStanislav Mekhanoshin 
101*9cf995beSStanislav Mekhanoshin   bool IsArg = isa<Argument>(Ptr);
102*9cf995beSStanislav Mekhanoshin   IRBuilder<> B(IsArg ? ArgCastInsertPt
103*9cf995beSStanislav Mekhanoshin                       : &*std::next(cast<Instruction>(Ptr)->getIterator()));
104*9cf995beSStanislav Mekhanoshin 
105*9cf995beSStanislav Mekhanoshin   // Cast pointer to global address space and back to flat and let
106*9cf995beSStanislav Mekhanoshin   // Infer Address Spaces pass to do all necessary rewriting.
107*9cf995beSStanislav Mekhanoshin   PointerType *NewPT =
108*9cf995beSStanislav Mekhanoshin       PointerType::getWithSamePointeeType(PT, AMDGPUAS::GLOBAL_ADDRESS);
109*9cf995beSStanislav Mekhanoshin   Value *Cast =
110*9cf995beSStanislav Mekhanoshin       B.CreateAddrSpaceCast(Ptr, NewPT, Twine(Ptr->getName(), ".global"));
111*9cf995beSStanislav Mekhanoshin   Value *CastBack =
112*9cf995beSStanislav Mekhanoshin       B.CreateAddrSpaceCast(Cast, PT, Twine(Ptr->getName(), ".flat"));
113*9cf995beSStanislav Mekhanoshin   Ptr->replaceUsesWithIf(CastBack,
114*9cf995beSStanislav Mekhanoshin                          [Cast](Use &U) { return U.getUser() != Cast; });
115*9cf995beSStanislav Mekhanoshin 
116*9cf995beSStanislav Mekhanoshin   return true;
117*9cf995beSStanislav Mekhanoshin }
118*9cf995beSStanislav Mekhanoshin 
119*9cf995beSStanislav Mekhanoshin // skip allocas
120*9cf995beSStanislav Mekhanoshin static BasicBlock::iterator getInsertPt(BasicBlock &BB) {
121*9cf995beSStanislav Mekhanoshin   BasicBlock::iterator InsPt = BB.getFirstInsertionPt();
122*9cf995beSStanislav Mekhanoshin   for (BasicBlock::iterator E = BB.end(); InsPt != E; ++InsPt) {
123*9cf995beSStanislav Mekhanoshin     AllocaInst *AI = dyn_cast<AllocaInst>(&*InsPt);
124*9cf995beSStanislav Mekhanoshin 
125*9cf995beSStanislav Mekhanoshin     // If this is a dynamic alloca, the value may depend on the loaded kernargs,
126*9cf995beSStanislav Mekhanoshin     // so loads will need to be inserted before it.
127*9cf995beSStanislav Mekhanoshin     if (!AI || !AI->isStaticAlloca())
128*9cf995beSStanislav Mekhanoshin       break;
129*9cf995beSStanislav Mekhanoshin   }
130*9cf995beSStanislav Mekhanoshin 
131*9cf995beSStanislav Mekhanoshin   return InsPt;
132*9cf995beSStanislav Mekhanoshin }
133*9cf995beSStanislav Mekhanoshin 
134*9cf995beSStanislav Mekhanoshin bool AMDGPUPromoteKernelArguments::run(Function &F, MemorySSA &MSSA) {
135*9cf995beSStanislav Mekhanoshin   if (skipFunction(F))
136*9cf995beSStanislav Mekhanoshin     return false;
137*9cf995beSStanislav Mekhanoshin 
138*9cf995beSStanislav Mekhanoshin   CallingConv::ID CC = F.getCallingConv();
139*9cf995beSStanislav Mekhanoshin   if (CC != CallingConv::AMDGPU_KERNEL || F.arg_empty())
140*9cf995beSStanislav Mekhanoshin     return false;
141*9cf995beSStanislav Mekhanoshin 
142*9cf995beSStanislav Mekhanoshin   ArgCastInsertPt = &*getInsertPt(*F.begin());
143*9cf995beSStanislav Mekhanoshin   this->MSSA = &MSSA;
144*9cf995beSStanislav Mekhanoshin 
145*9cf995beSStanislav Mekhanoshin   for (Argument &Arg : F.args()) {
146*9cf995beSStanislav Mekhanoshin     if (Arg.use_empty())
147*9cf995beSStanislav Mekhanoshin       continue;
148*9cf995beSStanislav Mekhanoshin 
149*9cf995beSStanislav Mekhanoshin     PointerType *PT = dyn_cast<PointerType>(Arg.getType());
150*9cf995beSStanislav Mekhanoshin     if (!PT || (PT->getAddressSpace() != AMDGPUAS::FLAT_ADDRESS &&
151*9cf995beSStanislav Mekhanoshin                 PT->getAddressSpace() != AMDGPUAS::GLOBAL_ADDRESS &&
152*9cf995beSStanislav Mekhanoshin                 PT->getAddressSpace() != AMDGPUAS::CONSTANT_ADDRESS))
153*9cf995beSStanislav Mekhanoshin       continue;
154*9cf995beSStanislav Mekhanoshin 
155*9cf995beSStanislav Mekhanoshin     Ptrs.push_back(&Arg);
156*9cf995beSStanislav Mekhanoshin   }
157*9cf995beSStanislav Mekhanoshin 
158*9cf995beSStanislav Mekhanoshin   bool Changed = false;
159*9cf995beSStanislav Mekhanoshin   while (!Ptrs.empty()) {
160*9cf995beSStanislav Mekhanoshin     Value *Ptr = Ptrs.pop_back_val();
161*9cf995beSStanislav Mekhanoshin     Changed |= promotePointer(Ptr);
162*9cf995beSStanislav Mekhanoshin   }
163*9cf995beSStanislav Mekhanoshin 
164*9cf995beSStanislav Mekhanoshin   return Changed;
165*9cf995beSStanislav Mekhanoshin }
166*9cf995beSStanislav Mekhanoshin 
167*9cf995beSStanislav Mekhanoshin bool AMDGPUPromoteKernelArguments::runOnFunction(Function &F) {
168*9cf995beSStanislav Mekhanoshin   MemorySSA &MSSA = getAnalysis<MemorySSAWrapperPass>().getMSSA();
169*9cf995beSStanislav Mekhanoshin   return run(F, MSSA);
170*9cf995beSStanislav Mekhanoshin }
171*9cf995beSStanislav Mekhanoshin 
172*9cf995beSStanislav Mekhanoshin INITIALIZE_PASS_BEGIN(AMDGPUPromoteKernelArguments, DEBUG_TYPE,
173*9cf995beSStanislav Mekhanoshin                       "AMDGPU Promote Kernel Arguments", false, false)
174*9cf995beSStanislav Mekhanoshin INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass)
175*9cf995beSStanislav Mekhanoshin INITIALIZE_PASS_END(AMDGPUPromoteKernelArguments, DEBUG_TYPE,
176*9cf995beSStanislav Mekhanoshin                     "AMDGPU Promote Kernel Arguments", false, false)
177*9cf995beSStanislav Mekhanoshin 
178*9cf995beSStanislav Mekhanoshin char AMDGPUPromoteKernelArguments::ID = 0;
179*9cf995beSStanislav Mekhanoshin 
180*9cf995beSStanislav Mekhanoshin FunctionPass *llvm::createAMDGPUPromoteKernelArgumentsPass() {
181*9cf995beSStanislav Mekhanoshin   return new AMDGPUPromoteKernelArguments();
182*9cf995beSStanislav Mekhanoshin }
183*9cf995beSStanislav Mekhanoshin 
184*9cf995beSStanislav Mekhanoshin PreservedAnalyses
185*9cf995beSStanislav Mekhanoshin AMDGPUPromoteKernelArgumentsPass::run(Function &F,
186*9cf995beSStanislav Mekhanoshin                                       FunctionAnalysisManager &AM) {
187*9cf995beSStanislav Mekhanoshin   MemorySSA &MSSA = AM.getResult<MemorySSAAnalysis>(F).getMSSA();
188*9cf995beSStanislav Mekhanoshin   if (AMDGPUPromoteKernelArguments().run(F, MSSA)) {
189*9cf995beSStanislav Mekhanoshin     PreservedAnalyses PA;
190*9cf995beSStanislav Mekhanoshin     PA.preserveSet<CFGAnalyses>();
191*9cf995beSStanislav Mekhanoshin     PA.preserve<MemorySSAAnalysis>();
192*9cf995beSStanislav Mekhanoshin     return PA;
193*9cf995beSStanislav Mekhanoshin   }
194*9cf995beSStanislav Mekhanoshin   return PreservedAnalyses::all();
195*9cf995beSStanislav Mekhanoshin }
196