1*372d796aSMatt Arsenault //===-- AMDGPULowerKernelAttributes.cpp ------------------------------------------===// 2*372d796aSMatt Arsenault // 3*372d796aSMatt Arsenault // The LLVM Compiler Infrastructure 4*372d796aSMatt Arsenault // 5*372d796aSMatt Arsenault // This file is distributed under the University of Illinois Open Source 6*372d796aSMatt Arsenault // License. See LICENSE.TXT for details. 7*372d796aSMatt Arsenault // 8*372d796aSMatt Arsenault //===----------------------------------------------------------------------===// 9*372d796aSMatt Arsenault // 10*372d796aSMatt Arsenault /// \file This pass does attempts to make use of reqd_work_group_size metadata 11*372d796aSMatt Arsenault /// to eliminate loads from the dispatch packet and to constant fold OpenCL 12*372d796aSMatt Arsenault /// get_local_size-like functions. 13*372d796aSMatt Arsenault // 14*372d796aSMatt Arsenault //===----------------------------------------------------------------------===// 15*372d796aSMatt Arsenault 16*372d796aSMatt Arsenault #include "AMDGPU.h" 17*372d796aSMatt Arsenault #include "AMDGPUTargetMachine.h" 18*372d796aSMatt Arsenault #include "llvm/Analysis/ValueTracking.h" 19*372d796aSMatt Arsenault #include "llvm/CodeGen/Passes.h" 20*372d796aSMatt Arsenault #include "llvm/CodeGen/TargetPassConfig.h" 21*372d796aSMatt Arsenault #include "llvm/IR/Constants.h" 22*372d796aSMatt Arsenault #include "llvm/IR/Function.h" 23*372d796aSMatt Arsenault #include "llvm/IR/Instructions.h" 24*372d796aSMatt Arsenault #include "llvm/IR/PatternMatch.h" 25*372d796aSMatt Arsenault #include "llvm/Pass.h" 26*372d796aSMatt Arsenault 27*372d796aSMatt Arsenault #define DEBUG_TYPE "amdgpu-lower-kernel-attributes" 28*372d796aSMatt Arsenault 29*372d796aSMatt Arsenault using namespace llvm; 30*372d796aSMatt Arsenault 31*372d796aSMatt Arsenault namespace { 32*372d796aSMatt Arsenault 33*372d796aSMatt Arsenault // Field offsets in hsa_kernel_dispatch_packet_t. 34*372d796aSMatt Arsenault enum DispatchPackedOffsets { 35*372d796aSMatt Arsenault WORKGROUP_SIZE_X = 4, 36*372d796aSMatt Arsenault WORKGROUP_SIZE_Y = 6, 37*372d796aSMatt Arsenault WORKGROUP_SIZE_Z = 8, 38*372d796aSMatt Arsenault 39*372d796aSMatt Arsenault GRID_SIZE_X = 12, 40*372d796aSMatt Arsenault GRID_SIZE_Y = 16, 41*372d796aSMatt Arsenault GRID_SIZE_Z = 20 42*372d796aSMatt Arsenault }; 43*372d796aSMatt Arsenault 44*372d796aSMatt Arsenault class AMDGPULowerKernelAttributes : public ModulePass { 45*372d796aSMatt Arsenault Module *Mod = nullptr; 46*372d796aSMatt Arsenault 47*372d796aSMatt Arsenault public: 48*372d796aSMatt Arsenault static char ID; 49*372d796aSMatt Arsenault 50*372d796aSMatt Arsenault AMDGPULowerKernelAttributes() : ModulePass(ID) {} 51*372d796aSMatt Arsenault 52*372d796aSMatt Arsenault bool processUse(CallInst *CI); 53*372d796aSMatt Arsenault 54*372d796aSMatt Arsenault bool doInitialization(Module &M) override; 55*372d796aSMatt Arsenault bool runOnModule(Module &M) override; 56*372d796aSMatt Arsenault 57*372d796aSMatt Arsenault StringRef getPassName() const override { 58*372d796aSMatt Arsenault return "AMDGPU Kernel Attributes"; 59*372d796aSMatt Arsenault } 60*372d796aSMatt Arsenault 61*372d796aSMatt Arsenault void getAnalysisUsage(AnalysisUsage &AU) const override { 62*372d796aSMatt Arsenault AU.setPreservesAll(); 63*372d796aSMatt Arsenault } 64*372d796aSMatt Arsenault }; 65*372d796aSMatt Arsenault 66*372d796aSMatt Arsenault } // end anonymous namespace 67*372d796aSMatt Arsenault 68*372d796aSMatt Arsenault bool AMDGPULowerKernelAttributes::doInitialization(Module &M) { 69*372d796aSMatt Arsenault Mod = &M; 70*372d796aSMatt Arsenault return false; 71*372d796aSMatt Arsenault } 72*372d796aSMatt Arsenault 73*372d796aSMatt Arsenault bool AMDGPULowerKernelAttributes::processUse(CallInst *CI) { 74*372d796aSMatt Arsenault Function *F = CI->getParent()->getParent(); 75*372d796aSMatt Arsenault 76*372d796aSMatt Arsenault auto MD = F->getMetadata("reqd_work_group_size"); 77*372d796aSMatt Arsenault const bool HasReqdWorkGroupSize = MD && MD->getNumOperands() == 3; 78*372d796aSMatt Arsenault 79*372d796aSMatt Arsenault const bool HasUniformWorkGroupSize = 80*372d796aSMatt Arsenault F->getFnAttribute("uniform-work-group-size").getValueAsString() == "true"; 81*372d796aSMatt Arsenault 82*372d796aSMatt Arsenault if (!HasReqdWorkGroupSize && !HasUniformWorkGroupSize) 83*372d796aSMatt Arsenault return false; 84*372d796aSMatt Arsenault 85*372d796aSMatt Arsenault Value *WorkGroupSizeX = nullptr; 86*372d796aSMatt Arsenault Value *WorkGroupSizeY = nullptr; 87*372d796aSMatt Arsenault Value *WorkGroupSizeZ = nullptr; 88*372d796aSMatt Arsenault 89*372d796aSMatt Arsenault Value *GridSizeX = nullptr; 90*372d796aSMatt Arsenault Value *GridSizeY = nullptr; 91*372d796aSMatt Arsenault Value *GridSizeZ = nullptr; 92*372d796aSMatt Arsenault 93*372d796aSMatt Arsenault const DataLayout &DL = Mod->getDataLayout(); 94*372d796aSMatt Arsenault 95*372d796aSMatt Arsenault // We expect to see several GEP users, casted to the appropriate type and 96*372d796aSMatt Arsenault // loaded. 97*372d796aSMatt Arsenault for (User *U : CI->users()) { 98*372d796aSMatt Arsenault if (!U->hasOneUse()) 99*372d796aSMatt Arsenault continue; 100*372d796aSMatt Arsenault 101*372d796aSMatt Arsenault int64_t Offset = 0; 102*372d796aSMatt Arsenault if (GetPointerBaseWithConstantOffset(U, Offset, DL) != CI) 103*372d796aSMatt Arsenault continue; 104*372d796aSMatt Arsenault 105*372d796aSMatt Arsenault auto *BCI = dyn_cast<BitCastInst>(*U->user_begin()); 106*372d796aSMatt Arsenault if (!BCI || !BCI->hasOneUse()) 107*372d796aSMatt Arsenault continue; 108*372d796aSMatt Arsenault 109*372d796aSMatt Arsenault auto *Load = dyn_cast<LoadInst>(*BCI->user_begin()); 110*372d796aSMatt Arsenault if (!Load || !Load->isSimple()) 111*372d796aSMatt Arsenault continue; 112*372d796aSMatt Arsenault 113*372d796aSMatt Arsenault unsigned LoadSize = DL.getTypeStoreSize(Load->getType()); 114*372d796aSMatt Arsenault 115*372d796aSMatt Arsenault // TODO: Handle merged loads. 116*372d796aSMatt Arsenault switch (Offset) { 117*372d796aSMatt Arsenault case WORKGROUP_SIZE_X: 118*372d796aSMatt Arsenault if (LoadSize == 2) 119*372d796aSMatt Arsenault WorkGroupSizeX = Load; 120*372d796aSMatt Arsenault break; 121*372d796aSMatt Arsenault case WORKGROUP_SIZE_Y: 122*372d796aSMatt Arsenault if (LoadSize == 2) 123*372d796aSMatt Arsenault WorkGroupSizeY = Load; 124*372d796aSMatt Arsenault break; 125*372d796aSMatt Arsenault case WORKGROUP_SIZE_Z: 126*372d796aSMatt Arsenault if (LoadSize == 2) 127*372d796aSMatt Arsenault WorkGroupSizeZ = Load; 128*372d796aSMatt Arsenault break; 129*372d796aSMatt Arsenault case GRID_SIZE_X: 130*372d796aSMatt Arsenault if (LoadSize == 4) 131*372d796aSMatt Arsenault GridSizeX = Load; 132*372d796aSMatt Arsenault break; 133*372d796aSMatt Arsenault case GRID_SIZE_Y: 134*372d796aSMatt Arsenault if (LoadSize == 4) 135*372d796aSMatt Arsenault GridSizeY = Load; 136*372d796aSMatt Arsenault break; 137*372d796aSMatt Arsenault case GRID_SIZE_Z: 138*372d796aSMatt Arsenault if (LoadSize == 4) 139*372d796aSMatt Arsenault GridSizeZ = Load; 140*372d796aSMatt Arsenault break; 141*372d796aSMatt Arsenault default: 142*372d796aSMatt Arsenault break; 143*372d796aSMatt Arsenault } 144*372d796aSMatt Arsenault } 145*372d796aSMatt Arsenault 146*372d796aSMatt Arsenault // Pattern match the code used to handle partial workgroup dispatches in the 147*372d796aSMatt Arsenault // library implementation of get_local_size, so the entire function can be 148*372d796aSMatt Arsenault // constant folded with a known group size. 149*372d796aSMatt Arsenault // 150*372d796aSMatt Arsenault // uint r = grid_size - group_id * group_size; 151*372d796aSMatt Arsenault // get_local_size = (r < group_size) ? r : group_size; 152*372d796aSMatt Arsenault // 153*372d796aSMatt Arsenault // If we have uniform-work-group-size (which is the default in OpenCL 1.2), 154*372d796aSMatt Arsenault // the grid_size is required to be a multiple of group_size). In this case: 155*372d796aSMatt Arsenault // 156*372d796aSMatt Arsenault // grid_size - (group_id * group_size) < group_size 157*372d796aSMatt Arsenault // -> 158*372d796aSMatt Arsenault // grid_size < group_size + (group_id * group_size) 159*372d796aSMatt Arsenault // 160*372d796aSMatt Arsenault // (grid_size / group_size) < 1 + group_id 161*372d796aSMatt Arsenault // 162*372d796aSMatt Arsenault // grid_size / group_size is at least 1, so we can conclude the select 163*372d796aSMatt Arsenault // condition is false (except for group_id == 0, where the select result is 164*372d796aSMatt Arsenault // the same). 165*372d796aSMatt Arsenault 166*372d796aSMatt Arsenault bool MadeChange = false; 167*372d796aSMatt Arsenault Value *WorkGroupSizes[3] = { WorkGroupSizeX, WorkGroupSizeY, WorkGroupSizeZ }; 168*372d796aSMatt Arsenault Value *GridSizes[3] = { GridSizeX, GridSizeY, GridSizeZ }; 169*372d796aSMatt Arsenault 170*372d796aSMatt Arsenault for (int I = 0; HasUniformWorkGroupSize && I < 3; ++I) { 171*372d796aSMatt Arsenault Value *GroupSize = WorkGroupSizes[I]; 172*372d796aSMatt Arsenault Value *GridSize = GridSizes[I]; 173*372d796aSMatt Arsenault if (!GroupSize || !GridSize) 174*372d796aSMatt Arsenault continue; 175*372d796aSMatt Arsenault 176*372d796aSMatt Arsenault for (User *U : GroupSize->users()) { 177*372d796aSMatt Arsenault auto *ZextGroupSize = dyn_cast<ZExtInst>(U); 178*372d796aSMatt Arsenault if (!ZextGroupSize) 179*372d796aSMatt Arsenault continue; 180*372d796aSMatt Arsenault 181*372d796aSMatt Arsenault for (User *ZextUser : ZextGroupSize->users()) { 182*372d796aSMatt Arsenault auto *SI = dyn_cast<SelectInst>(ZextUser); 183*372d796aSMatt Arsenault if (!SI) 184*372d796aSMatt Arsenault continue; 185*372d796aSMatt Arsenault 186*372d796aSMatt Arsenault using namespace llvm::PatternMatch; 187*372d796aSMatt Arsenault auto GroupIDIntrin = I == 0 ? 188*372d796aSMatt Arsenault m_Intrinsic<Intrinsic::amdgcn_workgroup_id_x>() : 189*372d796aSMatt Arsenault (I == 1 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_y>() : 190*372d796aSMatt Arsenault m_Intrinsic<Intrinsic::amdgcn_workgroup_id_z>()); 191*372d796aSMatt Arsenault 192*372d796aSMatt Arsenault auto SubExpr = m_Sub(m_Specific(GridSize), 193*372d796aSMatt Arsenault m_Mul(GroupIDIntrin, m_Specific(ZextGroupSize))); 194*372d796aSMatt Arsenault 195*372d796aSMatt Arsenault ICmpInst::Predicate Pred; 196*372d796aSMatt Arsenault if (match(SI, 197*372d796aSMatt Arsenault m_Select(m_ICmp(Pred, SubExpr, m_Specific(ZextGroupSize)), 198*372d796aSMatt Arsenault SubExpr, 199*372d796aSMatt Arsenault m_Specific(ZextGroupSize))) && 200*372d796aSMatt Arsenault Pred == ICmpInst::ICMP_ULT) { 201*372d796aSMatt Arsenault if (HasReqdWorkGroupSize) { 202*372d796aSMatt Arsenault ConstantInt *KnownSize 203*372d796aSMatt Arsenault = mdconst::extract<ConstantInt>(MD->getOperand(I)); 204*372d796aSMatt Arsenault SI->replaceAllUsesWith(ConstantExpr::getIntegerCast(KnownSize, 205*372d796aSMatt Arsenault SI->getType(), 206*372d796aSMatt Arsenault false)); 207*372d796aSMatt Arsenault } else { 208*372d796aSMatt Arsenault SI->replaceAllUsesWith(ZextGroupSize); 209*372d796aSMatt Arsenault } 210*372d796aSMatt Arsenault 211*372d796aSMatt Arsenault MadeChange = true; 212*372d796aSMatt Arsenault } 213*372d796aSMatt Arsenault } 214*372d796aSMatt Arsenault } 215*372d796aSMatt Arsenault } 216*372d796aSMatt Arsenault 217*372d796aSMatt Arsenault if (!HasReqdWorkGroupSize) 218*372d796aSMatt Arsenault return MadeChange; 219*372d796aSMatt Arsenault 220*372d796aSMatt Arsenault // Eliminate any other loads we can from the dispatch packet. 221*372d796aSMatt Arsenault for (int I = 0; I < 3; ++I) { 222*372d796aSMatt Arsenault Value *GroupSize = WorkGroupSizes[I]; 223*372d796aSMatt Arsenault if (!GroupSize) 224*372d796aSMatt Arsenault continue; 225*372d796aSMatt Arsenault 226*372d796aSMatt Arsenault ConstantInt *KnownSize = mdconst::extract<ConstantInt>(MD->getOperand(I)); 227*372d796aSMatt Arsenault GroupSize->replaceAllUsesWith( 228*372d796aSMatt Arsenault ConstantExpr::getIntegerCast(KnownSize, 229*372d796aSMatt Arsenault GroupSize->getType(), 230*372d796aSMatt Arsenault false)); 231*372d796aSMatt Arsenault MadeChange = true; 232*372d796aSMatt Arsenault } 233*372d796aSMatt Arsenault 234*372d796aSMatt Arsenault return MadeChange; 235*372d796aSMatt Arsenault } 236*372d796aSMatt Arsenault 237*372d796aSMatt Arsenault // TODO: Move makeLIDRangeMetadata usage into here. Seem to not get 238*372d796aSMatt Arsenault // TargetPassConfig for subtarget. 239*372d796aSMatt Arsenault bool AMDGPULowerKernelAttributes::runOnModule(Module &M) { 240*372d796aSMatt Arsenault StringRef DispatchPtrName 241*372d796aSMatt Arsenault = Intrinsic::getName(Intrinsic::amdgcn_dispatch_ptr); 242*372d796aSMatt Arsenault 243*372d796aSMatt Arsenault Function *DispatchPtr = Mod->getFunction(DispatchPtrName); 244*372d796aSMatt Arsenault if (!DispatchPtr) // Dispatch ptr not used. 245*372d796aSMatt Arsenault return false; 246*372d796aSMatt Arsenault 247*372d796aSMatt Arsenault bool MadeChange = false; 248*372d796aSMatt Arsenault 249*372d796aSMatt Arsenault SmallPtrSet<Instruction *, 4> HandledUses; 250*372d796aSMatt Arsenault for (auto *U : DispatchPtr->users()) { 251*372d796aSMatt Arsenault CallInst *CI = cast<CallInst>(U); 252*372d796aSMatt Arsenault if (HandledUses.insert(CI).second) { 253*372d796aSMatt Arsenault if (processUse(CI)) 254*372d796aSMatt Arsenault MadeChange = true; 255*372d796aSMatt Arsenault } 256*372d796aSMatt Arsenault } 257*372d796aSMatt Arsenault 258*372d796aSMatt Arsenault return MadeChange; 259*372d796aSMatt Arsenault } 260*372d796aSMatt Arsenault 261*372d796aSMatt Arsenault INITIALIZE_PASS_BEGIN(AMDGPULowerKernelAttributes, DEBUG_TYPE, 262*372d796aSMatt Arsenault "AMDGPU IR optimizations", false, false) 263*372d796aSMatt Arsenault INITIALIZE_PASS_END(AMDGPULowerKernelAttributes, DEBUG_TYPE, "AMDGPU IR optimizations", 264*372d796aSMatt Arsenault false, false) 265*372d796aSMatt Arsenault 266*372d796aSMatt Arsenault char AMDGPULowerKernelAttributes::ID = 0; 267*372d796aSMatt Arsenault 268*372d796aSMatt Arsenault ModulePass *llvm::createAMDGPULowerKernelAttributesPass() { 269*372d796aSMatt Arsenault return new AMDGPULowerKernelAttributes(); 270*372d796aSMatt Arsenault } 271