1*9cf995beSStanislav Mekhanoshin //===-- AMDGPUPromoteKernelArguments.cpp ----------------------------------===// 2*9cf995beSStanislav Mekhanoshin // 3*9cf995beSStanislav Mekhanoshin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*9cf995beSStanislav Mekhanoshin // See https://llvm.org/LICENSE.txt for license information. 5*9cf995beSStanislav Mekhanoshin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*9cf995beSStanislav Mekhanoshin // 7*9cf995beSStanislav Mekhanoshin //===----------------------------------------------------------------------===// 8*9cf995beSStanislav Mekhanoshin // 9*9cf995beSStanislav Mekhanoshin /// \file This pass recursively promotes generic pointer arguments of a kernel 10*9cf995beSStanislav Mekhanoshin /// into the global address space. 11*9cf995beSStanislav Mekhanoshin /// 12*9cf995beSStanislav Mekhanoshin /// The pass walks kernel's pointer arguments, then loads from them. If a loaded 13*9cf995beSStanislav Mekhanoshin /// value is a pointer and loaded pointer is unmodified in the kernel before the 14*9cf995beSStanislav Mekhanoshin /// load, then promote loaded pointer to global. Then recursively continue. 15*9cf995beSStanislav Mekhanoshin // 16*9cf995beSStanislav Mekhanoshin //===----------------------------------------------------------------------===// 17*9cf995beSStanislav Mekhanoshin 18*9cf995beSStanislav Mekhanoshin #include "AMDGPU.h" 19*9cf995beSStanislav Mekhanoshin #include "llvm/ADT/SmallVector.h" 20*9cf995beSStanislav Mekhanoshin #include "llvm/Analysis/MemorySSA.h" 21*9cf995beSStanislav Mekhanoshin #include "llvm/IR/IRBuilder.h" 22*9cf995beSStanislav Mekhanoshin #include "llvm/InitializePasses.h" 23*9cf995beSStanislav Mekhanoshin 24*9cf995beSStanislav Mekhanoshin #define DEBUG_TYPE "amdgpu-promote-kernel-arguments" 25*9cf995beSStanislav Mekhanoshin 26*9cf995beSStanislav Mekhanoshin using namespace llvm; 27*9cf995beSStanislav Mekhanoshin 28*9cf995beSStanislav Mekhanoshin namespace { 29*9cf995beSStanislav Mekhanoshin 30*9cf995beSStanislav Mekhanoshin class AMDGPUPromoteKernelArguments : public FunctionPass { 31*9cf995beSStanislav Mekhanoshin MemorySSA *MSSA; 32*9cf995beSStanislav Mekhanoshin 33*9cf995beSStanislav Mekhanoshin Instruction *ArgCastInsertPt; 34*9cf995beSStanislav Mekhanoshin 35*9cf995beSStanislav Mekhanoshin SmallVector<Value *> Ptrs; 36*9cf995beSStanislav Mekhanoshin 37*9cf995beSStanislav Mekhanoshin void enqueueUsers(Value *Ptr); 38*9cf995beSStanislav Mekhanoshin 39*9cf995beSStanislav Mekhanoshin bool promotePointer(Value *Ptr); 40*9cf995beSStanislav Mekhanoshin 41*9cf995beSStanislav Mekhanoshin public: 42*9cf995beSStanislav Mekhanoshin static char ID; 43*9cf995beSStanislav Mekhanoshin 44*9cf995beSStanislav Mekhanoshin AMDGPUPromoteKernelArguments() : FunctionPass(ID) {} 45*9cf995beSStanislav Mekhanoshin 46*9cf995beSStanislav Mekhanoshin bool run(Function &F, MemorySSA &MSSA); 47*9cf995beSStanislav Mekhanoshin 48*9cf995beSStanislav Mekhanoshin bool runOnFunction(Function &F) override; 49*9cf995beSStanislav Mekhanoshin 50*9cf995beSStanislav Mekhanoshin void getAnalysisUsage(AnalysisUsage &AU) const override { 51*9cf995beSStanislav Mekhanoshin AU.addRequired<MemorySSAWrapperPass>(); 52*9cf995beSStanislav Mekhanoshin AU.setPreservesAll(); 53*9cf995beSStanislav Mekhanoshin } 54*9cf995beSStanislav Mekhanoshin }; 55*9cf995beSStanislav Mekhanoshin 56*9cf995beSStanislav Mekhanoshin } // end anonymous namespace 57*9cf995beSStanislav Mekhanoshin 58*9cf995beSStanislav Mekhanoshin void AMDGPUPromoteKernelArguments::enqueueUsers(Value *Ptr) { 59*9cf995beSStanislav Mekhanoshin SmallVector<User *> PtrUsers(Ptr->users()); 60*9cf995beSStanislav Mekhanoshin 61*9cf995beSStanislav Mekhanoshin while (!PtrUsers.empty()) { 62*9cf995beSStanislav Mekhanoshin Instruction *U = dyn_cast<Instruction>(PtrUsers.pop_back_val()); 63*9cf995beSStanislav Mekhanoshin if (!U) 64*9cf995beSStanislav Mekhanoshin continue; 65*9cf995beSStanislav Mekhanoshin 66*9cf995beSStanislav Mekhanoshin switch (U->getOpcode()) { 67*9cf995beSStanislav Mekhanoshin default: 68*9cf995beSStanislav Mekhanoshin break; 69*9cf995beSStanislav Mekhanoshin case Instruction::Load: { 70*9cf995beSStanislav Mekhanoshin LoadInst *LD = cast<LoadInst>(U); 71*9cf995beSStanislav Mekhanoshin PointerType *PT = dyn_cast<PointerType>(LD->getType()); 72*9cf995beSStanislav Mekhanoshin if (!PT || 73*9cf995beSStanislav Mekhanoshin (PT->getAddressSpace() != AMDGPUAS::FLAT_ADDRESS && 74*9cf995beSStanislav Mekhanoshin PT->getAddressSpace() != AMDGPUAS::GLOBAL_ADDRESS && 75*9cf995beSStanislav Mekhanoshin PT->getAddressSpace() != AMDGPUAS::CONSTANT_ADDRESS) || 76*9cf995beSStanislav Mekhanoshin LD->getPointerOperand()->stripInBoundsOffsets() != Ptr) 77*9cf995beSStanislav Mekhanoshin break; 78*9cf995beSStanislav Mekhanoshin const MemoryAccess *MA = MSSA->getWalker()->getClobberingMemoryAccess(LD); 79*9cf995beSStanislav Mekhanoshin // TODO: This load poprobably can be promoted to constant address space. 80*9cf995beSStanislav Mekhanoshin if (MSSA->isLiveOnEntryDef(MA)) 81*9cf995beSStanislav Mekhanoshin Ptrs.push_back(LD); 82*9cf995beSStanislav Mekhanoshin break; 83*9cf995beSStanislav Mekhanoshin } 84*9cf995beSStanislav Mekhanoshin case Instruction::GetElementPtr: 85*9cf995beSStanislav Mekhanoshin case Instruction::AddrSpaceCast: 86*9cf995beSStanislav Mekhanoshin case Instruction::BitCast: 87*9cf995beSStanislav Mekhanoshin if (U->getOperand(0)->stripInBoundsOffsets() == Ptr) 88*9cf995beSStanislav Mekhanoshin PtrUsers.append(U->user_begin(), U->user_end()); 89*9cf995beSStanislav Mekhanoshin break; 90*9cf995beSStanislav Mekhanoshin } 91*9cf995beSStanislav Mekhanoshin } 92*9cf995beSStanislav Mekhanoshin } 93*9cf995beSStanislav Mekhanoshin 94*9cf995beSStanislav Mekhanoshin bool AMDGPUPromoteKernelArguments::promotePointer(Value *Ptr) { 95*9cf995beSStanislav Mekhanoshin enqueueUsers(Ptr); 96*9cf995beSStanislav Mekhanoshin 97*9cf995beSStanislav Mekhanoshin PointerType *PT = cast<PointerType>(Ptr->getType()); 98*9cf995beSStanislav Mekhanoshin if (PT->getAddressSpace() != AMDGPUAS::FLAT_ADDRESS) 99*9cf995beSStanislav Mekhanoshin return false; 100*9cf995beSStanislav Mekhanoshin 101*9cf995beSStanislav Mekhanoshin bool IsArg = isa<Argument>(Ptr); 102*9cf995beSStanislav Mekhanoshin IRBuilder<> B(IsArg ? ArgCastInsertPt 103*9cf995beSStanislav Mekhanoshin : &*std::next(cast<Instruction>(Ptr)->getIterator())); 104*9cf995beSStanislav Mekhanoshin 105*9cf995beSStanislav Mekhanoshin // Cast pointer to global address space and back to flat and let 106*9cf995beSStanislav Mekhanoshin // Infer Address Spaces pass to do all necessary rewriting. 107*9cf995beSStanislav Mekhanoshin PointerType *NewPT = 108*9cf995beSStanislav Mekhanoshin PointerType::getWithSamePointeeType(PT, AMDGPUAS::GLOBAL_ADDRESS); 109*9cf995beSStanislav Mekhanoshin Value *Cast = 110*9cf995beSStanislav Mekhanoshin B.CreateAddrSpaceCast(Ptr, NewPT, Twine(Ptr->getName(), ".global")); 111*9cf995beSStanislav Mekhanoshin Value *CastBack = 112*9cf995beSStanislav Mekhanoshin B.CreateAddrSpaceCast(Cast, PT, Twine(Ptr->getName(), ".flat")); 113*9cf995beSStanislav Mekhanoshin Ptr->replaceUsesWithIf(CastBack, 114*9cf995beSStanislav Mekhanoshin [Cast](Use &U) { return U.getUser() != Cast; }); 115*9cf995beSStanislav Mekhanoshin 116*9cf995beSStanislav Mekhanoshin return true; 117*9cf995beSStanislav Mekhanoshin } 118*9cf995beSStanislav Mekhanoshin 119*9cf995beSStanislav Mekhanoshin // skip allocas 120*9cf995beSStanislav Mekhanoshin static BasicBlock::iterator getInsertPt(BasicBlock &BB) { 121*9cf995beSStanislav Mekhanoshin BasicBlock::iterator InsPt = BB.getFirstInsertionPt(); 122*9cf995beSStanislav Mekhanoshin for (BasicBlock::iterator E = BB.end(); InsPt != E; ++InsPt) { 123*9cf995beSStanislav Mekhanoshin AllocaInst *AI = dyn_cast<AllocaInst>(&*InsPt); 124*9cf995beSStanislav Mekhanoshin 125*9cf995beSStanislav Mekhanoshin // If this is a dynamic alloca, the value may depend on the loaded kernargs, 126*9cf995beSStanislav Mekhanoshin // so loads will need to be inserted before it. 127*9cf995beSStanislav Mekhanoshin if (!AI || !AI->isStaticAlloca()) 128*9cf995beSStanislav Mekhanoshin break; 129*9cf995beSStanislav Mekhanoshin } 130*9cf995beSStanislav Mekhanoshin 131*9cf995beSStanislav Mekhanoshin return InsPt; 132*9cf995beSStanislav Mekhanoshin } 133*9cf995beSStanislav Mekhanoshin 134*9cf995beSStanislav Mekhanoshin bool AMDGPUPromoteKernelArguments::run(Function &F, MemorySSA &MSSA) { 135*9cf995beSStanislav Mekhanoshin if (skipFunction(F)) 136*9cf995beSStanislav Mekhanoshin return false; 137*9cf995beSStanislav Mekhanoshin 138*9cf995beSStanislav Mekhanoshin CallingConv::ID CC = F.getCallingConv(); 139*9cf995beSStanislav Mekhanoshin if (CC != CallingConv::AMDGPU_KERNEL || F.arg_empty()) 140*9cf995beSStanislav Mekhanoshin return false; 141*9cf995beSStanislav Mekhanoshin 142*9cf995beSStanislav Mekhanoshin ArgCastInsertPt = &*getInsertPt(*F.begin()); 143*9cf995beSStanislav Mekhanoshin this->MSSA = &MSSA; 144*9cf995beSStanislav Mekhanoshin 145*9cf995beSStanislav Mekhanoshin for (Argument &Arg : F.args()) { 146*9cf995beSStanislav Mekhanoshin if (Arg.use_empty()) 147*9cf995beSStanislav Mekhanoshin continue; 148*9cf995beSStanislav Mekhanoshin 149*9cf995beSStanislav Mekhanoshin PointerType *PT = dyn_cast<PointerType>(Arg.getType()); 150*9cf995beSStanislav Mekhanoshin if (!PT || (PT->getAddressSpace() != AMDGPUAS::FLAT_ADDRESS && 151*9cf995beSStanislav Mekhanoshin PT->getAddressSpace() != AMDGPUAS::GLOBAL_ADDRESS && 152*9cf995beSStanislav Mekhanoshin PT->getAddressSpace() != AMDGPUAS::CONSTANT_ADDRESS)) 153*9cf995beSStanislav Mekhanoshin continue; 154*9cf995beSStanislav Mekhanoshin 155*9cf995beSStanislav Mekhanoshin Ptrs.push_back(&Arg); 156*9cf995beSStanislav Mekhanoshin } 157*9cf995beSStanislav Mekhanoshin 158*9cf995beSStanislav Mekhanoshin bool Changed = false; 159*9cf995beSStanislav Mekhanoshin while (!Ptrs.empty()) { 160*9cf995beSStanislav Mekhanoshin Value *Ptr = Ptrs.pop_back_val(); 161*9cf995beSStanislav Mekhanoshin Changed |= promotePointer(Ptr); 162*9cf995beSStanislav Mekhanoshin } 163*9cf995beSStanislav Mekhanoshin 164*9cf995beSStanislav Mekhanoshin return Changed; 165*9cf995beSStanislav Mekhanoshin } 166*9cf995beSStanislav Mekhanoshin 167*9cf995beSStanislav Mekhanoshin bool AMDGPUPromoteKernelArguments::runOnFunction(Function &F) { 168*9cf995beSStanislav Mekhanoshin MemorySSA &MSSA = getAnalysis<MemorySSAWrapperPass>().getMSSA(); 169*9cf995beSStanislav Mekhanoshin return run(F, MSSA); 170*9cf995beSStanislav Mekhanoshin } 171*9cf995beSStanislav Mekhanoshin 172*9cf995beSStanislav Mekhanoshin INITIALIZE_PASS_BEGIN(AMDGPUPromoteKernelArguments, DEBUG_TYPE, 173*9cf995beSStanislav Mekhanoshin "AMDGPU Promote Kernel Arguments", false, false) 174*9cf995beSStanislav Mekhanoshin INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) 175*9cf995beSStanislav Mekhanoshin INITIALIZE_PASS_END(AMDGPUPromoteKernelArguments, DEBUG_TYPE, 176*9cf995beSStanislav Mekhanoshin "AMDGPU Promote Kernel Arguments", false, false) 177*9cf995beSStanislav Mekhanoshin 178*9cf995beSStanislav Mekhanoshin char AMDGPUPromoteKernelArguments::ID = 0; 179*9cf995beSStanislav Mekhanoshin 180*9cf995beSStanislav Mekhanoshin FunctionPass *llvm::createAMDGPUPromoteKernelArgumentsPass() { 181*9cf995beSStanislav Mekhanoshin return new AMDGPUPromoteKernelArguments(); 182*9cf995beSStanislav Mekhanoshin } 183*9cf995beSStanislav Mekhanoshin 184*9cf995beSStanislav Mekhanoshin PreservedAnalyses 185*9cf995beSStanislav Mekhanoshin AMDGPUPromoteKernelArgumentsPass::run(Function &F, 186*9cf995beSStanislav Mekhanoshin FunctionAnalysisManager &AM) { 187*9cf995beSStanislav Mekhanoshin MemorySSA &MSSA = AM.getResult<MemorySSAAnalysis>(F).getMSSA(); 188*9cf995beSStanislav Mekhanoshin if (AMDGPUPromoteKernelArguments().run(F, MSSA)) { 189*9cf995beSStanislav Mekhanoshin PreservedAnalyses PA; 190*9cf995beSStanislav Mekhanoshin PA.preserveSet<CFGAnalyses>(); 191*9cf995beSStanislav Mekhanoshin PA.preserve<MemorySSAAnalysis>(); 192*9cf995beSStanislav Mekhanoshin return PA; 193*9cf995beSStanislav Mekhanoshin } 194*9cf995beSStanislav Mekhanoshin return PreservedAnalyses::all(); 195*9cf995beSStanislav Mekhanoshin } 196