1*3e01af09SChristian Sigg //===- InferIntRangeInterfaceImpls.cpp - Integer range impls for gpu -===//
2*3e01af09SChristian Sigg //
3*3e01af09SChristian Sigg // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*3e01af09SChristian Sigg // See https://llvm.org/LICENSE.txt for license information.
5*3e01af09SChristian Sigg // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*3e01af09SChristian Sigg //
7*3e01af09SChristian Sigg //===----------------------------------------------------------------------===//
8*3e01af09SChristian Sigg
9*3e01af09SChristian Sigg #include "mlir/Dialect/GPU/IR/GPUDialect.h"
10*3e01af09SChristian Sigg #include "mlir/Interfaces/InferIntRangeInterface.h"
11*3e01af09SChristian Sigg
12*3e01af09SChristian Sigg using namespace mlir;
13*3e01af09SChristian Sigg using namespace mlir::gpu;
14*3e01af09SChristian Sigg
15*3e01af09SChristian Sigg // Maximum grid and block dimensions of all known GPUs are less than 2^32.
16*3e01af09SChristian Sigg static constexpr uint64_t kMaxDim = std::numeric_limits<uint32_t>::max();
17*3e01af09SChristian Sigg // Maximum subgroups are no larger than 128.
18*3e01af09SChristian Sigg static constexpr uint64_t kMaxSubgroupSize = 128;
19*3e01af09SChristian Sigg
getIndexRange(uint64_t umin,uint64_t umax)20*3e01af09SChristian Sigg static ConstantIntRanges getIndexRange(uint64_t umin, uint64_t umax) {
21*3e01af09SChristian Sigg unsigned width = IndexType::kInternalStorageBitWidth;
22*3e01af09SChristian Sigg return ConstantIntRanges::fromUnsigned(APInt(width, umin),
23*3e01af09SChristian Sigg APInt(width, umax));
24*3e01af09SChristian Sigg }
25*3e01af09SChristian Sigg
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)26*3e01af09SChristian Sigg void BlockDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
27*3e01af09SChristian Sigg SetIntRangeFn setResultRange) {
28*3e01af09SChristian Sigg setResultRange(getResult(), getIndexRange(1, kMaxDim));
29*3e01af09SChristian Sigg }
30*3e01af09SChristian Sigg
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)31*3e01af09SChristian Sigg void BlockIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
32*3e01af09SChristian Sigg SetIntRangeFn setResultRange) {
33*3e01af09SChristian Sigg setResultRange(getResult(), getIndexRange(0, kMaxDim - 1));
34*3e01af09SChristian Sigg }
35*3e01af09SChristian Sigg
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)36*3e01af09SChristian Sigg void GridDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
37*3e01af09SChristian Sigg SetIntRangeFn setResultRange) {
38*3e01af09SChristian Sigg setResultRange(getResult(), getIndexRange(1, kMaxDim));
39*3e01af09SChristian Sigg }
40*3e01af09SChristian Sigg
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)41*3e01af09SChristian Sigg void ThreadIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
42*3e01af09SChristian Sigg SetIntRangeFn setResultRange) {
43*3e01af09SChristian Sigg setResultRange(getResult(), getIndexRange(0, kMaxDim - 1));
44*3e01af09SChristian Sigg }
45*3e01af09SChristian Sigg
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)46*3e01af09SChristian Sigg void LaneIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
47*3e01af09SChristian Sigg SetIntRangeFn setResultRange) {
48*3e01af09SChristian Sigg setResultRange(getResult(), getIndexRange(0, kMaxSubgroupSize - 1));
49*3e01af09SChristian Sigg }
50*3e01af09SChristian Sigg
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)51*3e01af09SChristian Sigg void SubgroupIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
52*3e01af09SChristian Sigg SetIntRangeFn setResultRange) {
53*3e01af09SChristian Sigg setResultRange(getResult(), getIndexRange(0, kMaxDim - 1));
54*3e01af09SChristian Sigg }
55*3e01af09SChristian Sigg
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)56*3e01af09SChristian Sigg void GlobalIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
57*3e01af09SChristian Sigg SetIntRangeFn setResultRange) {
58*3e01af09SChristian Sigg setResultRange(getResult(),
59*3e01af09SChristian Sigg getIndexRange(0, std::numeric_limits<int64_t>::max()));
60*3e01af09SChristian Sigg }
61*3e01af09SChristian Sigg
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)62*3e01af09SChristian Sigg void NumSubgroupsOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
63*3e01af09SChristian Sigg SetIntRangeFn setResultRange) {
64*3e01af09SChristian Sigg setResultRange(getResult(), getIndexRange(1, kMaxDim));
65*3e01af09SChristian Sigg }
66*3e01af09SChristian Sigg
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)67*3e01af09SChristian Sigg void SubgroupSizeOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
68*3e01af09SChristian Sigg SetIntRangeFn setResultRange) {
69*3e01af09SChristian Sigg setResultRange(getResult(), getIndexRange(1, kMaxSubgroupSize));
70*3e01af09SChristian Sigg }
71*3e01af09SChristian Sigg
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)72*3e01af09SChristian Sigg void LaunchOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
73*3e01af09SChristian Sigg SetIntRangeFn setResultRange) {
74*3e01af09SChristian Sigg auto setRange = [&](ConstantIntRanges argRange, Value dimResult,
75*3e01af09SChristian Sigg Value idxResult) {
76*3e01af09SChristian Sigg if (argRange.umin().getBitWidth() != IndexType::kInternalStorageBitWidth)
77*3e01af09SChristian Sigg return;
78*3e01af09SChristian Sigg ConstantIntRanges dimRange =
79*3e01af09SChristian Sigg argRange.intersection(getIndexRange(1, kMaxDim));
80*3e01af09SChristian Sigg setResultRange(dimResult, dimRange);
81*3e01af09SChristian Sigg ConstantIntRanges idxRange =
82*3e01af09SChristian Sigg getIndexRange(0, dimRange.umax().getZExtValue() - 1);
83*3e01af09SChristian Sigg setResultRange(idxResult, idxRange);
84*3e01af09SChristian Sigg };
85*3e01af09SChristian Sigg
86*3e01af09SChristian Sigg argRanges = argRanges.drop_front(asyncDependencies().size());
87*3e01af09SChristian Sigg KernelDim3 gridDims = getGridSize();
88*3e01af09SChristian Sigg KernelDim3 blockIds = getBlockIds();
89*3e01af09SChristian Sigg setRange(argRanges[0], gridDims.x, blockIds.x);
90*3e01af09SChristian Sigg setRange(argRanges[1], gridDims.y, blockIds.y);
91*3e01af09SChristian Sigg setRange(argRanges[2], gridDims.z, blockIds.z);
92*3e01af09SChristian Sigg KernelDim3 blockDims = getBlockSize();
93*3e01af09SChristian Sigg KernelDim3 threadIds = getThreadIds();
94*3e01af09SChristian Sigg setRange(argRanges[3], blockDims.x, threadIds.x);
95*3e01af09SChristian Sigg setRange(argRanges[4], blockDims.y, threadIds.y);
96*3e01af09SChristian Sigg setRange(argRanges[5], blockDims.z, threadIds.z);
97*3e01af09SChristian Sigg }
98