1 //===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector operations to GPU dialect ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <type_traits>
14 
15 #include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
16 
17 #include "../PassDetail.h"
18 #include "mlir/Analysis/SliceAnalysis.h"
19 #include "mlir/Dialect/GPU/GPUDialect.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
22 #include "mlir/Dialect/Vector/VectorOps.h"
23 #include "mlir/Dialect/Vector/VectorUtils.h"
24 #include "mlir/IR/Builders.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 #include "mlir/Transforms/Passes.h"
28 
29 using namespace mlir;
30 
31 // Return true if the contract op can be convert to MMA matmul.
32 static bool contractSupportsMMAMatrixType(vector::ContractionOp contract) {
33   if (llvm::size(contract.masks()) != 0)
34     return false;
35 
36   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
37   auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
38   AffineExpr m, n, k;
39   bindDims(contract.getContext(), m, n, k);
40   auto iteratorTypes = contract.iterator_types().getValue();
41   if (!(isParallelIterator(iteratorTypes[0]) &&
42         isParallelIterator(iteratorTypes[1]) &&
43         isReductionIterator(iteratorTypes[2])))
44     return false;
45 
46   // The contract needs to represent a matmul to be able to convert to
47   // MMAMatrix matmul.
48   if (contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}}))
49     return false;
50 
51   // Check that the size matches what is natively supported.
52   VectorType lhsType = contract.lhs().getType().cast<VectorType>();
53   VectorType rhsType = contract.rhs().getType().cast<VectorType>();
54   VectorType accType = contract.acc().getType().cast<VectorType>();
55 
56   std::tuple<int, int, int> dim(lhsType.getDimSize(0), rhsType.getDimSize(1),
57                                 lhsType.getDimSize(1));
58   if (lhsType.getElementType().isInteger(8) &&
59       rhsType.getElementType().isInteger(8) &&
60       accType.getElementType().isInteger(32) &&
61       (dim == std::make_tuple(8, 8, 32) || dim == std::make_tuple(16, 16, 32) ||
62        dim == std::make_tuple(16, 8, 32)))
63     return true;
64 
65   if (lhsType.getElementType().isF16() && rhsType.getElementType().isF16() &&
66       (accType.getElementType().isF16() || accType.getElementType().isF32()) &&
67       (dim == std::make_tuple(8, 8, 16) || dim == std::make_tuple(16, 16, 16) ||
68        dim == std::make_tuple(16, 8, 16)))
69     return true;
70   return false;
71 }
72 
73 // Return the stide for the dimension 0 of |type| if it is a memref and has a
74 // constant stride.
75 static llvm::Optional<int64_t>
76 getMemrefConstantHorizontalStride(ShapedType type) {
77   auto memrefType = type.dyn_cast<MemRefType>();
78   if (!memrefType)
79     return false;
80   int64_t offset = 0;
81   SmallVector<int64_t, 2> strides;
82   if (failed(getStridesAndOffset(memrefType, strides, offset)))
83     return llvm::None;
84   if (strides[0] == ShapedType::kDynamicStrideOrOffset)
85     return llvm::None;
86   return strides[0];
87 }
88 
89 // Return true if the transfer op can be converted to a MMA matrix load.
90 static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
91   if (readOp.mask() || readOp.hasOutOfBoundsDim() ||
92       readOp.getVectorType().getRank() != 2)
93     return false;
94   if (!getMemrefConstantHorizontalStride(readOp.getShapedType()))
95     return false;
96   // TODO: Support transpose once it is added to GPU dialect ops.
97   if (!readOp.permutation_map().isMinorIdentity())
98     return false;
99   return true;
100 }
101 
102 // Return true if the transfer op can be converted to a MMA matrix store.
103 static bool
104 transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
105   if (writeOp.mask() || writeOp.hasOutOfBoundsDim() ||
106       writeOp.getVectorType().getRank() != 2)
107     return false;
108   if (!getMemrefConstantHorizontalStride(writeOp.getShapedType()))
109     return false;
110   // TODO: Support transpose once it is added to GPU dialect ops.
111   if (!writeOp.permutation_map().isMinorIdentity())
112     return false;
113   return true;
114 }
115 
116 /// Return true if the constant is a splat to a 2D vector so that it can be
117 /// converted to a MMA constant matrix op.
118 static bool constantSupportsMMAMatrixType(ConstantOp constantOp) {
119   auto vecType = constantOp.getType().dyn_cast<VectorType>();
120   if (!vecType || vecType.getRank() != 2)
121     return false;
122   return constantOp.value().isa<SplatElementsAttr>();
123 }
124 
125 static bool supportsMMaMatrixType(Operation *op) {
126   if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
127     return transferReadSupportsMMAMatrixType(transferRead);
128   if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
129     return transferWriteSupportsMMAMatrixType(transferWrite);
130   if (auto contract = dyn_cast<vector::ContractionOp>(op))
131     return contractSupportsMMAMatrixType(contract);
132   if (auto constant = dyn_cast<ConstantOp>(op))
133     return constantSupportsMMAMatrixType(constant);
134   return false;
135 }
136 
137 // Analyze slice of operations based on convert op to figure out if the whole
138 // slice can be converted to MMA operations.
139 static SetVector<Operation *> getOpToConvert(mlir::Operation *op) {
140   auto hasVectorDest = [](Operation *op) {
141     return op->getNumResults() == 0 ||
142            llvm::any_of(op->getResultTypes(),
143                         [](Type t) { return t.isa<VectorType>(); });
144   };
145   SetVector<Operation *> opToConvert;
146   op->walk([&](vector::ContractionOp contract) {
147     if (opToConvert.contains(contract.getOperation()))
148       return;
149     SetVector<Operation *> dependentOps =
150         getSlice(contract, hasVectorDest, hasVectorDest);
151     // If any instruction cannot use MMA matrix type drop the whole
152     // chaine. MMA matrix are stored in an opaque type so they cannot be used
153     // by all operations.
154     if (llvm::any_of(dependentOps,
155                      [](Operation *op) { return !supportsMMaMatrixType(op); }))
156       return;
157     opToConvert.insert(dependentOps.begin(), dependentOps.end());
158   });
159   return opToConvert;
160 }
161 
162 namespace {
163 // Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted
164 // to MMA matmul.
165 struct PrepareContractToGPUMMA
166     : public OpRewritePattern<vector::ContractionOp> {
167   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
168 
169   LogicalResult matchAndRewrite(vector::ContractionOp op,
170                                 PatternRewriter &rewriter) const override {
171     Location loc = op.getLoc();
172     Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc();
173 
174     // Set up the parallel/reduction structure in right form.
175     using MapList = ArrayRef<ArrayRef<AffineExpr>>;
176     auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
177     AffineExpr m, n, k;
178     bindDims(rewriter.getContext(), m, n, k);
179     static constexpr std::array<int64_t, 2> perm = {1, 0};
180     auto iteratorTypes = op.iterator_types().getValue();
181     SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
182     if (!(isParallelIterator(iteratorTypes[0]) &&
183           isParallelIterator(iteratorTypes[1]) &&
184           isReductionIterator(iteratorTypes[2])))
185       return failure();
186     //
187     // Two outer parallel, one inner reduction (matmat flavor).
188     //
189     if (maps == infer({{m, k}, {k, n}, {m, n}})) {
190       // This is the classical row-major matmul, nothing to do.
191       return failure();
192     }
193     if (maps == infer({{m, k}, {n, k}, {m, n}})) {
194       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
195     } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
196       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
197     } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
198       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
199       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
200     } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
201       std::swap(rhs, lhs);
202       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
203       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
204     } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
205       std::swap(rhs, lhs);
206       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
207     } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
208       std::swap(lhs, rhs);
209       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
210     } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
211       std::swap(lhs, rhs);
212     } else {
213       return failure();
214     }
215     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
216         op, lhs, rhs, res,
217         rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})),
218         op.iterator_types());
219     return success();
220   }
221 };
222 
223 // Merge transpose op into the transfer read op. Transpose are not supported on
224 // MMA types but MMA load can transpose the matrix when loading.
225 struct CombineTransferReadOpTranspose final
226     : public OpRewritePattern<vector::TransposeOp> {
227   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
228 
229   LogicalResult matchAndRewrite(vector::TransposeOp op,
230                                 PatternRewriter &rewriter) const override {
231     auto transferReadOp = op.vector().getDefiningOp<vector::TransferReadOp>();
232     if (!transferReadOp)
233       return failure();
234     if (transferReadOp.mask() || transferReadOp.hasOutOfBoundsDim())
235       return failure();
236     SmallVector<int64_t, 2> perm;
237     op.getTransp(perm);
238     SmallVector<unsigned, 2> permU;
239     for (int64_t o : perm)
240       permU.push_back(unsigned(o));
241     AffineMap permutationMap =
242         AffineMap::getPermutationMap(permU, op.getContext());
243     AffineMap newMap = permutationMap.compose(transferReadOp.permutation_map());
244     rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
245         op, op.getType(), transferReadOp.source(), transferReadOp.indices(),
246         newMap, transferReadOp.padding(), transferReadOp.mask(),
247         transferReadOp.in_boundsAttr());
248     return success();
249   }
250 };
251 
252 } // namespace
253 
254 // MMA types have different layout based on how they are used in matmul ops.
255 // Figure the right layout to use by looking at op uses.
256 // TODO: Change the GPU dialect to abstract the layout at the this level and
257 // only care about it during lowering to NVVM.
258 template <typename OpTy>
259 static const char *inferFragType(OpTy op) {
260   for (Operation *users : op->getUsers()) {
261     auto contract = dyn_cast<vector::ContractionOp>(users);
262     if (!contract)
263       continue;
264     if (contract.lhs() == op.getResult())
265       return "AOp";
266     if (contract.rhs() == op.getResult())
267       return "BOp";
268   }
269   return "COp";
270 }
271 
272 static void convertTransferReadOp(vector::TransferReadOp op,
273                                   llvm::DenseMap<Value, Value> &valueMapping) {
274   assert(transferReadSupportsMMAMatrixType(op));
275   Optional<int64_t> stride =
276       getMemrefConstantHorizontalStride(op.getShapedType());
277   assert(stride);
278   const char *fragType = inferFragType(op);
279   gpu::MMAMatrixType type =
280       gpu::MMAMatrixType::get(op.getVectorType().getShape(),
281                               op.getVectorType().getElementType(), fragType);
282   OpBuilder b(op);
283   Value load = b.create<gpu::SubgroupMmaLoadMatrixOp>(
284       op.getLoc(), type, op.source(), op.indices(), b.getIndexAttr(*stride));
285   valueMapping[op.getResult()] = load;
286 }
287 
288 static void convertTransferWriteOp(vector::TransferWriteOp op,
289                                    llvm::DenseMap<Value, Value> &valueMapping) {
290   assert(transferWriteSupportsMMAMatrixType(op));
291   Optional<int64_t> stride =
292       getMemrefConstantHorizontalStride(op.getShapedType());
293   assert(stride);
294   OpBuilder b(op);
295   Value matrix = valueMapping.find(op.vector())->second;
296   b.create<gpu::SubgroupMmaStoreMatrixOp>(
297       op.getLoc(), matrix, op.source(), op.indices(), b.getIndexAttr(*stride));
298   op.erase();
299 }
300 
301 static void convertContractOp(vector::ContractionOp op,
302                               llvm::DenseMap<Value, Value> &valueMapping) {
303   OpBuilder b(op);
304   Value opA = valueMapping.find(op.lhs())->second;
305   Value opB = valueMapping.find(op.rhs())->second;
306   Value opC = valueMapping.find(op.acc())->second;
307   Value matmul = b.create<gpu::SubgroupMmaComputeOp>(op.getLoc(), opC.getType(),
308                                                      opA, opB, opC);
309   valueMapping[op.getResult()] = matmul;
310 }
311 
312 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
313 static void convertConstantOp(ConstantOp op,
314                               llvm::DenseMap<Value, Value> &valueMapping) {
315   assert(constantSupportsMMAMatrixType(op));
316   OpBuilder b(op);
317   Attribute splat = op.getValue().cast<SplatElementsAttr>().getSplatValue();
318   auto scalarConstant =
319       b.create<ConstantOp>(op.getLoc(), splat.getType(), splat);
320   const char *fragType = inferFragType(op);
321   auto vecType = op.getType().cast<VectorType>();
322   gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
323       vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
324   auto matrix = b.create<gpu::SubgroupMmaConstantMatrixOp>(op.getLoc(), type,
325                                                            scalarConstant);
326   valueMapping[op.getResult()] = matrix;
327 }
328 
329 namespace mlir {
330 
331 void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns) {
332   patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
333       patterns.getContext());
334 }
335 
336 void convertVectorToMMAOps(FuncOp funcOp) {
337   SetVector<Operation *> ops = getOpToConvert(funcOp);
338   llvm::DenseMap<Value, Value> valueMapping;
339   for (Operation *op : ops) {
340     if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
341       convertTransferReadOp(transferRead, valueMapping);
342     } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
343       convertTransferWriteOp(transferWrite, valueMapping);
344     } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
345       convertContractOp(contractOp, valueMapping);
346     } else if (auto constantOp = dyn_cast<ConstantOp>(op)) {
347       convertConstantOp(constantOp, valueMapping);
348     }
349   }
350 }
351 
352 } // namespace mlir
353 namespace {
354 
355 struct ConvertVectorToGPUPass
356     : public ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
357   void runOnFunction() override {
358     RewritePatternSet patterns(getFunction().getContext());
359     populatePrepareVectorToMMAPatterns(patterns);
360     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
361 
362     convertVectorToMMAOps(getFunction());
363   }
364 };
365 
366 } // namespace
367 
368 std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass() {
369   return std::make_unique<ConvertVectorToGPUPass>();
370 }
371