1 //===- InferIntRangeInterfaceImpls.cpp - Integer range impls for gpu -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
10 #include "mlir/Interfaces/InferIntRangeInterface.h"
11
12 using namespace mlir;
13 using namespace mlir::gpu;
14
15 // Maximum grid and block dimensions of all known GPUs are less than 2^32.
16 static constexpr uint64_t kMaxDim = std::numeric_limits<uint32_t>::max();
17 // Maximum subgroups are no larger than 128.
18 static constexpr uint64_t kMaxSubgroupSize = 128;
19
getIndexRange(uint64_t umin,uint64_t umax)20 static ConstantIntRanges getIndexRange(uint64_t umin, uint64_t umax) {
21 unsigned width = IndexType::kInternalStorageBitWidth;
22 return ConstantIntRanges::fromUnsigned(APInt(width, umin),
23 APInt(width, umax));
24 }
25
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)26 void BlockDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
27 SetIntRangeFn setResultRange) {
28 setResultRange(getResult(), getIndexRange(1, kMaxDim));
29 }
30
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)31 void BlockIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
32 SetIntRangeFn setResultRange) {
33 setResultRange(getResult(), getIndexRange(0, kMaxDim - 1));
34 }
35
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)36 void GridDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
37 SetIntRangeFn setResultRange) {
38 setResultRange(getResult(), getIndexRange(1, kMaxDim));
39 }
40
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)41 void ThreadIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
42 SetIntRangeFn setResultRange) {
43 setResultRange(getResult(), getIndexRange(0, kMaxDim - 1));
44 }
45
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)46 void LaneIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
47 SetIntRangeFn setResultRange) {
48 setResultRange(getResult(), getIndexRange(0, kMaxSubgroupSize - 1));
49 }
50
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)51 void SubgroupIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
52 SetIntRangeFn setResultRange) {
53 setResultRange(getResult(), getIndexRange(0, kMaxDim - 1));
54 }
55
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)56 void GlobalIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
57 SetIntRangeFn setResultRange) {
58 setResultRange(getResult(),
59 getIndexRange(0, std::numeric_limits<int64_t>::max()));
60 }
61
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)62 void NumSubgroupsOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
63 SetIntRangeFn setResultRange) {
64 setResultRange(getResult(), getIndexRange(1, kMaxDim));
65 }
66
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)67 void SubgroupSizeOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
68 SetIntRangeFn setResultRange) {
69 setResultRange(getResult(), getIndexRange(1, kMaxSubgroupSize));
70 }
71
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)72 void LaunchOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
73 SetIntRangeFn setResultRange) {
74 auto setRange = [&](ConstantIntRanges argRange, Value dimResult,
75 Value idxResult) {
76 if (argRange.umin().getBitWidth() != IndexType::kInternalStorageBitWidth)
77 return;
78 ConstantIntRanges dimRange =
79 argRange.intersection(getIndexRange(1, kMaxDim));
80 setResultRange(dimResult, dimRange);
81 ConstantIntRanges idxRange =
82 getIndexRange(0, dimRange.umax().getZExtValue() - 1);
83 setResultRange(idxResult, idxRange);
84 };
85
86 argRanges = argRanges.drop_front(asyncDependencies().size());
87 KernelDim3 gridDims = getGridSize();
88 KernelDim3 blockIds = getBlockIds();
89 setRange(argRanges[0], gridDims.x, blockIds.x);
90 setRange(argRanges[1], gridDims.y, blockIds.y);
91 setRange(argRanges[2], gridDims.z, blockIds.z);
92 KernelDim3 blockDims = getBlockSize();
93 KernelDim3 threadIds = getThreadIds();
94 setRange(argRanges[3], blockDims.x, threadIds.x);
95 setRange(argRanges[4], blockDims.y, threadIds.y);
96 setRange(argRanges[5], blockDims.z, threadIds.z);
97 }
98