1372d796aSMatt Arsenault //===-- AMDGPULowerKernelAttributes.cpp ------------------------------------------===//
2372d796aSMatt Arsenault //
32946cd70SChandler Carruth // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42946cd70SChandler Carruth // See https://llvm.org/LICENSE.txt for license information.
52946cd70SChandler Carruth // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6372d796aSMatt Arsenault //
7372d796aSMatt Arsenault //===----------------------------------------------------------------------===//
8372d796aSMatt Arsenault //
9372d796aSMatt Arsenault /// \file This pass does attempts to make use of reqd_work_group_size metadata
10372d796aSMatt Arsenault /// to eliminate loads from the dispatch packet and to constant fold OpenCL
11372d796aSMatt Arsenault /// get_local_size-like functions.
12372d796aSMatt Arsenault //
13372d796aSMatt Arsenault //===----------------------------------------------------------------------===//
14372d796aSMatt Arsenault 
15372d796aSMatt Arsenault #include "AMDGPU.h"
16372d796aSMatt Arsenault #include "AMDGPUTargetMachine.h"
17372d796aSMatt Arsenault #include "llvm/Analysis/ValueTracking.h"
18372d796aSMatt Arsenault #include "llvm/CodeGen/Passes.h"
19372d796aSMatt Arsenault #include "llvm/CodeGen/TargetPassConfig.h"
20372d796aSMatt Arsenault #include "llvm/IR/Constants.h"
21372d796aSMatt Arsenault #include "llvm/IR/Function.h"
22*7ecbe0c7SArthur Eubanks #include "llvm/IR/InstIterator.h"
23372d796aSMatt Arsenault #include "llvm/IR/Instructions.h"
24*7ecbe0c7SArthur Eubanks #include "llvm/IR/PassManager.h"
25372d796aSMatt Arsenault #include "llvm/IR/PatternMatch.h"
26372d796aSMatt Arsenault #include "llvm/Pass.h"
27372d796aSMatt Arsenault 
28372d796aSMatt Arsenault #define DEBUG_TYPE "amdgpu-lower-kernel-attributes"
29372d796aSMatt Arsenault 
30372d796aSMatt Arsenault using namespace llvm;
31372d796aSMatt Arsenault 
32372d796aSMatt Arsenault namespace {
33372d796aSMatt Arsenault 
34372d796aSMatt Arsenault // Field offsets in hsa_kernel_dispatch_packet_t.
35372d796aSMatt Arsenault enum DispatchPackedOffsets {
36372d796aSMatt Arsenault   WORKGROUP_SIZE_X = 4,
37372d796aSMatt Arsenault   WORKGROUP_SIZE_Y = 6,
38372d796aSMatt Arsenault   WORKGROUP_SIZE_Z = 8,
39372d796aSMatt Arsenault 
40372d796aSMatt Arsenault   GRID_SIZE_X = 12,
41372d796aSMatt Arsenault   GRID_SIZE_Y = 16,
42372d796aSMatt Arsenault   GRID_SIZE_Z = 20
43372d796aSMatt Arsenault };
44372d796aSMatt Arsenault 
45372d796aSMatt Arsenault class AMDGPULowerKernelAttributes : public ModulePass {
46372d796aSMatt Arsenault public:
47372d796aSMatt Arsenault   static char ID;
48372d796aSMatt Arsenault 
49372d796aSMatt Arsenault   AMDGPULowerKernelAttributes() : ModulePass(ID) {}
50372d796aSMatt Arsenault 
51372d796aSMatt Arsenault   bool runOnModule(Module &M) override;
52372d796aSMatt Arsenault 
53372d796aSMatt Arsenault   StringRef getPassName() const override {
54372d796aSMatt Arsenault     return "AMDGPU Kernel Attributes";
55372d796aSMatt Arsenault   }
56372d796aSMatt Arsenault 
57372d796aSMatt Arsenault   void getAnalysisUsage(AnalysisUsage &AU) const override {
58372d796aSMatt Arsenault     AU.setPreservesAll();
59372d796aSMatt Arsenault  }
60372d796aSMatt Arsenault };
61372d796aSMatt Arsenault 
62372d796aSMatt Arsenault } // end anonymous namespace
63372d796aSMatt Arsenault 
64*7ecbe0c7SArthur Eubanks static bool processUse(CallInst *CI) {
65372d796aSMatt Arsenault   Function *F = CI->getParent()->getParent();
66372d796aSMatt Arsenault 
67372d796aSMatt Arsenault   auto MD = F->getMetadata("reqd_work_group_size");
68372d796aSMatt Arsenault   const bool HasReqdWorkGroupSize = MD && MD->getNumOperands() == 3;
69372d796aSMatt Arsenault 
70372d796aSMatt Arsenault   const bool HasUniformWorkGroupSize =
71372d796aSMatt Arsenault     F->getFnAttribute("uniform-work-group-size").getValueAsString() == "true";
72372d796aSMatt Arsenault 
73372d796aSMatt Arsenault   if (!HasReqdWorkGroupSize && !HasUniformWorkGroupSize)
74372d796aSMatt Arsenault     return false;
75372d796aSMatt Arsenault 
76372d796aSMatt Arsenault   Value *WorkGroupSizeX = nullptr;
77372d796aSMatt Arsenault   Value *WorkGroupSizeY = nullptr;
78372d796aSMatt Arsenault   Value *WorkGroupSizeZ = nullptr;
79372d796aSMatt Arsenault 
80372d796aSMatt Arsenault   Value *GridSizeX = nullptr;
81372d796aSMatt Arsenault   Value *GridSizeY = nullptr;
82372d796aSMatt Arsenault   Value *GridSizeZ = nullptr;
83372d796aSMatt Arsenault 
84*7ecbe0c7SArthur Eubanks   const DataLayout &DL = F->getParent()->getDataLayout();
85372d796aSMatt Arsenault 
86372d796aSMatt Arsenault   // We expect to see several GEP users, casted to the appropriate type and
87372d796aSMatt Arsenault   // loaded.
88372d796aSMatt Arsenault   for (User *U : CI->users()) {
89372d796aSMatt Arsenault     if (!U->hasOneUse())
90372d796aSMatt Arsenault       continue;
91372d796aSMatt Arsenault 
92372d796aSMatt Arsenault     int64_t Offset = 0;
93372d796aSMatt Arsenault     if (GetPointerBaseWithConstantOffset(U, Offset, DL) != CI)
94372d796aSMatt Arsenault       continue;
95372d796aSMatt Arsenault 
96372d796aSMatt Arsenault     auto *BCI = dyn_cast<BitCastInst>(*U->user_begin());
97372d796aSMatt Arsenault     if (!BCI || !BCI->hasOneUse())
98372d796aSMatt Arsenault       continue;
99372d796aSMatt Arsenault 
100372d796aSMatt Arsenault     auto *Load = dyn_cast<LoadInst>(*BCI->user_begin());
101372d796aSMatt Arsenault     if (!Load || !Load->isSimple())
102372d796aSMatt Arsenault       continue;
103372d796aSMatt Arsenault 
104372d796aSMatt Arsenault     unsigned LoadSize = DL.getTypeStoreSize(Load->getType());
105372d796aSMatt Arsenault 
106372d796aSMatt Arsenault     // TODO: Handle merged loads.
107372d796aSMatt Arsenault     switch (Offset) {
108372d796aSMatt Arsenault     case WORKGROUP_SIZE_X:
109372d796aSMatt Arsenault       if (LoadSize == 2)
110372d796aSMatt Arsenault         WorkGroupSizeX = Load;
111372d796aSMatt Arsenault       break;
112372d796aSMatt Arsenault     case WORKGROUP_SIZE_Y:
113372d796aSMatt Arsenault       if (LoadSize == 2)
114372d796aSMatt Arsenault         WorkGroupSizeY = Load;
115372d796aSMatt Arsenault       break;
116372d796aSMatt Arsenault     case WORKGROUP_SIZE_Z:
117372d796aSMatt Arsenault       if (LoadSize == 2)
118372d796aSMatt Arsenault         WorkGroupSizeZ = Load;
119372d796aSMatt Arsenault       break;
120372d796aSMatt Arsenault     case GRID_SIZE_X:
121372d796aSMatt Arsenault       if (LoadSize == 4)
122372d796aSMatt Arsenault         GridSizeX = Load;
123372d796aSMatt Arsenault       break;
124372d796aSMatt Arsenault     case GRID_SIZE_Y:
125372d796aSMatt Arsenault       if (LoadSize == 4)
126372d796aSMatt Arsenault         GridSizeY = Load;
127372d796aSMatt Arsenault       break;
128372d796aSMatt Arsenault     case GRID_SIZE_Z:
129372d796aSMatt Arsenault       if (LoadSize == 4)
130372d796aSMatt Arsenault         GridSizeZ = Load;
131372d796aSMatt Arsenault       break;
132372d796aSMatt Arsenault     default:
133372d796aSMatt Arsenault       break;
134372d796aSMatt Arsenault     }
135372d796aSMatt Arsenault   }
136372d796aSMatt Arsenault 
137372d796aSMatt Arsenault   // Pattern match the code used to handle partial workgroup dispatches in the
138372d796aSMatt Arsenault   // library implementation of get_local_size, so the entire function can be
139372d796aSMatt Arsenault   // constant folded with a known group size.
140372d796aSMatt Arsenault   //
141372d796aSMatt Arsenault   // uint r = grid_size - group_id * group_size;
142372d796aSMatt Arsenault   // get_local_size = (r < group_size) ? r : group_size;
143372d796aSMatt Arsenault   //
144372d796aSMatt Arsenault   // If we have uniform-work-group-size (which is the default in OpenCL 1.2),
145372d796aSMatt Arsenault   // the grid_size is required to be a multiple of group_size). In this case:
146372d796aSMatt Arsenault   //
147372d796aSMatt Arsenault   // grid_size - (group_id * group_size) < group_size
148372d796aSMatt Arsenault   // ->
149372d796aSMatt Arsenault   // grid_size < group_size + (group_id * group_size)
150372d796aSMatt Arsenault   //
151372d796aSMatt Arsenault   // (grid_size / group_size) < 1 + group_id
152372d796aSMatt Arsenault   //
153372d796aSMatt Arsenault   // grid_size / group_size is at least 1, so we can conclude the select
154372d796aSMatt Arsenault   // condition is false (except for group_id == 0, where the select result is
155372d796aSMatt Arsenault   // the same).
156372d796aSMatt Arsenault 
157372d796aSMatt Arsenault   bool MadeChange = false;
158372d796aSMatt Arsenault   Value *WorkGroupSizes[3] = { WorkGroupSizeX, WorkGroupSizeY, WorkGroupSizeZ };
159372d796aSMatt Arsenault   Value *GridSizes[3] = { GridSizeX, GridSizeY, GridSizeZ };
160372d796aSMatt Arsenault 
161372d796aSMatt Arsenault   for (int I = 0; HasUniformWorkGroupSize && I < 3; ++I) {
162372d796aSMatt Arsenault     Value *GroupSize = WorkGroupSizes[I];
163372d796aSMatt Arsenault     Value *GridSize = GridSizes[I];
164372d796aSMatt Arsenault     if (!GroupSize || !GridSize)
165372d796aSMatt Arsenault       continue;
166372d796aSMatt Arsenault 
167372d796aSMatt Arsenault     for (User *U : GroupSize->users()) {
168372d796aSMatt Arsenault       auto *ZextGroupSize = dyn_cast<ZExtInst>(U);
169372d796aSMatt Arsenault       if (!ZextGroupSize)
170372d796aSMatt Arsenault         continue;
171372d796aSMatt Arsenault 
172372d796aSMatt Arsenault       for (User *ZextUser : ZextGroupSize->users()) {
173372d796aSMatt Arsenault         auto *SI = dyn_cast<SelectInst>(ZextUser);
174372d796aSMatt Arsenault         if (!SI)
175372d796aSMatt Arsenault           continue;
176372d796aSMatt Arsenault 
177372d796aSMatt Arsenault         using namespace llvm::PatternMatch;
178372d796aSMatt Arsenault         auto GroupIDIntrin = I == 0 ?
179372d796aSMatt Arsenault           m_Intrinsic<Intrinsic::amdgcn_workgroup_id_x>() :
180372d796aSMatt Arsenault             (I == 1 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_y>() :
181372d796aSMatt Arsenault                       m_Intrinsic<Intrinsic::amdgcn_workgroup_id_z>());
182372d796aSMatt Arsenault 
183372d796aSMatt Arsenault         auto SubExpr = m_Sub(m_Specific(GridSize),
184372d796aSMatt Arsenault                              m_Mul(GroupIDIntrin, m_Specific(ZextGroupSize)));
185372d796aSMatt Arsenault 
186372d796aSMatt Arsenault         ICmpInst::Predicate Pred;
187372d796aSMatt Arsenault         if (match(SI,
188372d796aSMatt Arsenault                   m_Select(m_ICmp(Pred, SubExpr, m_Specific(ZextGroupSize)),
189372d796aSMatt Arsenault                            SubExpr,
190372d796aSMatt Arsenault                            m_Specific(ZextGroupSize))) &&
191372d796aSMatt Arsenault             Pred == ICmpInst::ICMP_ULT) {
192372d796aSMatt Arsenault           if (HasReqdWorkGroupSize) {
193372d796aSMatt Arsenault             ConstantInt *KnownSize
194372d796aSMatt Arsenault               = mdconst::extract<ConstantInt>(MD->getOperand(I));
195372d796aSMatt Arsenault             SI->replaceAllUsesWith(ConstantExpr::getIntegerCast(KnownSize,
196372d796aSMatt Arsenault                                                                 SI->getType(),
197372d796aSMatt Arsenault                                                                 false));
198372d796aSMatt Arsenault           } else {
199372d796aSMatt Arsenault             SI->replaceAllUsesWith(ZextGroupSize);
200372d796aSMatt Arsenault           }
201372d796aSMatt Arsenault 
202372d796aSMatt Arsenault           MadeChange = true;
203372d796aSMatt Arsenault         }
204372d796aSMatt Arsenault       }
205372d796aSMatt Arsenault     }
206372d796aSMatt Arsenault   }
207372d796aSMatt Arsenault 
208372d796aSMatt Arsenault   if (!HasReqdWorkGroupSize)
209372d796aSMatt Arsenault     return MadeChange;
210372d796aSMatt Arsenault 
211372d796aSMatt Arsenault   // Eliminate any other loads we can from the dispatch packet.
212372d796aSMatt Arsenault   for (int I = 0; I < 3; ++I) {
213372d796aSMatt Arsenault     Value *GroupSize = WorkGroupSizes[I];
214372d796aSMatt Arsenault     if (!GroupSize)
215372d796aSMatt Arsenault       continue;
216372d796aSMatt Arsenault 
217372d796aSMatt Arsenault     ConstantInt *KnownSize = mdconst::extract<ConstantInt>(MD->getOperand(I));
218372d796aSMatt Arsenault     GroupSize->replaceAllUsesWith(
219372d796aSMatt Arsenault       ConstantExpr::getIntegerCast(KnownSize,
220372d796aSMatt Arsenault                                    GroupSize->getType(),
221372d796aSMatt Arsenault                                    false));
222372d796aSMatt Arsenault     MadeChange = true;
223372d796aSMatt Arsenault   }
224372d796aSMatt Arsenault 
225372d796aSMatt Arsenault   return MadeChange;
226372d796aSMatt Arsenault }
227372d796aSMatt Arsenault 
228372d796aSMatt Arsenault // TODO: Move makeLIDRangeMetadata usage into here. Seem to not get
229372d796aSMatt Arsenault // TargetPassConfig for subtarget.
230372d796aSMatt Arsenault bool AMDGPULowerKernelAttributes::runOnModule(Module &M) {
231372d796aSMatt Arsenault   StringRef DispatchPtrName
232372d796aSMatt Arsenault     = Intrinsic::getName(Intrinsic::amdgcn_dispatch_ptr);
233372d796aSMatt Arsenault 
234*7ecbe0c7SArthur Eubanks   Function *DispatchPtr = M.getFunction(DispatchPtrName);
235372d796aSMatt Arsenault   if (!DispatchPtr) // Dispatch ptr not used.
236372d796aSMatt Arsenault     return false;
237372d796aSMatt Arsenault 
238372d796aSMatt Arsenault   bool MadeChange = false;
239372d796aSMatt Arsenault 
240372d796aSMatt Arsenault   SmallPtrSet<Instruction *, 4> HandledUses;
241372d796aSMatt Arsenault   for (auto *U : DispatchPtr->users()) {
242372d796aSMatt Arsenault     CallInst *CI = cast<CallInst>(U);
243372d796aSMatt Arsenault     if (HandledUses.insert(CI).second) {
244372d796aSMatt Arsenault       if (processUse(CI))
245372d796aSMatt Arsenault         MadeChange = true;
246372d796aSMatt Arsenault     }
247372d796aSMatt Arsenault   }
248372d796aSMatt Arsenault 
249372d796aSMatt Arsenault   return MadeChange;
250372d796aSMatt Arsenault }
251372d796aSMatt Arsenault 
252372d796aSMatt Arsenault INITIALIZE_PASS_BEGIN(AMDGPULowerKernelAttributes, DEBUG_TYPE,
253372d796aSMatt Arsenault                       "AMDGPU IR optimizations", false, false)
254372d796aSMatt Arsenault INITIALIZE_PASS_END(AMDGPULowerKernelAttributes, DEBUG_TYPE, "AMDGPU IR optimizations",
255372d796aSMatt Arsenault                     false, false)
256372d796aSMatt Arsenault 
257372d796aSMatt Arsenault char AMDGPULowerKernelAttributes::ID = 0;
258372d796aSMatt Arsenault 
259372d796aSMatt Arsenault ModulePass *llvm::createAMDGPULowerKernelAttributesPass() {
260372d796aSMatt Arsenault   return new AMDGPULowerKernelAttributes();
261372d796aSMatt Arsenault }
262*7ecbe0c7SArthur Eubanks 
263*7ecbe0c7SArthur Eubanks PreservedAnalyses
264*7ecbe0c7SArthur Eubanks AMDGPULowerKernelAttributesPass::run(Function &F, FunctionAnalysisManager &AM) {
265*7ecbe0c7SArthur Eubanks   StringRef DispatchPtrName =
266*7ecbe0c7SArthur Eubanks       Intrinsic::getName(Intrinsic::amdgcn_dispatch_ptr);
267*7ecbe0c7SArthur Eubanks 
268*7ecbe0c7SArthur Eubanks   Function *DispatchPtr = F.getParent()->getFunction(DispatchPtrName);
269*7ecbe0c7SArthur Eubanks   if (!DispatchPtr) // Dispatch ptr not used.
270*7ecbe0c7SArthur Eubanks     return PreservedAnalyses::all();
271*7ecbe0c7SArthur Eubanks 
272*7ecbe0c7SArthur Eubanks   for (Instruction &I : instructions(F)) {
273*7ecbe0c7SArthur Eubanks     if (CallInst *CI = dyn_cast<CallInst>(&I)) {
274*7ecbe0c7SArthur Eubanks       if (CI->getCalledFunction() == DispatchPtr)
275*7ecbe0c7SArthur Eubanks         processUse(CI);
276*7ecbe0c7SArthur Eubanks     }
277*7ecbe0c7SArthur Eubanks   }
278*7ecbe0c7SArthur Eubanks 
279*7ecbe0c7SArthur Eubanks   return PreservedAnalyses::all();
280*7ecbe0c7SArthur Eubanks }
281