1 //===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector operations to GPU dialect ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <type_traits>
14 
15 #include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
16 
17 #include "../PassDetail.h"
18 #include "mlir/Analysis/SliceAnalysis.h"
19 #include "mlir/Dialect/GPU/GPUDialect.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/Dialect/SCF/SCF.h"
22 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
23 #include "mlir/Dialect/Vector/VectorOps.h"
24 #include "mlir/Dialect/Vector/VectorUtils.h"
25 #include "mlir/IR/Builders.h"
26 #include "mlir/Pass/Pass.h"
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28 #include "mlir/Transforms/Passes.h"
29 
30 using namespace mlir;
31 
32 // Return true if the contract op can be convert to MMA matmul.
33 static bool contractSupportsMMAMatrixType(vector::ContractionOp contract) {
34   if (llvm::size(contract.masks()) != 0)
35     return false;
36 
37   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
38   auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
39   AffineExpr m, n, k;
40   bindDims(contract.getContext(), m, n, k);
41   auto iteratorTypes = contract.iterator_types().getValue();
42   if (!(isParallelIterator(iteratorTypes[0]) &&
43         isParallelIterator(iteratorTypes[1]) &&
44         isReductionIterator(iteratorTypes[2])))
45     return false;
46 
47   // The contract needs to represent a matmul to be able to convert to
48   // MMAMatrix matmul.
49   if (contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}}))
50     return false;
51 
52   // Check that the size matches what is natively supported.
53   VectorType lhsType = contract.lhs().getType().cast<VectorType>();
54   VectorType rhsType = contract.rhs().getType().cast<VectorType>();
55   VectorType accType = contract.acc().getType().cast<VectorType>();
56 
57   std::tuple<int, int, int> dim(lhsType.getDimSize(0), rhsType.getDimSize(1),
58                                 lhsType.getDimSize(1));
59   if (lhsType.getElementType().isInteger(8) &&
60       rhsType.getElementType().isInteger(8) &&
61       accType.getElementType().isInteger(32) &&
62       (dim == std::make_tuple(8, 8, 32) || dim == std::make_tuple(16, 16, 32) ||
63        dim == std::make_tuple(16, 8, 32)))
64     return true;
65 
66   if (lhsType.getElementType().isF16() && rhsType.getElementType().isF16() &&
67       (accType.getElementType().isF16() || accType.getElementType().isF32()) &&
68       (dim == std::make_tuple(8, 8, 16) || dim == std::make_tuple(16, 16, 16) ||
69        dim == std::make_tuple(16, 8, 16)))
70     return true;
71   return false;
72 }
73 
74 // Return the stide for the dimension 0 of |type| if it is a memref and has a
75 // constant stride.
76 static llvm::Optional<int64_t>
77 getMemrefConstantHorizontalStride(ShapedType type) {
78   auto memrefType = type.dyn_cast<MemRefType>();
79   if (!memrefType)
80     return false;
81   int64_t offset = 0;
82   SmallVector<int64_t, 2> strides;
83   if (failed(getStridesAndOffset(memrefType, strides, offset)))
84     return llvm::None;
85   if (strides[0] == ShapedType::kDynamicStrideOrOffset)
86     return llvm::None;
87   return strides[0];
88 }
89 
90 // Return true if the transfer op can be converted to a MMA matrix load.
91 static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
92   if (readOp.mask() || readOp.hasOutOfBoundsDim() ||
93       readOp.getVectorType().getRank() != 2)
94     return false;
95   if (!getMemrefConstantHorizontalStride(readOp.getShapedType()))
96     return false;
97   // TODO: Support transpose once it is added to GPU dialect ops.
98   if (!readOp.permutation_map().isMinorIdentity())
99     return false;
100   return true;
101 }
102 
103 // Return true if the transfer op can be converted to a MMA matrix store.
104 static bool
105 transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
106   if (writeOp.mask() || writeOp.hasOutOfBoundsDim() ||
107       writeOp.getVectorType().getRank() != 2)
108     return false;
109   if (!getMemrefConstantHorizontalStride(writeOp.getShapedType()))
110     return false;
111   // TODO: Support transpose once it is added to GPU dialect ops.
112   if (!writeOp.permutation_map().isMinorIdentity())
113     return false;
114   return true;
115 }
116 
117 /// Return true if the constant is a splat to a 2D vector so that it can be
118 /// converted to a MMA constant matrix op.
119 static bool constantSupportsMMAMatrixType(ConstantOp constantOp) {
120   auto vecType = constantOp.getType().dyn_cast<VectorType>();
121   if (!vecType || vecType.getRank() != 2)
122     return false;
123   return constantOp.value().isa<SplatElementsAttr>();
124 }
125 
126 static bool supportsMMaMatrixType(Operation *op) {
127   if (isa<scf::ForOp, scf::YieldOp>(op))
128     return true;
129   if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
130     return transferReadSupportsMMAMatrixType(transferRead);
131   if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
132     return transferWriteSupportsMMAMatrixType(transferWrite);
133   if (auto contract = dyn_cast<vector::ContractionOp>(op))
134     return contractSupportsMMAMatrixType(contract);
135   if (auto constant = dyn_cast<ConstantOp>(op))
136     return constantSupportsMMAMatrixType(constant);
137   return false;
138 }
139 
140 // Analyze slice of operations based on convert op to figure out if the whole
141 // slice can be converted to MMA operations.
142 static SetVector<Operation *> getOpToConvert(mlir::Operation *op) {
143   auto hasVectorDest = [](Operation *op) {
144     return op->getNumResults() == 0 ||
145            llvm::any_of(op->getResultTypes(),
146                         [](Type t) { return t.isa<VectorType>(); });
147   };
148   SetVector<Operation *> opToConvert;
149   op->walk([&](vector::ContractionOp contract) {
150     if (opToConvert.contains(contract.getOperation()))
151       return;
152     SetVector<Operation *> dependentOps =
153         getSlice(contract, hasVectorDest, hasVectorDest);
154     // If any instruction cannot use MMA matrix type drop the whole
155     // chaine. MMA matrix are stored in an opaque type so they cannot be used
156     // by all operations.
157     if (llvm::any_of(dependentOps,
158                      [](Operation *op) { return !supportsMMaMatrixType(op); }))
159       return;
160     opToConvert.insert(dependentOps.begin(), dependentOps.end());
161   });
162   return opToConvert;
163 }
164 
165 namespace {
166 // Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted
167 // to MMA matmul.
168 struct PrepareContractToGPUMMA
169     : public OpRewritePattern<vector::ContractionOp> {
170   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
171 
172   LogicalResult matchAndRewrite(vector::ContractionOp op,
173                                 PatternRewriter &rewriter) const override {
174     Location loc = op.getLoc();
175     Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc();
176 
177     // Set up the parallel/reduction structure in right form.
178     using MapList = ArrayRef<ArrayRef<AffineExpr>>;
179     auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
180     AffineExpr m, n, k;
181     bindDims(rewriter.getContext(), m, n, k);
182     static constexpr std::array<int64_t, 2> perm = {1, 0};
183     auto iteratorTypes = op.iterator_types().getValue();
184     SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
185     if (!(isParallelIterator(iteratorTypes[0]) &&
186           isParallelIterator(iteratorTypes[1]) &&
187           isReductionIterator(iteratorTypes[2])))
188       return failure();
189     //
190     // Two outer parallel, one inner reduction (matmat flavor).
191     //
192     if (maps == infer({{m, k}, {k, n}, {m, n}})) {
193       // This is the classical row-major matmul, nothing to do.
194       return failure();
195     }
196     if (maps == infer({{m, k}, {n, k}, {m, n}})) {
197       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
198     } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
199       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
200     } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
201       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
202       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
203     } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
204       std::swap(rhs, lhs);
205       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
206       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
207     } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
208       std::swap(rhs, lhs);
209       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
210     } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
211       std::swap(lhs, rhs);
212       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
213     } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
214       std::swap(lhs, rhs);
215     } else {
216       return failure();
217     }
218     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
219         op, lhs, rhs, res,
220         rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})),
221         op.iterator_types());
222     return success();
223   }
224 };
225 
226 // Merge transpose op into the transfer read op. Transpose are not supported on
227 // MMA types but MMA load can transpose the matrix when loading.
228 struct CombineTransferReadOpTranspose final
229     : public OpRewritePattern<vector::TransposeOp> {
230   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
231 
232   LogicalResult matchAndRewrite(vector::TransposeOp op,
233                                 PatternRewriter &rewriter) const override {
234     auto transferReadOp = op.vector().getDefiningOp<vector::TransferReadOp>();
235     if (!transferReadOp)
236       return failure();
237     if (transferReadOp.mask() || transferReadOp.hasOutOfBoundsDim())
238       return failure();
239     SmallVector<int64_t, 2> perm;
240     op.getTransp(perm);
241     SmallVector<unsigned, 2> permU;
242     for (int64_t o : perm)
243       permU.push_back(unsigned(o));
244     AffineMap permutationMap =
245         AffineMap::getPermutationMap(permU, op.getContext());
246     AffineMap newMap = permutationMap.compose(transferReadOp.permutation_map());
247     rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
248         op, op.getType(), transferReadOp.source(), transferReadOp.indices(),
249         newMap, transferReadOp.padding(), transferReadOp.mask(),
250         transferReadOp.in_boundsAttr());
251     return success();
252   }
253 };
254 
255 } // namespace
256 
257 // MMA types have different layout based on how they are used in matmul ops.
258 // Figure the right layout to use by looking at op uses.
259 // TODO: Change the GPU dialect to abstract the layout at the this level and
260 // only care about it during lowering to NVVM.
261 template <typename OpTy>
262 static const char *inferFragType(OpTy op) {
263   for (Operation *users : op->getUsers()) {
264     auto contract = dyn_cast<vector::ContractionOp>(users);
265     if (!contract)
266       continue;
267     if (contract.lhs() == op.getResult())
268       return "AOp";
269     if (contract.rhs() == op.getResult())
270       return "BOp";
271   }
272   return "COp";
273 }
274 
275 static void convertTransferReadOp(vector::TransferReadOp op,
276                                   llvm::DenseMap<Value, Value> &valueMapping) {
277   assert(transferReadSupportsMMAMatrixType(op));
278   Optional<int64_t> stride =
279       getMemrefConstantHorizontalStride(op.getShapedType());
280   assert(stride);
281   const char *fragType = inferFragType(op);
282   gpu::MMAMatrixType type =
283       gpu::MMAMatrixType::get(op.getVectorType().getShape(),
284                               op.getVectorType().getElementType(), fragType);
285   OpBuilder b(op);
286   Value load = b.create<gpu::SubgroupMmaLoadMatrixOp>(
287       op.getLoc(), type, op.source(), op.indices(), b.getIndexAttr(*stride));
288   valueMapping[op.getResult()] = load;
289 }
290 
291 static void convertTransferWriteOp(vector::TransferWriteOp op,
292                                    llvm::DenseMap<Value, Value> &valueMapping) {
293   assert(transferWriteSupportsMMAMatrixType(op));
294   Optional<int64_t> stride =
295       getMemrefConstantHorizontalStride(op.getShapedType());
296   assert(stride);
297   OpBuilder b(op);
298   Value matrix = valueMapping.find(op.vector())->second;
299   b.create<gpu::SubgroupMmaStoreMatrixOp>(
300       op.getLoc(), matrix, op.source(), op.indices(), b.getIndexAttr(*stride));
301   op.erase();
302 }
303 
304 static void convertContractOp(vector::ContractionOp op,
305                               llvm::DenseMap<Value, Value> &valueMapping) {
306   OpBuilder b(op);
307   Value opA = valueMapping.find(op.lhs())->second;
308   Value opB = valueMapping.find(op.rhs())->second;
309   Value opC = valueMapping.find(op.acc())->second;
310   Value matmul = b.create<gpu::SubgroupMmaComputeOp>(op.getLoc(), opC.getType(),
311                                                      opA, opB, opC);
312   valueMapping[op.getResult()] = matmul;
313 }
314 
315 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
316 static void convertConstantOp(ConstantOp op,
317                               llvm::DenseMap<Value, Value> &valueMapping) {
318   assert(constantSupportsMMAMatrixType(op));
319   OpBuilder b(op);
320   Attribute splat = op.getValue().cast<SplatElementsAttr>().getSplatValue();
321   auto scalarConstant =
322       b.create<ConstantOp>(op.getLoc(), splat.getType(), splat);
323   const char *fragType = inferFragType(op);
324   auto vecType = op.getType().cast<VectorType>();
325   gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
326       vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
327   auto matrix = b.create<gpu::SubgroupMmaConstantMatrixOp>(op.getLoc(), type,
328                                                            scalarConstant);
329   valueMapping[op.getResult()] = matrix;
330 }
331 
332 // Replace ForOp with a new ForOp with extra operands. The YieldOp is not
333 // updated and needs to be updated separatly for the loop to be correct.
334 static scf::ForOp replaceForOpWithNewSignature(OpBuilder &b, scf::ForOp loop,
335                                                ValueRange newIterOperands) {
336   // Create a new loop before the existing one, with the extra operands.
337   OpBuilder::InsertionGuard g(b);
338   b.setInsertionPoint(loop);
339   auto operands = llvm::to_vector<4>(loop.getIterOperands());
340   operands.append(newIterOperands.begin(), newIterOperands.end());
341   scf::ForOp newLoop =
342       b.create<scf::ForOp>(loop.getLoc(), loop.lowerBound(), loop.upperBound(),
343                            loop.step(), operands);
344   newLoop.getBody()->erase();
345   newLoop.getLoopBody().getBlocks().splice(
346       newLoop.getLoopBody().getBlocks().begin(),
347       loop.getLoopBody().getBlocks());
348   for (auto operand : newIterOperands)
349     newLoop.getBody()->addArgument(operand.getType());
350 
351   for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
352                                                   loop.getNumResults())))
353     std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
354   loop.erase();
355   return newLoop;
356 }
357 
358 static void convertForOp(scf::ForOp op,
359                          llvm::DenseMap<Value, Value> &valueMapping) {
360   SmallVector<Value> newOperands;
361   SmallVector<std::pair<size_t, size_t>> argMapping;
362   for (auto operand : llvm::enumerate(op.getIterOperands())) {
363     auto it = valueMapping.find(operand.value());
364     if (it == valueMapping.end())
365       continue;
366     argMapping.push_back(std::make_pair(
367         operand.index(), op.getNumIterOperands() + newOperands.size()));
368     newOperands.push_back(it->second);
369   }
370   OpBuilder b(op);
371   scf::ForOp newForOp = replaceForOpWithNewSignature(b, op, newOperands);
372   Block &loopBody = *newForOp.getBody();
373   for (auto mapping : argMapping) {
374     valueMapping[newForOp.getResult(mapping.first)] =
375         newForOp.getResult(mapping.second);
376     valueMapping[loopBody.getArgument(mapping.first +
377                                       newForOp.getNumInductionVars())] =
378         loopBody.getArgument(mapping.second + newForOp.getNumInductionVars());
379   }
380 }
381 
382 static void convertYieldOp(scf::YieldOp op,
383                            llvm::DenseMap<Value, Value> &valueMapping) {
384   OpBuilder b(op);
385   auto loop = cast<scf::ForOp>(op->getParentOp());
386   auto yieldOperands = llvm::to_vector<4>(op.getOperands());
387   for (auto operand : llvm::enumerate(op.getOperands())) {
388     auto it = valueMapping.find(operand.value());
389     if (it == valueMapping.end())
390       continue;
391     // Replace the yield of old value with the for op argument to make it easier
392     // to remove the dead code.
393     yieldOperands[operand.index()] = loop.getIterOperands()[operand.index()];
394     yieldOperands.push_back(it->second);
395   }
396   b.create<scf::YieldOp>(op.getLoc(), yieldOperands);
397   op.erase();
398 }
399 
400 namespace mlir {
401 
402 void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns) {
403   patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
404       patterns.getContext());
405 }
406 
407 void convertVectorToMMAOps(FuncOp funcOp) {
408   SetVector<Operation *> ops = getOpToConvert(funcOp);
409   llvm::DenseMap<Value, Value> valueMapping;
410   for (Operation *op : ops) {
411     if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
412       convertTransferReadOp(transferRead, valueMapping);
413     } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
414       convertTransferWriteOp(transferWrite, valueMapping);
415     } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
416       convertContractOp(contractOp, valueMapping);
417     } else if (auto constantOp = dyn_cast<ConstantOp>(op)) {
418       convertConstantOp(constantOp, valueMapping);
419     } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
420       convertForOp(forOp, valueMapping);
421     } else if (auto yiledOp = dyn_cast<scf::YieldOp>(op)) {
422       convertYieldOp(yiledOp, valueMapping);
423     }
424   }
425 }
426 
427 } // namespace mlir
428 namespace {
429 
430 struct ConvertVectorToGPUPass
431     : public ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
432   void runOnFunction() override {
433     RewritePatternSet patterns(getFunction().getContext());
434     populatePrepareVectorToMMAPatterns(patterns);
435     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
436 
437     convertVectorToMMAOps(getFunction());
438   }
439 };
440 
441 } // namespace
442 
443 std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass() {
444   return std::make_unique<ConvertVectorToGPUPass>();
445 }
446