1 //===- NvGpuSupport.cpp - MLIR Vector to GPU lowering support --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides utilities to assist in the lowering of Vector operations
10 // to NvGPU dialect MMA operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "NvGpuSupport.h"
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
17 #include "mlir/Dialect/Vector/IR/VectorOps.h"
18 
19 namespace mlir {
20 namespace nvgpu {
21 namespace {
22 
23 /// There are always 4 threads per [128|256|512] bit row.
24 constexpr int64_t kThreadsPerRow = 4;
25 
26 constexpr int64_t kNumRowsPerTile = 8;
27 
isAccumulatorOrResult(MatMulOperandRole operandType)28 bool isAccumulatorOrResult(MatMulOperandRole operandType) {
29   return operandType == MatMulOperandRole::C;
30 }
31 
32 /// Returns the number of registers which compose a matrix fragment held by a
33 /// single thread.
inferNumRegistersPerMatrixFragment(const WarpMatrixInfo & type)34 int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type) {
35   int64_t lineSize = inferTileWidthInBits(type);
36   auto shape = type.vectorType.getShape();
37   return (shape[0] / kNumRowsPerTile) *
38          (shape[1] * type.vectorType.getElementType().getIntOrFloatBitWidth()) /
39          lineSize;
40 }
41 
42 /// Returns the number of 8 x [128|256|512] bit tiles that compose the given
43 /// operand shape.
getTileShape(ArrayRef<int64_t> operandShape,Type elementType,int64_t lineSizeBits)44 std::array<int64_t, 2> getTileShape(ArrayRef<int64_t> operandShape,
45                                     Type elementType, int64_t lineSizeBits) {
46   // For each 8x128bit square, a thread is responsible for one 32bit register.
47   return {operandShape[0] / kNumRowsPerTile,
48           (operandShape[1] * elementType.getIntOrFloatBitWidth()) /
49               lineSizeBits};
50 }
51 
52 } // namespace
53 
getWarpMatrixInfo(Operation * op)54 FailureOr<WarpMatrixInfo> getWarpMatrixInfo(Operation *op) {
55   WarpMatrixInfo info;
56 
57   // Determine the vector type.
58   if (vector::TransferWriteOp writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
59     info.vectorType = writeOp.getVectorType();
60   } else if (isa<vector::TransferReadOp, vector::ContractionOp,
61                  arith::ConstantOp>(op)) {
62     info.vectorType = op->getResult(0).getType().cast<VectorType>();
63   } else {
64     return op->emitError()
65            << "unhandled operation type in nvgpu.mma.sync conversion path";
66   }
67 
68   // Determine the operand role. We assume it is an accumulator/result unless it
69   // is directly consumed by a `vector.contract` op.
70   info.operandRole = MatMulOperandRole::C;
71   for (Operation *user : op->getUsers()) {
72     auto contract = dyn_cast<vector::ContractionOp>(user);
73     if (!contract)
74       continue;
75     if (contract.getLhs() == op->getResult(0)) {
76       info.operandRole = MatMulOperandRole::A;
77       break;
78     }
79     if (contract.getRhs() == op->getResult(0)) {
80       info.operandRole = MatMulOperandRole::B;
81       break;
82     }
83   }
84   return info;
85 }
86 
inferTileWidthInBits(const WarpMatrixInfo & type)87 int64_t inferTileWidthInBits(const WarpMatrixInfo &type) {
88   bool isAcc = isAccumulatorOrResult(type.operandRole);
89   Type elType = type.vectorType.getElementType();
90   if (isAcc && elType.getIntOrFloatBitWidth() == 32) {
91     return 256;
92   }
93   if (elType.getIntOrFloatBitWidth() == 64) {
94     return isAcc ? 512 : 256;
95   }
96   return 128;
97 }
98 
99 FailureOr<FragmentElementInfo>
getMmaSyncRegisterType(const WarpMatrixInfo & type)100 getMmaSyncRegisterType(const WarpMatrixInfo &type) {
101   MLIRContext *ctx = type.vectorType.getContext();
102   const bool isAccum = isAccumulatorOrResult(type.operandRole);
103 
104   Type elType = type.vectorType.getElementType();
105   if (elType.isF16()) {
106     return FragmentElementInfo{
107         LLVM::getFixedVectorType(Float16Type::get(ctx), 2), 2, 32,
108         inferNumRegistersPerMatrixFragment(type)};
109   }
110 
111   // f64 operand
112   Type f64Ty = Float64Type::get(ctx);
113   if (elType.isF64()) {
114     return isAccum
115                ? FragmentElementInfo{LLVM::getFixedVectorType(f64Ty, 2), 2, 128,
116                                      inferNumRegistersPerMatrixFragment(type)}
117                : FragmentElementInfo{f64Ty, 1, 64,
118                                      inferNumRegistersPerMatrixFragment(type)};
119   }
120 
121   // int8 operand
122   if (elType.isInteger(8)) {
123     return FragmentElementInfo{
124         LLVM::getFixedVectorType(IntegerType::get(ctx, 8), 4), 4, 32,
125         inferNumRegistersPerMatrixFragment(type)};
126   }
127 
128   // int4 operand
129   if (elType.isInteger(4)) {
130     return FragmentElementInfo{
131         LLVM::getFixedVectorType(IntegerType::get(ctx, 4), 8), 8, 32,
132         inferNumRegistersPerMatrixFragment(type)};
133   }
134 
135   // Integer 32bit acc operands
136   if (elType.isInteger(32)) {
137     return FragmentElementInfo{
138         LLVM::getFixedVectorType(IntegerType::get(ctx, 32), 2), 2, 64,
139         inferNumRegistersPerMatrixFragment(type)};
140   }
141 
142   // Floating point 32bit operands
143   if (elType.isF32()) {
144     Type f32Ty = Float32Type::get(ctx);
145     return isAccum
146                ? FragmentElementInfo{LLVM::getFixedVectorType(f32Ty, 2), 2, 64,
147                                      inferNumRegistersPerMatrixFragment(type)}
148                : FragmentElementInfo{f32Ty, 1, 32,
149                                      inferNumRegistersPerMatrixFragment(type)};
150   }
151   return failure();
152 }
153 
getRegisterIndexToTileOffsetMap(int64_t lineSize,Type elementType,ArrayRef<int64_t> operandShape,bool isAccumulator,int64_t elementsPerRegister,AffineExpr logicalValueId)154 static AffineMap getRegisterIndexToTileOffsetMap(int64_t lineSize,
155                                                  Type elementType,
156                                                  ArrayRef<int64_t> operandShape,
157                                                  bool isAccumulator,
158                                                  int64_t elementsPerRegister,
159                                                  AffineExpr logicalValueId) {
160   const int64_t elementsPerLine =
161       lineSize / elementType.getIntOrFloatBitWidth();
162   const std::array<int64_t, 2> num8x128bTiles =
163       getTileShape(operandShape, elementType, lineSize);
164   AffineExpr registerIdx = logicalValueId.floorDiv(elementsPerRegister);
165   return AffineMap::get(
166       2, 0,
167       {(registerIdx % num8x128bTiles[0]) * 8,
168        (registerIdx.floorDiv(num8x128bTiles[0])) * elementsPerLine},
169       elementType.getContext());
170 }
171 
172 FailureOr<AffineMap>
getLaneIdAndValueIdToOperandCoord(Location loc,OpBuilder & builder,const WarpMatrixInfo & fragmentType)173 getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder,
174                                   const WarpMatrixInfo &fragmentType) {
175   Type elementType = fragmentType.vectorType.getElementType();
176   ArrayRef<int64_t> operandShape = fragmentType.vectorType.getShape();
177   FailureOr<nvgpu::FragmentElementInfo> regInfo =
178       getMmaSyncRegisterType(fragmentType);
179   if (failed(regInfo))
180     return failure();
181 
182   const int64_t elementBitWidth = elementType.getIntOrFloatBitWidth();
183   const int64_t elementsPerRegister =
184       regInfo->registerWidthBits / elementBitWidth;
185   const int64_t lineSize = inferTileWidthInBits(fragmentType);
186 
187   AffineExpr laneId, logicalValueIdDim;
188   bindDims(builder.getContext(), laneId, logicalValueIdDim);
189 
190   // Determine what register logicalValueId corresponds to. Use that as a
191   // linear index into the coordinate mapping `index -> (tile row, tile col)`.
192   AffineMap registerIndexToTileCoord = getRegisterIndexToTileOffsetMap(
193       lineSize, elementType, operandShape,
194       isAccumulatorOrResult(fragmentType.operandRole), elementsPerRegister,
195       logicalValueIdDim);
196 
197   auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
198     return AffineMap::get(2, 0, dimExprs, builder.getContext());
199   };
200 
201   auto tileRow = registerIndexToTileCoord.getResult(0);
202   auto tileCol = registerIndexToTileCoord.getResult(1);
203   return makeMap({tileRow + laneId.floorDiv(kThreadsPerRow),
204                   tileCol + (laneId % kThreadsPerRow) * elementsPerRegister +
205                       (logicalValueIdDim % elementsPerRegister)});
206 }
207 
getLdMatrixParams(const WarpMatrixInfo & type,bool transpose)208 FailureOr<nvgpu::LdMatrixParams> getLdMatrixParams(const WarpMatrixInfo &type,
209                                                    bool transpose) {
210   LdMatrixParams params;
211   Type elType = type.vectorType.getElementType();
212   params.fragmentType = type.vectorType;
213   if (type.operandRole == MatMulOperandRole::A ||
214       type.operandRole == MatMulOperandRole::C) {
215     params.targetLayout = NVVM::MMALayout::row;
216   } else {
217     params.targetLayout = NVVM::MMALayout::col;
218   }
219   ArrayRef<int64_t> shape = type.vectorType.getShape();
220   params.contiguousDimType =
221       transpose ? IteratorType::Parallel : IteratorType::Reduction;
222 
223   if (params.contiguousDimType == IteratorType::Reduction) {
224     params.numTiles = (shape[0] / kNumRowsPerTile) *
225                       ((shape[1] * elType.getIntOrFloatBitWidth()) / 128);
226   } else {
227     params.numTiles = (shape[1] / kNumRowsPerTile) *
228                       ((shape[0] * elType.getIntOrFloatBitWidth()) / 128);
229   }
230 
231   if (params.numTiles == 0)
232     return failure();
233 
234   return params;
235 }
236 
237 FailureOr<AffineMap>
getLaneIdToLdMatrixMatrixCoord(Location loc,OpBuilder & builder,const LdMatrixParams & params)238 getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder,
239                                const LdMatrixParams &params) {
240   // One thread per 128b row.
241   const int64_t kNumThreadsPerTile = kNumRowsPerTile;
242   const int bitsPerElement = static_cast<int>(
243       params.fragmentType.getElementType().getIntOrFloatBitWidth());
244   const int kElementsPer128b = (128 / bitsPerElement);
245   ArrayRef<int64_t> operandShape = params.fragmentType.getShape();
246   AffineExpr d0 = getAffineDimExpr(0, builder.getContext());
247 
248   auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
249     return AffineMap::get(1, 0, dimExprs, builder.getContext());
250   };
251 
252   // This case corresponds to row-major A|C or col-major B operands.
253   if (params.contiguousDimType == IteratorType::Reduction) {
254     AffineExpr row = d0 % (operandShape[0]);
255     AffineExpr col = d0.floorDiv(operandShape[0]) * (kElementsPer128b);
256     return makeMap({row, col});
257   }
258 
259   // This case Corresponds to col-major A|C or row-major B operands. The
260   // operandShape given is already pre-transposed (e.g. 8x16 = KxN).
261   if (params.contiguousDimType == IteratorType::Parallel) {
262     const int64_t num8x128bCols = (operandShape[0] * bitsPerElement) / 128;
263     // Threads are assigned in groups of 8 first across columns, then to
264     // rows. This is transpose of what `ldmatrix` expects, but when
265     // `ldmatrix` gets the `.trans` qualifier, final the effect will be to
266     // transpose just the blocks.
267     auto groupIdx = d0.floorDiv(kNumThreadsPerTile);
268     auto tileCol = (groupIdx % num8x128bCols);
269     auto tileRow = groupIdx.floorDiv(num8x128bCols);
270     return makeMap({tileCol * kElementsPer128b,
271                     tileRow * kNumRowsPerTile + (d0 % kNumRowsPerTile)});
272   }
273   return failure();
274 }
275 
276 LogicalResult
matchAndRewrite(vector::ContractionOp op,PatternRewriter & rewriter) const277 PrepareContractToGPUMMASync::matchAndRewrite(vector::ContractionOp op,
278                                              PatternRewriter &rewriter) const {
279   Location loc = op.getLoc();
280   Value lhs = op.getLhs();
281   Value rhs = op.getRhs();
282   Value res = op.getAcc();
283 
284   // Set up the parallel/reduction structure in right form.
285   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
286   auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
287   AffineExpr m;
288   AffineExpr n;
289   AffineExpr k;
290   bindDims(rewriter.getContext(), m, n, k);
291   static constexpr std::array<int64_t, 2> perm = {1, 0};
292   auto iteratorTypes = op.getIteratorTypes().getValue();
293   SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
294   if (iteratorTypes.size() != 3)
295     return failure();
296   if (!(isParallelIterator(iteratorTypes[0]) &&
297         isParallelIterator(iteratorTypes[1]) &&
298         isReductionIterator(iteratorTypes[2])))
299     return failure();
300 
301   // The canonical form is "TNT" = A row-major, B col-major, C row-major.
302   const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
303   if (maps == canonicalForm) {
304     return failure();
305   }
306   if (maps == infer({{m, k}, {k, n}, {m, n}})) {
307     rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
308   } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
309     lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
310   } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
311     rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
312     lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
313   } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
314     std::swap(rhs, lhs);
315     rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
316     lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
317   } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
318     std::swap(rhs, lhs);
319     rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
320   } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
321     std::swap(lhs, rhs);
322     lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
323   } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
324     std::swap(lhs, rhs);
325   } else {
326     return failure();
327   }
328   rewriter.replaceOpWithNewOp<vector::ContractionOp>(
329       op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm),
330       op.getIteratorTypes());
331   return success();
332 }
333 
334 } // namespace nvgpu
335 } // namespace mlir
336