//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for gpu -===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Interfaces/InferIntRangeInterface.h" using namespace mlir; using namespace mlir::gpu; // Maximum grid and block dimensions of all known GPUs are less than 2^32. static constexpr uint64_t kMaxDim = std::numeric_limits::max(); // Maximum subgroups are no larger than 128. static constexpr uint64_t kMaxSubgroupSize = 128; static ConstantIntRanges getIndexRange(uint64_t umin, uint64_t umax) { unsigned width = IndexType::kInternalStorageBitWidth; return ConstantIntRanges::fromUnsigned(APInt(width, umin), APInt(width, umax)); } void BlockDimOp::inferResultRanges(ArrayRef, SetIntRangeFn setResultRange) { setResultRange(getResult(), getIndexRange(1, kMaxDim)); } void BlockIdOp::inferResultRanges(ArrayRef, SetIntRangeFn setResultRange) { setResultRange(getResult(), getIndexRange(0, kMaxDim - 1)); } void GridDimOp::inferResultRanges(ArrayRef, SetIntRangeFn setResultRange) { setResultRange(getResult(), getIndexRange(1, kMaxDim)); } void ThreadIdOp::inferResultRanges(ArrayRef, SetIntRangeFn setResultRange) { setResultRange(getResult(), getIndexRange(0, kMaxDim - 1)); } void LaneIdOp::inferResultRanges(ArrayRef, SetIntRangeFn setResultRange) { setResultRange(getResult(), getIndexRange(0, kMaxSubgroupSize - 1)); } void SubgroupIdOp::inferResultRanges(ArrayRef, SetIntRangeFn setResultRange) { setResultRange(getResult(), getIndexRange(0, kMaxDim - 1)); } void GlobalIdOp::inferResultRanges(ArrayRef, SetIntRangeFn setResultRange) { setResultRange(getResult(), getIndexRange(0, std::numeric_limits::max())); } void NumSubgroupsOp::inferResultRanges(ArrayRef, SetIntRangeFn setResultRange) { setResultRange(getResult(), getIndexRange(1, kMaxDim)); } void SubgroupSizeOp::inferResultRanges(ArrayRef, SetIntRangeFn setResultRange) { setResultRange(getResult(), getIndexRange(1, kMaxSubgroupSize)); } void LaunchOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { auto setRange = [&](ConstantIntRanges argRange, Value dimResult, Value idxResult) { if (argRange.umin().getBitWidth() != IndexType::kInternalStorageBitWidth) return; ConstantIntRanges dimRange = argRange.intersection(getIndexRange(1, kMaxDim)); setResultRange(dimResult, dimRange); ConstantIntRanges idxRange = getIndexRange(0, dimRange.umax().getZExtValue() - 1); setResultRange(idxResult, idxRange); }; argRanges = argRanges.drop_front(asyncDependencies().size()); KernelDim3 gridDims = getGridSize(); KernelDim3 blockIds = getBlockIds(); setRange(argRanges[0], gridDims.x, blockIds.x); setRange(argRanges[1], gridDims.y, blockIds.y); setRange(argRanges[2], gridDims.z, blockIds.z); KernelDim3 blockDims = getBlockSize(); KernelDim3 threadIds = getThreadIds(); setRange(argRanges[3], blockDims.x, threadIds.x); setRange(argRanges[4], blockDims.y, threadIds.y); setRange(argRanges[5], blockDims.z, threadIds.z); }