1*4ba319b5SDimitry Andric //===-- AMDGPULowerKernelAttributes.cpp ------------------------------------------===//
2*4ba319b5SDimitry Andric //
3*4ba319b5SDimitry Andric //                     The LLVM Compiler Infrastructure
4*4ba319b5SDimitry Andric //
5*4ba319b5SDimitry Andric // This file is distributed under the University of Illinois Open Source
6*4ba319b5SDimitry Andric // License. See LICENSE.TXT for details.
7*4ba319b5SDimitry Andric //
8*4ba319b5SDimitry Andric //===----------------------------------------------------------------------===//
9*4ba319b5SDimitry Andric //
10*4ba319b5SDimitry Andric /// \file This pass does attempts to make use of reqd_work_group_size metadata
11*4ba319b5SDimitry Andric /// to eliminate loads from the dispatch packet and to constant fold OpenCL
12*4ba319b5SDimitry Andric /// get_local_size-like functions.
13*4ba319b5SDimitry Andric //
14*4ba319b5SDimitry Andric //===----------------------------------------------------------------------===//
15*4ba319b5SDimitry Andric 
16*4ba319b5SDimitry Andric #include "AMDGPU.h"
17*4ba319b5SDimitry Andric #include "AMDGPUTargetMachine.h"
18*4ba319b5SDimitry Andric #include "llvm/Analysis/ValueTracking.h"
19*4ba319b5SDimitry Andric #include "llvm/CodeGen/Passes.h"
20*4ba319b5SDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h"
21*4ba319b5SDimitry Andric #include "llvm/IR/Constants.h"
22*4ba319b5SDimitry Andric #include "llvm/IR/Function.h"
23*4ba319b5SDimitry Andric #include "llvm/IR/Instructions.h"
24*4ba319b5SDimitry Andric #include "llvm/IR/PatternMatch.h"
25*4ba319b5SDimitry Andric #include "llvm/Pass.h"
26*4ba319b5SDimitry Andric 
27*4ba319b5SDimitry Andric #define DEBUG_TYPE "amdgpu-lower-kernel-attributes"
28*4ba319b5SDimitry Andric 
29*4ba319b5SDimitry Andric using namespace llvm;
30*4ba319b5SDimitry Andric 
31*4ba319b5SDimitry Andric namespace {
32*4ba319b5SDimitry Andric 
33*4ba319b5SDimitry Andric // Field offsets in hsa_kernel_dispatch_packet_t.
34*4ba319b5SDimitry Andric enum DispatchPackedOffsets {
35*4ba319b5SDimitry Andric   WORKGROUP_SIZE_X = 4,
36*4ba319b5SDimitry Andric   WORKGROUP_SIZE_Y = 6,
37*4ba319b5SDimitry Andric   WORKGROUP_SIZE_Z = 8,
38*4ba319b5SDimitry Andric 
39*4ba319b5SDimitry Andric   GRID_SIZE_X = 12,
40*4ba319b5SDimitry Andric   GRID_SIZE_Y = 16,
41*4ba319b5SDimitry Andric   GRID_SIZE_Z = 20
42*4ba319b5SDimitry Andric };
43*4ba319b5SDimitry Andric 
44*4ba319b5SDimitry Andric class AMDGPULowerKernelAttributes : public ModulePass {
45*4ba319b5SDimitry Andric   Module *Mod = nullptr;
46*4ba319b5SDimitry Andric 
47*4ba319b5SDimitry Andric public:
48*4ba319b5SDimitry Andric   static char ID;
49*4ba319b5SDimitry Andric 
AMDGPULowerKernelAttributes()50*4ba319b5SDimitry Andric   AMDGPULowerKernelAttributes() : ModulePass(ID) {}
51*4ba319b5SDimitry Andric 
52*4ba319b5SDimitry Andric   bool processUse(CallInst *CI);
53*4ba319b5SDimitry Andric 
54*4ba319b5SDimitry Andric   bool doInitialization(Module &M) override;
55*4ba319b5SDimitry Andric   bool runOnModule(Module &M) override;
56*4ba319b5SDimitry Andric 
getPassName() const57*4ba319b5SDimitry Andric   StringRef getPassName() const override {
58*4ba319b5SDimitry Andric     return "AMDGPU Kernel Attributes";
59*4ba319b5SDimitry Andric   }
60*4ba319b5SDimitry Andric 
getAnalysisUsage(AnalysisUsage & AU) const61*4ba319b5SDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override {
62*4ba319b5SDimitry Andric     AU.setPreservesAll();
63*4ba319b5SDimitry Andric  }
64*4ba319b5SDimitry Andric };
65*4ba319b5SDimitry Andric 
66*4ba319b5SDimitry Andric } // end anonymous namespace
67*4ba319b5SDimitry Andric 
doInitialization(Module & M)68*4ba319b5SDimitry Andric bool AMDGPULowerKernelAttributes::doInitialization(Module &M) {
69*4ba319b5SDimitry Andric   Mod = &M;
70*4ba319b5SDimitry Andric   return false;
71*4ba319b5SDimitry Andric }
72*4ba319b5SDimitry Andric 
processUse(CallInst * CI)73*4ba319b5SDimitry Andric bool AMDGPULowerKernelAttributes::processUse(CallInst *CI) {
74*4ba319b5SDimitry Andric   Function *F = CI->getParent()->getParent();
75*4ba319b5SDimitry Andric 
76*4ba319b5SDimitry Andric   auto MD = F->getMetadata("reqd_work_group_size");
77*4ba319b5SDimitry Andric   const bool HasReqdWorkGroupSize = MD && MD->getNumOperands() == 3;
78*4ba319b5SDimitry Andric 
79*4ba319b5SDimitry Andric   const bool HasUniformWorkGroupSize =
80*4ba319b5SDimitry Andric     F->getFnAttribute("uniform-work-group-size").getValueAsString() == "true";
81*4ba319b5SDimitry Andric 
82*4ba319b5SDimitry Andric   if (!HasReqdWorkGroupSize && !HasUniformWorkGroupSize)
83*4ba319b5SDimitry Andric     return false;
84*4ba319b5SDimitry Andric 
85*4ba319b5SDimitry Andric   Value *WorkGroupSizeX = nullptr;
86*4ba319b5SDimitry Andric   Value *WorkGroupSizeY = nullptr;
87*4ba319b5SDimitry Andric   Value *WorkGroupSizeZ = nullptr;
88*4ba319b5SDimitry Andric 
89*4ba319b5SDimitry Andric   Value *GridSizeX = nullptr;
90*4ba319b5SDimitry Andric   Value *GridSizeY = nullptr;
91*4ba319b5SDimitry Andric   Value *GridSizeZ = nullptr;
92*4ba319b5SDimitry Andric 
93*4ba319b5SDimitry Andric   const DataLayout &DL = Mod->getDataLayout();
94*4ba319b5SDimitry Andric 
95*4ba319b5SDimitry Andric   // We expect to see several GEP users, casted to the appropriate type and
96*4ba319b5SDimitry Andric   // loaded.
97*4ba319b5SDimitry Andric   for (User *U : CI->users()) {
98*4ba319b5SDimitry Andric     if (!U->hasOneUse())
99*4ba319b5SDimitry Andric       continue;
100*4ba319b5SDimitry Andric 
101*4ba319b5SDimitry Andric     int64_t Offset = 0;
102*4ba319b5SDimitry Andric     if (GetPointerBaseWithConstantOffset(U, Offset, DL) != CI)
103*4ba319b5SDimitry Andric       continue;
104*4ba319b5SDimitry Andric 
105*4ba319b5SDimitry Andric     auto *BCI = dyn_cast<BitCastInst>(*U->user_begin());
106*4ba319b5SDimitry Andric     if (!BCI || !BCI->hasOneUse())
107*4ba319b5SDimitry Andric       continue;
108*4ba319b5SDimitry Andric 
109*4ba319b5SDimitry Andric     auto *Load = dyn_cast<LoadInst>(*BCI->user_begin());
110*4ba319b5SDimitry Andric     if (!Load || !Load->isSimple())
111*4ba319b5SDimitry Andric       continue;
112*4ba319b5SDimitry Andric 
113*4ba319b5SDimitry Andric     unsigned LoadSize = DL.getTypeStoreSize(Load->getType());
114*4ba319b5SDimitry Andric 
115*4ba319b5SDimitry Andric     // TODO: Handle merged loads.
116*4ba319b5SDimitry Andric     switch (Offset) {
117*4ba319b5SDimitry Andric     case WORKGROUP_SIZE_X:
118*4ba319b5SDimitry Andric       if (LoadSize == 2)
119*4ba319b5SDimitry Andric         WorkGroupSizeX = Load;
120*4ba319b5SDimitry Andric       break;
121*4ba319b5SDimitry Andric     case WORKGROUP_SIZE_Y:
122*4ba319b5SDimitry Andric       if (LoadSize == 2)
123*4ba319b5SDimitry Andric         WorkGroupSizeY = Load;
124*4ba319b5SDimitry Andric       break;
125*4ba319b5SDimitry Andric     case WORKGROUP_SIZE_Z:
126*4ba319b5SDimitry Andric       if (LoadSize == 2)
127*4ba319b5SDimitry Andric         WorkGroupSizeZ = Load;
128*4ba319b5SDimitry Andric       break;
129*4ba319b5SDimitry Andric     case GRID_SIZE_X:
130*4ba319b5SDimitry Andric       if (LoadSize == 4)
131*4ba319b5SDimitry Andric         GridSizeX = Load;
132*4ba319b5SDimitry Andric       break;
133*4ba319b5SDimitry Andric     case GRID_SIZE_Y:
134*4ba319b5SDimitry Andric       if (LoadSize == 4)
135*4ba319b5SDimitry Andric         GridSizeY = Load;
136*4ba319b5SDimitry Andric       break;
137*4ba319b5SDimitry Andric     case GRID_SIZE_Z:
138*4ba319b5SDimitry Andric       if (LoadSize == 4)
139*4ba319b5SDimitry Andric         GridSizeZ = Load;
140*4ba319b5SDimitry Andric       break;
141*4ba319b5SDimitry Andric     default:
142*4ba319b5SDimitry Andric       break;
143*4ba319b5SDimitry Andric     }
144*4ba319b5SDimitry Andric   }
145*4ba319b5SDimitry Andric 
146*4ba319b5SDimitry Andric   // Pattern match the code used to handle partial workgroup dispatches in the
147*4ba319b5SDimitry Andric   // library implementation of get_local_size, so the entire function can be
148*4ba319b5SDimitry Andric   // constant folded with a known group size.
149*4ba319b5SDimitry Andric   //
150*4ba319b5SDimitry Andric   // uint r = grid_size - group_id * group_size;
151*4ba319b5SDimitry Andric   // get_local_size = (r < group_size) ? r : group_size;
152*4ba319b5SDimitry Andric   //
153*4ba319b5SDimitry Andric   // If we have uniform-work-group-size (which is the default in OpenCL 1.2),
154*4ba319b5SDimitry Andric   // the grid_size is required to be a multiple of group_size). In this case:
155*4ba319b5SDimitry Andric   //
156*4ba319b5SDimitry Andric   // grid_size - (group_id * group_size) < group_size
157*4ba319b5SDimitry Andric   // ->
158*4ba319b5SDimitry Andric   // grid_size < group_size + (group_id * group_size)
159*4ba319b5SDimitry Andric   //
160*4ba319b5SDimitry Andric   // (grid_size / group_size) < 1 + group_id
161*4ba319b5SDimitry Andric   //
162*4ba319b5SDimitry Andric   // grid_size / group_size is at least 1, so we can conclude the select
163*4ba319b5SDimitry Andric   // condition is false (except for group_id == 0, where the select result is
164*4ba319b5SDimitry Andric   // the same).
165*4ba319b5SDimitry Andric 
166*4ba319b5SDimitry Andric   bool MadeChange = false;
167*4ba319b5SDimitry Andric   Value *WorkGroupSizes[3] = { WorkGroupSizeX, WorkGroupSizeY, WorkGroupSizeZ };
168*4ba319b5SDimitry Andric   Value *GridSizes[3] = { GridSizeX, GridSizeY, GridSizeZ };
169*4ba319b5SDimitry Andric 
170*4ba319b5SDimitry Andric   for (int I = 0; HasUniformWorkGroupSize && I < 3; ++I) {
171*4ba319b5SDimitry Andric     Value *GroupSize = WorkGroupSizes[I];
172*4ba319b5SDimitry Andric     Value *GridSize = GridSizes[I];
173*4ba319b5SDimitry Andric     if (!GroupSize || !GridSize)
174*4ba319b5SDimitry Andric       continue;
175*4ba319b5SDimitry Andric 
176*4ba319b5SDimitry Andric     for (User *U : GroupSize->users()) {
177*4ba319b5SDimitry Andric       auto *ZextGroupSize = dyn_cast<ZExtInst>(U);
178*4ba319b5SDimitry Andric       if (!ZextGroupSize)
179*4ba319b5SDimitry Andric         continue;
180*4ba319b5SDimitry Andric 
181*4ba319b5SDimitry Andric       for (User *ZextUser : ZextGroupSize->users()) {
182*4ba319b5SDimitry Andric         auto *SI = dyn_cast<SelectInst>(ZextUser);
183*4ba319b5SDimitry Andric         if (!SI)
184*4ba319b5SDimitry Andric           continue;
185*4ba319b5SDimitry Andric 
186*4ba319b5SDimitry Andric         using namespace llvm::PatternMatch;
187*4ba319b5SDimitry Andric         auto GroupIDIntrin = I == 0 ?
188*4ba319b5SDimitry Andric           m_Intrinsic<Intrinsic::amdgcn_workgroup_id_x>() :
189*4ba319b5SDimitry Andric             (I == 1 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_y>() :
190*4ba319b5SDimitry Andric                       m_Intrinsic<Intrinsic::amdgcn_workgroup_id_z>());
191*4ba319b5SDimitry Andric 
192*4ba319b5SDimitry Andric         auto SubExpr = m_Sub(m_Specific(GridSize),
193*4ba319b5SDimitry Andric                              m_Mul(GroupIDIntrin, m_Specific(ZextGroupSize)));
194*4ba319b5SDimitry Andric 
195*4ba319b5SDimitry Andric         ICmpInst::Predicate Pred;
196*4ba319b5SDimitry Andric         if (match(SI,
197*4ba319b5SDimitry Andric                   m_Select(m_ICmp(Pred, SubExpr, m_Specific(ZextGroupSize)),
198*4ba319b5SDimitry Andric                            SubExpr,
199*4ba319b5SDimitry Andric                            m_Specific(ZextGroupSize))) &&
200*4ba319b5SDimitry Andric             Pred == ICmpInst::ICMP_ULT) {
201*4ba319b5SDimitry Andric           if (HasReqdWorkGroupSize) {
202*4ba319b5SDimitry Andric             ConstantInt *KnownSize
203*4ba319b5SDimitry Andric               = mdconst::extract<ConstantInt>(MD->getOperand(I));
204*4ba319b5SDimitry Andric             SI->replaceAllUsesWith(ConstantExpr::getIntegerCast(KnownSize,
205*4ba319b5SDimitry Andric                                                                 SI->getType(),
206*4ba319b5SDimitry Andric                                                                 false));
207*4ba319b5SDimitry Andric           } else {
208*4ba319b5SDimitry Andric             SI->replaceAllUsesWith(ZextGroupSize);
209*4ba319b5SDimitry Andric           }
210*4ba319b5SDimitry Andric 
211*4ba319b5SDimitry Andric           MadeChange = true;
212*4ba319b5SDimitry Andric         }
213*4ba319b5SDimitry Andric       }
214*4ba319b5SDimitry Andric     }
215*4ba319b5SDimitry Andric   }
216*4ba319b5SDimitry Andric 
217*4ba319b5SDimitry Andric   if (!HasReqdWorkGroupSize)
218*4ba319b5SDimitry Andric     return MadeChange;
219*4ba319b5SDimitry Andric 
220*4ba319b5SDimitry Andric   // Eliminate any other loads we can from the dispatch packet.
221*4ba319b5SDimitry Andric   for (int I = 0; I < 3; ++I) {
222*4ba319b5SDimitry Andric     Value *GroupSize = WorkGroupSizes[I];
223*4ba319b5SDimitry Andric     if (!GroupSize)
224*4ba319b5SDimitry Andric       continue;
225*4ba319b5SDimitry Andric 
226*4ba319b5SDimitry Andric     ConstantInt *KnownSize = mdconst::extract<ConstantInt>(MD->getOperand(I));
227*4ba319b5SDimitry Andric     GroupSize->replaceAllUsesWith(
228*4ba319b5SDimitry Andric       ConstantExpr::getIntegerCast(KnownSize,
229*4ba319b5SDimitry Andric                                    GroupSize->getType(),
230*4ba319b5SDimitry Andric                                    false));
231*4ba319b5SDimitry Andric     MadeChange = true;
232*4ba319b5SDimitry Andric   }
233*4ba319b5SDimitry Andric 
234*4ba319b5SDimitry Andric   return MadeChange;
235*4ba319b5SDimitry Andric }
236*4ba319b5SDimitry Andric 
237*4ba319b5SDimitry Andric // TODO: Move makeLIDRangeMetadata usage into here. Seem to not get
238*4ba319b5SDimitry Andric // TargetPassConfig for subtarget.
runOnModule(Module & M)239*4ba319b5SDimitry Andric bool AMDGPULowerKernelAttributes::runOnModule(Module &M) {
240*4ba319b5SDimitry Andric   StringRef DispatchPtrName
241*4ba319b5SDimitry Andric     = Intrinsic::getName(Intrinsic::amdgcn_dispatch_ptr);
242*4ba319b5SDimitry Andric 
243*4ba319b5SDimitry Andric   Function *DispatchPtr = Mod->getFunction(DispatchPtrName);
244*4ba319b5SDimitry Andric   if (!DispatchPtr) // Dispatch ptr not used.
245*4ba319b5SDimitry Andric     return false;
246*4ba319b5SDimitry Andric 
247*4ba319b5SDimitry Andric   bool MadeChange = false;
248*4ba319b5SDimitry Andric 
249*4ba319b5SDimitry Andric   SmallPtrSet<Instruction *, 4> HandledUses;
250*4ba319b5SDimitry Andric   for (auto *U : DispatchPtr->users()) {
251*4ba319b5SDimitry Andric     CallInst *CI = cast<CallInst>(U);
252*4ba319b5SDimitry Andric     if (HandledUses.insert(CI).second) {
253*4ba319b5SDimitry Andric       if (processUse(CI))
254*4ba319b5SDimitry Andric         MadeChange = true;
255*4ba319b5SDimitry Andric     }
256*4ba319b5SDimitry Andric   }
257*4ba319b5SDimitry Andric 
258*4ba319b5SDimitry Andric   return MadeChange;
259*4ba319b5SDimitry Andric }
260*4ba319b5SDimitry Andric 
261*4ba319b5SDimitry Andric INITIALIZE_PASS_BEGIN(AMDGPULowerKernelAttributes, DEBUG_TYPE,
262*4ba319b5SDimitry Andric                       "AMDGPU IR optimizations", false, false)
263*4ba319b5SDimitry Andric INITIALIZE_PASS_END(AMDGPULowerKernelAttributes, DEBUG_TYPE, "AMDGPU IR optimizations",
264*4ba319b5SDimitry Andric                     false, false)
265*4ba319b5SDimitry Andric 
266*4ba319b5SDimitry Andric char AMDGPULowerKernelAttributes::ID = 0;
267*4ba319b5SDimitry Andric 
createAMDGPULowerKernelAttributesPass()268*4ba319b5SDimitry Andric ModulePass *llvm::createAMDGPULowerKernelAttributesPass() {
269*4ba319b5SDimitry Andric   return new AMDGPULowerKernelAttributes();
270*4ba319b5SDimitry Andric }
271