1 //===- NVGPUDialect.cpp - MLIR NVGPU ops implementation -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the NVGPU dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
14 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/DialectImplementation.h"
17 #include "mlir/IR/OpImplementation.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "llvm/ADT/TypeSwitch.h"
20
21 using namespace mlir;
22 using namespace mlir::nvgpu;
23
24 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.cpp.inc"
25
initialize()26 void nvgpu::NVGPUDialect::initialize() {
27 addTypes<DeviceAsyncTokenType>();
28 addOperations<
29 #define GET_OP_LIST
30 #include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc"
31 >();
32 }
33
parseType(DialectAsmParser & parser) const34 Type NVGPUDialect::parseType(DialectAsmParser &parser) const {
35 // Parse the main keyword for the type.
36 StringRef keyword;
37 if (parser.parseKeyword(&keyword))
38 return Type();
39 MLIRContext *context = getContext();
40 // Handle 'device async token' types.
41 if (keyword == "device.async.token")
42 return DeviceAsyncTokenType::get(context);
43
44 parser.emitError(parser.getNameLoc(), "unknown nvgpu type: " + keyword);
45 return Type();
46 }
47
printType(Type type,DialectAsmPrinter & os) const48 void NVGPUDialect::printType(Type type, DialectAsmPrinter &os) const {
49 TypeSwitch<Type>(type)
50 .Case<DeviceAsyncTokenType>([&](Type) { os << "device.async.token"; })
51 .Default([](Type) { llvm_unreachable("unexpected 'nvgpu' type kind"); });
52 }
53 //===----------------------------------------------------------------------===//
54 // NVGPU_DeviceAsyncCopyOp
55 //===----------------------------------------------------------------------===//
56
57 /// Return true if the last dimension of the MemRefType has unit stride. Also
58 /// return true for memrefs with no strides.
isLastMemrefDimUnitStride(MemRefType type)59 static bool isLastMemrefDimUnitStride(MemRefType type) {
60 int64_t offset;
61 SmallVector<int64_t> strides;
62 if (failed(getStridesAndOffset(type, strides, offset))) {
63 return false;
64 }
65 return strides.back() == 1;
66 }
67
verify()68 LogicalResult DeviceAsyncCopyOp::verify() {
69 auto srcMemref = getSrc().getType().cast<MemRefType>();
70 auto dstMemref = getDst().getType().cast<MemRefType>();
71 unsigned workgroupAddressSpace = gpu::GPUDialect::getWorkgroupAddressSpace();
72 if (!isLastMemrefDimUnitStride(srcMemref))
73 return emitError("source memref most minor dim must have unit stride");
74 if (!isLastMemrefDimUnitStride(dstMemref))
75 return emitError("destination memref most minor dim must have unit stride");
76 if (dstMemref.getMemorySpaceAsInt() != workgroupAddressSpace)
77 return emitError("destination memref must have memory space ")
78 << workgroupAddressSpace;
79 if (dstMemref.getElementType() != srcMemref.getElementType())
80 return emitError("source and destination must have the same element type");
81 if (size_t(srcMemref.getRank()) != getSrcIndices().size())
82 return emitOpError() << "expected " << srcMemref.getRank()
83 << " source indices, got " << getSrcIndices().size();
84 if (size_t(dstMemref.getRank()) != getDstIndices().size())
85 return emitOpError() << "expected " << dstMemref.getRank()
86 << " destination indices, got "
87 << getDstIndices().size();
88 return success();
89 }
90
91 //===----------------------------------------------------------------------===//
92 // NVGPU_MmaSyncOp
93 //===----------------------------------------------------------------------===//
94
verify()95 LogicalResult MmaSyncOp::verify() {
96
97 // Fundamental tensor core mma.sync op
98 // For F32 (TF32), F16, S8, and S4 data types fundamental tensor core
99 // operation is of shape: 8-by-8-by-128b. F64 is an exception. The
100 // verification for mma.sync covering various shapes and data types is based
101 // on the fundamental tensor core operionation.
102 constexpr int kThreads = 32; // 32 threads per warp
103 int64_t shapeM = 8;
104 int64_t shapeN = 8;
105 int64_t shapeK; // set based on data type (128b for all data types except F64)
106
107 // Number of elements A, B, and C per thread per fundamental tensor core tile
108 int64_t numElementA; // set based on data type (32b except F64)
109 int64_t numElementB; // set based on data type (32b except F64)
110 int64_t numElementC{2}; // two accumulator elements per fundamental tile
111
112 // nvgpu.mma.sync vector operands (per thread)
113 auto aVector = getMatrixA().getType().cast<VectorType>();
114 auto bVector = getMatrixB().getType().cast<VectorType>();
115 auto cVector = getMatrixC().getType().cast<VectorType>();
116
117 // vector shapes
118 ArrayRef<int64_t> aShape = aVector.getShape();
119 ArrayRef<int64_t> bShape = bVector.getShape();
120 ArrayRef<int64_t> cShape = cVector.getShape();
121
122 // vector element type
123 Type aType = aVector.getElementType();
124
125 // nvgpu.mma.sync shape (per 32 threads or per warp)
126 int64_t m = getMmaShape()[0].cast<IntegerAttr>().getInt();
127 int64_t n = getMmaShape()[1].cast<IntegerAttr>().getInt();
128 int64_t k = getMmaShape()[2].cast<IntegerAttr>().getInt();
129
130 if (aType.isF64()) {
131 // exception to 8-by-8-128b fundamental tensor core tile size
132 shapeK = 4;
133 numElementA = 1;
134 numElementB = 1;
135 } else if (aType.isF32() || aType.isBF16() || aType.isF16() ||
136 aType.isInteger(8) || aType.isInteger(4)) {
137 // 8-by-8-128b fundamental tensor core tile size
138 int operandBitwidth = aType.getIntOrFloatBitWidth();
139 shapeK = 128 / operandBitwidth; // 128b wide shapeK
140 numElementA = 32 / operandBitwidth; // 32b wide operand A
141 numElementB = 32 / operandBitwidth; // 32b wide operand B
142 } else {
143 return emitError() << "expected input data type (i4,i8,f16,bf16,tf32,f64) "
144 "supported by nvgpu.mma.sync";
145 }
146
147 //
148 // Basic verification
149 //
150
151 // verify warp-wide size for vector a
152 if (aShape[0] * aShape[1] * kThreads != m * k)
153 return emitOpError() << "expected " << m * k
154 << " warp-wide matrix A elements";
155
156 // verify warp-wide size for vector b
157 if (bShape[0] * bShape[1] * kThreads != k * n)
158 return emitOpError() << "expected " << k * n
159 << " warp-wide matrix B elements";
160
161 // verify warp-wide size for vector c
162 if (cShape[0] * cShape[1] * kThreads != m * n)
163 return emitOpError() << "expected " << m * n
164 << " warp-wide matrix C elements";
165
166 //
167 // Extended verification
168 //
169
170 // tiles of fundamental tensor core operations
171 int64_t mTile = m / shapeM;
172 int64_t nTile = n / shapeN;
173 int64_t kTile = k / shapeK;
174
175 // verify shape of aVector
176 if (!((aShape[0] == mTile * kTile) && (aShape[1] == numElementA)))
177 return emitOpError() << "expected matrix A to be shaped (" << mTile * kTile
178 << " x " << numElementA << ")";
179
180 // verify shape of bVector
181 if (!((bShape[0] == kTile * nTile) && (bShape[1] == numElementB)))
182 return emitOpError() << "expected matrix B to be shaped (" << kTile * nTile
183 << " x " << numElementB << ")";
184
185 // verify shape of cVector
186 if (!((cShape[0] == mTile * nTile) && (cShape[1] == numElementC)))
187 return emitOpError() << "expected matrix C to be shaped (" << mTile * nTile
188 << " x " << numElementC << ")";
189
190 return success();
191 }
192
193 //===----------------------------------------------------------------------===//
194 // NVGPU_LdMatrixOp
195 //===----------------------------------------------------------------------===//
verify()196 LogicalResult LdMatrixOp::verify() {
197
198 // ldmatrix reads data from source in shared memory
199 auto srcMemref = getSrcMemref().getType().cast<MemRefType>();
200
201 // ldmatrix writes data to result/destination in vector registers
202 auto resVector = getRes().getType().cast<VectorType>();
203
204 // vector register shape, element type, and bitwidth
205 ArrayRef<int64_t> resShape = resVector.getShape();
206 Type resType = resVector.getElementType();
207 int64_t elementBitWidth = resType.getIntOrFloatBitWidth();
208
209 // ldmatrix loads 32 bits into vector registers per 8-by-8 tile per thread
210 int64_t numElementsPer32b = 32 / elementBitWidth;
211
212 // number of 8-by-8 tiles
213 int64_t numTiles = getNumTiles();
214
215 // transpose elements in vector registers at 16b granularity when true
216 bool isTranspose = getTranspose();
217
218 // address space id for shared memory
219 unsigned smemAddressSpace = gpu::GPUDialect::getWorkgroupAddressSpace();
220
221 //
222 // verification
223 //
224
225 if (!(srcMemref.getMemorySpaceAsInt() == smemAddressSpace))
226 return emitError()
227 << "expected nvgpu.ldmatrix srcMemref must have memory space "
228 << smemAddressSpace;
229 if (elementBitWidth > 32)
230 return emitError() << "nvgpu.ldmatrix works for 32b or lower";
231 if (isTranspose && !(elementBitWidth == 16))
232 return emitError()
233 << "nvgpu.ldmatrix transpose works only at 16b granularity";
234 if (!(resShape[1] == numElementsPer32b))
235 return emitError() << "expected vector register shape[1] = "
236 << numElementsPer32b;
237 if (!(resShape[0] == numTiles))
238 return emitError()
239 << "expected vector register shape[0] and numTiles to match";
240
241 return success();
242 }
243
244 #define GET_OP_CLASSES
245 #include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc"
246