1*372d796aSMatt Arsenault //===-- AMDGPULowerKernelAttributes.cpp ------------------------------------------===//
2*372d796aSMatt Arsenault //
3*372d796aSMatt Arsenault //                     The LLVM Compiler Infrastructure
4*372d796aSMatt Arsenault //
5*372d796aSMatt Arsenault // This file is distributed under the University of Illinois Open Source
6*372d796aSMatt Arsenault // License. See LICENSE.TXT for details.
7*372d796aSMatt Arsenault //
8*372d796aSMatt Arsenault //===----------------------------------------------------------------------===//
9*372d796aSMatt Arsenault //
10*372d796aSMatt Arsenault /// \file This pass does attempts to make use of reqd_work_group_size metadata
11*372d796aSMatt Arsenault /// to eliminate loads from the dispatch packet and to constant fold OpenCL
12*372d796aSMatt Arsenault /// get_local_size-like functions.
13*372d796aSMatt Arsenault //
14*372d796aSMatt Arsenault //===----------------------------------------------------------------------===//
15*372d796aSMatt Arsenault 
16*372d796aSMatt Arsenault #include "AMDGPU.h"
17*372d796aSMatt Arsenault #include "AMDGPUTargetMachine.h"
18*372d796aSMatt Arsenault #include "llvm/Analysis/ValueTracking.h"
19*372d796aSMatt Arsenault #include "llvm/CodeGen/Passes.h"
20*372d796aSMatt Arsenault #include "llvm/CodeGen/TargetPassConfig.h"
21*372d796aSMatt Arsenault #include "llvm/IR/Constants.h"
22*372d796aSMatt Arsenault #include "llvm/IR/Function.h"
23*372d796aSMatt Arsenault #include "llvm/IR/Instructions.h"
24*372d796aSMatt Arsenault #include "llvm/IR/PatternMatch.h"
25*372d796aSMatt Arsenault #include "llvm/Pass.h"
26*372d796aSMatt Arsenault 
27*372d796aSMatt Arsenault #define DEBUG_TYPE "amdgpu-lower-kernel-attributes"
28*372d796aSMatt Arsenault 
29*372d796aSMatt Arsenault using namespace llvm;
30*372d796aSMatt Arsenault 
31*372d796aSMatt Arsenault namespace {
32*372d796aSMatt Arsenault 
33*372d796aSMatt Arsenault // Field offsets in hsa_kernel_dispatch_packet_t.
34*372d796aSMatt Arsenault enum DispatchPackedOffsets {
35*372d796aSMatt Arsenault   WORKGROUP_SIZE_X = 4,
36*372d796aSMatt Arsenault   WORKGROUP_SIZE_Y = 6,
37*372d796aSMatt Arsenault   WORKGROUP_SIZE_Z = 8,
38*372d796aSMatt Arsenault 
39*372d796aSMatt Arsenault   GRID_SIZE_X = 12,
40*372d796aSMatt Arsenault   GRID_SIZE_Y = 16,
41*372d796aSMatt Arsenault   GRID_SIZE_Z = 20
42*372d796aSMatt Arsenault };
43*372d796aSMatt Arsenault 
44*372d796aSMatt Arsenault class AMDGPULowerKernelAttributes : public ModulePass {
45*372d796aSMatt Arsenault   Module *Mod = nullptr;
46*372d796aSMatt Arsenault 
47*372d796aSMatt Arsenault public:
48*372d796aSMatt Arsenault   static char ID;
49*372d796aSMatt Arsenault 
50*372d796aSMatt Arsenault   AMDGPULowerKernelAttributes() : ModulePass(ID) {}
51*372d796aSMatt Arsenault 
52*372d796aSMatt Arsenault   bool processUse(CallInst *CI);
53*372d796aSMatt Arsenault 
54*372d796aSMatt Arsenault   bool doInitialization(Module &M) override;
55*372d796aSMatt Arsenault   bool runOnModule(Module &M) override;
56*372d796aSMatt Arsenault 
57*372d796aSMatt Arsenault   StringRef getPassName() const override {
58*372d796aSMatt Arsenault     return "AMDGPU Kernel Attributes";
59*372d796aSMatt Arsenault   }
60*372d796aSMatt Arsenault 
61*372d796aSMatt Arsenault   void getAnalysisUsage(AnalysisUsage &AU) const override {
62*372d796aSMatt Arsenault     AU.setPreservesAll();
63*372d796aSMatt Arsenault  }
64*372d796aSMatt Arsenault };
65*372d796aSMatt Arsenault 
66*372d796aSMatt Arsenault } // end anonymous namespace
67*372d796aSMatt Arsenault 
68*372d796aSMatt Arsenault bool AMDGPULowerKernelAttributes::doInitialization(Module &M) {
69*372d796aSMatt Arsenault   Mod = &M;
70*372d796aSMatt Arsenault   return false;
71*372d796aSMatt Arsenault }
72*372d796aSMatt Arsenault 
73*372d796aSMatt Arsenault bool AMDGPULowerKernelAttributes::processUse(CallInst *CI) {
74*372d796aSMatt Arsenault   Function *F = CI->getParent()->getParent();
75*372d796aSMatt Arsenault 
76*372d796aSMatt Arsenault   auto MD = F->getMetadata("reqd_work_group_size");
77*372d796aSMatt Arsenault   const bool HasReqdWorkGroupSize = MD && MD->getNumOperands() == 3;
78*372d796aSMatt Arsenault 
79*372d796aSMatt Arsenault   const bool HasUniformWorkGroupSize =
80*372d796aSMatt Arsenault     F->getFnAttribute("uniform-work-group-size").getValueAsString() == "true";
81*372d796aSMatt Arsenault 
82*372d796aSMatt Arsenault   if (!HasReqdWorkGroupSize && !HasUniformWorkGroupSize)
83*372d796aSMatt Arsenault     return false;
84*372d796aSMatt Arsenault 
85*372d796aSMatt Arsenault   Value *WorkGroupSizeX = nullptr;
86*372d796aSMatt Arsenault   Value *WorkGroupSizeY = nullptr;
87*372d796aSMatt Arsenault   Value *WorkGroupSizeZ = nullptr;
88*372d796aSMatt Arsenault 
89*372d796aSMatt Arsenault   Value *GridSizeX = nullptr;
90*372d796aSMatt Arsenault   Value *GridSizeY = nullptr;
91*372d796aSMatt Arsenault   Value *GridSizeZ = nullptr;
92*372d796aSMatt Arsenault 
93*372d796aSMatt Arsenault   const DataLayout &DL = Mod->getDataLayout();
94*372d796aSMatt Arsenault 
95*372d796aSMatt Arsenault   // We expect to see several GEP users, casted to the appropriate type and
96*372d796aSMatt Arsenault   // loaded.
97*372d796aSMatt Arsenault   for (User *U : CI->users()) {
98*372d796aSMatt Arsenault     if (!U->hasOneUse())
99*372d796aSMatt Arsenault       continue;
100*372d796aSMatt Arsenault 
101*372d796aSMatt Arsenault     int64_t Offset = 0;
102*372d796aSMatt Arsenault     if (GetPointerBaseWithConstantOffset(U, Offset, DL) != CI)
103*372d796aSMatt Arsenault       continue;
104*372d796aSMatt Arsenault 
105*372d796aSMatt Arsenault     auto *BCI = dyn_cast<BitCastInst>(*U->user_begin());
106*372d796aSMatt Arsenault     if (!BCI || !BCI->hasOneUse())
107*372d796aSMatt Arsenault       continue;
108*372d796aSMatt Arsenault 
109*372d796aSMatt Arsenault     auto *Load = dyn_cast<LoadInst>(*BCI->user_begin());
110*372d796aSMatt Arsenault     if (!Load || !Load->isSimple())
111*372d796aSMatt Arsenault       continue;
112*372d796aSMatt Arsenault 
113*372d796aSMatt Arsenault     unsigned LoadSize = DL.getTypeStoreSize(Load->getType());
114*372d796aSMatt Arsenault 
115*372d796aSMatt Arsenault     // TODO: Handle merged loads.
116*372d796aSMatt Arsenault     switch (Offset) {
117*372d796aSMatt Arsenault     case WORKGROUP_SIZE_X:
118*372d796aSMatt Arsenault       if (LoadSize == 2)
119*372d796aSMatt Arsenault         WorkGroupSizeX = Load;
120*372d796aSMatt Arsenault       break;
121*372d796aSMatt Arsenault     case WORKGROUP_SIZE_Y:
122*372d796aSMatt Arsenault       if (LoadSize == 2)
123*372d796aSMatt Arsenault         WorkGroupSizeY = Load;
124*372d796aSMatt Arsenault       break;
125*372d796aSMatt Arsenault     case WORKGROUP_SIZE_Z:
126*372d796aSMatt Arsenault       if (LoadSize == 2)
127*372d796aSMatt Arsenault         WorkGroupSizeZ = Load;
128*372d796aSMatt Arsenault       break;
129*372d796aSMatt Arsenault     case GRID_SIZE_X:
130*372d796aSMatt Arsenault       if (LoadSize == 4)
131*372d796aSMatt Arsenault         GridSizeX = Load;
132*372d796aSMatt Arsenault       break;
133*372d796aSMatt Arsenault     case GRID_SIZE_Y:
134*372d796aSMatt Arsenault       if (LoadSize == 4)
135*372d796aSMatt Arsenault         GridSizeY = Load;
136*372d796aSMatt Arsenault       break;
137*372d796aSMatt Arsenault     case GRID_SIZE_Z:
138*372d796aSMatt Arsenault       if (LoadSize == 4)
139*372d796aSMatt Arsenault         GridSizeZ = Load;
140*372d796aSMatt Arsenault       break;
141*372d796aSMatt Arsenault     default:
142*372d796aSMatt Arsenault       break;
143*372d796aSMatt Arsenault     }
144*372d796aSMatt Arsenault   }
145*372d796aSMatt Arsenault 
146*372d796aSMatt Arsenault   // Pattern match the code used to handle partial workgroup dispatches in the
147*372d796aSMatt Arsenault   // library implementation of get_local_size, so the entire function can be
148*372d796aSMatt Arsenault   // constant folded with a known group size.
149*372d796aSMatt Arsenault   //
150*372d796aSMatt Arsenault   // uint r = grid_size - group_id * group_size;
151*372d796aSMatt Arsenault   // get_local_size = (r < group_size) ? r : group_size;
152*372d796aSMatt Arsenault   //
153*372d796aSMatt Arsenault   // If we have uniform-work-group-size (which is the default in OpenCL 1.2),
154*372d796aSMatt Arsenault   // the grid_size is required to be a multiple of group_size). In this case:
155*372d796aSMatt Arsenault   //
156*372d796aSMatt Arsenault   // grid_size - (group_id * group_size) < group_size
157*372d796aSMatt Arsenault   // ->
158*372d796aSMatt Arsenault   // grid_size < group_size + (group_id * group_size)
159*372d796aSMatt Arsenault   //
160*372d796aSMatt Arsenault   // (grid_size / group_size) < 1 + group_id
161*372d796aSMatt Arsenault   //
162*372d796aSMatt Arsenault   // grid_size / group_size is at least 1, so we can conclude the select
163*372d796aSMatt Arsenault   // condition is false (except for group_id == 0, where the select result is
164*372d796aSMatt Arsenault   // the same).
165*372d796aSMatt Arsenault 
166*372d796aSMatt Arsenault   bool MadeChange = false;
167*372d796aSMatt Arsenault   Value *WorkGroupSizes[3] = { WorkGroupSizeX, WorkGroupSizeY, WorkGroupSizeZ };
168*372d796aSMatt Arsenault   Value *GridSizes[3] = { GridSizeX, GridSizeY, GridSizeZ };
169*372d796aSMatt Arsenault 
170*372d796aSMatt Arsenault   for (int I = 0; HasUniformWorkGroupSize && I < 3; ++I) {
171*372d796aSMatt Arsenault     Value *GroupSize = WorkGroupSizes[I];
172*372d796aSMatt Arsenault     Value *GridSize = GridSizes[I];
173*372d796aSMatt Arsenault     if (!GroupSize || !GridSize)
174*372d796aSMatt Arsenault       continue;
175*372d796aSMatt Arsenault 
176*372d796aSMatt Arsenault     for (User *U : GroupSize->users()) {
177*372d796aSMatt Arsenault       auto *ZextGroupSize = dyn_cast<ZExtInst>(U);
178*372d796aSMatt Arsenault       if (!ZextGroupSize)
179*372d796aSMatt Arsenault         continue;
180*372d796aSMatt Arsenault 
181*372d796aSMatt Arsenault       for (User *ZextUser : ZextGroupSize->users()) {
182*372d796aSMatt Arsenault         auto *SI = dyn_cast<SelectInst>(ZextUser);
183*372d796aSMatt Arsenault         if (!SI)
184*372d796aSMatt Arsenault           continue;
185*372d796aSMatt Arsenault 
186*372d796aSMatt Arsenault         using namespace llvm::PatternMatch;
187*372d796aSMatt Arsenault         auto GroupIDIntrin = I == 0 ?
188*372d796aSMatt Arsenault           m_Intrinsic<Intrinsic::amdgcn_workgroup_id_x>() :
189*372d796aSMatt Arsenault             (I == 1 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_y>() :
190*372d796aSMatt Arsenault                       m_Intrinsic<Intrinsic::amdgcn_workgroup_id_z>());
191*372d796aSMatt Arsenault 
192*372d796aSMatt Arsenault         auto SubExpr = m_Sub(m_Specific(GridSize),
193*372d796aSMatt Arsenault                              m_Mul(GroupIDIntrin, m_Specific(ZextGroupSize)));
194*372d796aSMatt Arsenault 
195*372d796aSMatt Arsenault         ICmpInst::Predicate Pred;
196*372d796aSMatt Arsenault         if (match(SI,
197*372d796aSMatt Arsenault                   m_Select(m_ICmp(Pred, SubExpr, m_Specific(ZextGroupSize)),
198*372d796aSMatt Arsenault                            SubExpr,
199*372d796aSMatt Arsenault                            m_Specific(ZextGroupSize))) &&
200*372d796aSMatt Arsenault             Pred == ICmpInst::ICMP_ULT) {
201*372d796aSMatt Arsenault           if (HasReqdWorkGroupSize) {
202*372d796aSMatt Arsenault             ConstantInt *KnownSize
203*372d796aSMatt Arsenault               = mdconst::extract<ConstantInt>(MD->getOperand(I));
204*372d796aSMatt Arsenault             SI->replaceAllUsesWith(ConstantExpr::getIntegerCast(KnownSize,
205*372d796aSMatt Arsenault                                                                 SI->getType(),
206*372d796aSMatt Arsenault                                                                 false));
207*372d796aSMatt Arsenault           } else {
208*372d796aSMatt Arsenault             SI->replaceAllUsesWith(ZextGroupSize);
209*372d796aSMatt Arsenault           }
210*372d796aSMatt Arsenault 
211*372d796aSMatt Arsenault           MadeChange = true;
212*372d796aSMatt Arsenault         }
213*372d796aSMatt Arsenault       }
214*372d796aSMatt Arsenault     }
215*372d796aSMatt Arsenault   }
216*372d796aSMatt Arsenault 
217*372d796aSMatt Arsenault   if (!HasReqdWorkGroupSize)
218*372d796aSMatt Arsenault     return MadeChange;
219*372d796aSMatt Arsenault 
220*372d796aSMatt Arsenault   // Eliminate any other loads we can from the dispatch packet.
221*372d796aSMatt Arsenault   for (int I = 0; I < 3; ++I) {
222*372d796aSMatt Arsenault     Value *GroupSize = WorkGroupSizes[I];
223*372d796aSMatt Arsenault     if (!GroupSize)
224*372d796aSMatt Arsenault       continue;
225*372d796aSMatt Arsenault 
226*372d796aSMatt Arsenault     ConstantInt *KnownSize = mdconst::extract<ConstantInt>(MD->getOperand(I));
227*372d796aSMatt Arsenault     GroupSize->replaceAllUsesWith(
228*372d796aSMatt Arsenault       ConstantExpr::getIntegerCast(KnownSize,
229*372d796aSMatt Arsenault                                    GroupSize->getType(),
230*372d796aSMatt Arsenault                                    false));
231*372d796aSMatt Arsenault     MadeChange = true;
232*372d796aSMatt Arsenault   }
233*372d796aSMatt Arsenault 
234*372d796aSMatt Arsenault   return MadeChange;
235*372d796aSMatt Arsenault }
236*372d796aSMatt Arsenault 
237*372d796aSMatt Arsenault // TODO: Move makeLIDRangeMetadata usage into here. Seem to not get
238*372d796aSMatt Arsenault // TargetPassConfig for subtarget.
239*372d796aSMatt Arsenault bool AMDGPULowerKernelAttributes::runOnModule(Module &M) {
240*372d796aSMatt Arsenault   StringRef DispatchPtrName
241*372d796aSMatt Arsenault     = Intrinsic::getName(Intrinsic::amdgcn_dispatch_ptr);
242*372d796aSMatt Arsenault 
243*372d796aSMatt Arsenault   Function *DispatchPtr = Mod->getFunction(DispatchPtrName);
244*372d796aSMatt Arsenault   if (!DispatchPtr) // Dispatch ptr not used.
245*372d796aSMatt Arsenault     return false;
246*372d796aSMatt Arsenault 
247*372d796aSMatt Arsenault   bool MadeChange = false;
248*372d796aSMatt Arsenault 
249*372d796aSMatt Arsenault   SmallPtrSet<Instruction *, 4> HandledUses;
250*372d796aSMatt Arsenault   for (auto *U : DispatchPtr->users()) {
251*372d796aSMatt Arsenault     CallInst *CI = cast<CallInst>(U);
252*372d796aSMatt Arsenault     if (HandledUses.insert(CI).second) {
253*372d796aSMatt Arsenault       if (processUse(CI))
254*372d796aSMatt Arsenault         MadeChange = true;
255*372d796aSMatt Arsenault     }
256*372d796aSMatt Arsenault   }
257*372d796aSMatt Arsenault 
258*372d796aSMatt Arsenault   return MadeChange;
259*372d796aSMatt Arsenault }
260*372d796aSMatt Arsenault 
261*372d796aSMatt Arsenault INITIALIZE_PASS_BEGIN(AMDGPULowerKernelAttributes, DEBUG_TYPE,
262*372d796aSMatt Arsenault                       "AMDGPU IR optimizations", false, false)
263*372d796aSMatt Arsenault INITIALIZE_PASS_END(AMDGPULowerKernelAttributes, DEBUG_TYPE, "AMDGPU IR optimizations",
264*372d796aSMatt Arsenault                     false, false)
265*372d796aSMatt Arsenault 
266*372d796aSMatt Arsenault char AMDGPULowerKernelAttributes::ID = 0;
267*372d796aSMatt Arsenault 
268*372d796aSMatt Arsenault ModulePass *llvm::createAMDGPULowerKernelAttributesPass() {
269*372d796aSMatt Arsenault   return new AMDGPULowerKernelAttributes();
270*372d796aSMatt Arsenault }
271