1*edd9515bSthomasraoux //===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===//
2*edd9515bSthomasraoux //
3*edd9515bSthomasraoux // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*edd9515bSthomasraoux // See https://llvm.org/LICENSE.txt for license information.
5*edd9515bSthomasraoux // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*edd9515bSthomasraoux //
7*edd9515bSthomasraoux //===----------------------------------------------------------------------===//
8*edd9515bSthomasraoux //
9*edd9515bSthomasraoux // This file implements lowering of vector operations to GPU dialect ops.
10*edd9515bSthomasraoux //
11*edd9515bSthomasraoux //===----------------------------------------------------------------------===//
12*edd9515bSthomasraoux 
13*edd9515bSthomasraoux #include <type_traits>
14*edd9515bSthomasraoux 
15*edd9515bSthomasraoux #include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
16*edd9515bSthomasraoux 
17*edd9515bSthomasraoux #include "../PassDetail.h"
18*edd9515bSthomasraoux #include "mlir/Analysis/SliceAnalysis.h"
19*edd9515bSthomasraoux #include "mlir/Dialect/GPU/GPUDialect.h"
20*edd9515bSthomasraoux #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
21*edd9515bSthomasraoux #include "mlir/Dialect/Vector/VectorOps.h"
22*edd9515bSthomasraoux #include "mlir/Dialect/Vector/VectorUtils.h"
23*edd9515bSthomasraoux #include "mlir/IR/Builders.h"
24*edd9515bSthomasraoux #include "mlir/Pass/Pass.h"
25*edd9515bSthomasraoux #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26*edd9515bSthomasraoux #include "mlir/Transforms/Passes.h"
27*edd9515bSthomasraoux 
28*edd9515bSthomasraoux using namespace mlir;
29*edd9515bSthomasraoux 
30*edd9515bSthomasraoux // Return true if the contract op can be convert to MMA matmul.
31*edd9515bSthomasraoux static bool contractSupportsMMAMatrixType(vector::ContractionOp contract) {
32*edd9515bSthomasraoux   if (llvm::size(contract.masks()) != 0)
33*edd9515bSthomasraoux     return false;
34*edd9515bSthomasraoux 
35*edd9515bSthomasraoux   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
36*edd9515bSthomasraoux   auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
37*edd9515bSthomasraoux   AffineExpr m, n, k;
38*edd9515bSthomasraoux   bindDims(contract.getContext(), m, n, k);
39*edd9515bSthomasraoux   auto iteratorTypes = contract.iterator_types().getValue();
40*edd9515bSthomasraoux   if (!(isParallelIterator(iteratorTypes[0]) &&
41*edd9515bSthomasraoux         isParallelIterator(iteratorTypes[1]) &&
42*edd9515bSthomasraoux         isReductionIterator(iteratorTypes[2])))
43*edd9515bSthomasraoux     return false;
44*edd9515bSthomasraoux 
45*edd9515bSthomasraoux   // The contract needs to represent a matmul to be able to convert to
46*edd9515bSthomasraoux   // MMAMatrix matmul.
47*edd9515bSthomasraoux   if (contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}}))
48*edd9515bSthomasraoux     return false;
49*edd9515bSthomasraoux 
50*edd9515bSthomasraoux   // Check that the size matches what is natively supported.
51*edd9515bSthomasraoux   VectorType lhsType = contract.lhs().getType().cast<VectorType>();
52*edd9515bSthomasraoux   VectorType rhsType = contract.rhs().getType().cast<VectorType>();
53*edd9515bSthomasraoux   VectorType accType = contract.acc().getType().cast<VectorType>();
54*edd9515bSthomasraoux 
55*edd9515bSthomasraoux   std::tuple<int, int, int> dim(lhsType.getDimSize(0), rhsType.getDimSize(1),
56*edd9515bSthomasraoux                                 lhsType.getDimSize(1));
57*edd9515bSthomasraoux   if (lhsType.getElementType().isInteger(8) &&
58*edd9515bSthomasraoux       rhsType.getElementType().isInteger(8) &&
59*edd9515bSthomasraoux       accType.getElementType().isInteger(32) &&
60*edd9515bSthomasraoux       (dim == std::make_tuple(8, 8, 32) || dim == std::make_tuple(16, 16, 32) ||
61*edd9515bSthomasraoux        dim == std::make_tuple(16, 8, 32)))
62*edd9515bSthomasraoux     return true;
63*edd9515bSthomasraoux 
64*edd9515bSthomasraoux   if (lhsType.getElementType().isF16() && rhsType.getElementType().isF16() &&
65*edd9515bSthomasraoux       (accType.getElementType().isF16() || accType.getElementType().isF32()) &&
66*edd9515bSthomasraoux       (dim == std::make_tuple(8, 8, 16) || dim == std::make_tuple(16, 16, 16) ||
67*edd9515bSthomasraoux        dim == std::make_tuple(16, 8, 16)))
68*edd9515bSthomasraoux     return true;
69*edd9515bSthomasraoux   return false;
70*edd9515bSthomasraoux }
71*edd9515bSthomasraoux 
72*edd9515bSthomasraoux // Return the stide for the dimension 0 of |type| if it is a memref and has a
73*edd9515bSthomasraoux // constant stride.
74*edd9515bSthomasraoux static llvm::Optional<int64_t>
75*edd9515bSthomasraoux getMemrefConstantHorizontalStride(ShapedType type) {
76*edd9515bSthomasraoux   auto memrefType = type.dyn_cast<MemRefType>();
77*edd9515bSthomasraoux   if (!memrefType)
78*edd9515bSthomasraoux     return false;
79*edd9515bSthomasraoux   int64_t offset = 0;
80*edd9515bSthomasraoux   SmallVector<int64_t, 2> strides;
81*edd9515bSthomasraoux   if (failed(getStridesAndOffset(memrefType, strides, offset)))
82*edd9515bSthomasraoux     return llvm::None;
83*edd9515bSthomasraoux   if (strides[0] == ShapedType::kDynamicStrideOrOffset)
84*edd9515bSthomasraoux     return llvm::None;
85*edd9515bSthomasraoux   return strides[0];
86*edd9515bSthomasraoux }
87*edd9515bSthomasraoux 
88*edd9515bSthomasraoux // Return true if the transfer op can be converted to a MMA matrix load.
89*edd9515bSthomasraoux static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
90*edd9515bSthomasraoux   if (readOp.mask() || readOp.hasOutOfBoundsDim() ||
91*edd9515bSthomasraoux       readOp.getVectorType().getRank() != 2)
92*edd9515bSthomasraoux     return false;
93*edd9515bSthomasraoux   if (!getMemrefConstantHorizontalStride(readOp.getShapedType()))
94*edd9515bSthomasraoux     return false;
95*edd9515bSthomasraoux   // TODO: Support transpose once it is added to GPU dialect ops.
96*edd9515bSthomasraoux   if (!readOp.permutation_map().isMinorIdentity())
97*edd9515bSthomasraoux     return false;
98*edd9515bSthomasraoux   return true;
99*edd9515bSthomasraoux }
100*edd9515bSthomasraoux 
101*edd9515bSthomasraoux // Return true if the transfer op can be converted to a MMA matrix store.
102*edd9515bSthomasraoux static bool
103*edd9515bSthomasraoux transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
104*edd9515bSthomasraoux   if (writeOp.mask() || writeOp.hasOutOfBoundsDim() ||
105*edd9515bSthomasraoux       writeOp.getVectorType().getRank() != 2)
106*edd9515bSthomasraoux     return false;
107*edd9515bSthomasraoux   if (!getMemrefConstantHorizontalStride(writeOp.getShapedType()))
108*edd9515bSthomasraoux     return false;
109*edd9515bSthomasraoux   // TODO: Support transpose once it is added to GPU dialect ops.
110*edd9515bSthomasraoux   if (!writeOp.permutation_map().isMinorIdentity())
111*edd9515bSthomasraoux     return false;
112*edd9515bSthomasraoux   return true;
113*edd9515bSthomasraoux }
114*edd9515bSthomasraoux 
115*edd9515bSthomasraoux static bool supportsMMaMatrixType(Operation *op) {
116*edd9515bSthomasraoux   if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
117*edd9515bSthomasraoux     return transferReadSupportsMMAMatrixType(transferRead);
118*edd9515bSthomasraoux   if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
119*edd9515bSthomasraoux     return transferWriteSupportsMMAMatrixType(transferWrite);
120*edd9515bSthomasraoux   if (auto contract = dyn_cast<vector::ContractionOp>(op))
121*edd9515bSthomasraoux     return contractSupportsMMAMatrixType(contract);
122*edd9515bSthomasraoux   return false;
123*edd9515bSthomasraoux }
124*edd9515bSthomasraoux 
125*edd9515bSthomasraoux // Analyze slice of operations based on convert op to figure out if the whole
126*edd9515bSthomasraoux // slice can be converted to MMA operations.
127*edd9515bSthomasraoux static SetVector<Operation *> getOpToConvert(mlir::Operation *op) {
128*edd9515bSthomasraoux   auto hasVectorDest = [](Operation *op) {
129*edd9515bSthomasraoux     return op->getNumResults() == 0 ||
130*edd9515bSthomasraoux            llvm::any_of(op->getResultTypes(),
131*edd9515bSthomasraoux                         [](Type t) { return t.isa<VectorType>(); });
132*edd9515bSthomasraoux   };
133*edd9515bSthomasraoux   SetVector<Operation *> opToConvert;
134*edd9515bSthomasraoux   op->walk([&](vector::ContractionOp contract) {
135*edd9515bSthomasraoux     if (opToConvert.contains(contract.getOperation()))
136*edd9515bSthomasraoux       return;
137*edd9515bSthomasraoux     SetVector<Operation *> dependentOps =
138*edd9515bSthomasraoux         getSlice(contract, hasVectorDest, hasVectorDest);
139*edd9515bSthomasraoux     // If any instruction cannot use MMA matrix type drop the whole
140*edd9515bSthomasraoux     // chaine. MMA matrix are stored in an opaque type so they cannot be used
141*edd9515bSthomasraoux     // by all operations.
142*edd9515bSthomasraoux     if (llvm::any_of(dependentOps,
143*edd9515bSthomasraoux                      [](Operation *op) { return !supportsMMaMatrixType(op); }))
144*edd9515bSthomasraoux       return;
145*edd9515bSthomasraoux     opToConvert.insert(dependentOps.begin(), dependentOps.end());
146*edd9515bSthomasraoux   });
147*edd9515bSthomasraoux   return opToConvert;
148*edd9515bSthomasraoux }
149*edd9515bSthomasraoux 
150*edd9515bSthomasraoux namespace {
151*edd9515bSthomasraoux // Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted
152*edd9515bSthomasraoux // to MMA matmul.
153*edd9515bSthomasraoux struct PrepareContractToGPUMMA
154*edd9515bSthomasraoux     : public OpRewritePattern<vector::ContractionOp> {
155*edd9515bSthomasraoux   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
156*edd9515bSthomasraoux 
157*edd9515bSthomasraoux   LogicalResult matchAndRewrite(vector::ContractionOp op,
158*edd9515bSthomasraoux                                 PatternRewriter &rewriter) const override {
159*edd9515bSthomasraoux     Location loc = op.getLoc();
160*edd9515bSthomasraoux     Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc();
161*edd9515bSthomasraoux 
162*edd9515bSthomasraoux     // Set up the parallel/reduction structure in right form.
163*edd9515bSthomasraoux     using MapList = ArrayRef<ArrayRef<AffineExpr>>;
164*edd9515bSthomasraoux     auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
165*edd9515bSthomasraoux     AffineExpr m, n, k;
166*edd9515bSthomasraoux     bindDims(rewriter.getContext(), m, n, k);
167*edd9515bSthomasraoux     static constexpr std::array<int64_t, 2> perm = {1, 0};
168*edd9515bSthomasraoux     auto iteratorTypes = op.iterator_types().getValue();
169*edd9515bSthomasraoux     SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
170*edd9515bSthomasraoux     if (!(isParallelIterator(iteratorTypes[0]) &&
171*edd9515bSthomasraoux           isParallelIterator(iteratorTypes[1]) &&
172*edd9515bSthomasraoux           isReductionIterator(iteratorTypes[2])))
173*edd9515bSthomasraoux       return failure();
174*edd9515bSthomasraoux     //
175*edd9515bSthomasraoux     // Two outer parallel, one inner reduction (matmat flavor).
176*edd9515bSthomasraoux     //
177*edd9515bSthomasraoux     if (maps == infer({{m, k}, {k, n}, {m, n}})) {
178*edd9515bSthomasraoux       // This is the classical row-major matmul, nothing to do.
179*edd9515bSthomasraoux       return failure();
180*edd9515bSthomasraoux     }
181*edd9515bSthomasraoux     if (maps == infer({{m, k}, {n, k}, {m, n}})) {
182*edd9515bSthomasraoux       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
183*edd9515bSthomasraoux     } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
184*edd9515bSthomasraoux       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
185*edd9515bSthomasraoux     } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
186*edd9515bSthomasraoux       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
187*edd9515bSthomasraoux       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
188*edd9515bSthomasraoux     } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
189*edd9515bSthomasraoux       std::swap(rhs, lhs);
190*edd9515bSthomasraoux       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
191*edd9515bSthomasraoux       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
192*edd9515bSthomasraoux     } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
193*edd9515bSthomasraoux       std::swap(rhs, lhs);
194*edd9515bSthomasraoux       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
195*edd9515bSthomasraoux     } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
196*edd9515bSthomasraoux       std::swap(lhs, rhs);
197*edd9515bSthomasraoux       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
198*edd9515bSthomasraoux     } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
199*edd9515bSthomasraoux       std::swap(lhs, rhs);
200*edd9515bSthomasraoux     } else {
201*edd9515bSthomasraoux       return failure();
202*edd9515bSthomasraoux     }
203*edd9515bSthomasraoux     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
204*edd9515bSthomasraoux         op, lhs, rhs, res,
205*edd9515bSthomasraoux         rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})),
206*edd9515bSthomasraoux         op.iterator_types());
207*edd9515bSthomasraoux     return success();
208*edd9515bSthomasraoux   }
209*edd9515bSthomasraoux };
210*edd9515bSthomasraoux 
211*edd9515bSthomasraoux // Merge transpose op into the transfer read op. Transpose are not supported on
212*edd9515bSthomasraoux // MMA types but MMA load can transpose the matrix when loading.
213*edd9515bSthomasraoux struct CombineTransferReadOpTranspose final
214*edd9515bSthomasraoux     : public OpRewritePattern<vector::TransposeOp> {
215*edd9515bSthomasraoux   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
216*edd9515bSthomasraoux 
217*edd9515bSthomasraoux   LogicalResult matchAndRewrite(vector::TransposeOp op,
218*edd9515bSthomasraoux                                 PatternRewriter &rewriter) const override {
219*edd9515bSthomasraoux     auto transferReadOp = op.vector().getDefiningOp<vector::TransferReadOp>();
220*edd9515bSthomasraoux     if (!transferReadOp)
221*edd9515bSthomasraoux       return failure();
222*edd9515bSthomasraoux     if (transferReadOp.mask() || transferReadOp.hasOutOfBoundsDim())
223*edd9515bSthomasraoux       return failure();
224*edd9515bSthomasraoux     SmallVector<int64_t, 2> perm;
225*edd9515bSthomasraoux     op.getTransp(perm);
226*edd9515bSthomasraoux     SmallVector<unsigned, 2> permU;
227*edd9515bSthomasraoux     for (int64_t o : perm)
228*edd9515bSthomasraoux       permU.push_back(unsigned(o));
229*edd9515bSthomasraoux     AffineMap permutationMap =
230*edd9515bSthomasraoux         AffineMap::getPermutationMap(permU, op.getContext());
231*edd9515bSthomasraoux     AffineMap newMap = permutationMap.compose(transferReadOp.permutation_map());
232*edd9515bSthomasraoux     rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
233*edd9515bSthomasraoux         op, op.getType(), transferReadOp.source(), transferReadOp.indices(),
234*edd9515bSthomasraoux         newMap, transferReadOp.padding(), transferReadOp.mask(),
235*edd9515bSthomasraoux         transferReadOp.in_boundsAttr());
236*edd9515bSthomasraoux     return success();
237*edd9515bSthomasraoux   }
238*edd9515bSthomasraoux };
239*edd9515bSthomasraoux 
240*edd9515bSthomasraoux } // namespace
241*edd9515bSthomasraoux 
242*edd9515bSthomasraoux // MMA types have different layout based on how they are used in matmul ops.
243*edd9515bSthomasraoux // Figure the right layout to use by looking at Transfer op uses.
244*edd9515bSthomasraoux // TODO: Change the GPU dialect to abstract the layout at the this level and
245*edd9515bSthomasraoux // only care about it during lowering to NVVM.
246*edd9515bSthomasraoux static const char *inferFragType(vector::TransferReadOp op) {
247*edd9515bSthomasraoux   for (Operation *users : op->getUsers()) {
248*edd9515bSthomasraoux     auto contract = dyn_cast<vector::ContractionOp>(users);
249*edd9515bSthomasraoux     if (!contract)
250*edd9515bSthomasraoux       continue;
251*edd9515bSthomasraoux     if (contract.lhs() == op.getResult())
252*edd9515bSthomasraoux       return "AOp";
253*edd9515bSthomasraoux     if (contract.rhs() == op.getResult())
254*edd9515bSthomasraoux       return "BOp";
255*edd9515bSthomasraoux   }
256*edd9515bSthomasraoux   return "COp";
257*edd9515bSthomasraoux }
258*edd9515bSthomasraoux 
259*edd9515bSthomasraoux static void convertTransferReadOp(vector::TransferReadOp op,
260*edd9515bSthomasraoux                                   llvm::DenseMap<Value, Value> &valueMapping) {
261*edd9515bSthomasraoux   assert(transferReadSupportsMMAMatrixType(op));
262*edd9515bSthomasraoux   Optional<int64_t> stride =
263*edd9515bSthomasraoux       getMemrefConstantHorizontalStride(op.getShapedType());
264*edd9515bSthomasraoux   assert(stride);
265*edd9515bSthomasraoux   const char *fragType = inferFragType(op);
266*edd9515bSthomasraoux   gpu::MMAMatrixType type =
267*edd9515bSthomasraoux       gpu::MMAMatrixType::get(op.getVectorType().getShape(),
268*edd9515bSthomasraoux                               op.getVectorType().getElementType(), fragType);
269*edd9515bSthomasraoux   OpBuilder b(op);
270*edd9515bSthomasraoux   Value load = b.create<gpu::SubgroupMmaLoadMatrixOp>(
271*edd9515bSthomasraoux       op.getLoc(), type, op.source(), op.indices(), b.getIndexAttr(*stride));
272*edd9515bSthomasraoux   valueMapping[op.getResult()] = load;
273*edd9515bSthomasraoux }
274*edd9515bSthomasraoux 
275*edd9515bSthomasraoux static void convertTransferWriteOp(vector::TransferWriteOp op,
276*edd9515bSthomasraoux                                    llvm::DenseMap<Value, Value> &valueMapping) {
277*edd9515bSthomasraoux   assert(transferWriteSupportsMMAMatrixType(op));
278*edd9515bSthomasraoux   Optional<int64_t> stride =
279*edd9515bSthomasraoux       getMemrefConstantHorizontalStride(op.getShapedType());
280*edd9515bSthomasraoux   assert(stride);
281*edd9515bSthomasraoux   OpBuilder b(op);
282*edd9515bSthomasraoux   Value matrix = valueMapping.find(op.vector())->second;
283*edd9515bSthomasraoux   b.create<gpu::SubgroupMmaStoreMatrixOp>(
284*edd9515bSthomasraoux       op.getLoc(), matrix, op.source(), op.indices(), b.getIndexAttr(*stride));
285*edd9515bSthomasraoux   op.erase();
286*edd9515bSthomasraoux }
287*edd9515bSthomasraoux 
288*edd9515bSthomasraoux static void convertContractOp(vector::ContractionOp op,
289*edd9515bSthomasraoux                               llvm::DenseMap<Value, Value> &valueMapping) {
290*edd9515bSthomasraoux   OpBuilder b(op);
291*edd9515bSthomasraoux   Value opA = valueMapping.find(op.lhs())->second;
292*edd9515bSthomasraoux   Value opB = valueMapping.find(op.rhs())->second;
293*edd9515bSthomasraoux   Value opC = valueMapping.find(op.acc())->second;
294*edd9515bSthomasraoux   Value matmul = b.create<gpu::SubgroupMmaComputeOp>(op.getLoc(), opC.getType(),
295*edd9515bSthomasraoux                                                      opA, opB, opC);
296*edd9515bSthomasraoux   valueMapping[op.getResult()] = matmul;
297*edd9515bSthomasraoux }
298*edd9515bSthomasraoux 
299*edd9515bSthomasraoux namespace mlir {
300*edd9515bSthomasraoux 
301*edd9515bSthomasraoux void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns) {
302*edd9515bSthomasraoux   patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
303*edd9515bSthomasraoux       patterns.getContext());
304*edd9515bSthomasraoux }
305*edd9515bSthomasraoux 
306*edd9515bSthomasraoux void convertVectorToMMAOps(FuncOp funcOp) {
307*edd9515bSthomasraoux   SetVector<Operation *> ops = getOpToConvert(funcOp);
308*edd9515bSthomasraoux   llvm::DenseMap<Value, Value> valueMapping;
309*edd9515bSthomasraoux   for (Operation *op : ops) {
310*edd9515bSthomasraoux     if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
311*edd9515bSthomasraoux       convertTransferReadOp(transferRead, valueMapping);
312*edd9515bSthomasraoux     } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
313*edd9515bSthomasraoux       convertTransferWriteOp(transferWrite, valueMapping);
314*edd9515bSthomasraoux     } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
315*edd9515bSthomasraoux       convertContractOp(contractOp, valueMapping);
316*edd9515bSthomasraoux     }
317*edd9515bSthomasraoux   }
318*edd9515bSthomasraoux }
319*edd9515bSthomasraoux 
320*edd9515bSthomasraoux } // namespace mlir
321*edd9515bSthomasraoux namespace {
322*edd9515bSthomasraoux 
323*edd9515bSthomasraoux struct ConvertVectorToGPUPass
324*edd9515bSthomasraoux     : public ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
325*edd9515bSthomasraoux   void runOnFunction() override {
326*edd9515bSthomasraoux     RewritePatternSet patterns(getFunction().getContext());
327*edd9515bSthomasraoux     populatePrepareVectorToMMAPatterns(patterns);
328*edd9515bSthomasraoux     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
329*edd9515bSthomasraoux 
330*edd9515bSthomasraoux     convertVectorToMMAOps(getFunction());
331*edd9515bSthomasraoux   }
332*edd9515bSthomasraoux };
333*edd9515bSthomasraoux 
334*edd9515bSthomasraoux } // namespace
335*edd9515bSthomasraoux 
336*edd9515bSthomasraoux std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass() {
337*edd9515bSthomasraoux   return std::make_unique<ConvertVectorToGPUPass>();
338*edd9515bSthomasraoux }
339