1 //===- NvGpuSupport.cpp - MLIR Vector to GPU lowering support --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides utilities to assist in the lowering of Vector operations
10 // to NvGPU dialect MMA operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "NvGpuSupport.h"
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/NVGPU/NVGPUDialect.h"
17 #include "mlir/Dialect/Vector/IR/VectorOps.h"
18 
19 namespace mlir {
20 namespace nvgpu {
21 namespace {
22 
23 /// There are always 4 threads per [128|256|512] bit row.
24 constexpr int64_t kThreadsPerRow = 4;
25 
26 constexpr int64_t kNumRowsPerTile = 8;
27 
28 bool isAccumulatorOrResult(MatMulOperandRole operandType) {
29   return operandType == MatMulOperandRole::C;
30 }
31 
32 /// Returns the number of registers which compose a matrix fragment held by a
33 /// single thread.
34 int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type) {
35   int64_t lineSize = inferTileWidthInBits(type);
36   auto shape = type.vectorType.getShape();
37   return (shape[0] / kNumRowsPerTile) *
38          (shape[1] * type.vectorType.getElementType().getIntOrFloatBitWidth()) /
39          lineSize;
40 }
41 
42 /// Returns the number of 8 x [128|256|512] bit tiles that compose the given
43 /// operand shape.
44 std::array<int64_t, 2> getTileShape(ArrayRef<int64_t> operandShape,
45                                     Type elementType, int64_t lineSizeBits) {
46   // For each 8x128bit square, a thread is responsible for one 32bit register.
47   return {operandShape[0] / kNumRowsPerTile,
48           (operandShape[1] * elementType.getIntOrFloatBitWidth()) /
49               lineSizeBits};
50 }
51 
52 } // namespace
53 
54 FailureOr<WarpMatrixInfo> getWarpMatrixInfo(Operation *op) {
55   WarpMatrixInfo info;
56 
57   // Determine the vector type.
58   if (vector::TransferWriteOp writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
59     info.vectorType = writeOp.getVectorType();
60   } else if (isa<vector::TransferReadOp, vector::ContractionOp,
61                  arith::ConstantOp>(op)) {
62     info.vectorType = op->getResult(0).getType().cast<VectorType>();
63   } else {
64     return op->emitError()
65            << "unhandled operation type in nvgpu.mma.sync conversion path";
66   }
67 
68   // Determine the operand role. We assume it is an accumulator/result unless it
69   // is directly consumed by a `vector.contract` op.
70   info.operandRole = MatMulOperandRole::C;
71   for (Operation *user : op->getUsers()) {
72     auto contract = dyn_cast<vector::ContractionOp>(user);
73     if (!contract)
74       continue;
75     if (contract.getLhs() == op->getResult(0)) {
76       info.operandRole = MatMulOperandRole::A;
77       break;
78     }
79     if (contract.getRhs() == op->getResult(0)) {
80       info.operandRole = MatMulOperandRole::B;
81       break;
82     }
83   }
84   return info;
85 }
86 
87 int64_t inferTileWidthInBits(const WarpMatrixInfo &type) {
88   bool isAcc = isAccumulatorOrResult(type.operandRole);
89   Type elType = type.vectorType.getElementType();
90   if (isAcc && elType.getIntOrFloatBitWidth() == 32) {
91     return 256;
92   }
93   if (elType.getIntOrFloatBitWidth() == 64) {
94     return isAcc ? 512 : 256;
95   }
96   return 128;
97 }
98 
99 FailureOr<FragmentElementInfo>
100 getMmaSyncRegisterType(const WarpMatrixInfo &type) {
101   MLIRContext *ctx = type.vectorType.getContext();
102   const bool isAccum = isAccumulatorOrResult(type.operandRole);
103 
104   Type elType = type.vectorType.getElementType();
105   if (elType.isF16()) {
106     return FragmentElementInfo{
107         LLVM::getFixedVectorType(Float16Type::get(ctx), 2), 2, 32,
108         inferNumRegistersPerMatrixFragment(type)};
109   }
110 
111   // f64 operand
112   Type f64Ty = Float64Type::get(ctx);
113   if (elType.isF64()) {
114     return isAccum
115                ? FragmentElementInfo{LLVM::getFixedVectorType(f64Ty, 2), 2, 128,
116                                      inferNumRegistersPerMatrixFragment(type)}
117                : FragmentElementInfo{f64Ty, 1, 64,
118                                      inferNumRegistersPerMatrixFragment(type)};
119   }
120 
121   // int8 operand
122   if (elType.isInteger(8)) {
123     return FragmentElementInfo{
124         LLVM::getFixedVectorType(IntegerType::get(ctx, 8), 4), 4, 32,
125         inferNumRegistersPerMatrixFragment(type)};
126   }
127   // Integer 32bit acc operands
128   if (elType.isInteger(32)) {
129     return FragmentElementInfo{
130         LLVM::getFixedVectorType(IntegerType::get(ctx, 32), 2), 2, 64,
131         inferNumRegistersPerMatrixFragment(type)};
132   }
133 
134   // Floating point 32bit operands
135   if (elType.isF32()) {
136     Type f32Ty = Float32Type::get(ctx);
137     return isAccum
138                ? FragmentElementInfo{LLVM::getFixedVectorType(f32Ty, 2), 2, 64,
139                                      inferNumRegistersPerMatrixFragment(type)}
140                : FragmentElementInfo{f32Ty, 1, 32,
141                                      inferNumRegistersPerMatrixFragment(type)};
142   }
143   return failure();
144 }
145 
146 static AffineMap getRegisterIndexToTileOffsetMap(int64_t lineSize,
147                                                  Type elementType,
148                                                  ArrayRef<int64_t> operandShape,
149                                                  bool isAccumulator,
150                                                  int64_t elementsPerRegister,
151                                                  AffineExpr logicalValueId) {
152   const int64_t elementsPerLine =
153       lineSize / elementType.getIntOrFloatBitWidth();
154   const std::array<int64_t, 2> num8x128bTiles =
155       getTileShape(operandShape, elementType, lineSize);
156   AffineExpr registerIdx = logicalValueId.floorDiv(elementsPerRegister);
157   return AffineMap::get(
158       2, 0,
159       {(registerIdx % num8x128bTiles[0]) * 8,
160        (registerIdx.floorDiv(num8x128bTiles[0])) * elementsPerLine},
161       elementType.getContext());
162 }
163 
164 FailureOr<AffineMap>
165 getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder,
166                                   const WarpMatrixInfo &fragmentType) {
167   Type elementType = fragmentType.vectorType.getElementType();
168   ArrayRef<int64_t> operandShape = fragmentType.vectorType.getShape();
169   FailureOr<nvgpu::FragmentElementInfo> regInfo =
170       getMmaSyncRegisterType(fragmentType);
171   if (failed(regInfo))
172     return failure();
173 
174   const int64_t elementBitWidth = elementType.getIntOrFloatBitWidth();
175   const int64_t elementsPerRegister =
176       regInfo->registerWidthBits / elementBitWidth;
177   const int64_t lineSize = inferTileWidthInBits(fragmentType);
178 
179   AffineExpr laneId, logicalValueIdDim;
180   bindDims(builder.getContext(), laneId, logicalValueIdDim);
181 
182   // Determine what register logicalValueId corresponds to. Use that as a
183   // linear index into the coordinate mapping `index -> (tile row, tile col)`.
184   AffineMap registerIndexToTileCoord = getRegisterIndexToTileOffsetMap(
185       lineSize, elementType, operandShape,
186       isAccumulatorOrResult(fragmentType.operandRole), elementsPerRegister,
187       logicalValueIdDim);
188 
189   auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
190     return AffineMap::get(2, 0, dimExprs, builder.getContext());
191   };
192 
193   auto tileRow = registerIndexToTileCoord.getResult(0);
194   auto tileCol = registerIndexToTileCoord.getResult(1);
195   return makeMap({tileRow + laneId.floorDiv(kThreadsPerRow),
196                   tileCol + (laneId % kThreadsPerRow) * elementsPerRegister +
197                       (logicalValueIdDim % elementsPerRegister)});
198 }
199 
200 FailureOr<nvgpu::LdMatrixParams> getLdMatrixParams(const WarpMatrixInfo &type,
201                                                    bool transpose) {
202   LdMatrixParams params;
203   Type elType = type.vectorType.getElementType();
204   params.fragmentType = type.vectorType;
205   if (type.operandRole == MatMulOperandRole::A ||
206       type.operandRole == MatMulOperandRole::C) {
207     params.targetLayout = NVVM::MMALayout::row;
208   } else {
209     params.targetLayout = NVVM::MMALayout::col;
210   }
211   ArrayRef<int64_t> shape = type.vectorType.getShape();
212   params.contiguousDimType =
213       transpose ? IteratorType::Parallel : IteratorType::Reduction;
214 
215   if (params.targetLayout == NVVM::MMALayout::row) {
216     params.numTiles = (shape[0] / kNumRowsPerTile) *
217                       ((shape[1] * elType.getIntOrFloatBitWidth()) / 128);
218   } else {
219     params.numTiles = (shape[1] / kNumRowsPerTile) *
220                       ((shape[0] * elType.getIntOrFloatBitWidth()) / 128);
221   }
222 
223   if (params.numTiles == 0)
224     return failure();
225 
226   return params;
227 }
228 
229 FailureOr<AffineMap>
230 getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder,
231                                const LdMatrixParams &params) {
232   // One thread per 128b row.
233   const int64_t kNumThreadsPerTile = kNumRowsPerTile;
234   const int bitsPerElement = static_cast<int>(
235       params.fragmentType.getElementType().getIntOrFloatBitWidth());
236   const int kElementsPer128b = (128 / bitsPerElement);
237   ArrayRef<int64_t> operandShape = params.fragmentType.getShape();
238   AffineExpr d0 = getAffineDimExpr(0, builder.getContext());
239 
240   auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
241     return AffineMap::get(1, 0, dimExprs, builder.getContext());
242   };
243 
244   // This case corresponds to row-major A|C or col-major B operands.
245   if (params.contiguousDimType == IteratorType::Reduction) {
246     AffineExpr row = d0 % (operandShape[0]);
247     AffineExpr col = d0.floorDiv(operandShape[0]) * (kElementsPer128b);
248     return makeMap({row, col});
249   }
250 
251   // This case Corresponds to col-major A|C or row-major B operands. The
252   // operandShape given is already pre-transposed (e.g. 8x16 = KxN).
253   if (params.contiguousDimType == IteratorType::Parallel) {
254     const int64_t num8x128bCols = (operandShape[0] * bitsPerElement) / 128;
255     // Threads are assigned in groups of 8 first across columns, then to
256     // rows. This is transpose of what `ldmatrix` expects, but when
257     // `ldmatrix` gets the `.trans` qualifier, final the effect will be to
258     // transpose just the blocks.
259     auto groupIdx = d0.floorDiv(kNumThreadsPerTile);
260     auto tileCol = (groupIdx % num8x128bCols);
261     auto tileRow = groupIdx.floorDiv(num8x128bCols);
262     return makeMap({tileCol * kElementsPer128b,
263                     tileRow * kNumRowsPerTile + (d0 % kNumRowsPerTile)});
264   }
265   return failure();
266 }
267 
268 LogicalResult
269 PrepareContractToGPUMMASync::matchAndRewrite(vector::ContractionOp op,
270                                              PatternRewriter &rewriter) const {
271   Location loc = op.getLoc();
272   Value lhs = op.getLhs();
273   Value rhs = op.getRhs();
274   Value res = op.getAcc();
275 
276   // Set up the parallel/reduction structure in right form.
277   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
278   auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
279   AffineExpr m;
280   AffineExpr n;
281   AffineExpr k;
282   bindDims(rewriter.getContext(), m, n, k);
283   static constexpr std::array<int64_t, 2> perm = {1, 0};
284   auto iteratorTypes = op.getIteratorTypes().getValue();
285   SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
286   if (iteratorTypes.size() != 3)
287     return failure();
288   if (!(isParallelIterator(iteratorTypes[0]) &&
289         isParallelIterator(iteratorTypes[1]) &&
290         isReductionIterator(iteratorTypes[2])))
291     return failure();
292 
293   // The canonical form is "TNT" = A row-major, B col-major, C row-major.
294   const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
295   if (maps == canonicalForm) {
296     return failure();
297   }
298   if (maps == infer({{m, k}, {k, n}, {m, n}})) {
299     rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
300   } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
301     lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
302   } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
303     rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
304     lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
305   } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
306     std::swap(rhs, lhs);
307     rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
308     lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
309   } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
310     std::swap(rhs, lhs);
311     rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
312   } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
313     std::swap(lhs, rhs);
314     lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
315   } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
316     std::swap(lhs, rhs);
317   } else {
318     return failure();
319   }
320   rewriter.replaceOpWithNewOp<vector::ContractionOp>(
321       op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm),
322       op.getIteratorTypes());
323   return success();
324 }
325 
326 } // namespace nvgpu
327 } // namespace mlir