1 //===- NVGPUDialect.cpp - MLIR NVGPU ops implementation -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the NVGPU dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
14 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/DialectImplementation.h"
17 #include "mlir/IR/OpImplementation.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 
21 using namespace mlir;
22 using namespace mlir::nvgpu;
23 
24 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.cpp.inc"
25 
initialize()26 void nvgpu::NVGPUDialect::initialize() {
27   addTypes<DeviceAsyncTokenType>();
28   addOperations<
29 #define GET_OP_LIST
30 #include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc"
31       >();
32 }
33 
parseType(DialectAsmParser & parser) const34 Type NVGPUDialect::parseType(DialectAsmParser &parser) const {
35   // Parse the main keyword for the type.
36   StringRef keyword;
37   if (parser.parseKeyword(&keyword))
38     return Type();
39   MLIRContext *context = getContext();
40   // Handle 'device async token' types.
41   if (keyword == "device.async.token")
42     return DeviceAsyncTokenType::get(context);
43 
44   parser.emitError(parser.getNameLoc(), "unknown nvgpu type: " + keyword);
45   return Type();
46 }
47 
printType(Type type,DialectAsmPrinter & os) const48 void NVGPUDialect::printType(Type type, DialectAsmPrinter &os) const {
49   TypeSwitch<Type>(type)
50       .Case<DeviceAsyncTokenType>([&](Type) { os << "device.async.token"; })
51       .Default([](Type) { llvm_unreachable("unexpected 'nvgpu' type kind"); });
52 }
53 //===----------------------------------------------------------------------===//
54 // NVGPU_DeviceAsyncCopyOp
55 //===----------------------------------------------------------------------===//
56 
57 /// Return true if the last dimension of the MemRefType has unit stride. Also
58 /// return true for memrefs with no strides.
isLastMemrefDimUnitStride(MemRefType type)59 static bool isLastMemrefDimUnitStride(MemRefType type) {
60   int64_t offset;
61   SmallVector<int64_t> strides;
62   if (failed(getStridesAndOffset(type, strides, offset))) {
63     return false;
64   }
65   return strides.back() == 1;
66 }
67 
verify()68 LogicalResult DeviceAsyncCopyOp::verify() {
69   auto srcMemref = getSrc().getType().cast<MemRefType>();
70   auto dstMemref = getDst().getType().cast<MemRefType>();
71   unsigned workgroupAddressSpace = gpu::GPUDialect::getWorkgroupAddressSpace();
72   if (!isLastMemrefDimUnitStride(srcMemref))
73     return emitError("source memref most minor dim must have unit stride");
74   if (!isLastMemrefDimUnitStride(dstMemref))
75     return emitError("destination memref most minor dim must have unit stride");
76   if (dstMemref.getMemorySpaceAsInt() != workgroupAddressSpace)
77     return emitError("destination memref must have memory space ")
78            << workgroupAddressSpace;
79   if (dstMemref.getElementType() != srcMemref.getElementType())
80     return emitError("source and destination must have the same element type");
81   if (size_t(srcMemref.getRank()) != getSrcIndices().size())
82     return emitOpError() << "expected " << srcMemref.getRank()
83                          << " source indices, got " << getSrcIndices().size();
84   if (size_t(dstMemref.getRank()) != getDstIndices().size())
85     return emitOpError() << "expected " << dstMemref.getRank()
86                          << " destination indices, got "
87                          << getDstIndices().size();
88   return success();
89 }
90 
91 //===----------------------------------------------------------------------===//
92 // NVGPU_MmaSyncOp
93 //===----------------------------------------------------------------------===//
94 
verify()95 LogicalResult MmaSyncOp::verify() {
96 
97   // Fundamental tensor core mma.sync op
98   // For F32 (TF32), F16, S8, and S4 data types fundamental tensor core
99   // operation is of shape: 8-by-8-by-128b. F64 is an exception. The
100   // verification for mma.sync covering various shapes and data types is based
101   // on the fundamental tensor core operionation.
102   constexpr int kThreads = 32; // 32 threads per warp
103   int64_t shapeM = 8;
104   int64_t shapeN = 8;
105   int64_t shapeK; // set based on data type (128b for all data types except F64)
106 
107   // Number of elements A, B, and C per thread per fundamental tensor core tile
108   int64_t numElementA;    // set based on data type (32b except F64)
109   int64_t numElementB;    // set based on data type (32b except F64)
110   int64_t numElementC{2}; // two accumulator elements per fundamental tile
111 
112   // nvgpu.mma.sync vector operands (per thread)
113   auto aVector = getMatrixA().getType().cast<VectorType>();
114   auto bVector = getMatrixB().getType().cast<VectorType>();
115   auto cVector = getMatrixC().getType().cast<VectorType>();
116 
117   // vector shapes
118   ArrayRef<int64_t> aShape = aVector.getShape();
119   ArrayRef<int64_t> bShape = bVector.getShape();
120   ArrayRef<int64_t> cShape = cVector.getShape();
121 
122   // vector element type
123   Type aType = aVector.getElementType();
124 
125   // nvgpu.mma.sync shape (per 32 threads or per warp)
126   int64_t m = getMmaShape()[0].cast<IntegerAttr>().getInt();
127   int64_t n = getMmaShape()[1].cast<IntegerAttr>().getInt();
128   int64_t k = getMmaShape()[2].cast<IntegerAttr>().getInt();
129 
130   if (aType.isF64()) {
131     // exception to 8-by-8-128b fundamental tensor core tile size
132     shapeK = 4;
133     numElementA = 1;
134     numElementB = 1;
135   } else if (aType.isF32() || aType.isBF16() || aType.isF16() ||
136              aType.isInteger(8) || aType.isInteger(4)) {
137     // 8-by-8-128b fundamental tensor core tile size
138     int operandBitwidth = aType.getIntOrFloatBitWidth();
139     shapeK = 128 / operandBitwidth;     // 128b wide shapeK
140     numElementA = 32 / operandBitwidth; // 32b wide operand A
141     numElementB = 32 / operandBitwidth; // 32b wide operand B
142   } else {
143     return emitError() << "expected input data type (i4,i8,f16,bf16,tf32,f64) "
144                           "supported by nvgpu.mma.sync";
145   }
146 
147   //
148   // Basic verification
149   //
150 
151   // verify warp-wide size for vector a
152   if (aShape[0] * aShape[1] * kThreads != m * k)
153     return emitOpError() << "expected " << m * k
154                          << " warp-wide matrix A elements";
155 
156   // verify warp-wide size for vector b
157   if (bShape[0] * bShape[1] * kThreads != k * n)
158     return emitOpError() << "expected " << k * n
159                          << " warp-wide matrix B elements";
160 
161   // verify warp-wide size for vector c
162   if (cShape[0] * cShape[1] * kThreads != m * n)
163     return emitOpError() << "expected " << m * n
164                          << " warp-wide matrix C elements";
165 
166   //
167   // Extended verification
168   //
169 
170   // tiles of fundamental tensor core operations
171   int64_t mTile = m / shapeM;
172   int64_t nTile = n / shapeN;
173   int64_t kTile = k / shapeK;
174 
175   // verify shape of aVector
176   if (!((aShape[0] == mTile * kTile) && (aShape[1] == numElementA)))
177     return emitOpError() << "expected matrix A to be shaped (" << mTile * kTile
178                          << " x " << numElementA << ")";
179 
180   // verify shape of bVector
181   if (!((bShape[0] == kTile * nTile) && (bShape[1] == numElementB)))
182     return emitOpError() << "expected matrix B to be shaped (" << kTile * nTile
183                          << " x " << numElementB << ")";
184 
185   // verify shape of cVector
186   if (!((cShape[0] == mTile * nTile) && (cShape[1] == numElementC)))
187     return emitOpError() << "expected matrix C to be shaped (" << mTile * nTile
188                          << " x " << numElementC << ")";
189 
190   return success();
191 }
192 
193 //===----------------------------------------------------------------------===//
194 // NVGPU_LdMatrixOp
195 //===----------------------------------------------------------------------===//
verify()196 LogicalResult LdMatrixOp::verify() {
197 
198   // ldmatrix reads data from source in shared memory
199   auto srcMemref = getSrcMemref().getType().cast<MemRefType>();
200 
201   // ldmatrix writes data to result/destination in vector registers
202   auto resVector = getRes().getType().cast<VectorType>();
203 
204   // vector register shape, element type, and bitwidth
205   ArrayRef<int64_t> resShape = resVector.getShape();
206   Type resType = resVector.getElementType();
207   int64_t elementBitWidth = resType.getIntOrFloatBitWidth();
208 
209   // ldmatrix loads 32 bits into vector registers per 8-by-8 tile per thread
210   int64_t numElementsPer32b = 32 / elementBitWidth;
211 
212   // number of 8-by-8 tiles
213   int64_t numTiles = getNumTiles();
214 
215   // transpose elements in vector registers at 16b granularity when true
216   bool isTranspose = getTranspose();
217 
218   // address space id for shared memory
219   unsigned smemAddressSpace = gpu::GPUDialect::getWorkgroupAddressSpace();
220 
221   //
222   // verification
223   //
224 
225   if (!(srcMemref.getMemorySpaceAsInt() == smemAddressSpace))
226     return emitError()
227            << "expected nvgpu.ldmatrix srcMemref must have memory space "
228            << smemAddressSpace;
229   if (elementBitWidth > 32)
230     return emitError() << "nvgpu.ldmatrix works for 32b or lower";
231   if (isTranspose && !(elementBitWidth == 16))
232     return emitError()
233            << "nvgpu.ldmatrix transpose works only at 16b granularity";
234   if (!(resShape[1] == numElementsPer32b))
235     return emitError() << "expected vector register shape[1] = "
236                        << numElementsPer32b;
237   if (!(resShape[0] == numTiles))
238     return emitError()
239            << "expected vector register shape[0] and numTiles to match";
240 
241   return success();
242 }
243 
244 #define GET_OP_CLASSES
245 #include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc"
246