1372d796aSMatt Arsenault //===-- AMDGPULowerKernelAttributes.cpp ------------------------------------------===// 2372d796aSMatt Arsenault // 32946cd70SChandler Carruth // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 42946cd70SChandler Carruth // See https://llvm.org/LICENSE.txt for license information. 52946cd70SChandler Carruth // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6372d796aSMatt Arsenault // 7372d796aSMatt Arsenault //===----------------------------------------------------------------------===// 8372d796aSMatt Arsenault // 9372d796aSMatt Arsenault /// \file This pass does attempts to make use of reqd_work_group_size metadata 10372d796aSMatt Arsenault /// to eliminate loads from the dispatch packet and to constant fold OpenCL 11372d796aSMatt Arsenault /// get_local_size-like functions. 12372d796aSMatt Arsenault // 13372d796aSMatt Arsenault //===----------------------------------------------------------------------===// 14372d796aSMatt Arsenault 15372d796aSMatt Arsenault #include "AMDGPU.h" 16372d796aSMatt Arsenault #include "AMDGPUTargetMachine.h" 17372d796aSMatt Arsenault #include "llvm/Analysis/ValueTracking.h" 18372d796aSMatt Arsenault #include "llvm/CodeGen/Passes.h" 19372d796aSMatt Arsenault #include "llvm/CodeGen/TargetPassConfig.h" 20372d796aSMatt Arsenault #include "llvm/IR/Constants.h" 21372d796aSMatt Arsenault #include "llvm/IR/Function.h" 22*7ecbe0c7SArthur Eubanks #include "llvm/IR/InstIterator.h" 23372d796aSMatt Arsenault #include "llvm/IR/Instructions.h" 24*7ecbe0c7SArthur Eubanks #include "llvm/IR/PassManager.h" 25372d796aSMatt Arsenault #include "llvm/IR/PatternMatch.h" 26372d796aSMatt Arsenault #include "llvm/Pass.h" 27372d796aSMatt Arsenault 28372d796aSMatt Arsenault #define DEBUG_TYPE "amdgpu-lower-kernel-attributes" 29372d796aSMatt Arsenault 30372d796aSMatt Arsenault using namespace llvm; 31372d796aSMatt Arsenault 32372d796aSMatt Arsenault namespace { 33372d796aSMatt Arsenault 34372d796aSMatt Arsenault // Field offsets in hsa_kernel_dispatch_packet_t. 35372d796aSMatt Arsenault enum DispatchPackedOffsets { 36372d796aSMatt Arsenault WORKGROUP_SIZE_X = 4, 37372d796aSMatt Arsenault WORKGROUP_SIZE_Y = 6, 38372d796aSMatt Arsenault WORKGROUP_SIZE_Z = 8, 39372d796aSMatt Arsenault 40372d796aSMatt Arsenault GRID_SIZE_X = 12, 41372d796aSMatt Arsenault GRID_SIZE_Y = 16, 42372d796aSMatt Arsenault GRID_SIZE_Z = 20 43372d796aSMatt Arsenault }; 44372d796aSMatt Arsenault 45372d796aSMatt Arsenault class AMDGPULowerKernelAttributes : public ModulePass { 46372d796aSMatt Arsenault public: 47372d796aSMatt Arsenault static char ID; 48372d796aSMatt Arsenault 49372d796aSMatt Arsenault AMDGPULowerKernelAttributes() : ModulePass(ID) {} 50372d796aSMatt Arsenault 51372d796aSMatt Arsenault bool runOnModule(Module &M) override; 52372d796aSMatt Arsenault 53372d796aSMatt Arsenault StringRef getPassName() const override { 54372d796aSMatt Arsenault return "AMDGPU Kernel Attributes"; 55372d796aSMatt Arsenault } 56372d796aSMatt Arsenault 57372d796aSMatt Arsenault void getAnalysisUsage(AnalysisUsage &AU) const override { 58372d796aSMatt Arsenault AU.setPreservesAll(); 59372d796aSMatt Arsenault } 60372d796aSMatt Arsenault }; 61372d796aSMatt Arsenault 62372d796aSMatt Arsenault } // end anonymous namespace 63372d796aSMatt Arsenault 64*7ecbe0c7SArthur Eubanks static bool processUse(CallInst *CI) { 65372d796aSMatt Arsenault Function *F = CI->getParent()->getParent(); 66372d796aSMatt Arsenault 67372d796aSMatt Arsenault auto MD = F->getMetadata("reqd_work_group_size"); 68372d796aSMatt Arsenault const bool HasReqdWorkGroupSize = MD && MD->getNumOperands() == 3; 69372d796aSMatt Arsenault 70372d796aSMatt Arsenault const bool HasUniformWorkGroupSize = 71372d796aSMatt Arsenault F->getFnAttribute("uniform-work-group-size").getValueAsString() == "true"; 72372d796aSMatt Arsenault 73372d796aSMatt Arsenault if (!HasReqdWorkGroupSize && !HasUniformWorkGroupSize) 74372d796aSMatt Arsenault return false; 75372d796aSMatt Arsenault 76372d796aSMatt Arsenault Value *WorkGroupSizeX = nullptr; 77372d796aSMatt Arsenault Value *WorkGroupSizeY = nullptr; 78372d796aSMatt Arsenault Value *WorkGroupSizeZ = nullptr; 79372d796aSMatt Arsenault 80372d796aSMatt Arsenault Value *GridSizeX = nullptr; 81372d796aSMatt Arsenault Value *GridSizeY = nullptr; 82372d796aSMatt Arsenault Value *GridSizeZ = nullptr; 83372d796aSMatt Arsenault 84*7ecbe0c7SArthur Eubanks const DataLayout &DL = F->getParent()->getDataLayout(); 85372d796aSMatt Arsenault 86372d796aSMatt Arsenault // We expect to see several GEP users, casted to the appropriate type and 87372d796aSMatt Arsenault // loaded. 88372d796aSMatt Arsenault for (User *U : CI->users()) { 89372d796aSMatt Arsenault if (!U->hasOneUse()) 90372d796aSMatt Arsenault continue; 91372d796aSMatt Arsenault 92372d796aSMatt Arsenault int64_t Offset = 0; 93372d796aSMatt Arsenault if (GetPointerBaseWithConstantOffset(U, Offset, DL) != CI) 94372d796aSMatt Arsenault continue; 95372d796aSMatt Arsenault 96372d796aSMatt Arsenault auto *BCI = dyn_cast<BitCastInst>(*U->user_begin()); 97372d796aSMatt Arsenault if (!BCI || !BCI->hasOneUse()) 98372d796aSMatt Arsenault continue; 99372d796aSMatt Arsenault 100372d796aSMatt Arsenault auto *Load = dyn_cast<LoadInst>(*BCI->user_begin()); 101372d796aSMatt Arsenault if (!Load || !Load->isSimple()) 102372d796aSMatt Arsenault continue; 103372d796aSMatt Arsenault 104372d796aSMatt Arsenault unsigned LoadSize = DL.getTypeStoreSize(Load->getType()); 105372d796aSMatt Arsenault 106372d796aSMatt Arsenault // TODO: Handle merged loads. 107372d796aSMatt Arsenault switch (Offset) { 108372d796aSMatt Arsenault case WORKGROUP_SIZE_X: 109372d796aSMatt Arsenault if (LoadSize == 2) 110372d796aSMatt Arsenault WorkGroupSizeX = Load; 111372d796aSMatt Arsenault break; 112372d796aSMatt Arsenault case WORKGROUP_SIZE_Y: 113372d796aSMatt Arsenault if (LoadSize == 2) 114372d796aSMatt Arsenault WorkGroupSizeY = Load; 115372d796aSMatt Arsenault break; 116372d796aSMatt Arsenault case WORKGROUP_SIZE_Z: 117372d796aSMatt Arsenault if (LoadSize == 2) 118372d796aSMatt Arsenault WorkGroupSizeZ = Load; 119372d796aSMatt Arsenault break; 120372d796aSMatt Arsenault case GRID_SIZE_X: 121372d796aSMatt Arsenault if (LoadSize == 4) 122372d796aSMatt Arsenault GridSizeX = Load; 123372d796aSMatt Arsenault break; 124372d796aSMatt Arsenault case GRID_SIZE_Y: 125372d796aSMatt Arsenault if (LoadSize == 4) 126372d796aSMatt Arsenault GridSizeY = Load; 127372d796aSMatt Arsenault break; 128372d796aSMatt Arsenault case GRID_SIZE_Z: 129372d796aSMatt Arsenault if (LoadSize == 4) 130372d796aSMatt Arsenault GridSizeZ = Load; 131372d796aSMatt Arsenault break; 132372d796aSMatt Arsenault default: 133372d796aSMatt Arsenault break; 134372d796aSMatt Arsenault } 135372d796aSMatt Arsenault } 136372d796aSMatt Arsenault 137372d796aSMatt Arsenault // Pattern match the code used to handle partial workgroup dispatches in the 138372d796aSMatt Arsenault // library implementation of get_local_size, so the entire function can be 139372d796aSMatt Arsenault // constant folded with a known group size. 140372d796aSMatt Arsenault // 141372d796aSMatt Arsenault // uint r = grid_size - group_id * group_size; 142372d796aSMatt Arsenault // get_local_size = (r < group_size) ? r : group_size; 143372d796aSMatt Arsenault // 144372d796aSMatt Arsenault // If we have uniform-work-group-size (which is the default in OpenCL 1.2), 145372d796aSMatt Arsenault // the grid_size is required to be a multiple of group_size). In this case: 146372d796aSMatt Arsenault // 147372d796aSMatt Arsenault // grid_size - (group_id * group_size) < group_size 148372d796aSMatt Arsenault // -> 149372d796aSMatt Arsenault // grid_size < group_size + (group_id * group_size) 150372d796aSMatt Arsenault // 151372d796aSMatt Arsenault // (grid_size / group_size) < 1 + group_id 152372d796aSMatt Arsenault // 153372d796aSMatt Arsenault // grid_size / group_size is at least 1, so we can conclude the select 154372d796aSMatt Arsenault // condition is false (except for group_id == 0, where the select result is 155372d796aSMatt Arsenault // the same). 156372d796aSMatt Arsenault 157372d796aSMatt Arsenault bool MadeChange = false; 158372d796aSMatt Arsenault Value *WorkGroupSizes[3] = { WorkGroupSizeX, WorkGroupSizeY, WorkGroupSizeZ }; 159372d796aSMatt Arsenault Value *GridSizes[3] = { GridSizeX, GridSizeY, GridSizeZ }; 160372d796aSMatt Arsenault 161372d796aSMatt Arsenault for (int I = 0; HasUniformWorkGroupSize && I < 3; ++I) { 162372d796aSMatt Arsenault Value *GroupSize = WorkGroupSizes[I]; 163372d796aSMatt Arsenault Value *GridSize = GridSizes[I]; 164372d796aSMatt Arsenault if (!GroupSize || !GridSize) 165372d796aSMatt Arsenault continue; 166372d796aSMatt Arsenault 167372d796aSMatt Arsenault for (User *U : GroupSize->users()) { 168372d796aSMatt Arsenault auto *ZextGroupSize = dyn_cast<ZExtInst>(U); 169372d796aSMatt Arsenault if (!ZextGroupSize) 170372d796aSMatt Arsenault continue; 171372d796aSMatt Arsenault 172372d796aSMatt Arsenault for (User *ZextUser : ZextGroupSize->users()) { 173372d796aSMatt Arsenault auto *SI = dyn_cast<SelectInst>(ZextUser); 174372d796aSMatt Arsenault if (!SI) 175372d796aSMatt Arsenault continue; 176372d796aSMatt Arsenault 177372d796aSMatt Arsenault using namespace llvm::PatternMatch; 178372d796aSMatt Arsenault auto GroupIDIntrin = I == 0 ? 179372d796aSMatt Arsenault m_Intrinsic<Intrinsic::amdgcn_workgroup_id_x>() : 180372d796aSMatt Arsenault (I == 1 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_y>() : 181372d796aSMatt Arsenault m_Intrinsic<Intrinsic::amdgcn_workgroup_id_z>()); 182372d796aSMatt Arsenault 183372d796aSMatt Arsenault auto SubExpr = m_Sub(m_Specific(GridSize), 184372d796aSMatt Arsenault m_Mul(GroupIDIntrin, m_Specific(ZextGroupSize))); 185372d796aSMatt Arsenault 186372d796aSMatt Arsenault ICmpInst::Predicate Pred; 187372d796aSMatt Arsenault if (match(SI, 188372d796aSMatt Arsenault m_Select(m_ICmp(Pred, SubExpr, m_Specific(ZextGroupSize)), 189372d796aSMatt Arsenault SubExpr, 190372d796aSMatt Arsenault m_Specific(ZextGroupSize))) && 191372d796aSMatt Arsenault Pred == ICmpInst::ICMP_ULT) { 192372d796aSMatt Arsenault if (HasReqdWorkGroupSize) { 193372d796aSMatt Arsenault ConstantInt *KnownSize 194372d796aSMatt Arsenault = mdconst::extract<ConstantInt>(MD->getOperand(I)); 195372d796aSMatt Arsenault SI->replaceAllUsesWith(ConstantExpr::getIntegerCast(KnownSize, 196372d796aSMatt Arsenault SI->getType(), 197372d796aSMatt Arsenault false)); 198372d796aSMatt Arsenault } else { 199372d796aSMatt Arsenault SI->replaceAllUsesWith(ZextGroupSize); 200372d796aSMatt Arsenault } 201372d796aSMatt Arsenault 202372d796aSMatt Arsenault MadeChange = true; 203372d796aSMatt Arsenault } 204372d796aSMatt Arsenault } 205372d796aSMatt Arsenault } 206372d796aSMatt Arsenault } 207372d796aSMatt Arsenault 208372d796aSMatt Arsenault if (!HasReqdWorkGroupSize) 209372d796aSMatt Arsenault return MadeChange; 210372d796aSMatt Arsenault 211372d796aSMatt Arsenault // Eliminate any other loads we can from the dispatch packet. 212372d796aSMatt Arsenault for (int I = 0; I < 3; ++I) { 213372d796aSMatt Arsenault Value *GroupSize = WorkGroupSizes[I]; 214372d796aSMatt Arsenault if (!GroupSize) 215372d796aSMatt Arsenault continue; 216372d796aSMatt Arsenault 217372d796aSMatt Arsenault ConstantInt *KnownSize = mdconst::extract<ConstantInt>(MD->getOperand(I)); 218372d796aSMatt Arsenault GroupSize->replaceAllUsesWith( 219372d796aSMatt Arsenault ConstantExpr::getIntegerCast(KnownSize, 220372d796aSMatt Arsenault GroupSize->getType(), 221372d796aSMatt Arsenault false)); 222372d796aSMatt Arsenault MadeChange = true; 223372d796aSMatt Arsenault } 224372d796aSMatt Arsenault 225372d796aSMatt Arsenault return MadeChange; 226372d796aSMatt Arsenault } 227372d796aSMatt Arsenault 228372d796aSMatt Arsenault // TODO: Move makeLIDRangeMetadata usage into here. Seem to not get 229372d796aSMatt Arsenault // TargetPassConfig for subtarget. 230372d796aSMatt Arsenault bool AMDGPULowerKernelAttributes::runOnModule(Module &M) { 231372d796aSMatt Arsenault StringRef DispatchPtrName 232372d796aSMatt Arsenault = Intrinsic::getName(Intrinsic::amdgcn_dispatch_ptr); 233372d796aSMatt Arsenault 234*7ecbe0c7SArthur Eubanks Function *DispatchPtr = M.getFunction(DispatchPtrName); 235372d796aSMatt Arsenault if (!DispatchPtr) // Dispatch ptr not used. 236372d796aSMatt Arsenault return false; 237372d796aSMatt Arsenault 238372d796aSMatt Arsenault bool MadeChange = false; 239372d796aSMatt Arsenault 240372d796aSMatt Arsenault SmallPtrSet<Instruction *, 4> HandledUses; 241372d796aSMatt Arsenault for (auto *U : DispatchPtr->users()) { 242372d796aSMatt Arsenault CallInst *CI = cast<CallInst>(U); 243372d796aSMatt Arsenault if (HandledUses.insert(CI).second) { 244372d796aSMatt Arsenault if (processUse(CI)) 245372d796aSMatt Arsenault MadeChange = true; 246372d796aSMatt Arsenault } 247372d796aSMatt Arsenault } 248372d796aSMatt Arsenault 249372d796aSMatt Arsenault return MadeChange; 250372d796aSMatt Arsenault } 251372d796aSMatt Arsenault 252372d796aSMatt Arsenault INITIALIZE_PASS_BEGIN(AMDGPULowerKernelAttributes, DEBUG_TYPE, 253372d796aSMatt Arsenault "AMDGPU IR optimizations", false, false) 254372d796aSMatt Arsenault INITIALIZE_PASS_END(AMDGPULowerKernelAttributes, DEBUG_TYPE, "AMDGPU IR optimizations", 255372d796aSMatt Arsenault false, false) 256372d796aSMatt Arsenault 257372d796aSMatt Arsenault char AMDGPULowerKernelAttributes::ID = 0; 258372d796aSMatt Arsenault 259372d796aSMatt Arsenault ModulePass *llvm::createAMDGPULowerKernelAttributesPass() { 260372d796aSMatt Arsenault return new AMDGPULowerKernelAttributes(); 261372d796aSMatt Arsenault } 262*7ecbe0c7SArthur Eubanks 263*7ecbe0c7SArthur Eubanks PreservedAnalyses 264*7ecbe0c7SArthur Eubanks AMDGPULowerKernelAttributesPass::run(Function &F, FunctionAnalysisManager &AM) { 265*7ecbe0c7SArthur Eubanks StringRef DispatchPtrName = 266*7ecbe0c7SArthur Eubanks Intrinsic::getName(Intrinsic::amdgcn_dispatch_ptr); 267*7ecbe0c7SArthur Eubanks 268*7ecbe0c7SArthur Eubanks Function *DispatchPtr = F.getParent()->getFunction(DispatchPtrName); 269*7ecbe0c7SArthur Eubanks if (!DispatchPtr) // Dispatch ptr not used. 270*7ecbe0c7SArthur Eubanks return PreservedAnalyses::all(); 271*7ecbe0c7SArthur Eubanks 272*7ecbe0c7SArthur Eubanks for (Instruction &I : instructions(F)) { 273*7ecbe0c7SArthur Eubanks if (CallInst *CI = dyn_cast<CallInst>(&I)) { 274*7ecbe0c7SArthur Eubanks if (CI->getCalledFunction() == DispatchPtr) 275*7ecbe0c7SArthur Eubanks processUse(CI); 276*7ecbe0c7SArthur Eubanks } 277*7ecbe0c7SArthur Eubanks } 278*7ecbe0c7SArthur Eubanks 279*7ecbe0c7SArthur Eubanks return PreservedAnalyses::all(); 280*7ecbe0c7SArthur Eubanks } 281