1 //===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector operations to GPU dialect ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <type_traits>
14 
15 #include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
16 
17 #include "../PassDetail.h"
18 #include "mlir/Analysis/SliceAnalysis.h"
19 #include "mlir/Dialect/GPU/GPUDialect.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
22 #include "mlir/Dialect/Vector/VectorOps.h"
23 #include "mlir/Dialect/Vector/VectorUtils.h"
24 #include "mlir/IR/Builders.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 #include "mlir/Transforms/Passes.h"
28 
29 using namespace mlir;
30 
31 // Return true if the contract op can be convert to MMA matmul.
32 static bool contractSupportsMMAMatrixType(vector::ContractionOp contract) {
33   if (llvm::size(contract.masks()) != 0)
34     return false;
35 
36   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
37   auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
38   AffineExpr m, n, k;
39   bindDims(contract.getContext(), m, n, k);
40   auto iteratorTypes = contract.iterator_types().getValue();
41   if (!(isParallelIterator(iteratorTypes[0]) &&
42         isParallelIterator(iteratorTypes[1]) &&
43         isReductionIterator(iteratorTypes[2])))
44     return false;
45 
46   // The contract needs to represent a matmul to be able to convert to
47   // MMAMatrix matmul.
48   if (contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}}))
49     return false;
50 
51   // Check that the size matches what is natively supported.
52   VectorType lhsType = contract.lhs().getType().cast<VectorType>();
53   VectorType rhsType = contract.rhs().getType().cast<VectorType>();
54   VectorType accType = contract.acc().getType().cast<VectorType>();
55 
56   std::tuple<int, int, int> dim(lhsType.getDimSize(0), rhsType.getDimSize(1),
57                                 lhsType.getDimSize(1));
58   if (lhsType.getElementType().isInteger(8) &&
59       rhsType.getElementType().isInteger(8) &&
60       accType.getElementType().isInteger(32) &&
61       (dim == std::make_tuple(8, 8, 32) || dim == std::make_tuple(16, 16, 32) ||
62        dim == std::make_tuple(16, 8, 32)))
63     return true;
64 
65   if (lhsType.getElementType().isF16() && rhsType.getElementType().isF16() &&
66       (accType.getElementType().isF16() || accType.getElementType().isF32()) &&
67       (dim == std::make_tuple(8, 8, 16) || dim == std::make_tuple(16, 16, 16) ||
68        dim == std::make_tuple(16, 8, 16)))
69     return true;
70   return false;
71 }
72 
73 // Return the stide for the dimension 0 of |type| if it is a memref and has a
74 // constant stride.
75 static llvm::Optional<int64_t>
76 getMemrefConstantHorizontalStride(ShapedType type) {
77   auto memrefType = type.dyn_cast<MemRefType>();
78   if (!memrefType)
79     return false;
80   int64_t offset = 0;
81   SmallVector<int64_t, 2> strides;
82   if (failed(getStridesAndOffset(memrefType, strides, offset)))
83     return llvm::None;
84   if (strides[0] == ShapedType::kDynamicStrideOrOffset)
85     return llvm::None;
86   return strides[0];
87 }
88 
89 // Return true if the transfer op can be converted to a MMA matrix load.
90 static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
91   if (readOp.mask() || readOp.hasOutOfBoundsDim() ||
92       readOp.getVectorType().getRank() != 2)
93     return false;
94   if (!getMemrefConstantHorizontalStride(readOp.getShapedType()))
95     return false;
96   // TODO: Support transpose once it is added to GPU dialect ops.
97   if (!readOp.permutation_map().isMinorIdentity())
98     return false;
99   return true;
100 }
101 
102 // Return true if the transfer op can be converted to a MMA matrix store.
103 static bool
104 transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
105   if (writeOp.mask() || writeOp.hasOutOfBoundsDim() ||
106       writeOp.getVectorType().getRank() != 2)
107     return false;
108   if (!getMemrefConstantHorizontalStride(writeOp.getShapedType()))
109     return false;
110   // TODO: Support transpose once it is added to GPU dialect ops.
111   if (!writeOp.permutation_map().isMinorIdentity())
112     return false;
113   return true;
114 }
115 
116 static bool supportsMMaMatrixType(Operation *op) {
117   if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
118     return transferReadSupportsMMAMatrixType(transferRead);
119   if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
120     return transferWriteSupportsMMAMatrixType(transferWrite);
121   if (auto contract = dyn_cast<vector::ContractionOp>(op))
122     return contractSupportsMMAMatrixType(contract);
123   return false;
124 }
125 
126 // Analyze slice of operations based on convert op to figure out if the whole
127 // slice can be converted to MMA operations.
128 static SetVector<Operation *> getOpToConvert(mlir::Operation *op) {
129   auto hasVectorDest = [](Operation *op) {
130     return op->getNumResults() == 0 ||
131            llvm::any_of(op->getResultTypes(),
132                         [](Type t) { return t.isa<VectorType>(); });
133   };
134   SetVector<Operation *> opToConvert;
135   op->walk([&](vector::ContractionOp contract) {
136     if (opToConvert.contains(contract.getOperation()))
137       return;
138     SetVector<Operation *> dependentOps =
139         getSlice(contract, hasVectorDest, hasVectorDest);
140     // If any instruction cannot use MMA matrix type drop the whole
141     // chaine. MMA matrix are stored in an opaque type so they cannot be used
142     // by all operations.
143     if (llvm::any_of(dependentOps,
144                      [](Operation *op) { return !supportsMMaMatrixType(op); }))
145       return;
146     opToConvert.insert(dependentOps.begin(), dependentOps.end());
147   });
148   return opToConvert;
149 }
150 
151 namespace {
152 // Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted
153 // to MMA matmul.
154 struct PrepareContractToGPUMMA
155     : public OpRewritePattern<vector::ContractionOp> {
156   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
157 
158   LogicalResult matchAndRewrite(vector::ContractionOp op,
159                                 PatternRewriter &rewriter) const override {
160     Location loc = op.getLoc();
161     Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc();
162 
163     // Set up the parallel/reduction structure in right form.
164     using MapList = ArrayRef<ArrayRef<AffineExpr>>;
165     auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
166     AffineExpr m, n, k;
167     bindDims(rewriter.getContext(), m, n, k);
168     static constexpr std::array<int64_t, 2> perm = {1, 0};
169     auto iteratorTypes = op.iterator_types().getValue();
170     SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
171     if (!(isParallelIterator(iteratorTypes[0]) &&
172           isParallelIterator(iteratorTypes[1]) &&
173           isReductionIterator(iteratorTypes[2])))
174       return failure();
175     //
176     // Two outer parallel, one inner reduction (matmat flavor).
177     //
178     if (maps == infer({{m, k}, {k, n}, {m, n}})) {
179       // This is the classical row-major matmul, nothing to do.
180       return failure();
181     }
182     if (maps == infer({{m, k}, {n, k}, {m, n}})) {
183       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
184     } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
185       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
186     } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
187       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
188       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
189     } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
190       std::swap(rhs, lhs);
191       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
192       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
193     } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
194       std::swap(rhs, lhs);
195       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
196     } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
197       std::swap(lhs, rhs);
198       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
199     } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
200       std::swap(lhs, rhs);
201     } else {
202       return failure();
203     }
204     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
205         op, lhs, rhs, res,
206         rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})),
207         op.iterator_types());
208     return success();
209   }
210 };
211 
212 // Merge transpose op into the transfer read op. Transpose are not supported on
213 // MMA types but MMA load can transpose the matrix when loading.
214 struct CombineTransferReadOpTranspose final
215     : public OpRewritePattern<vector::TransposeOp> {
216   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
217 
218   LogicalResult matchAndRewrite(vector::TransposeOp op,
219                                 PatternRewriter &rewriter) const override {
220     auto transferReadOp = op.vector().getDefiningOp<vector::TransferReadOp>();
221     if (!transferReadOp)
222       return failure();
223     if (transferReadOp.mask() || transferReadOp.hasOutOfBoundsDim())
224       return failure();
225     SmallVector<int64_t, 2> perm;
226     op.getTransp(perm);
227     SmallVector<unsigned, 2> permU;
228     for (int64_t o : perm)
229       permU.push_back(unsigned(o));
230     AffineMap permutationMap =
231         AffineMap::getPermutationMap(permU, op.getContext());
232     AffineMap newMap = permutationMap.compose(transferReadOp.permutation_map());
233     rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
234         op, op.getType(), transferReadOp.source(), transferReadOp.indices(),
235         newMap, transferReadOp.padding(), transferReadOp.mask(),
236         transferReadOp.in_boundsAttr());
237     return success();
238   }
239 };
240 
241 } // namespace
242 
243 // MMA types have different layout based on how they are used in matmul ops.
244 // Figure the right layout to use by looking at Transfer op uses.
245 // TODO: Change the GPU dialect to abstract the layout at the this level and
246 // only care about it during lowering to NVVM.
247 static const char *inferFragType(vector::TransferReadOp op) {
248   for (Operation *users : op->getUsers()) {
249     auto contract = dyn_cast<vector::ContractionOp>(users);
250     if (!contract)
251       continue;
252     if (contract.lhs() == op.getResult())
253       return "AOp";
254     if (contract.rhs() == op.getResult())
255       return "BOp";
256   }
257   return "COp";
258 }
259 
260 static void convertTransferReadOp(vector::TransferReadOp op,
261                                   llvm::DenseMap<Value, Value> &valueMapping) {
262   assert(transferReadSupportsMMAMatrixType(op));
263   Optional<int64_t> stride =
264       getMemrefConstantHorizontalStride(op.getShapedType());
265   assert(stride);
266   const char *fragType = inferFragType(op);
267   gpu::MMAMatrixType type =
268       gpu::MMAMatrixType::get(op.getVectorType().getShape(),
269                               op.getVectorType().getElementType(), fragType);
270   OpBuilder b(op);
271   Value load = b.create<gpu::SubgroupMmaLoadMatrixOp>(
272       op.getLoc(), type, op.source(), op.indices(), b.getIndexAttr(*stride));
273   valueMapping[op.getResult()] = load;
274 }
275 
276 static void convertTransferWriteOp(vector::TransferWriteOp op,
277                                    llvm::DenseMap<Value, Value> &valueMapping) {
278   assert(transferWriteSupportsMMAMatrixType(op));
279   Optional<int64_t> stride =
280       getMemrefConstantHorizontalStride(op.getShapedType());
281   assert(stride);
282   OpBuilder b(op);
283   Value matrix = valueMapping.find(op.vector())->second;
284   b.create<gpu::SubgroupMmaStoreMatrixOp>(
285       op.getLoc(), matrix, op.source(), op.indices(), b.getIndexAttr(*stride));
286   op.erase();
287 }
288 
289 static void convertContractOp(vector::ContractionOp op,
290                               llvm::DenseMap<Value, Value> &valueMapping) {
291   OpBuilder b(op);
292   Value opA = valueMapping.find(op.lhs())->second;
293   Value opB = valueMapping.find(op.rhs())->second;
294   Value opC = valueMapping.find(op.acc())->second;
295   Value matmul = b.create<gpu::SubgroupMmaComputeOp>(op.getLoc(), opC.getType(),
296                                                      opA, opB, opC);
297   valueMapping[op.getResult()] = matmul;
298 }
299 
300 namespace mlir {
301 
302 void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns) {
303   patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
304       patterns.getContext());
305 }
306 
307 void convertVectorToMMAOps(FuncOp funcOp) {
308   SetVector<Operation *> ops = getOpToConvert(funcOp);
309   llvm::DenseMap<Value, Value> valueMapping;
310   for (Operation *op : ops) {
311     if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
312       convertTransferReadOp(transferRead, valueMapping);
313     } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
314       convertTransferWriteOp(transferWrite, valueMapping);
315     } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
316       convertContractOp(contractOp, valueMapping);
317     }
318   }
319 }
320 
321 } // namespace mlir
322 namespace {
323 
324 struct ConvertVectorToGPUPass
325     : public ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
326   void runOnFunction() override {
327     RewritePatternSet patterns(getFunction().getContext());
328     populatePrepareVectorToMMAPatterns(patterns);
329     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
330 
331     convertVectorToMMAOps(getFunction());
332   }
333 };
334 
335 } // namespace
336 
337 std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass() {
338   return std::make_unique<ConvertVectorToGPUPass>();
339 }
340