14c564940SThomas Raoux //===- NVGPUDialect.cpp - MLIR NVGPU ops implementation -------------------===//
24c564940SThomas Raoux //
34c564940SThomas Raoux // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44c564940SThomas Raoux // See https://llvm.org/LICENSE.txt for license information.
54c564940SThomas Raoux // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
64c564940SThomas Raoux //
74c564940SThomas Raoux //===----------------------------------------------------------------------===//
84c564940SThomas Raoux //
94c564940SThomas Raoux // This file implements the NVGPU dialect and its operations.
104c564940SThomas Raoux //
114c564940SThomas Raoux //===----------------------------------------------------------------------===//
124c564940SThomas Raoux
1351b925dfSChristopher Bate #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
14d7ef488bSMogball #include "mlir/Dialect/GPU/IR/GPUDialect.h"
154c564940SThomas Raoux #include "mlir/IR/Builders.h"
1615bcc36eSThomas Raoux #include "mlir/IR/DialectImplementation.h"
174c564940SThomas Raoux #include "mlir/IR/OpImplementation.h"
184c564940SThomas Raoux #include "mlir/IR/TypeUtilities.h"
1915bcc36eSThomas Raoux #include "llvm/ADT/TypeSwitch.h"
204c564940SThomas Raoux
214c564940SThomas Raoux using namespace mlir;
2215bcc36eSThomas Raoux using namespace mlir::nvgpu;
234c564940SThomas Raoux
2451b925dfSChristopher Bate #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.cpp.inc"
254c564940SThomas Raoux
initialize()264c564940SThomas Raoux void nvgpu::NVGPUDialect::initialize() {
2715bcc36eSThomas Raoux addTypes<DeviceAsyncTokenType>();
284c564940SThomas Raoux addOperations<
294c564940SThomas Raoux #define GET_OP_LIST
3051b925dfSChristopher Bate #include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc"
314c564940SThomas Raoux >();
324c564940SThomas Raoux }
334c564940SThomas Raoux
parseType(DialectAsmParser & parser) const3415bcc36eSThomas Raoux Type NVGPUDialect::parseType(DialectAsmParser &parser) const {
3515bcc36eSThomas Raoux // Parse the main keyword for the type.
3615bcc36eSThomas Raoux StringRef keyword;
3715bcc36eSThomas Raoux if (parser.parseKeyword(&keyword))
3815bcc36eSThomas Raoux return Type();
3915bcc36eSThomas Raoux MLIRContext *context = getContext();
4015bcc36eSThomas Raoux // Handle 'device async token' types.
4115bcc36eSThomas Raoux if (keyword == "device.async.token")
4215bcc36eSThomas Raoux return DeviceAsyncTokenType::get(context);
4315bcc36eSThomas Raoux
4415bcc36eSThomas Raoux parser.emitError(parser.getNameLoc(), "unknown nvgpu type: " + keyword);
4515bcc36eSThomas Raoux return Type();
4615bcc36eSThomas Raoux }
4715bcc36eSThomas Raoux
printType(Type type,DialectAsmPrinter & os) const4815bcc36eSThomas Raoux void NVGPUDialect::printType(Type type, DialectAsmPrinter &os) const {
4915bcc36eSThomas Raoux TypeSwitch<Type>(type)
5015bcc36eSThomas Raoux .Case<DeviceAsyncTokenType>([&](Type) { os << "device.async.token"; })
5115bcc36eSThomas Raoux .Default([](Type) { llvm_unreachable("unexpected 'nvgpu' type kind"); });
5215bcc36eSThomas Raoux }
5315bcc36eSThomas Raoux //===----------------------------------------------------------------------===//
5415bcc36eSThomas Raoux // NVGPU_DeviceAsyncCopyOp
5515bcc36eSThomas Raoux //===----------------------------------------------------------------------===//
5615bcc36eSThomas Raoux
5715bcc36eSThomas Raoux /// Return true if the last dimension of the MemRefType has unit stride. Also
5815bcc36eSThomas Raoux /// return true for memrefs with no strides.
isLastMemrefDimUnitStride(MemRefType type)5915bcc36eSThomas Raoux static bool isLastMemrefDimUnitStride(MemRefType type) {
6015bcc36eSThomas Raoux int64_t offset;
6115bcc36eSThomas Raoux SmallVector<int64_t> strides;
6215bcc36eSThomas Raoux if (failed(getStridesAndOffset(type, strides, offset))) {
6315bcc36eSThomas Raoux return false;
6415bcc36eSThomas Raoux }
6515bcc36eSThomas Raoux return strides.back() == 1;
6615bcc36eSThomas Raoux }
6715bcc36eSThomas Raoux
verify()6815bcc36eSThomas Raoux LogicalResult DeviceAsyncCopyOp::verify() {
698df54a6aSJacques Pienaar auto srcMemref = getSrc().getType().cast<MemRefType>();
708df54a6aSJacques Pienaar auto dstMemref = getDst().getType().cast<MemRefType>();
7115bcc36eSThomas Raoux unsigned workgroupAddressSpace = gpu::GPUDialect::getWorkgroupAddressSpace();
7215bcc36eSThomas Raoux if (!isLastMemrefDimUnitStride(srcMemref))
7315bcc36eSThomas Raoux return emitError("source memref most minor dim must have unit stride");
7415bcc36eSThomas Raoux if (!isLastMemrefDimUnitStride(dstMemref))
7515bcc36eSThomas Raoux return emitError("destination memref most minor dim must have unit stride");
7615bcc36eSThomas Raoux if (dstMemref.getMemorySpaceAsInt() != workgroupAddressSpace)
7715bcc36eSThomas Raoux return emitError("destination memref must have memory space ")
7815bcc36eSThomas Raoux << workgroupAddressSpace;
7915bcc36eSThomas Raoux if (dstMemref.getElementType() != srcMemref.getElementType())
8015bcc36eSThomas Raoux return emitError("source and destination must have the same element type");
818df54a6aSJacques Pienaar if (size_t(srcMemref.getRank()) != getSrcIndices().size())
8215bcc36eSThomas Raoux return emitOpError() << "expected " << srcMemref.getRank()
838df54a6aSJacques Pienaar << " source indices, got " << getSrcIndices().size();
848df54a6aSJacques Pienaar if (size_t(dstMemref.getRank()) != getDstIndices().size())
8515bcc36eSThomas Raoux return emitOpError() << "expected " << dstMemref.getRank()
868df54a6aSJacques Pienaar << " destination indices, got "
878df54a6aSJacques Pienaar << getDstIndices().size();
8815bcc36eSThomas Raoux return success();
8915bcc36eSThomas Raoux }
9015bcc36eSThomas Raoux
91*713d3de5SManish Gupta //===----------------------------------------------------------------------===//
92*713d3de5SManish Gupta // NVGPU_MmaSyncOp
93*713d3de5SManish Gupta //===----------------------------------------------------------------------===//
94*713d3de5SManish Gupta
verify()95f7d42d51SManish Gupta LogicalResult MmaSyncOp::verify() {
96f7d42d51SManish Gupta
97f7d42d51SManish Gupta // Fundamental tensor core mma.sync op
98f7d42d51SManish Gupta // For F32 (TF32), F16, S8, and S4 data types fundamental tensor core
99f7d42d51SManish Gupta // operation is of shape: 8-by-8-by-128b. F64 is an exception. The
100f7d42d51SManish Gupta // verification for mma.sync covering various shapes and data types is based
101f7d42d51SManish Gupta // on the fundamental tensor core operionation.
102f7d42d51SManish Gupta constexpr int kThreads = 32; // 32 threads per warp
103f7d42d51SManish Gupta int64_t shapeM = 8;
104f7d42d51SManish Gupta int64_t shapeN = 8;
105f7d42d51SManish Gupta int64_t shapeK; // set based on data type (128b for all data types except F64)
106f7d42d51SManish Gupta
107f7d42d51SManish Gupta // Number of elements A, B, and C per thread per fundamental tensor core tile
108f7d42d51SManish Gupta int64_t numElementA; // set based on data type (32b except F64)
109f7d42d51SManish Gupta int64_t numElementB; // set based on data type (32b except F64)
110f7d42d51SManish Gupta int64_t numElementC{2}; // two accumulator elements per fundamental tile
111f7d42d51SManish Gupta
112f7d42d51SManish Gupta // nvgpu.mma.sync vector operands (per thread)
113f7d42d51SManish Gupta auto aVector = getMatrixA().getType().cast<VectorType>();
114f7d42d51SManish Gupta auto bVector = getMatrixB().getType().cast<VectorType>();
115f7d42d51SManish Gupta auto cVector = getMatrixC().getType().cast<VectorType>();
116f7d42d51SManish Gupta
117f7d42d51SManish Gupta // vector shapes
118f7d42d51SManish Gupta ArrayRef<int64_t> aShape = aVector.getShape();
119f7d42d51SManish Gupta ArrayRef<int64_t> bShape = bVector.getShape();
120f7d42d51SManish Gupta ArrayRef<int64_t> cShape = cVector.getShape();
121f7d42d51SManish Gupta
122f7d42d51SManish Gupta // vector element type
123f7d42d51SManish Gupta Type aType = aVector.getElementType();
124f7d42d51SManish Gupta
125f7d42d51SManish Gupta // nvgpu.mma.sync shape (per 32 threads or per warp)
126f7d42d51SManish Gupta int64_t m = getMmaShape()[0].cast<IntegerAttr>().getInt();
127f7d42d51SManish Gupta int64_t n = getMmaShape()[1].cast<IntegerAttr>().getInt();
128f7d42d51SManish Gupta int64_t k = getMmaShape()[2].cast<IntegerAttr>().getInt();
129f7d42d51SManish Gupta
130f7d42d51SManish Gupta if (aType.isF64()) {
131f7d42d51SManish Gupta // exception to 8-by-8-128b fundamental tensor core tile size
132f7d42d51SManish Gupta shapeK = 4;
133f7d42d51SManish Gupta numElementA = 1;
134f7d42d51SManish Gupta numElementB = 1;
135f7d42d51SManish Gupta } else if (aType.isF32() || aType.isBF16() || aType.isF16() ||
136f7d42d51SManish Gupta aType.isInteger(8) || aType.isInteger(4)) {
137f7d42d51SManish Gupta // 8-by-8-128b fundamental tensor core tile size
138f7d42d51SManish Gupta int operandBitwidth = aType.getIntOrFloatBitWidth();
139f7d42d51SManish Gupta shapeK = 128 / operandBitwidth; // 128b wide shapeK
140f7d42d51SManish Gupta numElementA = 32 / operandBitwidth; // 32b wide operand A
141f7d42d51SManish Gupta numElementB = 32 / operandBitwidth; // 32b wide operand B
142f7d42d51SManish Gupta } else {
143f7d42d51SManish Gupta return emitError() << "expected input data type (i4,i8,f16,bf16,tf32,f64) "
144f7d42d51SManish Gupta "supported by nvgpu.mma.sync";
145f7d42d51SManish Gupta }
146f7d42d51SManish Gupta
147f7d42d51SManish Gupta //
148f7d42d51SManish Gupta // Basic verification
149f7d42d51SManish Gupta //
150f7d42d51SManish Gupta
151f7d42d51SManish Gupta // verify warp-wide size for vector a
152f7d42d51SManish Gupta if (aShape[0] * aShape[1] * kThreads != m * k)
153f7d42d51SManish Gupta return emitOpError() << "expected " << m * k
154f7d42d51SManish Gupta << " warp-wide matrix A elements";
155f7d42d51SManish Gupta
156f7d42d51SManish Gupta // verify warp-wide size for vector b
157f7d42d51SManish Gupta if (bShape[0] * bShape[1] * kThreads != k * n)
158f7d42d51SManish Gupta return emitOpError() << "expected " << k * n
159f7d42d51SManish Gupta << " warp-wide matrix B elements";
160f7d42d51SManish Gupta
161f7d42d51SManish Gupta // verify warp-wide size for vector c
162f7d42d51SManish Gupta if (cShape[0] * cShape[1] * kThreads != m * n)
163f7d42d51SManish Gupta return emitOpError() << "expected " << m * n
164f7d42d51SManish Gupta << " warp-wide matrix C elements";
165f7d42d51SManish Gupta
166f7d42d51SManish Gupta //
167f7d42d51SManish Gupta // Extended verification
168f7d42d51SManish Gupta //
169f7d42d51SManish Gupta
170f7d42d51SManish Gupta // tiles of fundamental tensor core operations
171f7d42d51SManish Gupta int64_t mTile = m / shapeM;
172f7d42d51SManish Gupta int64_t nTile = n / shapeN;
173f7d42d51SManish Gupta int64_t kTile = k / shapeK;
174f7d42d51SManish Gupta
175f7d42d51SManish Gupta // verify shape of aVector
176f7d42d51SManish Gupta if (!((aShape[0] == mTile * kTile) && (aShape[1] == numElementA)))
177f7d42d51SManish Gupta return emitOpError() << "expected matrix A to be shaped (" << mTile * kTile
178f7d42d51SManish Gupta << " x " << numElementA << ")";
179f7d42d51SManish Gupta
180f7d42d51SManish Gupta // verify shape of bVector
181f7d42d51SManish Gupta if (!((bShape[0] == kTile * nTile) && (bShape[1] == numElementB)))
182f7d42d51SManish Gupta return emitOpError() << "expected matrix B to be shaped (" << kTile * nTile
183f7d42d51SManish Gupta << " x " << numElementB << ")";
184f7d42d51SManish Gupta
185f7d42d51SManish Gupta // verify shape of cVector
186f7d42d51SManish Gupta if (!((cShape[0] == mTile * nTile) && (cShape[1] == numElementC)))
187f7d42d51SManish Gupta return emitOpError() << "expected matrix C to be shaped (" << mTile * nTile
188f7d42d51SManish Gupta << " x " << numElementC << ")";
189f7d42d51SManish Gupta
190f7d42d51SManish Gupta return success();
191f7d42d51SManish Gupta }
192f7d42d51SManish Gupta
193*713d3de5SManish Gupta //===----------------------------------------------------------------------===//
194*713d3de5SManish Gupta // NVGPU_LdMatrixOp
195*713d3de5SManish Gupta //===----------------------------------------------------------------------===//
verify()196*713d3de5SManish Gupta LogicalResult LdMatrixOp::verify() {
197*713d3de5SManish Gupta
198*713d3de5SManish Gupta // ldmatrix reads data from source in shared memory
199*713d3de5SManish Gupta auto srcMemref = getSrcMemref().getType().cast<MemRefType>();
200*713d3de5SManish Gupta
201*713d3de5SManish Gupta // ldmatrix writes data to result/destination in vector registers
202*713d3de5SManish Gupta auto resVector = getRes().getType().cast<VectorType>();
203*713d3de5SManish Gupta
204*713d3de5SManish Gupta // vector register shape, element type, and bitwidth
205*713d3de5SManish Gupta ArrayRef<int64_t> resShape = resVector.getShape();
206*713d3de5SManish Gupta Type resType = resVector.getElementType();
207*713d3de5SManish Gupta int64_t elementBitWidth = resType.getIntOrFloatBitWidth();
208*713d3de5SManish Gupta
209*713d3de5SManish Gupta // ldmatrix loads 32 bits into vector registers per 8-by-8 tile per thread
210*713d3de5SManish Gupta int64_t numElementsPer32b = 32 / elementBitWidth;
211*713d3de5SManish Gupta
212*713d3de5SManish Gupta // number of 8-by-8 tiles
213*713d3de5SManish Gupta int64_t numTiles = getNumTiles();
214*713d3de5SManish Gupta
215*713d3de5SManish Gupta // transpose elements in vector registers at 16b granularity when true
216*713d3de5SManish Gupta bool isTranspose = getTranspose();
217*713d3de5SManish Gupta
218*713d3de5SManish Gupta // address space id for shared memory
219*713d3de5SManish Gupta unsigned smemAddressSpace = gpu::GPUDialect::getWorkgroupAddressSpace();
220*713d3de5SManish Gupta
221*713d3de5SManish Gupta //
222*713d3de5SManish Gupta // verification
223*713d3de5SManish Gupta //
224*713d3de5SManish Gupta
225*713d3de5SManish Gupta if (!(srcMemref.getMemorySpaceAsInt() == smemAddressSpace))
226*713d3de5SManish Gupta return emitError()
227*713d3de5SManish Gupta << "expected nvgpu.ldmatrix srcMemref must have memory space "
228*713d3de5SManish Gupta << smemAddressSpace;
229*713d3de5SManish Gupta if (elementBitWidth > 32)
230*713d3de5SManish Gupta return emitError() << "nvgpu.ldmatrix works for 32b or lower";
231*713d3de5SManish Gupta if (isTranspose && !(elementBitWidth == 16))
232*713d3de5SManish Gupta return emitError()
233*713d3de5SManish Gupta << "nvgpu.ldmatrix transpose works only at 16b granularity";
234*713d3de5SManish Gupta if (!(resShape[1] == numElementsPer32b))
235*713d3de5SManish Gupta return emitError() << "expected vector register shape[1] = "
236*713d3de5SManish Gupta << numElementsPer32b;
237*713d3de5SManish Gupta if (!(resShape[0] == numTiles))
238*713d3de5SManish Gupta return emitError()
239*713d3de5SManish Gupta << "expected vector register shape[0] and numTiles to match";
240*713d3de5SManish Gupta
241*713d3de5SManish Gupta return success();
242*713d3de5SManish Gupta }
243*713d3de5SManish Gupta
2444c564940SThomas Raoux #define GET_OP_CLASSES
24551b925dfSChristopher Bate #include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc"
246