1 //===- InferIntRangeInterfaceImpls.cpp - Integer range impls for gpu -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
10 #include "mlir/Interfaces/InferIntRangeInterface.h"
11 
12 using namespace mlir;
13 using namespace mlir::gpu;
14 
15 // Maximum grid and block dimensions of all known GPUs are less than 2^32.
16 static constexpr uint64_t kMaxDim = std::numeric_limits<uint32_t>::max();
17 // Maximum subgroups are no larger than 128.
18 static constexpr uint64_t kMaxSubgroupSize = 128;
19 
getIndexRange(uint64_t umin,uint64_t umax)20 static ConstantIntRanges getIndexRange(uint64_t umin, uint64_t umax) {
21   unsigned width = IndexType::kInternalStorageBitWidth;
22   return ConstantIntRanges::fromUnsigned(APInt(width, umin),
23                                          APInt(width, umax));
24 }
25 
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)26 void BlockDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
27                                    SetIntRangeFn setResultRange) {
28   setResultRange(getResult(), getIndexRange(1, kMaxDim));
29 }
30 
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)31 void BlockIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
32                                   SetIntRangeFn setResultRange) {
33   setResultRange(getResult(), getIndexRange(0, kMaxDim - 1));
34 }
35 
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)36 void GridDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
37                                   SetIntRangeFn setResultRange) {
38   setResultRange(getResult(), getIndexRange(1, kMaxDim));
39 }
40 
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)41 void ThreadIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
42                                    SetIntRangeFn setResultRange) {
43   setResultRange(getResult(), getIndexRange(0, kMaxDim - 1));
44 }
45 
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)46 void LaneIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
47                                  SetIntRangeFn setResultRange) {
48   setResultRange(getResult(), getIndexRange(0, kMaxSubgroupSize - 1));
49 }
50 
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)51 void SubgroupIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
52                                      SetIntRangeFn setResultRange) {
53   setResultRange(getResult(), getIndexRange(0, kMaxDim - 1));
54 }
55 
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)56 void GlobalIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
57                                    SetIntRangeFn setResultRange) {
58   setResultRange(getResult(),
59                  getIndexRange(0, std::numeric_limits<int64_t>::max()));
60 }
61 
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)62 void NumSubgroupsOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
63                                        SetIntRangeFn setResultRange) {
64   setResultRange(getResult(), getIndexRange(1, kMaxDim));
65 }
66 
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)67 void SubgroupSizeOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
68                                        SetIntRangeFn setResultRange) {
69   setResultRange(getResult(), getIndexRange(1, kMaxSubgroupSize));
70 }
71 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)72 void LaunchOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
73                                  SetIntRangeFn setResultRange) {
74   auto setRange = [&](ConstantIntRanges argRange, Value dimResult,
75                       Value idxResult) {
76     if (argRange.umin().getBitWidth() != IndexType::kInternalStorageBitWidth)
77       return;
78     ConstantIntRanges dimRange =
79         argRange.intersection(getIndexRange(1, kMaxDim));
80     setResultRange(dimResult, dimRange);
81     ConstantIntRanges idxRange =
82         getIndexRange(0, dimRange.umax().getZExtValue() - 1);
83     setResultRange(idxResult, idxRange);
84   };
85 
86   argRanges = argRanges.drop_front(asyncDependencies().size());
87   KernelDim3 gridDims = getGridSize();
88   KernelDim3 blockIds = getBlockIds();
89   setRange(argRanges[0], gridDims.x, blockIds.x);
90   setRange(argRanges[1], gridDims.y, blockIds.y);
91   setRange(argRanges[2], gridDims.z, blockIds.z);
92   KernelDim3 blockDims = getBlockSize();
93   KernelDim3 threadIds = getThreadIds();
94   setRange(argRanges[3], blockDims.x, threadIds.x);
95   setRange(argRanges[4], blockDims.y, threadIds.y);
96   setRange(argRanges[5], blockDims.z, threadIds.z);
97 }
98