1 //===- NvGpuSupport.cpp - MLIR Vector to GPU lowering support --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides utilities to assist in the lowering of Vector operations
10 // to NvGPU dialect MMA operations.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "NvGpuSupport.h"
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
17 #include "mlir/Dialect/Vector/IR/VectorOps.h"
18
19 namespace mlir {
20 namespace nvgpu {
21 namespace {
22
23 /// There are always 4 threads per [128|256|512] bit row.
24 constexpr int64_t kThreadsPerRow = 4;
25
26 constexpr int64_t kNumRowsPerTile = 8;
27
isAccumulatorOrResult(MatMulOperandRole operandType)28 bool isAccumulatorOrResult(MatMulOperandRole operandType) {
29 return operandType == MatMulOperandRole::C;
30 }
31
32 /// Returns the number of registers which compose a matrix fragment held by a
33 /// single thread.
inferNumRegistersPerMatrixFragment(const WarpMatrixInfo & type)34 int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type) {
35 int64_t lineSize = inferTileWidthInBits(type);
36 auto shape = type.vectorType.getShape();
37 return (shape[0] / kNumRowsPerTile) *
38 (shape[1] * type.vectorType.getElementType().getIntOrFloatBitWidth()) /
39 lineSize;
40 }
41
42 /// Returns the number of 8 x [128|256|512] bit tiles that compose the given
43 /// operand shape.
getTileShape(ArrayRef<int64_t> operandShape,Type elementType,int64_t lineSizeBits)44 std::array<int64_t, 2> getTileShape(ArrayRef<int64_t> operandShape,
45 Type elementType, int64_t lineSizeBits) {
46 // For each 8x128bit square, a thread is responsible for one 32bit register.
47 return {operandShape[0] / kNumRowsPerTile,
48 (operandShape[1] * elementType.getIntOrFloatBitWidth()) /
49 lineSizeBits};
50 }
51
52 } // namespace
53
getWarpMatrixInfo(Operation * op)54 FailureOr<WarpMatrixInfo> getWarpMatrixInfo(Operation *op) {
55 WarpMatrixInfo info;
56
57 // Determine the vector type.
58 if (vector::TransferWriteOp writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
59 info.vectorType = writeOp.getVectorType();
60 } else if (isa<vector::TransferReadOp, vector::ContractionOp,
61 arith::ConstantOp>(op)) {
62 info.vectorType = op->getResult(0).getType().cast<VectorType>();
63 } else {
64 return op->emitError()
65 << "unhandled operation type in nvgpu.mma.sync conversion path";
66 }
67
68 // Determine the operand role. We assume it is an accumulator/result unless it
69 // is directly consumed by a `vector.contract` op.
70 info.operandRole = MatMulOperandRole::C;
71 for (Operation *user : op->getUsers()) {
72 auto contract = dyn_cast<vector::ContractionOp>(user);
73 if (!contract)
74 continue;
75 if (contract.getLhs() == op->getResult(0)) {
76 info.operandRole = MatMulOperandRole::A;
77 break;
78 }
79 if (contract.getRhs() == op->getResult(0)) {
80 info.operandRole = MatMulOperandRole::B;
81 break;
82 }
83 }
84 return info;
85 }
86
inferTileWidthInBits(const WarpMatrixInfo & type)87 int64_t inferTileWidthInBits(const WarpMatrixInfo &type) {
88 bool isAcc = isAccumulatorOrResult(type.operandRole);
89 Type elType = type.vectorType.getElementType();
90 if (isAcc && elType.getIntOrFloatBitWidth() == 32) {
91 return 256;
92 }
93 if (elType.getIntOrFloatBitWidth() == 64) {
94 return isAcc ? 512 : 256;
95 }
96 return 128;
97 }
98
99 FailureOr<FragmentElementInfo>
getMmaSyncRegisterType(const WarpMatrixInfo & type)100 getMmaSyncRegisterType(const WarpMatrixInfo &type) {
101 MLIRContext *ctx = type.vectorType.getContext();
102 const bool isAccum = isAccumulatorOrResult(type.operandRole);
103
104 Type elType = type.vectorType.getElementType();
105 if (elType.isF16()) {
106 return FragmentElementInfo{
107 LLVM::getFixedVectorType(Float16Type::get(ctx), 2), 2, 32,
108 inferNumRegistersPerMatrixFragment(type)};
109 }
110
111 // f64 operand
112 Type f64Ty = Float64Type::get(ctx);
113 if (elType.isF64()) {
114 return isAccum
115 ? FragmentElementInfo{LLVM::getFixedVectorType(f64Ty, 2), 2, 128,
116 inferNumRegistersPerMatrixFragment(type)}
117 : FragmentElementInfo{f64Ty, 1, 64,
118 inferNumRegistersPerMatrixFragment(type)};
119 }
120
121 // int8 operand
122 if (elType.isInteger(8)) {
123 return FragmentElementInfo{
124 LLVM::getFixedVectorType(IntegerType::get(ctx, 8), 4), 4, 32,
125 inferNumRegistersPerMatrixFragment(type)};
126 }
127
128 // int4 operand
129 if (elType.isInteger(4)) {
130 return FragmentElementInfo{
131 LLVM::getFixedVectorType(IntegerType::get(ctx, 4), 8), 8, 32,
132 inferNumRegistersPerMatrixFragment(type)};
133 }
134
135 // Integer 32bit acc operands
136 if (elType.isInteger(32)) {
137 return FragmentElementInfo{
138 LLVM::getFixedVectorType(IntegerType::get(ctx, 32), 2), 2, 64,
139 inferNumRegistersPerMatrixFragment(type)};
140 }
141
142 // Floating point 32bit operands
143 if (elType.isF32()) {
144 Type f32Ty = Float32Type::get(ctx);
145 return isAccum
146 ? FragmentElementInfo{LLVM::getFixedVectorType(f32Ty, 2), 2, 64,
147 inferNumRegistersPerMatrixFragment(type)}
148 : FragmentElementInfo{f32Ty, 1, 32,
149 inferNumRegistersPerMatrixFragment(type)};
150 }
151 return failure();
152 }
153
getRegisterIndexToTileOffsetMap(int64_t lineSize,Type elementType,ArrayRef<int64_t> operandShape,bool isAccumulator,int64_t elementsPerRegister,AffineExpr logicalValueId)154 static AffineMap getRegisterIndexToTileOffsetMap(int64_t lineSize,
155 Type elementType,
156 ArrayRef<int64_t> operandShape,
157 bool isAccumulator,
158 int64_t elementsPerRegister,
159 AffineExpr logicalValueId) {
160 const int64_t elementsPerLine =
161 lineSize / elementType.getIntOrFloatBitWidth();
162 const std::array<int64_t, 2> num8x128bTiles =
163 getTileShape(operandShape, elementType, lineSize);
164 AffineExpr registerIdx = logicalValueId.floorDiv(elementsPerRegister);
165 return AffineMap::get(
166 2, 0,
167 {(registerIdx % num8x128bTiles[0]) * 8,
168 (registerIdx.floorDiv(num8x128bTiles[0])) * elementsPerLine},
169 elementType.getContext());
170 }
171
172 FailureOr<AffineMap>
getLaneIdAndValueIdToOperandCoord(Location loc,OpBuilder & builder,const WarpMatrixInfo & fragmentType)173 getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder,
174 const WarpMatrixInfo &fragmentType) {
175 Type elementType = fragmentType.vectorType.getElementType();
176 ArrayRef<int64_t> operandShape = fragmentType.vectorType.getShape();
177 FailureOr<nvgpu::FragmentElementInfo> regInfo =
178 getMmaSyncRegisterType(fragmentType);
179 if (failed(regInfo))
180 return failure();
181
182 const int64_t elementBitWidth = elementType.getIntOrFloatBitWidth();
183 const int64_t elementsPerRegister =
184 regInfo->registerWidthBits / elementBitWidth;
185 const int64_t lineSize = inferTileWidthInBits(fragmentType);
186
187 AffineExpr laneId, logicalValueIdDim;
188 bindDims(builder.getContext(), laneId, logicalValueIdDim);
189
190 // Determine what register logicalValueId corresponds to. Use that as a
191 // linear index into the coordinate mapping `index -> (tile row, tile col)`.
192 AffineMap registerIndexToTileCoord = getRegisterIndexToTileOffsetMap(
193 lineSize, elementType, operandShape,
194 isAccumulatorOrResult(fragmentType.operandRole), elementsPerRegister,
195 logicalValueIdDim);
196
197 auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
198 return AffineMap::get(2, 0, dimExprs, builder.getContext());
199 };
200
201 auto tileRow = registerIndexToTileCoord.getResult(0);
202 auto tileCol = registerIndexToTileCoord.getResult(1);
203 return makeMap({tileRow + laneId.floorDiv(kThreadsPerRow),
204 tileCol + (laneId % kThreadsPerRow) * elementsPerRegister +
205 (logicalValueIdDim % elementsPerRegister)});
206 }
207
getLdMatrixParams(const WarpMatrixInfo & type,bool transpose)208 FailureOr<nvgpu::LdMatrixParams> getLdMatrixParams(const WarpMatrixInfo &type,
209 bool transpose) {
210 LdMatrixParams params;
211 Type elType = type.vectorType.getElementType();
212 params.fragmentType = type.vectorType;
213 if (type.operandRole == MatMulOperandRole::A ||
214 type.operandRole == MatMulOperandRole::C) {
215 params.targetLayout = NVVM::MMALayout::row;
216 } else {
217 params.targetLayout = NVVM::MMALayout::col;
218 }
219 ArrayRef<int64_t> shape = type.vectorType.getShape();
220 params.contiguousDimType =
221 transpose ? IteratorType::Parallel : IteratorType::Reduction;
222
223 if (params.contiguousDimType == IteratorType::Reduction) {
224 params.numTiles = (shape[0] / kNumRowsPerTile) *
225 ((shape[1] * elType.getIntOrFloatBitWidth()) / 128);
226 } else {
227 params.numTiles = (shape[1] / kNumRowsPerTile) *
228 ((shape[0] * elType.getIntOrFloatBitWidth()) / 128);
229 }
230
231 if (params.numTiles == 0)
232 return failure();
233
234 return params;
235 }
236
237 FailureOr<AffineMap>
getLaneIdToLdMatrixMatrixCoord(Location loc,OpBuilder & builder,const LdMatrixParams & params)238 getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder,
239 const LdMatrixParams ¶ms) {
240 // One thread per 128b row.
241 const int64_t kNumThreadsPerTile = kNumRowsPerTile;
242 const int bitsPerElement = static_cast<int>(
243 params.fragmentType.getElementType().getIntOrFloatBitWidth());
244 const int kElementsPer128b = (128 / bitsPerElement);
245 ArrayRef<int64_t> operandShape = params.fragmentType.getShape();
246 AffineExpr d0 = getAffineDimExpr(0, builder.getContext());
247
248 auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
249 return AffineMap::get(1, 0, dimExprs, builder.getContext());
250 };
251
252 // This case corresponds to row-major A|C or col-major B operands.
253 if (params.contiguousDimType == IteratorType::Reduction) {
254 AffineExpr row = d0 % (operandShape[0]);
255 AffineExpr col = d0.floorDiv(operandShape[0]) * (kElementsPer128b);
256 return makeMap({row, col});
257 }
258
259 // This case Corresponds to col-major A|C or row-major B operands. The
260 // operandShape given is already pre-transposed (e.g. 8x16 = KxN).
261 if (params.contiguousDimType == IteratorType::Parallel) {
262 const int64_t num8x128bCols = (operandShape[0] * bitsPerElement) / 128;
263 // Threads are assigned in groups of 8 first across columns, then to
264 // rows. This is transpose of what `ldmatrix` expects, but when
265 // `ldmatrix` gets the `.trans` qualifier, final the effect will be to
266 // transpose just the blocks.
267 auto groupIdx = d0.floorDiv(kNumThreadsPerTile);
268 auto tileCol = (groupIdx % num8x128bCols);
269 auto tileRow = groupIdx.floorDiv(num8x128bCols);
270 return makeMap({tileCol * kElementsPer128b,
271 tileRow * kNumRowsPerTile + (d0 % kNumRowsPerTile)});
272 }
273 return failure();
274 }
275
276 LogicalResult
matchAndRewrite(vector::ContractionOp op,PatternRewriter & rewriter) const277 PrepareContractToGPUMMASync::matchAndRewrite(vector::ContractionOp op,
278 PatternRewriter &rewriter) const {
279 Location loc = op.getLoc();
280 Value lhs = op.getLhs();
281 Value rhs = op.getRhs();
282 Value res = op.getAcc();
283
284 // Set up the parallel/reduction structure in right form.
285 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
286 auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
287 AffineExpr m;
288 AffineExpr n;
289 AffineExpr k;
290 bindDims(rewriter.getContext(), m, n, k);
291 static constexpr std::array<int64_t, 2> perm = {1, 0};
292 auto iteratorTypes = op.getIteratorTypes().getValue();
293 SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
294 if (iteratorTypes.size() != 3)
295 return failure();
296 if (!(isParallelIterator(iteratorTypes[0]) &&
297 isParallelIterator(iteratorTypes[1]) &&
298 isReductionIterator(iteratorTypes[2])))
299 return failure();
300
301 // The canonical form is "TNT" = A row-major, B col-major, C row-major.
302 const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
303 if (maps == canonicalForm) {
304 return failure();
305 }
306 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
307 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
308 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
309 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
310 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
311 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
312 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
313 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
314 std::swap(rhs, lhs);
315 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
316 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
317 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
318 std::swap(rhs, lhs);
319 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
320 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
321 std::swap(lhs, rhs);
322 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
323 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
324 std::swap(lhs, rhs);
325 } else {
326 return failure();
327 }
328 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
329 op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm),
330 op.getIteratorTypes());
331 return success();
332 }
333
334 } // namespace nvgpu
335 } // namespace mlir
336