1 //===- InferIntRangeInterfaceImpls.cpp - Integer range impls for gpu -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 10 #include "mlir/Interfaces/InferIntRangeInterface.h" 11 12 using namespace mlir; 13 using namespace mlir::gpu; 14 15 // Maximum grid and block dimensions of all known GPUs are less than 2^32. 16 static constexpr uint64_t kMaxDim = std::numeric_limits<uint32_t>::max(); 17 // Maximum subgroups are no larger than 128. 18 static constexpr uint64_t kMaxSubgroupSize = 128; 19 20 static ConstantIntRanges getIndexRange(uint64_t umin, uint64_t umax) { 21 unsigned width = IndexType::kInternalStorageBitWidth; 22 return ConstantIntRanges::fromUnsigned(APInt(width, umin), 23 APInt(width, umax)); 24 } 25 26 void BlockDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>, 27 SetIntRangeFn setResultRange) { 28 setResultRange(getResult(), getIndexRange(1, kMaxDim)); 29 } 30 31 void BlockIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>, 32 SetIntRangeFn setResultRange) { 33 setResultRange(getResult(), getIndexRange(0, kMaxDim - 1)); 34 } 35 36 void GridDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>, 37 SetIntRangeFn setResultRange) { 38 setResultRange(getResult(), getIndexRange(1, kMaxDim)); 39 } 40 41 void ThreadIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>, 42 SetIntRangeFn setResultRange) { 43 setResultRange(getResult(), getIndexRange(0, kMaxDim - 1)); 44 } 45 46 void LaneIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>, 47 SetIntRangeFn setResultRange) { 48 setResultRange(getResult(), getIndexRange(0, kMaxSubgroupSize - 1)); 49 } 50 51 void SubgroupIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>, 52 SetIntRangeFn setResultRange) { 53 setResultRange(getResult(), getIndexRange(0, kMaxDim - 1)); 54 } 55 56 void GlobalIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>, 57 SetIntRangeFn setResultRange) { 58 setResultRange(getResult(), 59 getIndexRange(0, std::numeric_limits<int64_t>::max())); 60 } 61 62 void NumSubgroupsOp::inferResultRanges(ArrayRef<ConstantIntRanges>, 63 SetIntRangeFn setResultRange) { 64 setResultRange(getResult(), getIndexRange(1, kMaxDim)); 65 } 66 67 void SubgroupSizeOp::inferResultRanges(ArrayRef<ConstantIntRanges>, 68 SetIntRangeFn setResultRange) { 69 setResultRange(getResult(), getIndexRange(1, kMaxSubgroupSize)); 70 } 71 72 void LaunchOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 73 SetIntRangeFn setResultRange) { 74 auto setRange = [&](ConstantIntRanges argRange, Value dimResult, 75 Value idxResult) { 76 if (argRange.umin().getBitWidth() != IndexType::kInternalStorageBitWidth) 77 return; 78 ConstantIntRanges dimRange = 79 argRange.intersection(getIndexRange(1, kMaxDim)); 80 setResultRange(dimResult, dimRange); 81 ConstantIntRanges idxRange = 82 getIndexRange(0, dimRange.umax().getZExtValue() - 1); 83 setResultRange(idxResult, idxRange); 84 }; 85 86 argRanges = argRanges.drop_front(asyncDependencies().size()); 87 KernelDim3 gridDims = getGridSize(); 88 KernelDim3 blockIds = getBlockIds(); 89 setRange(argRanges[0], gridDims.x, blockIds.x); 90 setRange(argRanges[1], gridDims.y, blockIds.y); 91 setRange(argRanges[2], gridDims.z, blockIds.z); 92 KernelDim3 blockDims = getBlockSize(); 93 KernelDim3 threadIds = getThreadIds(); 94 setRange(argRanges[3], blockDims.x, threadIds.x); 95 setRange(argRanges[4], blockDims.y, threadIds.y); 96 setRange(argRanges[5], blockDims.z, threadIds.z); 97 } 98