1*edd9515bSthomasraoux //===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===// 2*edd9515bSthomasraoux // 3*edd9515bSthomasraoux // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*edd9515bSthomasraoux // See https://llvm.org/LICENSE.txt for license information. 5*edd9515bSthomasraoux // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*edd9515bSthomasraoux // 7*edd9515bSthomasraoux //===----------------------------------------------------------------------===// 8*edd9515bSthomasraoux // 9*edd9515bSthomasraoux // This file implements lowering of vector operations to GPU dialect ops. 10*edd9515bSthomasraoux // 11*edd9515bSthomasraoux //===----------------------------------------------------------------------===// 12*edd9515bSthomasraoux 13*edd9515bSthomasraoux #include <type_traits> 14*edd9515bSthomasraoux 15*edd9515bSthomasraoux #include "mlir/Conversion/VectorToGPU/VectorToGPU.h" 16*edd9515bSthomasraoux 17*edd9515bSthomasraoux #include "../PassDetail.h" 18*edd9515bSthomasraoux #include "mlir/Analysis/SliceAnalysis.h" 19*edd9515bSthomasraoux #include "mlir/Dialect/GPU/GPUDialect.h" 20*edd9515bSthomasraoux #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 21*edd9515bSthomasraoux #include "mlir/Dialect/Vector/VectorOps.h" 22*edd9515bSthomasraoux #include "mlir/Dialect/Vector/VectorUtils.h" 23*edd9515bSthomasraoux #include "mlir/IR/Builders.h" 24*edd9515bSthomasraoux #include "mlir/Pass/Pass.h" 25*edd9515bSthomasraoux #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 26*edd9515bSthomasraoux #include "mlir/Transforms/Passes.h" 27*edd9515bSthomasraoux 28*edd9515bSthomasraoux using namespace mlir; 29*edd9515bSthomasraoux 30*edd9515bSthomasraoux // Return true if the contract op can be convert to MMA matmul. 31*edd9515bSthomasraoux static bool contractSupportsMMAMatrixType(vector::ContractionOp contract) { 32*edd9515bSthomasraoux if (llvm::size(contract.masks()) != 0) 33*edd9515bSthomasraoux return false; 34*edd9515bSthomasraoux 35*edd9515bSthomasraoux using MapList = ArrayRef<ArrayRef<AffineExpr>>; 36*edd9515bSthomasraoux auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; 37*edd9515bSthomasraoux AffineExpr m, n, k; 38*edd9515bSthomasraoux bindDims(contract.getContext(), m, n, k); 39*edd9515bSthomasraoux auto iteratorTypes = contract.iterator_types().getValue(); 40*edd9515bSthomasraoux if (!(isParallelIterator(iteratorTypes[0]) && 41*edd9515bSthomasraoux isParallelIterator(iteratorTypes[1]) && 42*edd9515bSthomasraoux isReductionIterator(iteratorTypes[2]))) 43*edd9515bSthomasraoux return false; 44*edd9515bSthomasraoux 45*edd9515bSthomasraoux // The contract needs to represent a matmul to be able to convert to 46*edd9515bSthomasraoux // MMAMatrix matmul. 47*edd9515bSthomasraoux if (contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}})) 48*edd9515bSthomasraoux return false; 49*edd9515bSthomasraoux 50*edd9515bSthomasraoux // Check that the size matches what is natively supported. 51*edd9515bSthomasraoux VectorType lhsType = contract.lhs().getType().cast<VectorType>(); 52*edd9515bSthomasraoux VectorType rhsType = contract.rhs().getType().cast<VectorType>(); 53*edd9515bSthomasraoux VectorType accType = contract.acc().getType().cast<VectorType>(); 54*edd9515bSthomasraoux 55*edd9515bSthomasraoux std::tuple<int, int, int> dim(lhsType.getDimSize(0), rhsType.getDimSize(1), 56*edd9515bSthomasraoux lhsType.getDimSize(1)); 57*edd9515bSthomasraoux if (lhsType.getElementType().isInteger(8) && 58*edd9515bSthomasraoux rhsType.getElementType().isInteger(8) && 59*edd9515bSthomasraoux accType.getElementType().isInteger(32) && 60*edd9515bSthomasraoux (dim == std::make_tuple(8, 8, 32) || dim == std::make_tuple(16, 16, 32) || 61*edd9515bSthomasraoux dim == std::make_tuple(16, 8, 32))) 62*edd9515bSthomasraoux return true; 63*edd9515bSthomasraoux 64*edd9515bSthomasraoux if (lhsType.getElementType().isF16() && rhsType.getElementType().isF16() && 65*edd9515bSthomasraoux (accType.getElementType().isF16() || accType.getElementType().isF32()) && 66*edd9515bSthomasraoux (dim == std::make_tuple(8, 8, 16) || dim == std::make_tuple(16, 16, 16) || 67*edd9515bSthomasraoux dim == std::make_tuple(16, 8, 16))) 68*edd9515bSthomasraoux return true; 69*edd9515bSthomasraoux return false; 70*edd9515bSthomasraoux } 71*edd9515bSthomasraoux 72*edd9515bSthomasraoux // Return the stide for the dimension 0 of |type| if it is a memref and has a 73*edd9515bSthomasraoux // constant stride. 74*edd9515bSthomasraoux static llvm::Optional<int64_t> 75*edd9515bSthomasraoux getMemrefConstantHorizontalStride(ShapedType type) { 76*edd9515bSthomasraoux auto memrefType = type.dyn_cast<MemRefType>(); 77*edd9515bSthomasraoux if (!memrefType) 78*edd9515bSthomasraoux return false; 79*edd9515bSthomasraoux int64_t offset = 0; 80*edd9515bSthomasraoux SmallVector<int64_t, 2> strides; 81*edd9515bSthomasraoux if (failed(getStridesAndOffset(memrefType, strides, offset))) 82*edd9515bSthomasraoux return llvm::None; 83*edd9515bSthomasraoux if (strides[0] == ShapedType::kDynamicStrideOrOffset) 84*edd9515bSthomasraoux return llvm::None; 85*edd9515bSthomasraoux return strides[0]; 86*edd9515bSthomasraoux } 87*edd9515bSthomasraoux 88*edd9515bSthomasraoux // Return true if the transfer op can be converted to a MMA matrix load. 89*edd9515bSthomasraoux static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) { 90*edd9515bSthomasraoux if (readOp.mask() || readOp.hasOutOfBoundsDim() || 91*edd9515bSthomasraoux readOp.getVectorType().getRank() != 2) 92*edd9515bSthomasraoux return false; 93*edd9515bSthomasraoux if (!getMemrefConstantHorizontalStride(readOp.getShapedType())) 94*edd9515bSthomasraoux return false; 95*edd9515bSthomasraoux // TODO: Support transpose once it is added to GPU dialect ops. 96*edd9515bSthomasraoux if (!readOp.permutation_map().isMinorIdentity()) 97*edd9515bSthomasraoux return false; 98*edd9515bSthomasraoux return true; 99*edd9515bSthomasraoux } 100*edd9515bSthomasraoux 101*edd9515bSthomasraoux // Return true if the transfer op can be converted to a MMA matrix store. 102*edd9515bSthomasraoux static bool 103*edd9515bSthomasraoux transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) { 104*edd9515bSthomasraoux if (writeOp.mask() || writeOp.hasOutOfBoundsDim() || 105*edd9515bSthomasraoux writeOp.getVectorType().getRank() != 2) 106*edd9515bSthomasraoux return false; 107*edd9515bSthomasraoux if (!getMemrefConstantHorizontalStride(writeOp.getShapedType())) 108*edd9515bSthomasraoux return false; 109*edd9515bSthomasraoux // TODO: Support transpose once it is added to GPU dialect ops. 110*edd9515bSthomasraoux if (!writeOp.permutation_map().isMinorIdentity()) 111*edd9515bSthomasraoux return false; 112*edd9515bSthomasraoux return true; 113*edd9515bSthomasraoux } 114*edd9515bSthomasraoux 115*edd9515bSthomasraoux static bool supportsMMaMatrixType(Operation *op) { 116*edd9515bSthomasraoux if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) 117*edd9515bSthomasraoux return transferReadSupportsMMAMatrixType(transferRead); 118*edd9515bSthomasraoux if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) 119*edd9515bSthomasraoux return transferWriteSupportsMMAMatrixType(transferWrite); 120*edd9515bSthomasraoux if (auto contract = dyn_cast<vector::ContractionOp>(op)) 121*edd9515bSthomasraoux return contractSupportsMMAMatrixType(contract); 122*edd9515bSthomasraoux return false; 123*edd9515bSthomasraoux } 124*edd9515bSthomasraoux 125*edd9515bSthomasraoux // Analyze slice of operations based on convert op to figure out if the whole 126*edd9515bSthomasraoux // slice can be converted to MMA operations. 127*edd9515bSthomasraoux static SetVector<Operation *> getOpToConvert(mlir::Operation *op) { 128*edd9515bSthomasraoux auto hasVectorDest = [](Operation *op) { 129*edd9515bSthomasraoux return op->getNumResults() == 0 || 130*edd9515bSthomasraoux llvm::any_of(op->getResultTypes(), 131*edd9515bSthomasraoux [](Type t) { return t.isa<VectorType>(); }); 132*edd9515bSthomasraoux }; 133*edd9515bSthomasraoux SetVector<Operation *> opToConvert; 134*edd9515bSthomasraoux op->walk([&](vector::ContractionOp contract) { 135*edd9515bSthomasraoux if (opToConvert.contains(contract.getOperation())) 136*edd9515bSthomasraoux return; 137*edd9515bSthomasraoux SetVector<Operation *> dependentOps = 138*edd9515bSthomasraoux getSlice(contract, hasVectorDest, hasVectorDest); 139*edd9515bSthomasraoux // If any instruction cannot use MMA matrix type drop the whole 140*edd9515bSthomasraoux // chaine. MMA matrix are stored in an opaque type so they cannot be used 141*edd9515bSthomasraoux // by all operations. 142*edd9515bSthomasraoux if (llvm::any_of(dependentOps, 143*edd9515bSthomasraoux [](Operation *op) { return !supportsMMaMatrixType(op); })) 144*edd9515bSthomasraoux return; 145*edd9515bSthomasraoux opToConvert.insert(dependentOps.begin(), dependentOps.end()); 146*edd9515bSthomasraoux }); 147*edd9515bSthomasraoux return opToConvert; 148*edd9515bSthomasraoux } 149*edd9515bSthomasraoux 150*edd9515bSthomasraoux namespace { 151*edd9515bSthomasraoux // Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted 152*edd9515bSthomasraoux // to MMA matmul. 153*edd9515bSthomasraoux struct PrepareContractToGPUMMA 154*edd9515bSthomasraoux : public OpRewritePattern<vector::ContractionOp> { 155*edd9515bSthomasraoux using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 156*edd9515bSthomasraoux 157*edd9515bSthomasraoux LogicalResult matchAndRewrite(vector::ContractionOp op, 158*edd9515bSthomasraoux PatternRewriter &rewriter) const override { 159*edd9515bSthomasraoux Location loc = op.getLoc(); 160*edd9515bSthomasraoux Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc(); 161*edd9515bSthomasraoux 162*edd9515bSthomasraoux // Set up the parallel/reduction structure in right form. 163*edd9515bSthomasraoux using MapList = ArrayRef<ArrayRef<AffineExpr>>; 164*edd9515bSthomasraoux auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; 165*edd9515bSthomasraoux AffineExpr m, n, k; 166*edd9515bSthomasraoux bindDims(rewriter.getContext(), m, n, k); 167*edd9515bSthomasraoux static constexpr std::array<int64_t, 2> perm = {1, 0}; 168*edd9515bSthomasraoux auto iteratorTypes = op.iterator_types().getValue(); 169*edd9515bSthomasraoux SmallVector<AffineMap, 4> maps = op.getIndexingMaps(); 170*edd9515bSthomasraoux if (!(isParallelIterator(iteratorTypes[0]) && 171*edd9515bSthomasraoux isParallelIterator(iteratorTypes[1]) && 172*edd9515bSthomasraoux isReductionIterator(iteratorTypes[2]))) 173*edd9515bSthomasraoux return failure(); 174*edd9515bSthomasraoux // 175*edd9515bSthomasraoux // Two outer parallel, one inner reduction (matmat flavor). 176*edd9515bSthomasraoux // 177*edd9515bSthomasraoux if (maps == infer({{m, k}, {k, n}, {m, n}})) { 178*edd9515bSthomasraoux // This is the classical row-major matmul, nothing to do. 179*edd9515bSthomasraoux return failure(); 180*edd9515bSthomasraoux } 181*edd9515bSthomasraoux if (maps == infer({{m, k}, {n, k}, {m, n}})) { 182*edd9515bSthomasraoux rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 183*edd9515bSthomasraoux } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { 184*edd9515bSthomasraoux lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 185*edd9515bSthomasraoux } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { 186*edd9515bSthomasraoux rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 187*edd9515bSthomasraoux lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 188*edd9515bSthomasraoux } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { 189*edd9515bSthomasraoux std::swap(rhs, lhs); 190*edd9515bSthomasraoux rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 191*edd9515bSthomasraoux lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 192*edd9515bSthomasraoux } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { 193*edd9515bSthomasraoux std::swap(rhs, lhs); 194*edd9515bSthomasraoux rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 195*edd9515bSthomasraoux } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { 196*edd9515bSthomasraoux std::swap(lhs, rhs); 197*edd9515bSthomasraoux lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 198*edd9515bSthomasraoux } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { 199*edd9515bSthomasraoux std::swap(lhs, rhs); 200*edd9515bSthomasraoux } else { 201*edd9515bSthomasraoux return failure(); 202*edd9515bSthomasraoux } 203*edd9515bSthomasraoux rewriter.replaceOpWithNewOp<vector::ContractionOp>( 204*edd9515bSthomasraoux op, lhs, rhs, res, 205*edd9515bSthomasraoux rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})), 206*edd9515bSthomasraoux op.iterator_types()); 207*edd9515bSthomasraoux return success(); 208*edd9515bSthomasraoux } 209*edd9515bSthomasraoux }; 210*edd9515bSthomasraoux 211*edd9515bSthomasraoux // Merge transpose op into the transfer read op. Transpose are not supported on 212*edd9515bSthomasraoux // MMA types but MMA load can transpose the matrix when loading. 213*edd9515bSthomasraoux struct CombineTransferReadOpTranspose final 214*edd9515bSthomasraoux : public OpRewritePattern<vector::TransposeOp> { 215*edd9515bSthomasraoux using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 216*edd9515bSthomasraoux 217*edd9515bSthomasraoux LogicalResult matchAndRewrite(vector::TransposeOp op, 218*edd9515bSthomasraoux PatternRewriter &rewriter) const override { 219*edd9515bSthomasraoux auto transferReadOp = op.vector().getDefiningOp<vector::TransferReadOp>(); 220*edd9515bSthomasraoux if (!transferReadOp) 221*edd9515bSthomasraoux return failure(); 222*edd9515bSthomasraoux if (transferReadOp.mask() || transferReadOp.hasOutOfBoundsDim()) 223*edd9515bSthomasraoux return failure(); 224*edd9515bSthomasraoux SmallVector<int64_t, 2> perm; 225*edd9515bSthomasraoux op.getTransp(perm); 226*edd9515bSthomasraoux SmallVector<unsigned, 2> permU; 227*edd9515bSthomasraoux for (int64_t o : perm) 228*edd9515bSthomasraoux permU.push_back(unsigned(o)); 229*edd9515bSthomasraoux AffineMap permutationMap = 230*edd9515bSthomasraoux AffineMap::getPermutationMap(permU, op.getContext()); 231*edd9515bSthomasraoux AffineMap newMap = permutationMap.compose(transferReadOp.permutation_map()); 232*edd9515bSthomasraoux rewriter.replaceOpWithNewOp<vector::TransferReadOp>( 233*edd9515bSthomasraoux op, op.getType(), transferReadOp.source(), transferReadOp.indices(), 234*edd9515bSthomasraoux newMap, transferReadOp.padding(), transferReadOp.mask(), 235*edd9515bSthomasraoux transferReadOp.in_boundsAttr()); 236*edd9515bSthomasraoux return success(); 237*edd9515bSthomasraoux } 238*edd9515bSthomasraoux }; 239*edd9515bSthomasraoux 240*edd9515bSthomasraoux } // namespace 241*edd9515bSthomasraoux 242*edd9515bSthomasraoux // MMA types have different layout based on how they are used in matmul ops. 243*edd9515bSthomasraoux // Figure the right layout to use by looking at Transfer op uses. 244*edd9515bSthomasraoux // TODO: Change the GPU dialect to abstract the layout at the this level and 245*edd9515bSthomasraoux // only care about it during lowering to NVVM. 246*edd9515bSthomasraoux static const char *inferFragType(vector::TransferReadOp op) { 247*edd9515bSthomasraoux for (Operation *users : op->getUsers()) { 248*edd9515bSthomasraoux auto contract = dyn_cast<vector::ContractionOp>(users); 249*edd9515bSthomasraoux if (!contract) 250*edd9515bSthomasraoux continue; 251*edd9515bSthomasraoux if (contract.lhs() == op.getResult()) 252*edd9515bSthomasraoux return "AOp"; 253*edd9515bSthomasraoux if (contract.rhs() == op.getResult()) 254*edd9515bSthomasraoux return "BOp"; 255*edd9515bSthomasraoux } 256*edd9515bSthomasraoux return "COp"; 257*edd9515bSthomasraoux } 258*edd9515bSthomasraoux 259*edd9515bSthomasraoux static void convertTransferReadOp(vector::TransferReadOp op, 260*edd9515bSthomasraoux llvm::DenseMap<Value, Value> &valueMapping) { 261*edd9515bSthomasraoux assert(transferReadSupportsMMAMatrixType(op)); 262*edd9515bSthomasraoux Optional<int64_t> stride = 263*edd9515bSthomasraoux getMemrefConstantHorizontalStride(op.getShapedType()); 264*edd9515bSthomasraoux assert(stride); 265*edd9515bSthomasraoux const char *fragType = inferFragType(op); 266*edd9515bSthomasraoux gpu::MMAMatrixType type = 267*edd9515bSthomasraoux gpu::MMAMatrixType::get(op.getVectorType().getShape(), 268*edd9515bSthomasraoux op.getVectorType().getElementType(), fragType); 269*edd9515bSthomasraoux OpBuilder b(op); 270*edd9515bSthomasraoux Value load = b.create<gpu::SubgroupMmaLoadMatrixOp>( 271*edd9515bSthomasraoux op.getLoc(), type, op.source(), op.indices(), b.getIndexAttr(*stride)); 272*edd9515bSthomasraoux valueMapping[op.getResult()] = load; 273*edd9515bSthomasraoux } 274*edd9515bSthomasraoux 275*edd9515bSthomasraoux static void convertTransferWriteOp(vector::TransferWriteOp op, 276*edd9515bSthomasraoux llvm::DenseMap<Value, Value> &valueMapping) { 277*edd9515bSthomasraoux assert(transferWriteSupportsMMAMatrixType(op)); 278*edd9515bSthomasraoux Optional<int64_t> stride = 279*edd9515bSthomasraoux getMemrefConstantHorizontalStride(op.getShapedType()); 280*edd9515bSthomasraoux assert(stride); 281*edd9515bSthomasraoux OpBuilder b(op); 282*edd9515bSthomasraoux Value matrix = valueMapping.find(op.vector())->second; 283*edd9515bSthomasraoux b.create<gpu::SubgroupMmaStoreMatrixOp>( 284*edd9515bSthomasraoux op.getLoc(), matrix, op.source(), op.indices(), b.getIndexAttr(*stride)); 285*edd9515bSthomasraoux op.erase(); 286*edd9515bSthomasraoux } 287*edd9515bSthomasraoux 288*edd9515bSthomasraoux static void convertContractOp(vector::ContractionOp op, 289*edd9515bSthomasraoux llvm::DenseMap<Value, Value> &valueMapping) { 290*edd9515bSthomasraoux OpBuilder b(op); 291*edd9515bSthomasraoux Value opA = valueMapping.find(op.lhs())->second; 292*edd9515bSthomasraoux Value opB = valueMapping.find(op.rhs())->second; 293*edd9515bSthomasraoux Value opC = valueMapping.find(op.acc())->second; 294*edd9515bSthomasraoux Value matmul = b.create<gpu::SubgroupMmaComputeOp>(op.getLoc(), opC.getType(), 295*edd9515bSthomasraoux opA, opB, opC); 296*edd9515bSthomasraoux valueMapping[op.getResult()] = matmul; 297*edd9515bSthomasraoux } 298*edd9515bSthomasraoux 299*edd9515bSthomasraoux namespace mlir { 300*edd9515bSthomasraoux 301*edd9515bSthomasraoux void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns) { 302*edd9515bSthomasraoux patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>( 303*edd9515bSthomasraoux patterns.getContext()); 304*edd9515bSthomasraoux } 305*edd9515bSthomasraoux 306*edd9515bSthomasraoux void convertVectorToMMAOps(FuncOp funcOp) { 307*edd9515bSthomasraoux SetVector<Operation *> ops = getOpToConvert(funcOp); 308*edd9515bSthomasraoux llvm::DenseMap<Value, Value> valueMapping; 309*edd9515bSthomasraoux for (Operation *op : ops) { 310*edd9515bSthomasraoux if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) { 311*edd9515bSthomasraoux convertTransferReadOp(transferRead, valueMapping); 312*edd9515bSthomasraoux } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) { 313*edd9515bSthomasraoux convertTransferWriteOp(transferWrite, valueMapping); 314*edd9515bSthomasraoux } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) { 315*edd9515bSthomasraoux convertContractOp(contractOp, valueMapping); 316*edd9515bSthomasraoux } 317*edd9515bSthomasraoux } 318*edd9515bSthomasraoux } 319*edd9515bSthomasraoux 320*edd9515bSthomasraoux } // namespace mlir 321*edd9515bSthomasraoux namespace { 322*edd9515bSthomasraoux 323*edd9515bSthomasraoux struct ConvertVectorToGPUPass 324*edd9515bSthomasraoux : public ConvertVectorToGPUBase<ConvertVectorToGPUPass> { 325*edd9515bSthomasraoux void runOnFunction() override { 326*edd9515bSthomasraoux RewritePatternSet patterns(getFunction().getContext()); 327*edd9515bSthomasraoux populatePrepareVectorToMMAPatterns(patterns); 328*edd9515bSthomasraoux (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 329*edd9515bSthomasraoux 330*edd9515bSthomasraoux convertVectorToMMAOps(getFunction()); 331*edd9515bSthomasraoux } 332*edd9515bSthomasraoux }; 333*edd9515bSthomasraoux 334*edd9515bSthomasraoux } // namespace 335*edd9515bSthomasraoux 336*edd9515bSthomasraoux std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass() { 337*edd9515bSthomasraoux return std::make_unique<ConvertVectorToGPUPass>(); 338*edd9515bSthomasraoux } 339