1372d796aSMatt Arsenault //===-- AMDGPULowerKernelAttributes.cpp ------------------------------------------===//
2372d796aSMatt Arsenault //
32946cd70SChandler Carruth // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42946cd70SChandler Carruth // See https://llvm.org/LICENSE.txt for license information.
52946cd70SChandler Carruth // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6372d796aSMatt Arsenault //
7372d796aSMatt Arsenault //===----------------------------------------------------------------------===//
8372d796aSMatt Arsenault //
9372d796aSMatt Arsenault /// \file This pass does attempts to make use of reqd_work_group_size metadata
10372d796aSMatt Arsenault /// to eliminate loads from the dispatch packet and to constant fold OpenCL
11372d796aSMatt Arsenault /// get_local_size-like functions.
12372d796aSMatt Arsenault //
13372d796aSMatt Arsenault //===----------------------------------------------------------------------===//
14372d796aSMatt Arsenault
15372d796aSMatt Arsenault #include "AMDGPU.h"
16372d796aSMatt Arsenault #include "llvm/Analysis/ValueTracking.h"
17372d796aSMatt Arsenault #include "llvm/CodeGen/Passes.h"
18372d796aSMatt Arsenault #include "llvm/CodeGen/TargetPassConfig.h"
19372d796aSMatt Arsenault #include "llvm/IR/Constants.h"
20372d796aSMatt Arsenault #include "llvm/IR/Function.h"
217ecbe0c7SArthur Eubanks #include "llvm/IR/InstIterator.h"
22372d796aSMatt Arsenault #include "llvm/IR/Instructions.h"
236a87e9b0Sdfukalov #include "llvm/IR/IntrinsicsAMDGPU.h"
24372d796aSMatt Arsenault #include "llvm/IR/PatternMatch.h"
25372d796aSMatt Arsenault #include "llvm/Pass.h"
26372d796aSMatt Arsenault
27372d796aSMatt Arsenault #define DEBUG_TYPE "amdgpu-lower-kernel-attributes"
28372d796aSMatt Arsenault
29372d796aSMatt Arsenault using namespace llvm;
30372d796aSMatt Arsenault
31372d796aSMatt Arsenault namespace {
32372d796aSMatt Arsenault
33372d796aSMatt Arsenault // Field offsets in hsa_kernel_dispatch_packet_t.
34372d796aSMatt Arsenault enum DispatchPackedOffsets {
35372d796aSMatt Arsenault WORKGROUP_SIZE_X = 4,
36372d796aSMatt Arsenault WORKGROUP_SIZE_Y = 6,
37372d796aSMatt Arsenault WORKGROUP_SIZE_Z = 8,
38372d796aSMatt Arsenault
39372d796aSMatt Arsenault GRID_SIZE_X = 12,
40372d796aSMatt Arsenault GRID_SIZE_Y = 16,
41372d796aSMatt Arsenault GRID_SIZE_Z = 20
42372d796aSMatt Arsenault };
43372d796aSMatt Arsenault
44372d796aSMatt Arsenault class AMDGPULowerKernelAttributes : public ModulePass {
45372d796aSMatt Arsenault public:
46372d796aSMatt Arsenault static char ID;
47372d796aSMatt Arsenault
AMDGPULowerKernelAttributes()48372d796aSMatt Arsenault AMDGPULowerKernelAttributes() : ModulePass(ID) {}
49372d796aSMatt Arsenault
50372d796aSMatt Arsenault bool runOnModule(Module &M) override;
51372d796aSMatt Arsenault
getPassName() const52372d796aSMatt Arsenault StringRef getPassName() const override {
53372d796aSMatt Arsenault return "AMDGPU Kernel Attributes";
54372d796aSMatt Arsenault }
55372d796aSMatt Arsenault
getAnalysisUsage(AnalysisUsage & AU) const56372d796aSMatt Arsenault void getAnalysisUsage(AnalysisUsage &AU) const override {
57372d796aSMatt Arsenault AU.setPreservesAll();
58372d796aSMatt Arsenault }
59372d796aSMatt Arsenault };
60372d796aSMatt Arsenault
61372d796aSMatt Arsenault } // end anonymous namespace
62372d796aSMatt Arsenault
processUse(CallInst * CI)637ecbe0c7SArthur Eubanks static bool processUse(CallInst *CI) {
64372d796aSMatt Arsenault Function *F = CI->getParent()->getParent();
65372d796aSMatt Arsenault
66372d796aSMatt Arsenault auto MD = F->getMetadata("reqd_work_group_size");
67372d796aSMatt Arsenault const bool HasReqdWorkGroupSize = MD && MD->getNumOperands() == 3;
68372d796aSMatt Arsenault
69372d796aSMatt Arsenault const bool HasUniformWorkGroupSize =
70d6de1e1aSSerge Guelton F->getFnAttribute("uniform-work-group-size").getValueAsBool();
71372d796aSMatt Arsenault
72372d796aSMatt Arsenault if (!HasReqdWorkGroupSize && !HasUniformWorkGroupSize)
73372d796aSMatt Arsenault return false;
74372d796aSMatt Arsenault
75372d796aSMatt Arsenault Value *WorkGroupSizeX = nullptr;
76372d796aSMatt Arsenault Value *WorkGroupSizeY = nullptr;
77372d796aSMatt Arsenault Value *WorkGroupSizeZ = nullptr;
78372d796aSMatt Arsenault
79372d796aSMatt Arsenault Value *GridSizeX = nullptr;
80372d796aSMatt Arsenault Value *GridSizeY = nullptr;
81372d796aSMatt Arsenault Value *GridSizeZ = nullptr;
82372d796aSMatt Arsenault
837ecbe0c7SArthur Eubanks const DataLayout &DL = F->getParent()->getDataLayout();
84372d796aSMatt Arsenault
85372d796aSMatt Arsenault // We expect to see several GEP users, casted to the appropriate type and
86372d796aSMatt Arsenault // loaded.
87372d796aSMatt Arsenault for (User *U : CI->users()) {
88372d796aSMatt Arsenault if (!U->hasOneUse())
89372d796aSMatt Arsenault continue;
90372d796aSMatt Arsenault
91372d796aSMatt Arsenault int64_t Offset = 0;
92372d796aSMatt Arsenault if (GetPointerBaseWithConstantOffset(U, Offset, DL) != CI)
93372d796aSMatt Arsenault continue;
94372d796aSMatt Arsenault
95372d796aSMatt Arsenault auto *BCI = dyn_cast<BitCastInst>(*U->user_begin());
96372d796aSMatt Arsenault if (!BCI || !BCI->hasOneUse())
97372d796aSMatt Arsenault continue;
98372d796aSMatt Arsenault
99372d796aSMatt Arsenault auto *Load = dyn_cast<LoadInst>(*BCI->user_begin());
100372d796aSMatt Arsenault if (!Load || !Load->isSimple())
101372d796aSMatt Arsenault continue;
102372d796aSMatt Arsenault
103372d796aSMatt Arsenault unsigned LoadSize = DL.getTypeStoreSize(Load->getType());
104372d796aSMatt Arsenault
105372d796aSMatt Arsenault // TODO: Handle merged loads.
106372d796aSMatt Arsenault switch (Offset) {
107372d796aSMatt Arsenault case WORKGROUP_SIZE_X:
108372d796aSMatt Arsenault if (LoadSize == 2)
109372d796aSMatt Arsenault WorkGroupSizeX = Load;
110372d796aSMatt Arsenault break;
111372d796aSMatt Arsenault case WORKGROUP_SIZE_Y:
112372d796aSMatt Arsenault if (LoadSize == 2)
113372d796aSMatt Arsenault WorkGroupSizeY = Load;
114372d796aSMatt Arsenault break;
115372d796aSMatt Arsenault case WORKGROUP_SIZE_Z:
116372d796aSMatt Arsenault if (LoadSize == 2)
117372d796aSMatt Arsenault WorkGroupSizeZ = Load;
118372d796aSMatt Arsenault break;
119372d796aSMatt Arsenault case GRID_SIZE_X:
120372d796aSMatt Arsenault if (LoadSize == 4)
121372d796aSMatt Arsenault GridSizeX = Load;
122372d796aSMatt Arsenault break;
123372d796aSMatt Arsenault case GRID_SIZE_Y:
124372d796aSMatt Arsenault if (LoadSize == 4)
125372d796aSMatt Arsenault GridSizeY = Load;
126372d796aSMatt Arsenault break;
127372d796aSMatt Arsenault case GRID_SIZE_Z:
128372d796aSMatt Arsenault if (LoadSize == 4)
129372d796aSMatt Arsenault GridSizeZ = Load;
130372d796aSMatt Arsenault break;
131372d796aSMatt Arsenault default:
132372d796aSMatt Arsenault break;
133372d796aSMatt Arsenault }
134372d796aSMatt Arsenault }
135372d796aSMatt Arsenault
136372d796aSMatt Arsenault // Pattern match the code used to handle partial workgroup dispatches in the
137372d796aSMatt Arsenault // library implementation of get_local_size, so the entire function can be
138372d796aSMatt Arsenault // constant folded with a known group size.
139372d796aSMatt Arsenault //
140372d796aSMatt Arsenault // uint r = grid_size - group_id * group_size;
141372d796aSMatt Arsenault // get_local_size = (r < group_size) ? r : group_size;
142372d796aSMatt Arsenault //
143372d796aSMatt Arsenault // If we have uniform-work-group-size (which is the default in OpenCL 1.2),
144372d796aSMatt Arsenault // the grid_size is required to be a multiple of group_size). In this case:
145372d796aSMatt Arsenault //
146372d796aSMatt Arsenault // grid_size - (group_id * group_size) < group_size
147372d796aSMatt Arsenault // ->
148372d796aSMatt Arsenault // grid_size < group_size + (group_id * group_size)
149372d796aSMatt Arsenault //
150372d796aSMatt Arsenault // (grid_size / group_size) < 1 + group_id
151372d796aSMatt Arsenault //
152372d796aSMatt Arsenault // grid_size / group_size is at least 1, so we can conclude the select
153372d796aSMatt Arsenault // condition is false (except for group_id == 0, where the select result is
154372d796aSMatt Arsenault // the same).
155372d796aSMatt Arsenault
156372d796aSMatt Arsenault bool MadeChange = false;
157372d796aSMatt Arsenault Value *WorkGroupSizes[3] = { WorkGroupSizeX, WorkGroupSizeY, WorkGroupSizeZ };
158372d796aSMatt Arsenault Value *GridSizes[3] = { GridSizeX, GridSizeY, GridSizeZ };
159372d796aSMatt Arsenault
160372d796aSMatt Arsenault for (int I = 0; HasUniformWorkGroupSize && I < 3; ++I) {
161372d796aSMatt Arsenault Value *GroupSize = WorkGroupSizes[I];
162372d796aSMatt Arsenault Value *GridSize = GridSizes[I];
163372d796aSMatt Arsenault if (!GroupSize || !GridSize)
164372d796aSMatt Arsenault continue;
165372d796aSMatt Arsenault
166*c986d476SMatt Arsenault using namespace llvm::PatternMatch;
167*c986d476SMatt Arsenault auto GroupIDIntrin =
168*c986d476SMatt Arsenault I == 0 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_x>()
169*c986d476SMatt Arsenault : (I == 1 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_y>()
170*c986d476SMatt Arsenault : m_Intrinsic<Intrinsic::amdgcn_workgroup_id_z>());
171*c986d476SMatt Arsenault
172372d796aSMatt Arsenault for (User *U : GroupSize->users()) {
173372d796aSMatt Arsenault auto *ZextGroupSize = dyn_cast<ZExtInst>(U);
174372d796aSMatt Arsenault if (!ZextGroupSize)
175372d796aSMatt Arsenault continue;
176372d796aSMatt Arsenault
177*c986d476SMatt Arsenault for (User *UMin : ZextGroupSize->users()) {
178*c986d476SMatt Arsenault if (match(UMin,
179*c986d476SMatt Arsenault m_UMin(m_Sub(m_Specific(GridSize),
180*c986d476SMatt Arsenault m_Mul(GroupIDIntrin, m_Specific(ZextGroupSize))),
181*c986d476SMatt Arsenault m_Specific(ZextGroupSize)))) {
182372d796aSMatt Arsenault if (HasReqdWorkGroupSize) {
183372d796aSMatt Arsenault ConstantInt *KnownSize
184372d796aSMatt Arsenault = mdconst::extract<ConstantInt>(MD->getOperand(I));
185*c986d476SMatt Arsenault UMin->replaceAllUsesWith(ConstantExpr::getIntegerCast(
186*c986d476SMatt Arsenault KnownSize, UMin->getType(), false));
187372d796aSMatt Arsenault } else {
188*c986d476SMatt Arsenault UMin->replaceAllUsesWith(ZextGroupSize);
189372d796aSMatt Arsenault }
190372d796aSMatt Arsenault
191372d796aSMatt Arsenault MadeChange = true;
192372d796aSMatt Arsenault }
193372d796aSMatt Arsenault }
194372d796aSMatt Arsenault }
195372d796aSMatt Arsenault }
196372d796aSMatt Arsenault
197372d796aSMatt Arsenault if (!HasReqdWorkGroupSize)
198372d796aSMatt Arsenault return MadeChange;
199372d796aSMatt Arsenault
200372d796aSMatt Arsenault // Eliminate any other loads we can from the dispatch packet.
201372d796aSMatt Arsenault for (int I = 0; I < 3; ++I) {
202372d796aSMatt Arsenault Value *GroupSize = WorkGroupSizes[I];
203372d796aSMatt Arsenault if (!GroupSize)
204372d796aSMatt Arsenault continue;
205372d796aSMatt Arsenault
206372d796aSMatt Arsenault ConstantInt *KnownSize = mdconst::extract<ConstantInt>(MD->getOperand(I));
207372d796aSMatt Arsenault GroupSize->replaceAllUsesWith(
208372d796aSMatt Arsenault ConstantExpr::getIntegerCast(KnownSize,
209372d796aSMatt Arsenault GroupSize->getType(),
210372d796aSMatt Arsenault false));
211372d796aSMatt Arsenault MadeChange = true;
212372d796aSMatt Arsenault }
213372d796aSMatt Arsenault
214372d796aSMatt Arsenault return MadeChange;
215372d796aSMatt Arsenault }
216372d796aSMatt Arsenault
217372d796aSMatt Arsenault // TODO: Move makeLIDRangeMetadata usage into here. Seem to not get
218372d796aSMatt Arsenault // TargetPassConfig for subtarget.
runOnModule(Module & M)219372d796aSMatt Arsenault bool AMDGPULowerKernelAttributes::runOnModule(Module &M) {
220372d796aSMatt Arsenault StringRef DispatchPtrName
221372d796aSMatt Arsenault = Intrinsic::getName(Intrinsic::amdgcn_dispatch_ptr);
222372d796aSMatt Arsenault
2237ecbe0c7SArthur Eubanks Function *DispatchPtr = M.getFunction(DispatchPtrName);
224372d796aSMatt Arsenault if (!DispatchPtr) // Dispatch ptr not used.
225372d796aSMatt Arsenault return false;
226372d796aSMatt Arsenault
227372d796aSMatt Arsenault bool MadeChange = false;
228372d796aSMatt Arsenault
229372d796aSMatt Arsenault SmallPtrSet<Instruction *, 4> HandledUses;
230372d796aSMatt Arsenault for (auto *U : DispatchPtr->users()) {
231372d796aSMatt Arsenault CallInst *CI = cast<CallInst>(U);
232372d796aSMatt Arsenault if (HandledUses.insert(CI).second) {
233372d796aSMatt Arsenault if (processUse(CI))
234372d796aSMatt Arsenault MadeChange = true;
235372d796aSMatt Arsenault }
236372d796aSMatt Arsenault }
237372d796aSMatt Arsenault
238372d796aSMatt Arsenault return MadeChange;
239372d796aSMatt Arsenault }
240372d796aSMatt Arsenault
241372d796aSMatt Arsenault INITIALIZE_PASS_BEGIN(AMDGPULowerKernelAttributes, DEBUG_TYPE,
242aff66b7eSStanislav Mekhanoshin "AMDGPU Kernel Attributes", false, false)
243aff66b7eSStanislav Mekhanoshin INITIALIZE_PASS_END(AMDGPULowerKernelAttributes, DEBUG_TYPE,
244aff66b7eSStanislav Mekhanoshin "AMDGPU Kernel Attributes", false, false)
245372d796aSMatt Arsenault
246372d796aSMatt Arsenault char AMDGPULowerKernelAttributes::ID = 0;
247372d796aSMatt Arsenault
createAMDGPULowerKernelAttributesPass()248372d796aSMatt Arsenault ModulePass *llvm::createAMDGPULowerKernelAttributesPass() {
249372d796aSMatt Arsenault return new AMDGPULowerKernelAttributes();
250372d796aSMatt Arsenault }
2517ecbe0c7SArthur Eubanks
2527ecbe0c7SArthur Eubanks PreservedAnalyses
run(Function & F,FunctionAnalysisManager & AM)2537ecbe0c7SArthur Eubanks AMDGPULowerKernelAttributesPass::run(Function &F, FunctionAnalysisManager &AM) {
2547ecbe0c7SArthur Eubanks StringRef DispatchPtrName =
2557ecbe0c7SArthur Eubanks Intrinsic::getName(Intrinsic::amdgcn_dispatch_ptr);
2567ecbe0c7SArthur Eubanks
2577ecbe0c7SArthur Eubanks Function *DispatchPtr = F.getParent()->getFunction(DispatchPtrName);
2587ecbe0c7SArthur Eubanks if (!DispatchPtr) // Dispatch ptr not used.
2597ecbe0c7SArthur Eubanks return PreservedAnalyses::all();
2607ecbe0c7SArthur Eubanks
2617ecbe0c7SArthur Eubanks for (Instruction &I : instructions(F)) {
2627ecbe0c7SArthur Eubanks if (CallInst *CI = dyn_cast<CallInst>(&I)) {
2637ecbe0c7SArthur Eubanks if (CI->getCalledFunction() == DispatchPtr)
2647ecbe0c7SArthur Eubanks processUse(CI);
2657ecbe0c7SArthur Eubanks }
2667ecbe0c7SArthur Eubanks }
2677ecbe0c7SArthur Eubanks
2687ecbe0c7SArthur Eubanks return PreservedAnalyses::all();
2697ecbe0c7SArthur Eubanks }
270