10b57cec5SDimitry Andric //===-- AMDGPULowerKernelAttributes.cpp ------------------------------------------===//
20b57cec5SDimitry Andric //
30b57cec5SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
40b57cec5SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
50b57cec5SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60b57cec5SDimitry Andric //
70b57cec5SDimitry Andric //===----------------------------------------------------------------------===//
80b57cec5SDimitry Andric //
90b57cec5SDimitry Andric /// \file This pass does attempts to make use of reqd_work_group_size metadata
100b57cec5SDimitry Andric /// to eliminate loads from the dispatch packet and to constant fold OpenCL
110b57cec5SDimitry Andric /// get_local_size-like functions.
120b57cec5SDimitry Andric //
130b57cec5SDimitry Andric //===----------------------------------------------------------------------===//
140b57cec5SDimitry Andric
150b57cec5SDimitry Andric #include "AMDGPU.h"
160b57cec5SDimitry Andric #include "llvm/Analysis/ValueTracking.h"
170b57cec5SDimitry Andric #include "llvm/CodeGen/Passes.h"
180b57cec5SDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h"
190b57cec5SDimitry Andric #include "llvm/IR/Constants.h"
200b57cec5SDimitry Andric #include "llvm/IR/Function.h"
21af732203SDimitry Andric #include "llvm/IR/InstIterator.h"
220b57cec5SDimitry Andric #include "llvm/IR/Instructions.h"
23af732203SDimitry Andric #include "llvm/IR/IntrinsicsAMDGPU.h"
240b57cec5SDimitry Andric #include "llvm/IR/PatternMatch.h"
250b57cec5SDimitry Andric #include "llvm/Pass.h"
260b57cec5SDimitry Andric
270b57cec5SDimitry Andric #define DEBUG_TYPE "amdgpu-lower-kernel-attributes"
280b57cec5SDimitry Andric
290b57cec5SDimitry Andric using namespace llvm;
300b57cec5SDimitry Andric
310b57cec5SDimitry Andric namespace {
320b57cec5SDimitry Andric
330b57cec5SDimitry Andric // Field offsets in hsa_kernel_dispatch_packet_t.
340b57cec5SDimitry Andric enum DispatchPackedOffsets {
350b57cec5SDimitry Andric WORKGROUP_SIZE_X = 4,
360b57cec5SDimitry Andric WORKGROUP_SIZE_Y = 6,
370b57cec5SDimitry Andric WORKGROUP_SIZE_Z = 8,
380b57cec5SDimitry Andric
390b57cec5SDimitry Andric GRID_SIZE_X = 12,
400b57cec5SDimitry Andric GRID_SIZE_Y = 16,
410b57cec5SDimitry Andric GRID_SIZE_Z = 20
420b57cec5SDimitry Andric };
430b57cec5SDimitry Andric
440b57cec5SDimitry Andric class AMDGPULowerKernelAttributes : public ModulePass {
450b57cec5SDimitry Andric public:
460b57cec5SDimitry Andric static char ID;
470b57cec5SDimitry Andric
AMDGPULowerKernelAttributes()480b57cec5SDimitry Andric AMDGPULowerKernelAttributes() : ModulePass(ID) {}
490b57cec5SDimitry Andric
500b57cec5SDimitry Andric bool runOnModule(Module &M) override;
510b57cec5SDimitry Andric
getPassName() const520b57cec5SDimitry Andric StringRef getPassName() const override {
530b57cec5SDimitry Andric return "AMDGPU Kernel Attributes";
540b57cec5SDimitry Andric }
550b57cec5SDimitry Andric
getAnalysisUsage(AnalysisUsage & AU) const560b57cec5SDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override {
570b57cec5SDimitry Andric AU.setPreservesAll();
580b57cec5SDimitry Andric }
590b57cec5SDimitry Andric };
600b57cec5SDimitry Andric
610b57cec5SDimitry Andric } // end anonymous namespace
620b57cec5SDimitry Andric
processUse(CallInst * CI)63af732203SDimitry Andric static bool processUse(CallInst *CI) {
640b57cec5SDimitry Andric Function *F = CI->getParent()->getParent();
650b57cec5SDimitry Andric
660b57cec5SDimitry Andric auto MD = F->getMetadata("reqd_work_group_size");
670b57cec5SDimitry Andric const bool HasReqdWorkGroupSize = MD && MD->getNumOperands() == 3;
680b57cec5SDimitry Andric
690b57cec5SDimitry Andric const bool HasUniformWorkGroupSize =
70*5f7ddb14SDimitry Andric F->getFnAttribute("uniform-work-group-size").getValueAsBool();
710b57cec5SDimitry Andric
720b57cec5SDimitry Andric if (!HasReqdWorkGroupSize && !HasUniformWorkGroupSize)
730b57cec5SDimitry Andric return false;
740b57cec5SDimitry Andric
750b57cec5SDimitry Andric Value *WorkGroupSizeX = nullptr;
760b57cec5SDimitry Andric Value *WorkGroupSizeY = nullptr;
770b57cec5SDimitry Andric Value *WorkGroupSizeZ = nullptr;
780b57cec5SDimitry Andric
790b57cec5SDimitry Andric Value *GridSizeX = nullptr;
800b57cec5SDimitry Andric Value *GridSizeY = nullptr;
810b57cec5SDimitry Andric Value *GridSizeZ = nullptr;
820b57cec5SDimitry Andric
83af732203SDimitry Andric const DataLayout &DL = F->getParent()->getDataLayout();
840b57cec5SDimitry Andric
850b57cec5SDimitry Andric // We expect to see several GEP users, casted to the appropriate type and
860b57cec5SDimitry Andric // loaded.
870b57cec5SDimitry Andric for (User *U : CI->users()) {
880b57cec5SDimitry Andric if (!U->hasOneUse())
890b57cec5SDimitry Andric continue;
900b57cec5SDimitry Andric
910b57cec5SDimitry Andric int64_t Offset = 0;
920b57cec5SDimitry Andric if (GetPointerBaseWithConstantOffset(U, Offset, DL) != CI)
930b57cec5SDimitry Andric continue;
940b57cec5SDimitry Andric
950b57cec5SDimitry Andric auto *BCI = dyn_cast<BitCastInst>(*U->user_begin());
960b57cec5SDimitry Andric if (!BCI || !BCI->hasOneUse())
970b57cec5SDimitry Andric continue;
980b57cec5SDimitry Andric
990b57cec5SDimitry Andric auto *Load = dyn_cast<LoadInst>(*BCI->user_begin());
1000b57cec5SDimitry Andric if (!Load || !Load->isSimple())
1010b57cec5SDimitry Andric continue;
1020b57cec5SDimitry Andric
1030b57cec5SDimitry Andric unsigned LoadSize = DL.getTypeStoreSize(Load->getType());
1040b57cec5SDimitry Andric
1050b57cec5SDimitry Andric // TODO: Handle merged loads.
1060b57cec5SDimitry Andric switch (Offset) {
1070b57cec5SDimitry Andric case WORKGROUP_SIZE_X:
1080b57cec5SDimitry Andric if (LoadSize == 2)
1090b57cec5SDimitry Andric WorkGroupSizeX = Load;
1100b57cec5SDimitry Andric break;
1110b57cec5SDimitry Andric case WORKGROUP_SIZE_Y:
1120b57cec5SDimitry Andric if (LoadSize == 2)
1130b57cec5SDimitry Andric WorkGroupSizeY = Load;
1140b57cec5SDimitry Andric break;
1150b57cec5SDimitry Andric case WORKGROUP_SIZE_Z:
1160b57cec5SDimitry Andric if (LoadSize == 2)
1170b57cec5SDimitry Andric WorkGroupSizeZ = Load;
1180b57cec5SDimitry Andric break;
1190b57cec5SDimitry Andric case GRID_SIZE_X:
1200b57cec5SDimitry Andric if (LoadSize == 4)
1210b57cec5SDimitry Andric GridSizeX = Load;
1220b57cec5SDimitry Andric break;
1230b57cec5SDimitry Andric case GRID_SIZE_Y:
1240b57cec5SDimitry Andric if (LoadSize == 4)
1250b57cec5SDimitry Andric GridSizeY = Load;
1260b57cec5SDimitry Andric break;
1270b57cec5SDimitry Andric case GRID_SIZE_Z:
1280b57cec5SDimitry Andric if (LoadSize == 4)
1290b57cec5SDimitry Andric GridSizeZ = Load;
1300b57cec5SDimitry Andric break;
1310b57cec5SDimitry Andric default:
1320b57cec5SDimitry Andric break;
1330b57cec5SDimitry Andric }
1340b57cec5SDimitry Andric }
1350b57cec5SDimitry Andric
1360b57cec5SDimitry Andric // Pattern match the code used to handle partial workgroup dispatches in the
1370b57cec5SDimitry Andric // library implementation of get_local_size, so the entire function can be
1380b57cec5SDimitry Andric // constant folded with a known group size.
1390b57cec5SDimitry Andric //
1400b57cec5SDimitry Andric // uint r = grid_size - group_id * group_size;
1410b57cec5SDimitry Andric // get_local_size = (r < group_size) ? r : group_size;
1420b57cec5SDimitry Andric //
1430b57cec5SDimitry Andric // If we have uniform-work-group-size (which is the default in OpenCL 1.2),
1440b57cec5SDimitry Andric // the grid_size is required to be a multiple of group_size). In this case:
1450b57cec5SDimitry Andric //
1460b57cec5SDimitry Andric // grid_size - (group_id * group_size) < group_size
1470b57cec5SDimitry Andric // ->
1480b57cec5SDimitry Andric // grid_size < group_size + (group_id * group_size)
1490b57cec5SDimitry Andric //
1500b57cec5SDimitry Andric // (grid_size / group_size) < 1 + group_id
1510b57cec5SDimitry Andric //
1520b57cec5SDimitry Andric // grid_size / group_size is at least 1, so we can conclude the select
1530b57cec5SDimitry Andric // condition is false (except for group_id == 0, where the select result is
1540b57cec5SDimitry Andric // the same).
1550b57cec5SDimitry Andric
1560b57cec5SDimitry Andric bool MadeChange = false;
1570b57cec5SDimitry Andric Value *WorkGroupSizes[3] = { WorkGroupSizeX, WorkGroupSizeY, WorkGroupSizeZ };
1580b57cec5SDimitry Andric Value *GridSizes[3] = { GridSizeX, GridSizeY, GridSizeZ };
1590b57cec5SDimitry Andric
1600b57cec5SDimitry Andric for (int I = 0; HasUniformWorkGroupSize && I < 3; ++I) {
1610b57cec5SDimitry Andric Value *GroupSize = WorkGroupSizes[I];
1620b57cec5SDimitry Andric Value *GridSize = GridSizes[I];
1630b57cec5SDimitry Andric if (!GroupSize || !GridSize)
1640b57cec5SDimitry Andric continue;
1650b57cec5SDimitry Andric
1660b57cec5SDimitry Andric for (User *U : GroupSize->users()) {
1670b57cec5SDimitry Andric auto *ZextGroupSize = dyn_cast<ZExtInst>(U);
1680b57cec5SDimitry Andric if (!ZextGroupSize)
1690b57cec5SDimitry Andric continue;
1700b57cec5SDimitry Andric
1710b57cec5SDimitry Andric for (User *ZextUser : ZextGroupSize->users()) {
1720b57cec5SDimitry Andric auto *SI = dyn_cast<SelectInst>(ZextUser);
1730b57cec5SDimitry Andric if (!SI)
1740b57cec5SDimitry Andric continue;
1750b57cec5SDimitry Andric
1760b57cec5SDimitry Andric using namespace llvm::PatternMatch;
1770b57cec5SDimitry Andric auto GroupIDIntrin = I == 0 ?
1780b57cec5SDimitry Andric m_Intrinsic<Intrinsic::amdgcn_workgroup_id_x>() :
1790b57cec5SDimitry Andric (I == 1 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_y>() :
1800b57cec5SDimitry Andric m_Intrinsic<Intrinsic::amdgcn_workgroup_id_z>());
1810b57cec5SDimitry Andric
1820b57cec5SDimitry Andric auto SubExpr = m_Sub(m_Specific(GridSize),
1830b57cec5SDimitry Andric m_Mul(GroupIDIntrin, m_Specific(ZextGroupSize)));
1840b57cec5SDimitry Andric
1850b57cec5SDimitry Andric ICmpInst::Predicate Pred;
1860b57cec5SDimitry Andric if (match(SI,
1870b57cec5SDimitry Andric m_Select(m_ICmp(Pred, SubExpr, m_Specific(ZextGroupSize)),
1880b57cec5SDimitry Andric SubExpr,
1890b57cec5SDimitry Andric m_Specific(ZextGroupSize))) &&
1900b57cec5SDimitry Andric Pred == ICmpInst::ICMP_ULT) {
1910b57cec5SDimitry Andric if (HasReqdWorkGroupSize) {
1920b57cec5SDimitry Andric ConstantInt *KnownSize
1930b57cec5SDimitry Andric = mdconst::extract<ConstantInt>(MD->getOperand(I));
1940b57cec5SDimitry Andric SI->replaceAllUsesWith(ConstantExpr::getIntegerCast(KnownSize,
1950b57cec5SDimitry Andric SI->getType(),
1960b57cec5SDimitry Andric false));
1970b57cec5SDimitry Andric } else {
1980b57cec5SDimitry Andric SI->replaceAllUsesWith(ZextGroupSize);
1990b57cec5SDimitry Andric }
2000b57cec5SDimitry Andric
2010b57cec5SDimitry Andric MadeChange = true;
2020b57cec5SDimitry Andric }
2030b57cec5SDimitry Andric }
2040b57cec5SDimitry Andric }
2050b57cec5SDimitry Andric }
2060b57cec5SDimitry Andric
2070b57cec5SDimitry Andric if (!HasReqdWorkGroupSize)
2080b57cec5SDimitry Andric return MadeChange;
2090b57cec5SDimitry Andric
2100b57cec5SDimitry Andric // Eliminate any other loads we can from the dispatch packet.
2110b57cec5SDimitry Andric for (int I = 0; I < 3; ++I) {
2120b57cec5SDimitry Andric Value *GroupSize = WorkGroupSizes[I];
2130b57cec5SDimitry Andric if (!GroupSize)
2140b57cec5SDimitry Andric continue;
2150b57cec5SDimitry Andric
2160b57cec5SDimitry Andric ConstantInt *KnownSize = mdconst::extract<ConstantInt>(MD->getOperand(I));
2170b57cec5SDimitry Andric GroupSize->replaceAllUsesWith(
2180b57cec5SDimitry Andric ConstantExpr::getIntegerCast(KnownSize,
2190b57cec5SDimitry Andric GroupSize->getType(),
2200b57cec5SDimitry Andric false));
2210b57cec5SDimitry Andric MadeChange = true;
2220b57cec5SDimitry Andric }
2230b57cec5SDimitry Andric
2240b57cec5SDimitry Andric return MadeChange;
2250b57cec5SDimitry Andric }
2260b57cec5SDimitry Andric
2270b57cec5SDimitry Andric // TODO: Move makeLIDRangeMetadata usage into here. Seem to not get
2280b57cec5SDimitry Andric // TargetPassConfig for subtarget.
runOnModule(Module & M)2290b57cec5SDimitry Andric bool AMDGPULowerKernelAttributes::runOnModule(Module &M) {
2300b57cec5SDimitry Andric StringRef DispatchPtrName
2310b57cec5SDimitry Andric = Intrinsic::getName(Intrinsic::amdgcn_dispatch_ptr);
2320b57cec5SDimitry Andric
233af732203SDimitry Andric Function *DispatchPtr = M.getFunction(DispatchPtrName);
2340b57cec5SDimitry Andric if (!DispatchPtr) // Dispatch ptr not used.
2350b57cec5SDimitry Andric return false;
2360b57cec5SDimitry Andric
2370b57cec5SDimitry Andric bool MadeChange = false;
2380b57cec5SDimitry Andric
2390b57cec5SDimitry Andric SmallPtrSet<Instruction *, 4> HandledUses;
2400b57cec5SDimitry Andric for (auto *U : DispatchPtr->users()) {
2410b57cec5SDimitry Andric CallInst *CI = cast<CallInst>(U);
2420b57cec5SDimitry Andric if (HandledUses.insert(CI).second) {
2430b57cec5SDimitry Andric if (processUse(CI))
2440b57cec5SDimitry Andric MadeChange = true;
2450b57cec5SDimitry Andric }
2460b57cec5SDimitry Andric }
2470b57cec5SDimitry Andric
2480b57cec5SDimitry Andric return MadeChange;
2490b57cec5SDimitry Andric }
2500b57cec5SDimitry Andric
2510b57cec5SDimitry Andric INITIALIZE_PASS_BEGIN(AMDGPULowerKernelAttributes, DEBUG_TYPE,
252*5f7ddb14SDimitry Andric "AMDGPU Kernel Attributes", false, false)
253*5f7ddb14SDimitry Andric INITIALIZE_PASS_END(AMDGPULowerKernelAttributes, DEBUG_TYPE,
254*5f7ddb14SDimitry Andric "AMDGPU Kernel Attributes", false, false)
2550b57cec5SDimitry Andric
2560b57cec5SDimitry Andric char AMDGPULowerKernelAttributes::ID = 0;
2570b57cec5SDimitry Andric
createAMDGPULowerKernelAttributesPass()2580b57cec5SDimitry Andric ModulePass *llvm::createAMDGPULowerKernelAttributesPass() {
2590b57cec5SDimitry Andric return new AMDGPULowerKernelAttributes();
2600b57cec5SDimitry Andric }
261af732203SDimitry Andric
262af732203SDimitry Andric PreservedAnalyses
run(Function & F,FunctionAnalysisManager & AM)263af732203SDimitry Andric AMDGPULowerKernelAttributesPass::run(Function &F, FunctionAnalysisManager &AM) {
264af732203SDimitry Andric StringRef DispatchPtrName =
265af732203SDimitry Andric Intrinsic::getName(Intrinsic::amdgcn_dispatch_ptr);
266af732203SDimitry Andric
267af732203SDimitry Andric Function *DispatchPtr = F.getParent()->getFunction(DispatchPtrName);
268af732203SDimitry Andric if (!DispatchPtr) // Dispatch ptr not used.
269af732203SDimitry Andric return PreservedAnalyses::all();
270af732203SDimitry Andric
271af732203SDimitry Andric for (Instruction &I : instructions(F)) {
272af732203SDimitry Andric if (CallInst *CI = dyn_cast<CallInst>(&I)) {
273af732203SDimitry Andric if (CI->getCalledFunction() == DispatchPtr)
274af732203SDimitry Andric processUse(CI);
275af732203SDimitry Andric }
276af732203SDimitry Andric }
277af732203SDimitry Andric
278af732203SDimitry Andric return PreservedAnalyses::all();
279af732203SDimitry Andric }
280