1edd9515bSthomasraoux //===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===// 2edd9515bSthomasraoux // 3edd9515bSthomasraoux // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4edd9515bSthomasraoux // See https://llvm.org/LICENSE.txt for license information. 5edd9515bSthomasraoux // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6edd9515bSthomasraoux // 7edd9515bSthomasraoux //===----------------------------------------------------------------------===// 8edd9515bSthomasraoux // 9edd9515bSthomasraoux // This file implements lowering of vector operations to GPU dialect ops. 10edd9515bSthomasraoux // 11edd9515bSthomasraoux //===----------------------------------------------------------------------===// 12edd9515bSthomasraoux 13edd9515bSthomasraoux #include <type_traits> 14edd9515bSthomasraoux 15*1ca772edSChristopher Bate #include "NvGpuSupport.h" 16edd9515bSthomasraoux #include "mlir/Conversion/VectorToGPU/VectorToGPU.h" 17edd9515bSthomasraoux 18edd9515bSthomasraoux #include "../PassDetail.h" 19edd9515bSthomasraoux #include "mlir/Analysis/SliceAnalysis.h" 20a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 21edd9515bSthomasraoux #include "mlir/Dialect/GPU/GPUDialect.h" 2266f878ceSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 23*1ca772edSChristopher Bate #include "mlir/Dialect/NVGPU/NVGPUDialect.h" 241a865592Sthomasraoux #include "mlir/Dialect/SCF/SCF.h" 25edd9515bSthomasraoux #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 2699ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h" 2799ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 28edd9515bSthomasraoux #include "mlir/IR/Builders.h" 29edd9515bSthomasraoux #include "mlir/Pass/Pass.h" 30edd9515bSthomasraoux #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 31edd9515bSthomasraoux #include "mlir/Transforms/Passes.h" 32*1ca772edSChristopher Bate #include "llvm/ADT/TypeSwitch.h" 33edd9515bSthomasraoux 34edd9515bSthomasraoux using namespace mlir; 35edd9515bSthomasraoux 36*1ca772edSChristopher Bate /// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an 37*1ca772edSChristopher Bate /// AffineMap representing offsets to apply to indices, the function fills 38*1ca772edSChristopher Bate /// `indices` with the original indices plus the offsets. The offsets are 39*1ca772edSChristopher Bate /// applied by taking into account the permutation map of the transfer op. If 40*1ca772edSChristopher Bate /// the `offsetMap` has dimension placeholders, those should be provided in 41*1ca772edSChristopher Bate /// `dimValues`. 42*1ca772edSChristopher Bate template <typename TransferOpType> 43*1ca772edSChristopher Bate static void getXferIndices(OpBuilder &b, TransferOpType xferOp, 44*1ca772edSChristopher Bate AffineMap offsetMap, ArrayRef<Value> dimValues, 45*1ca772edSChristopher Bate SmallVector<Value, 4> &indices) { 46*1ca772edSChristopher Bate indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end()); 47*1ca772edSChristopher Bate Location loc = xferOp.getLoc(); 48*1ca772edSChristopher Bate unsigned offsetsIdx = 0; 49*1ca772edSChristopher Bate for (auto expr : xferOp.getPermutationMap().getResults()) { 50*1ca772edSChristopher Bate if (auto dim = expr.template dyn_cast<AffineDimExpr>()) { 51*1ca772edSChristopher Bate Value prevIdx = indices[dim.getPosition()]; 52*1ca772edSChristopher Bate SmallVector<Value, 3> dims(dimValues.begin(), dimValues.end()); 53*1ca772edSChristopher Bate dims.push_back(prevIdx); 54*1ca772edSChristopher Bate AffineExpr d0 = b.getAffineDimExpr(offsetMap.getNumDims()); 55*1ca772edSChristopher Bate indices[dim.getPosition()] = makeComposedAffineApply( 56*1ca772edSChristopher Bate b, loc, d0 + offsetMap.getResult(offsetsIdx++), dims); 57*1ca772edSChristopher Bate continue; 58*1ca772edSChristopher Bate } 59*1ca772edSChristopher Bate } 60*1ca772edSChristopher Bate } 61*1ca772edSChristopher Bate 62edd9515bSthomasraoux // Return true if the contract op can be convert to MMA matmul. 63*1ca772edSChristopher Bate static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, 64*1ca772edSChristopher Bate bool useNvGpu) { 657c38fd60SJacques Pienaar if (llvm::size(contract.getMasks()) != 0) 66edd9515bSthomasraoux return false; 67edd9515bSthomasraoux 68edd9515bSthomasraoux using MapList = ArrayRef<ArrayRef<AffineExpr>>; 69edd9515bSthomasraoux auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; 70edd9515bSthomasraoux AffineExpr m, n, k; 71edd9515bSthomasraoux bindDims(contract.getContext(), m, n, k); 727c38fd60SJacques Pienaar auto iteratorTypes = contract.getIteratorTypes().getValue(); 73edd9515bSthomasraoux if (!(isParallelIterator(iteratorTypes[0]) && 74edd9515bSthomasraoux isParallelIterator(iteratorTypes[1]) && 75edd9515bSthomasraoux isReductionIterator(iteratorTypes[2]))) 76edd9515bSthomasraoux return false; 77edd9515bSthomasraoux 78edd9515bSthomasraoux // The contract needs to represent a matmul to be able to convert to 79edd9515bSthomasraoux // MMAMatrix matmul. 80*1ca772edSChristopher Bate if (!useNvGpu && 81*1ca772edSChristopher Bate contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}})) 82*1ca772edSChristopher Bate return false; 83*1ca772edSChristopher Bate if (useNvGpu && contract.getIndexingMaps() != infer({{m, k}, {n, k}, {m, n}})) 84edd9515bSthomasraoux return false; 85edd9515bSthomasraoux 86edd9515bSthomasraoux return true; 87edd9515bSthomasraoux } 88edd9515bSthomasraoux 89edd9515bSthomasraoux // Return the stide for the dimension 0 of |type| if it is a memref and has a 90edd9515bSthomasraoux // constant stride. 91edd9515bSthomasraoux static llvm::Optional<int64_t> 92edd9515bSthomasraoux getMemrefConstantHorizontalStride(ShapedType type) { 93edd9515bSthomasraoux auto memrefType = type.dyn_cast<MemRefType>(); 94edd9515bSthomasraoux if (!memrefType) 95edd9515bSthomasraoux return false; 96a57ccad5SThomas Raoux // If the memref is 0 or 1D the horizontal stride is 0. 97a57ccad5SThomas Raoux if (memrefType.getRank() < 2) 98a57ccad5SThomas Raoux return 0; 99edd9515bSthomasraoux int64_t offset = 0; 100edd9515bSthomasraoux SmallVector<int64_t, 2> strides; 101d77f4836SThomas Raoux if (failed(getStridesAndOffset(memrefType, strides, offset)) || 102d77f4836SThomas Raoux strides.back() != 1) 103edd9515bSthomasraoux return llvm::None; 104a57ccad5SThomas Raoux int64_t stride = strides[strides.size() - 2]; 105a57ccad5SThomas Raoux if (stride == ShapedType::kDynamicStrideOrOffset) 106edd9515bSthomasraoux return llvm::None; 107a57ccad5SThomas Raoux return stride; 108edd9515bSthomasraoux } 109edd9515bSthomasraoux 110edd9515bSthomasraoux // Return true if the transfer op can be converted to a MMA matrix load. 111*1ca772edSChristopher Bate static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp, 112*1ca772edSChristopher Bate bool useNvGpu) { 1137c38fd60SJacques Pienaar if (readOp.getMask() || readOp.hasOutOfBoundsDim() || 114edd9515bSthomasraoux readOp.getVectorType().getRank() != 2) 115edd9515bSthomasraoux return false; 116edd9515bSthomasraoux if (!getMemrefConstantHorizontalStride(readOp.getShapedType())) 117edd9515bSthomasraoux return false; 1187c38fd60SJacques Pienaar AffineMap map = readOp.getPermutationMap(); 119e7969240SThomas Raoux OpBuilder b(readOp.getContext()); 120e7969240SThomas Raoux AffineExpr innerDim = b.getAffineDimExpr(map.getNumDims() - 1); 121e7969240SThomas Raoux AffineExpr zero = b.getAffineConstantExpr(0); 122e7969240SThomas Raoux auto broadcastInnerDim = AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, 123e7969240SThomas Raoux readOp.getContext()); 124*1ca772edSChristopher Bate 125*1ca772edSChristopher Bate if (!useNvGpu) { 126edd9515bSthomasraoux // TODO: Support transpose once it is added to GPU dialect ops. 127e7969240SThomas Raoux // For now we only support (d0, d1) -> (d0, d1) and (d0, d1) -> (0, d1). 128*1ca772edSChristopher Bate return map.isMinorIdentity() || map == broadcastInnerDim; 129*1ca772edSChristopher Bate } 130*1ca772edSChristopher Bate 131*1ca772edSChristopher Bate return true; 132edd9515bSthomasraoux } 133edd9515bSthomasraoux 134edd9515bSthomasraoux // Return true if the transfer op can be converted to a MMA matrix store. 135edd9515bSthomasraoux static bool 136edd9515bSthomasraoux transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) { 137c537a943SNicolas Vasilache // TODO: support 0-d corner case. 138c537a943SNicolas Vasilache if (writeOp.getTransferRank() == 0) 139c537a943SNicolas Vasilache return false; 140c537a943SNicolas Vasilache 1417c38fd60SJacques Pienaar if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() || 142edd9515bSthomasraoux writeOp.getVectorType().getRank() != 2) 143edd9515bSthomasraoux return false; 144edd9515bSthomasraoux if (!getMemrefConstantHorizontalStride(writeOp.getShapedType())) 145edd9515bSthomasraoux return false; 146edd9515bSthomasraoux // TODO: Support transpose once it is added to GPU dialect ops. 1477c38fd60SJacques Pienaar if (!writeOp.getPermutationMap().isMinorIdentity()) 148edd9515bSthomasraoux return false; 149edd9515bSthomasraoux return true; 150edd9515bSthomasraoux } 151edd9515bSthomasraoux 1526413226dSthomasraoux /// Return true if the constant is a splat to a 2D vector so that it can be 1536413226dSthomasraoux /// converted to a MMA constant matrix op. 154a54f4eaeSMogball static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) { 1556413226dSthomasraoux auto vecType = constantOp.getType().dyn_cast<VectorType>(); 1566413226dSthomasraoux if (!vecType || vecType.getRank() != 2) 1576413226dSthomasraoux return false; 158cfb72fd3SJacques Pienaar return constantOp.getValue().isa<SplatElementsAttr>(); 1596413226dSthomasraoux } 1606413226dSthomasraoux 16143928419Sthomasraoux /// Return true if this is a broadcast from scalar to a 2D vector. 16243928419Sthomasraoux static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) { 16343928419Sthomasraoux return broadcastOp.getVectorType().getRank() == 2 && 1647c38fd60SJacques Pienaar broadcastOp.getSource().getType().isa<FloatType>(); 16543928419Sthomasraoux } 16643928419Sthomasraoux 1677fbb0678Sthomasraoux /// Return the MMA elementwise enum associated with `op` if it is supported. 1687fbb0678Sthomasraoux /// Return `llvm::None` otherwise. 1697fbb0678Sthomasraoux static llvm::Optional<gpu::MMAElementwiseOp> 1707fbb0678Sthomasraoux convertElementwiseOpToMMA(Operation *op) { 1717fbb0678Sthomasraoux if (isa<arith::AddFOp>(op)) 1727fbb0678Sthomasraoux return gpu::MMAElementwiseOp::ADDF; 1737fbb0678Sthomasraoux if (isa<arith::MulFOp>(op)) 1747fbb0678Sthomasraoux return gpu::MMAElementwiseOp::MULF; 1759b1d90e8SAlexander Belyaev if (isa<arith::MaxFOp>(op)) 1767fbb0678Sthomasraoux return gpu::MMAElementwiseOp::MAXF; 1779b1d90e8SAlexander Belyaev if (isa<arith::MinFOp>(op)) 1787fbb0678Sthomasraoux return gpu::MMAElementwiseOp::MINF; 179e7969240SThomas Raoux if (isa<arith::DivFOp>(op)) 180e7969240SThomas Raoux return gpu::MMAElementwiseOp::DIVF; 1817fbb0678Sthomasraoux return llvm::None; 1827fbb0678Sthomasraoux } 1837fbb0678Sthomasraoux 1847fbb0678Sthomasraoux /// Return true if the op is supported as elementwise op on MMAMatrix type. 1857fbb0678Sthomasraoux static bool elementwiseSupportsMMAMatrixType(Operation *op) { 1867fbb0678Sthomasraoux return convertElementwiseOpToMMA(op).hasValue(); 1877fbb0678Sthomasraoux } 1887fbb0678Sthomasraoux 189*1ca772edSChristopher Bate static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) { 1901a865592Sthomasraoux if (isa<scf::ForOp, scf::YieldOp>(op)) 1911a865592Sthomasraoux return true; 192edd9515bSthomasraoux if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) 193*1ca772edSChristopher Bate return transferReadSupportsMMAMatrixType(transferRead, useNvGpu); 194edd9515bSthomasraoux if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) 195edd9515bSthomasraoux return transferWriteSupportsMMAMatrixType(transferWrite); 196edd9515bSthomasraoux if (auto contract = dyn_cast<vector::ContractionOp>(op)) 197*1ca772edSChristopher Bate return contractSupportsMMAMatrixType(contract, useNvGpu); 198a54f4eaeSMogball if (auto constant = dyn_cast<arith::ConstantOp>(op)) 1996413226dSthomasraoux return constantSupportsMMAMatrixType(constant); 20043928419Sthomasraoux if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) 20143928419Sthomasraoux return broadcastSupportsMMAMatrixType(broadcast); 2027fbb0678Sthomasraoux return elementwiseSupportsMMAMatrixType(op); 203edd9515bSthomasraoux } 204edd9515bSthomasraoux 205e7969240SThomas Raoux /// Return an unsorted slice handling scf.for region differently than 206e7969240SThomas Raoux /// `getSlice`. In scf.for we only want to include as part of the slice elements 207e7969240SThomas Raoux /// that are part of the use/def chain. 208e7969240SThomas Raoux static SetVector<Operation *> getSliceContract(Operation *op, 209e7969240SThomas Raoux TransitiveFilter backwardFilter, 210e7969240SThomas Raoux TransitiveFilter forwardFilter) { 211e7969240SThomas Raoux SetVector<Operation *> slice; 212e7969240SThomas Raoux slice.insert(op); 213e7969240SThomas Raoux unsigned currentIndex = 0; 214e7969240SThomas Raoux SetVector<Operation *> backwardSlice; 215e7969240SThomas Raoux SetVector<Operation *> forwardSlice; 216e7969240SThomas Raoux while (currentIndex != slice.size()) { 217e7969240SThomas Raoux auto *currentOp = (slice)[currentIndex]; 218e7969240SThomas Raoux // Compute and insert the backwardSlice starting from currentOp. 219e7969240SThomas Raoux backwardSlice.clear(); 220e7969240SThomas Raoux getBackwardSlice(currentOp, &backwardSlice, backwardFilter); 221e7969240SThomas Raoux slice.insert(backwardSlice.begin(), backwardSlice.end()); 222e7969240SThomas Raoux 223e7969240SThomas Raoux // Compute and insert the forwardSlice starting from currentOp. 224e7969240SThomas Raoux forwardSlice.clear(); 225e7969240SThomas Raoux // Special case for ForOp, we don't want to include the whole region but 226e7969240SThomas Raoux // only the value using the region arguments. 227e7969240SThomas Raoux // TODO: We should refine this to only care about the region arguments being 228e7969240SThomas Raoux // converted to matrix type. 229e7969240SThomas Raoux if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) { 230e7969240SThomas Raoux for (Value forOpResult : forOp.getResults()) 231e7969240SThomas Raoux getForwardSlice(forOpResult, &forwardSlice, forwardFilter); 232e7969240SThomas Raoux for (BlockArgument &arg : forOp.getRegionIterArgs()) 233e7969240SThomas Raoux getForwardSlice(arg, &forwardSlice, forwardFilter); 234e7969240SThomas Raoux } else { 235e7969240SThomas Raoux getForwardSlice(currentOp, &forwardSlice, forwardFilter); 236e7969240SThomas Raoux } 237e7969240SThomas Raoux slice.insert(forwardSlice.begin(), forwardSlice.end()); 238e7969240SThomas Raoux ++currentIndex; 239e7969240SThomas Raoux } 240e7969240SThomas Raoux return slice; 241e7969240SThomas Raoux } 242e7969240SThomas Raoux 243edd9515bSthomasraoux // Analyze slice of operations based on convert op to figure out if the whole 244edd9515bSthomasraoux // slice can be converted to MMA operations. 245*1ca772edSChristopher Bate static SetVector<Operation *> getOpToConvert(mlir::Operation *op, 246*1ca772edSChristopher Bate bool useNvGpu) { 247edd9515bSthomasraoux auto hasVectorDest = [](Operation *op) { 24843928419Sthomasraoux return llvm::any_of(op->getResultTypes(), 24943928419Sthomasraoux [](Type t) { return t.isa<VectorType>(); }); 25043928419Sthomasraoux }; 25143928419Sthomasraoux auto hasVectorSrc = [](Operation *op) { 25243928419Sthomasraoux return llvm::any_of(op->getOperandTypes(), 253edd9515bSthomasraoux [](Type t) { return t.isa<VectorType>(); }); 254edd9515bSthomasraoux }; 255edd9515bSthomasraoux SetVector<Operation *> opToConvert; 256edd9515bSthomasraoux op->walk([&](vector::ContractionOp contract) { 257edd9515bSthomasraoux if (opToConvert.contains(contract.getOperation())) 258edd9515bSthomasraoux return; 259edd9515bSthomasraoux SetVector<Operation *> dependentOps = 260e7969240SThomas Raoux getSliceContract(contract, hasVectorDest, hasVectorSrc); 261edd9515bSthomasraoux // If any instruction cannot use MMA matrix type drop the whole 262e7969240SThomas Raoux // chain. MMA matrix are stored in an opaque type so they cannot be used 263edd9515bSthomasraoux // by all operations. 264*1ca772edSChristopher Bate if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) { 265*1ca772edSChristopher Bate return !supportsMMaMatrixType(op, useNvGpu); 266*1ca772edSChristopher Bate })) 267edd9515bSthomasraoux return; 268edd9515bSthomasraoux opToConvert.insert(dependentOps.begin(), dependentOps.end()); 269edd9515bSthomasraoux }); 270e7969240SThomas Raoux // Sort the operations so that we can convert them in topological order. 271e7969240SThomas Raoux return topologicalSort(opToConvert); 272edd9515bSthomasraoux } 273edd9515bSthomasraoux 274edd9515bSthomasraoux namespace { 275edd9515bSthomasraoux // Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted 276edd9515bSthomasraoux // to MMA matmul. 277edd9515bSthomasraoux struct PrepareContractToGPUMMA 278edd9515bSthomasraoux : public OpRewritePattern<vector::ContractionOp> { 279edd9515bSthomasraoux using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 280edd9515bSthomasraoux 281edd9515bSthomasraoux LogicalResult matchAndRewrite(vector::ContractionOp op, 282edd9515bSthomasraoux PatternRewriter &rewriter) const override { 283edd9515bSthomasraoux Location loc = op.getLoc(); 2847c38fd60SJacques Pienaar Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc(); 285edd9515bSthomasraoux 286edd9515bSthomasraoux // Set up the parallel/reduction structure in right form. 287edd9515bSthomasraoux using MapList = ArrayRef<ArrayRef<AffineExpr>>; 288edd9515bSthomasraoux auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; 289edd9515bSthomasraoux AffineExpr m, n, k; 290edd9515bSthomasraoux bindDims(rewriter.getContext(), m, n, k); 291edd9515bSthomasraoux static constexpr std::array<int64_t, 2> perm = {1, 0}; 2927c38fd60SJacques Pienaar auto iteratorTypes = op.getIteratorTypes().getValue(); 293edd9515bSthomasraoux SmallVector<AffineMap, 4> maps = op.getIndexingMaps(); 294edd9515bSthomasraoux if (!(isParallelIterator(iteratorTypes[0]) && 295edd9515bSthomasraoux isParallelIterator(iteratorTypes[1]) && 296edd9515bSthomasraoux isReductionIterator(iteratorTypes[2]))) 297edd9515bSthomasraoux return failure(); 298edd9515bSthomasraoux // 299edd9515bSthomasraoux // Two outer parallel, one inner reduction (matmat flavor). 300edd9515bSthomasraoux // 301edd9515bSthomasraoux if (maps == infer({{m, k}, {k, n}, {m, n}})) { 302edd9515bSthomasraoux // This is the classical row-major matmul, nothing to do. 303edd9515bSthomasraoux return failure(); 304edd9515bSthomasraoux } 305edd9515bSthomasraoux if (maps == infer({{m, k}, {n, k}, {m, n}})) { 306edd9515bSthomasraoux rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 307edd9515bSthomasraoux } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { 308edd9515bSthomasraoux lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 309edd9515bSthomasraoux } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { 310edd9515bSthomasraoux rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 311edd9515bSthomasraoux lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 312edd9515bSthomasraoux } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { 313edd9515bSthomasraoux std::swap(rhs, lhs); 314edd9515bSthomasraoux rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 315edd9515bSthomasraoux lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 316edd9515bSthomasraoux } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { 317edd9515bSthomasraoux std::swap(rhs, lhs); 318edd9515bSthomasraoux rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 319edd9515bSthomasraoux } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { 320edd9515bSthomasraoux std::swap(lhs, rhs); 321edd9515bSthomasraoux lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 322edd9515bSthomasraoux } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { 323edd9515bSthomasraoux std::swap(lhs, rhs); 324edd9515bSthomasraoux } else { 325edd9515bSthomasraoux return failure(); 326edd9515bSthomasraoux } 327edd9515bSthomasraoux rewriter.replaceOpWithNewOp<vector::ContractionOp>( 328edd9515bSthomasraoux op, lhs, rhs, res, 329edd9515bSthomasraoux rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})), 3307c38fd60SJacques Pienaar op.getIteratorTypes()); 331edd9515bSthomasraoux return success(); 332edd9515bSthomasraoux } 333edd9515bSthomasraoux }; 334edd9515bSthomasraoux 335edd9515bSthomasraoux // Merge transpose op into the transfer read op. Transpose are not supported on 336edd9515bSthomasraoux // MMA types but MMA load can transpose the matrix when loading. 337edd9515bSthomasraoux struct CombineTransferReadOpTranspose final 338edd9515bSthomasraoux : public OpRewritePattern<vector::TransposeOp> { 339edd9515bSthomasraoux using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 340edd9515bSthomasraoux 341edd9515bSthomasraoux LogicalResult matchAndRewrite(vector::TransposeOp op, 342edd9515bSthomasraoux PatternRewriter &rewriter) const override { 3437c38fd60SJacques Pienaar auto transferReadOp = 3447c38fd60SJacques Pienaar op.getVector().getDefiningOp<vector::TransferReadOp>(); 345edd9515bSthomasraoux if (!transferReadOp) 346edd9515bSthomasraoux return failure(); 347c537a943SNicolas Vasilache 348c537a943SNicolas Vasilache // TODO: support 0-d corner case. 349c537a943SNicolas Vasilache if (transferReadOp.getTransferRank() == 0) 350c537a943SNicolas Vasilache return failure(); 351c537a943SNicolas Vasilache 3527c38fd60SJacques Pienaar if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim()) 353edd9515bSthomasraoux return failure(); 354edd9515bSthomasraoux SmallVector<int64_t, 2> perm; 355edd9515bSthomasraoux op.getTransp(perm); 356edd9515bSthomasraoux SmallVector<unsigned, 2> permU; 357edd9515bSthomasraoux for (int64_t o : perm) 358edd9515bSthomasraoux permU.push_back(unsigned(o)); 359edd9515bSthomasraoux AffineMap permutationMap = 360edd9515bSthomasraoux AffineMap::getPermutationMap(permU, op.getContext()); 3617c38fd60SJacques Pienaar AffineMap newMap = 3627c38fd60SJacques Pienaar permutationMap.compose(transferReadOp.getPermutationMap()); 363edd9515bSthomasraoux rewriter.replaceOpWithNewOp<vector::TransferReadOp>( 3647c38fd60SJacques Pienaar op, op.getType(), transferReadOp.getSource(), 3657c38fd60SJacques Pienaar transferReadOp.getIndices(), AffineMapAttr::get(newMap), 3667c38fd60SJacques Pienaar transferReadOp.getPadding(), transferReadOp.getMask(), 3677c38fd60SJacques Pienaar transferReadOp.getInBoundsAttr()); 368edd9515bSthomasraoux return success(); 369edd9515bSthomasraoux } 370edd9515bSthomasraoux }; 371edd9515bSthomasraoux 372edd9515bSthomasraoux } // namespace 373edd9515bSthomasraoux 374edd9515bSthomasraoux // MMA types have different layout based on how they are used in matmul ops. 3756413226dSthomasraoux // Figure the right layout to use by looking at op uses. 376edd9515bSthomasraoux // TODO: Change the GPU dialect to abstract the layout at the this level and 377edd9515bSthomasraoux // only care about it during lowering to NVVM. 3786413226dSthomasraoux template <typename OpTy> 3796413226dSthomasraoux static const char *inferFragType(OpTy op) { 380edd9515bSthomasraoux for (Operation *users : op->getUsers()) { 381edd9515bSthomasraoux auto contract = dyn_cast<vector::ContractionOp>(users); 382edd9515bSthomasraoux if (!contract) 383edd9515bSthomasraoux continue; 3847c38fd60SJacques Pienaar if (contract.getLhs() == op.getResult()) 385edd9515bSthomasraoux return "AOp"; 3867c38fd60SJacques Pienaar if (contract.getRhs() == op.getResult()) 387edd9515bSthomasraoux return "BOp"; 388edd9515bSthomasraoux } 389edd9515bSthomasraoux return "COp"; 390edd9515bSthomasraoux } 391edd9515bSthomasraoux 392edd9515bSthomasraoux static void convertTransferReadOp(vector::TransferReadOp op, 393edd9515bSthomasraoux llvm::DenseMap<Value, Value> &valueMapping) { 394c537a943SNicolas Vasilache assert(op.getTransferRank() > 0 && "unexpected 0-d transfer"); 395*1ca772edSChristopher Bate assert(transferReadSupportsMMAMatrixType(op, /*useNvGpu=*/false)); 396edd9515bSthomasraoux Optional<int64_t> stride = 397edd9515bSthomasraoux getMemrefConstantHorizontalStride(op.getShapedType()); 3987c38fd60SJacques Pienaar AffineMap map = op.getPermutationMap(); 399e7969240SThomas Raoux // Handle broadcast by setting the stride to 0. 400e7969240SThomas Raoux if (map.getResult(0).isa<AffineConstantExpr>()) { 401e7969240SThomas Raoux assert(map.getResult(0).cast<AffineConstantExpr>().getValue() == 0); 402e7969240SThomas Raoux stride = 0; 403e7969240SThomas Raoux } 404edd9515bSthomasraoux assert(stride); 405edd9515bSthomasraoux const char *fragType = inferFragType(op); 406edd9515bSthomasraoux gpu::MMAMatrixType type = 407edd9515bSthomasraoux gpu::MMAMatrixType::get(op.getVectorType().getShape(), 408edd9515bSthomasraoux op.getVectorType().getElementType(), fragType); 409edd9515bSthomasraoux OpBuilder b(op); 410edd9515bSthomasraoux Value load = b.create<gpu::SubgroupMmaLoadMatrixOp>( 4117c38fd60SJacques Pienaar op.getLoc(), type, op.getSource(), op.getIndices(), 4127c38fd60SJacques Pienaar b.getIndexAttr(*stride)); 413edd9515bSthomasraoux valueMapping[op.getResult()] = load; 414edd9515bSthomasraoux } 415edd9515bSthomasraoux 416edd9515bSthomasraoux static void convertTransferWriteOp(vector::TransferWriteOp op, 417edd9515bSthomasraoux llvm::DenseMap<Value, Value> &valueMapping) { 418edd9515bSthomasraoux assert(transferWriteSupportsMMAMatrixType(op)); 419edd9515bSthomasraoux Optional<int64_t> stride = 420edd9515bSthomasraoux getMemrefConstantHorizontalStride(op.getShapedType()); 421edd9515bSthomasraoux assert(stride); 422edd9515bSthomasraoux OpBuilder b(op); 4237c38fd60SJacques Pienaar Value matrix = valueMapping.find(op.getVector())->second; 4247c38fd60SJacques Pienaar b.create<gpu::SubgroupMmaStoreMatrixOp>(op.getLoc(), matrix, op.getSource(), 4257c38fd60SJacques Pienaar op.getIndices(), 4267c38fd60SJacques Pienaar b.getIndexAttr(*stride)); 427edd9515bSthomasraoux op.erase(); 428edd9515bSthomasraoux } 429edd9515bSthomasraoux 430*1ca772edSChristopher Bate /// Returns the vector type which represents a matrix fragment. 431*1ca772edSChristopher Bate static VectorType 432*1ca772edSChristopher Bate getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo ®Info) { 433*1ca772edSChristopher Bate SmallVector<int64_t> shape{regInfo.numRegistersPerFragment, 434*1ca772edSChristopher Bate regInfo.elementsPerRegister}; 435*1ca772edSChristopher Bate Type elType = regInfo.registerLLVMType; 436*1ca772edSChristopher Bate if (auto vecType = elType.dyn_cast<VectorType>()) 437*1ca772edSChristopher Bate elType = vecType.getElementType(); 438*1ca772edSChristopher Bate return VectorType::get(shape, elType); 439*1ca772edSChristopher Bate } 440*1ca772edSChristopher Bate 441*1ca772edSChristopher Bate /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. 442*1ca772edSChristopher Bate static LogicalResult 443*1ca772edSChristopher Bate convertConstantOpMmaSync(arith::ConstantOp op, 444*1ca772edSChristopher Bate llvm::DenseMap<Value, Value> &valueMapping) { 445*1ca772edSChristopher Bate OpBuilder b(op); 446*1ca772edSChristopher Bate FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 447*1ca772edSChristopher Bate nvgpu::getWarpMatrixInfo(op); 448*1ca772edSChristopher Bate if (failed(warpMatrixInfo)) 449*1ca772edSChristopher Bate return failure(); 450*1ca772edSChristopher Bate 451*1ca772edSChristopher Bate FailureOr<nvgpu::FragmentElementInfo> regInfo = 452*1ca772edSChristopher Bate nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 453*1ca772edSChristopher Bate if (failed(regInfo)) 454*1ca772edSChristopher Bate return failure(); 455*1ca772edSChristopher Bate 456*1ca772edSChristopher Bate VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 457*1ca772edSChristopher Bate auto dense = op.getValue().dyn_cast<SplatElementsAttr>(); 458*1ca772edSChristopher Bate if (!dense) 459*1ca772edSChristopher Bate return failure(); 460*1ca772edSChristopher Bate Value result = b.create<arith::ConstantOp>( 461*1ca772edSChristopher Bate op.getLoc(), vectorType, 462*1ca772edSChristopher Bate DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>())); 463*1ca772edSChristopher Bate valueMapping[op.getResult()] = result; 464*1ca772edSChristopher Bate return success(); 465*1ca772edSChristopher Bate } 466*1ca772edSChristopher Bate 467*1ca772edSChristopher Bate static LogicalResult 468*1ca772edSChristopher Bate creatLdMatrixCompatibleLoads(vector::TransferReadOp op, OpBuilder &builder, 469*1ca772edSChristopher Bate llvm::DenseMap<Value, Value> &valueMapping) { 470*1ca772edSChristopher Bate Location loc = op->getLoc(); 471*1ca772edSChristopher Bate 472*1ca772edSChristopher Bate FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 473*1ca772edSChristopher Bate nvgpu::getWarpMatrixInfo(op); 474*1ca772edSChristopher Bate if (failed(warpMatrixInfo)) 475*1ca772edSChristopher Bate return failure(); 476*1ca772edSChristopher Bate 477*1ca772edSChristopher Bate FailureOr<nvgpu::FragmentElementInfo> regInfo = 478*1ca772edSChristopher Bate nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 479*1ca772edSChristopher Bate if (failed(regInfo)) 480*1ca772edSChristopher Bate return failure(); 481*1ca772edSChristopher Bate 482*1ca772edSChristopher Bate FailureOr<nvgpu::LdMatrixParams> params = nvgpu::getLdMatrixParams( 483*1ca772edSChristopher Bate *warpMatrixInfo, 484*1ca772edSChristopher Bate /*transpose=*/!op.getPermutationMap().isMinorIdentity()); 485*1ca772edSChristopher Bate if (failed(params)) { 486*1ca772edSChristopher Bate return op->emitError() 487*1ca772edSChristopher Bate << "failed to convert vector.transfer_read to ldmatrix; this op " 488*1ca772edSChristopher Bate "likely " 489*1ca772edSChristopher Bate "should not be converted to a nvgpu.ldmatrix call."; 490*1ca772edSChristopher Bate } 491*1ca772edSChristopher Bate 492*1ca772edSChristopher Bate // Adjust the load offset. 493*1ca772edSChristopher Bate auto laneId = builder.create<gpu::LaneIdOp>(loc); 494*1ca772edSChristopher Bate FailureOr<AffineMap> offsets = 495*1ca772edSChristopher Bate nvgpu::getLaneIdToLdMatrixMatrixCoord(loc, builder, *params); 496*1ca772edSChristopher Bate if (failed(offsets)) 497*1ca772edSChristopher Bate return failure(); 498*1ca772edSChristopher Bate 499*1ca772edSChristopher Bate VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 500*1ca772edSChristopher Bate 501*1ca772edSChristopher Bate SmallVector<Value, 4> indices; 502*1ca772edSChristopher Bate getXferIndices<vector::TransferReadOp>(builder, op, *offsets, {laneId}, 503*1ca772edSChristopher Bate indices); 504*1ca772edSChristopher Bate nvgpu::LdMatrixOp newOp = builder.create<nvgpu::LdMatrixOp>( 505*1ca772edSChristopher Bate loc, vectorType, op.getSource(), indices, 506*1ca772edSChristopher Bate !op.getPermutationMap().isMinorIdentity(), params->numTiles); 507*1ca772edSChristopher Bate valueMapping[op] = newOp->getResult(0); 508*1ca772edSChristopher Bate return success(); 509*1ca772edSChristopher Bate } 510*1ca772edSChristopher Bate 511*1ca772edSChristopher Bate static LogicalResult 512*1ca772edSChristopher Bate createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder, 513*1ca772edSChristopher Bate llvm::DenseMap<Value, Value> &valueMapping) { 514*1ca772edSChristopher Bate Location loc = op.getLoc(); 515*1ca772edSChristopher Bate FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 516*1ca772edSChristopher Bate nvgpu::getWarpMatrixInfo(op); 517*1ca772edSChristopher Bate if (failed(warpMatrixInfo)) 518*1ca772edSChristopher Bate return failure(); 519*1ca772edSChristopher Bate FailureOr<nvgpu::FragmentElementInfo> regInfo = 520*1ca772edSChristopher Bate nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 521*1ca772edSChristopher Bate if (failed(regInfo)) { 522*1ca772edSChristopher Bate op->emitError() << "Failed to deduce register fragment type during " 523*1ca772edSChristopher Bate "conversion to distributed non-ldmatrix compatible load"; 524*1ca772edSChristopher Bate return failure(); 525*1ca772edSChristopher Bate } 526*1ca772edSChristopher Bate 527*1ca772edSChristopher Bate NVVM::MMALayout targetLayout = 528*1ca772edSChristopher Bate warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B 529*1ca772edSChristopher Bate ? NVVM::MMALayout::col 530*1ca772edSChristopher Bate : NVVM::MMALayout::row; 531*1ca772edSChristopher Bate 532*1ca772edSChristopher Bate Value laneId = builder.create<gpu::LaneIdOp>(loc); 533*1ca772edSChristopher Bate SmallVector<Value, 4> elements; 534*1ca772edSChristopher Bate 535*1ca772edSChristopher Bate // This is the individual element type. 536*1ca772edSChristopher Bate Type loadedElType = regInfo->registerLLVMType; 537*1ca772edSChristopher Bate VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 538*1ca772edSChristopher Bate 539*1ca772edSChristopher Bate Value fill = builder.create<arith::ConstantOp>( 540*1ca772edSChristopher Bate op.getLoc(), vectorType.getElementType(), 541*1ca772edSChristopher Bate builder.getZeroAttr(vectorType.getElementType())); 542*1ca772edSChristopher Bate Value result = builder.create<vector::SplatOp>(op.getLoc(), fill, vectorType); 543*1ca772edSChristopher Bate 544*1ca772edSChristopher Bate bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity(); 545*1ca772edSChristopher Bate 546*1ca772edSChristopher Bate // Vectorized loads. 547*1ca772edSChristopher Bate if (!isTransposeLoad && targetLayout == NVVM::MMALayout::row) { 548*1ca772edSChristopher Bate if (!loadedElType.isa<VectorType>()) { 549*1ca772edSChristopher Bate loadedElType = VectorType::get({1}, loadedElType); 550*1ca772edSChristopher Bate } 551*1ca772edSChristopher Bate 552*1ca772edSChristopher Bate for (int i = 0; i < vectorType.getShape()[0]; i++) { 553*1ca772edSChristopher Bate FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord( 554*1ca772edSChristopher Bate op.getLoc(), builder, *warpMatrixInfo); 555*1ca772edSChristopher Bate if (failed(coords)) 556*1ca772edSChristopher Bate return failure(); 557*1ca772edSChristopher Bate Value logicalValueId = builder.create<arith::ConstantOp>( 558*1ca772edSChristopher Bate loc, builder.getIndexType(), 559*1ca772edSChristopher Bate builder.getIndexAttr(i * regInfo->elementsPerRegister)); 560*1ca772edSChristopher Bate SmallVector<Value, 4> newIndices; 561*1ca772edSChristopher Bate getXferIndices<vector::TransferReadOp>( 562*1ca772edSChristopher Bate builder, op, *coords, {laneId, logicalValueId}, newIndices); 563*1ca772edSChristopher Bate 564*1ca772edSChristopher Bate Value el = builder.create<vector::LoadOp>(loc, loadedElType, 565*1ca772edSChristopher Bate op.getSource(), newIndices); 566*1ca772edSChristopher Bate result = builder.create<vector::InsertOp>(loc, el, result, 567*1ca772edSChristopher Bate builder.getI64ArrayAttr(i)); 568*1ca772edSChristopher Bate } 569*1ca772edSChristopher Bate } else if (isTransposeLoad && targetLayout == NVVM::MMALayout::col) { 570*1ca772edSChristopher Bate if (auto vecType = loadedElType.dyn_cast<VectorType>()) { 571*1ca772edSChristopher Bate loadedElType = vecType.getElementType(); 572*1ca772edSChristopher Bate } 573*1ca772edSChristopher Bate // Load each element individually. 574*1ca772edSChristopher Bate for (int i = 0; i < vectorType.getShape()[0]; i++) { 575*1ca772edSChristopher Bate for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1]; 576*1ca772edSChristopher Bate innerIdx++) { 577*1ca772edSChristopher Bate 578*1ca772edSChristopher Bate Value logicalValueId = builder.create<arith::ConstantOp>( 579*1ca772edSChristopher Bate loc, builder.getIndexType(), 580*1ca772edSChristopher Bate builder.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx)); 581*1ca772edSChristopher Bate FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord( 582*1ca772edSChristopher Bate op.getLoc(), builder, *warpMatrixInfo); 583*1ca772edSChristopher Bate if (failed(coords)) 584*1ca772edSChristopher Bate return failure(); 585*1ca772edSChristopher Bate 586*1ca772edSChristopher Bate SmallVector<Value, 4> newIndices; 587*1ca772edSChristopher Bate getXferIndices<vector::TransferReadOp>( 588*1ca772edSChristopher Bate builder, op, *coords, {laneId, logicalValueId}, newIndices); 589*1ca772edSChristopher Bate Value el = builder.create<memref::LoadOp>(op.getLoc(), loadedElType, 590*1ca772edSChristopher Bate op.getSource(), newIndices); 591*1ca772edSChristopher Bate result = builder.create<vector::InsertOp>( 592*1ca772edSChristopher Bate op.getLoc(), el, result, builder.getI64ArrayAttr({i, innerIdx})); 593*1ca772edSChristopher Bate } 594*1ca772edSChristopher Bate } 595*1ca772edSChristopher Bate } else { 596*1ca772edSChristopher Bate return failure(); 597*1ca772edSChristopher Bate } 598*1ca772edSChristopher Bate 599*1ca772edSChristopher Bate valueMapping[op.getResult()] = result; 600*1ca772edSChristopher Bate return success(); 601*1ca772edSChristopher Bate } 602*1ca772edSChristopher Bate 603*1ca772edSChristopher Bate /// Converts a `vector.transfer_read` operation directly to either a 604*1ca772edSChristopher Bate /// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be 605*1ca772edSChristopher Bate /// used when converting to `nvgpu.mma.sync` operations. 606*1ca772edSChristopher Bate static LogicalResult 607*1ca772edSChristopher Bate convertTransferReadToLoads(vector::TransferReadOp op, 608*1ca772edSChristopher Bate llvm::DenseMap<Value, Value> &valueMapping) { 609*1ca772edSChristopher Bate OpBuilder b(op); 610*1ca772edSChristopher Bate 611*1ca772edSChristopher Bate FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 612*1ca772edSChristopher Bate nvgpu::getWarpMatrixInfo(op); 613*1ca772edSChristopher Bate if (failed(warpMatrixInfo)) 614*1ca772edSChristopher Bate return failure(); 615*1ca772edSChristopher Bate 616*1ca772edSChristopher Bate bool isLdMatrixCompatible = 617*1ca772edSChristopher Bate op.getSource().getType().cast<MemRefType>().getMemorySpaceAsInt() == 3 && 618*1ca772edSChristopher Bate nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128; 619*1ca772edSChristopher Bate 620*1ca772edSChristopher Bate VectorType vecTy = op.getVectorType(); 621*1ca772edSChristopher Bate int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth(); 622*1ca772edSChristopher Bate 623*1ca772edSChristopher Bate // When we are transposing the B operand, ldmatrix will only work if we have 624*1ca772edSChristopher Bate // at least 8 rows to read and the width to read for the transpose is 128 625*1ca772edSChristopher Bate // bits. 626*1ca772edSChristopher Bate if (!op.getPermutationMap().isMinorIdentity() && 627*1ca772edSChristopher Bate (vecTy.getDimSize(1) < 8 || vecTy.getDimSize(0) * bitWidth < 128)) 628*1ca772edSChristopher Bate isLdMatrixCompatible = false; 629*1ca772edSChristopher Bate 630*1ca772edSChristopher Bate if (!isLdMatrixCompatible) 631*1ca772edSChristopher Bate return createNonLdMatrixLoads(op, b, valueMapping); 632*1ca772edSChristopher Bate 633*1ca772edSChristopher Bate return creatLdMatrixCompatibleLoads(op, b, valueMapping); 634*1ca772edSChristopher Bate } 635*1ca772edSChristopher Bate 636*1ca772edSChristopher Bate static LogicalResult 637*1ca772edSChristopher Bate convertTransferWriteToStores(vector::TransferWriteOp op, 638*1ca772edSChristopher Bate llvm::DenseMap<Value, Value> &valueMapping) { 639*1ca772edSChristopher Bate OpBuilder b(op); 640*1ca772edSChristopher Bate Location loc = op->getLoc(); 641*1ca772edSChristopher Bate Value matrix = valueMapping.find(op.getVector())->second; 642*1ca772edSChristopher Bate 643*1ca772edSChristopher Bate FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 644*1ca772edSChristopher Bate nvgpu::getWarpMatrixInfo(op); 645*1ca772edSChristopher Bate if (failed(warpMatrixInfo)) 646*1ca772edSChristopher Bate return failure(); 647*1ca772edSChristopher Bate FailureOr<nvgpu::FragmentElementInfo> regInfo = 648*1ca772edSChristopher Bate nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 649*1ca772edSChristopher Bate if (failed(regInfo)) 650*1ca772edSChristopher Bate return failure(); 651*1ca772edSChristopher Bate 652*1ca772edSChristopher Bate VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 653*1ca772edSChristopher Bate Value laneId = b.create<gpu::LaneIdOp>(loc); 654*1ca772edSChristopher Bate 655*1ca772edSChristopher Bate for (unsigned i = 0; i < vectorType.getShape()[0]; i++) { 656*1ca772edSChristopher Bate Value logicalValueId = b.create<arith::ConstantOp>( 657*1ca772edSChristopher Bate loc, b.getIndexType(), 658*1ca772edSChristopher Bate b.getIndexAttr(i * regInfo->elementsPerRegister)); 659*1ca772edSChristopher Bate FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord( 660*1ca772edSChristopher Bate op.getLoc(), b, *warpMatrixInfo); 661*1ca772edSChristopher Bate if (failed(coords)) 662*1ca772edSChristopher Bate return failure(); 663*1ca772edSChristopher Bate 664*1ca772edSChristopher Bate Value el = b.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i}); 665*1ca772edSChristopher Bate SmallVector<Value, 4> newIndices; 666*1ca772edSChristopher Bate getXferIndices<vector::TransferWriteOp>( 667*1ca772edSChristopher Bate b, op, *coords, {laneId, logicalValueId}, newIndices); 668*1ca772edSChristopher Bate b.create<vector::StoreOp>(loc, el, op.getSource(), newIndices); 669*1ca772edSChristopher Bate } 670*1ca772edSChristopher Bate op->erase(); 671*1ca772edSChristopher Bate return success(); 672*1ca772edSChristopher Bate } 673*1ca772edSChristopher Bate 674edd9515bSthomasraoux static void convertContractOp(vector::ContractionOp op, 675edd9515bSthomasraoux llvm::DenseMap<Value, Value> &valueMapping) { 676edd9515bSthomasraoux OpBuilder b(op); 6777c38fd60SJacques Pienaar Value opA = valueMapping.find(op.getLhs())->second; 6787c38fd60SJacques Pienaar Value opB = valueMapping.find(op.getRhs())->second; 6797c38fd60SJacques Pienaar Value opC = valueMapping.find(op.getAcc())->second; 680edd9515bSthomasraoux Value matmul = b.create<gpu::SubgroupMmaComputeOp>(op.getLoc(), opC.getType(), 681edd9515bSthomasraoux opA, opB, opC); 682edd9515bSthomasraoux valueMapping[op.getResult()] = matmul; 683edd9515bSthomasraoux } 684edd9515bSthomasraoux 685*1ca772edSChristopher Bate static LogicalResult 686*1ca772edSChristopher Bate convertContractOpToMmaSync(vector::ContractionOp op, 687*1ca772edSChristopher Bate llvm::DenseMap<Value, Value> &valueMapping) { 688*1ca772edSChristopher Bate OpBuilder b(op); 689*1ca772edSChristopher Bate Value opA = valueMapping.find(op.getLhs())->second; 690*1ca772edSChristopher Bate Value opB = valueMapping.find(op.getRhs())->second; 691*1ca772edSChristopher Bate Value opC = valueMapping.find(op.getAcc())->second; 692*1ca772edSChristopher Bate int64_t m = op.getLhs().getType().cast<VectorType>().getShape()[0]; 693*1ca772edSChristopher Bate int64_t n = op.getRhs().getType().cast<VectorType>().getShape()[0]; 694*1ca772edSChristopher Bate int64_t k = op.getLhs().getType().cast<VectorType>().getShape()[1]; 695*1ca772edSChristopher Bate Value matmul = b.create<nvgpu::MmaSyncOp>( 696*1ca772edSChristopher Bate op.getLoc(), opC.getType(), opA, opB, opC, b.getI64ArrayAttr({m, n, k})); 697*1ca772edSChristopher Bate valueMapping[op.getResult()] = matmul; 698*1ca772edSChristopher Bate return success(); 699*1ca772edSChristopher Bate } 700*1ca772edSChristopher Bate 7016413226dSthomasraoux /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. 702a54f4eaeSMogball static void convertConstantOp(arith::ConstantOp op, 7036413226dSthomasraoux llvm::DenseMap<Value, Value> &valueMapping) { 7046413226dSthomasraoux assert(constantSupportsMMAMatrixType(op)); 7056413226dSthomasraoux OpBuilder b(op); 706937e40a8SRiver Riddle Attribute splat = 707937e40a8SRiver Riddle op.getValue().cast<SplatElementsAttr>().getSplatValue<Attribute>(); 7086413226dSthomasraoux auto scalarConstant = 709a54f4eaeSMogball b.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat); 7106413226dSthomasraoux const char *fragType = inferFragType(op); 7116413226dSthomasraoux auto vecType = op.getType().cast<VectorType>(); 7126413226dSthomasraoux gpu::MMAMatrixType type = gpu::MMAMatrixType::get( 7136413226dSthomasraoux vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); 7146413226dSthomasraoux auto matrix = b.create<gpu::SubgroupMmaConstantMatrixOp>(op.getLoc(), type, 7156413226dSthomasraoux scalarConstant); 7166413226dSthomasraoux valueMapping[op.getResult()] = matrix; 7176413226dSthomasraoux } 7186413226dSthomasraoux 71943928419Sthomasraoux /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op. 72043928419Sthomasraoux static void convertBroadcastOp(vector::BroadcastOp op, 72143928419Sthomasraoux llvm::DenseMap<Value, Value> &valueMapping) { 72243928419Sthomasraoux assert(broadcastSupportsMMAMatrixType(op)); 72343928419Sthomasraoux OpBuilder b(op); 72443928419Sthomasraoux const char *fragType = inferFragType(op); 72543928419Sthomasraoux auto vecType = op.getVectorType(); 72643928419Sthomasraoux gpu::MMAMatrixType type = gpu::MMAMatrixType::get( 72743928419Sthomasraoux vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); 72843928419Sthomasraoux auto matrix = b.create<gpu::SubgroupMmaConstantMatrixOp>(op.getLoc(), type, 7297c38fd60SJacques Pienaar op.getSource()); 73043928419Sthomasraoux valueMapping[op.getResult()] = matrix; 73143928419Sthomasraoux } 73243928419Sthomasraoux 7331a865592Sthomasraoux // Replace ForOp with a new ForOp with extra operands. The YieldOp is not 7341a865592Sthomasraoux // updated and needs to be updated separatly for the loop to be correct. 7351a865592Sthomasraoux static scf::ForOp replaceForOpWithNewSignature(OpBuilder &b, scf::ForOp loop, 7361a865592Sthomasraoux ValueRange newIterOperands) { 7371a865592Sthomasraoux // Create a new loop before the existing one, with the extra operands. 7381a865592Sthomasraoux OpBuilder::InsertionGuard g(b); 7391a865592Sthomasraoux b.setInsertionPoint(loop); 7401a865592Sthomasraoux auto operands = llvm::to_vector<4>(loop.getIterOperands()); 7411a865592Sthomasraoux operands.append(newIterOperands.begin(), newIterOperands.end()); 7421a865592Sthomasraoux scf::ForOp newLoop = 743c0342a2dSJacques Pienaar b.create<scf::ForOp>(loop.getLoc(), loop.getLowerBound(), 744c0342a2dSJacques Pienaar loop.getUpperBound(), loop.getStep(), operands); 7451a865592Sthomasraoux newLoop.getBody()->erase(); 7461a865592Sthomasraoux newLoop.getLoopBody().getBlocks().splice( 7471a865592Sthomasraoux newLoop.getLoopBody().getBlocks().begin(), 7481a865592Sthomasraoux loop.getLoopBody().getBlocks()); 749e084679fSRiver Riddle for (Value operand : newIterOperands) 750e084679fSRiver Riddle newLoop.getBody()->addArgument(operand.getType(), operand.getLoc()); 7511a865592Sthomasraoux 7521a865592Sthomasraoux for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front( 7531a865592Sthomasraoux loop.getNumResults()))) 7541a865592Sthomasraoux std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); 7551a865592Sthomasraoux loop.erase(); 7561a865592Sthomasraoux return newLoop; 7571a865592Sthomasraoux } 7581a865592Sthomasraoux 7591a865592Sthomasraoux static void convertForOp(scf::ForOp op, 7601a865592Sthomasraoux llvm::DenseMap<Value, Value> &valueMapping) { 7611a865592Sthomasraoux SmallVector<Value> newOperands; 7621a865592Sthomasraoux SmallVector<std::pair<size_t, size_t>> argMapping; 763e4853be2SMehdi Amini for (const auto &operand : llvm::enumerate(op.getIterOperands())) { 7641a865592Sthomasraoux auto it = valueMapping.find(operand.value()); 7651a865592Sthomasraoux if (it == valueMapping.end()) 7661a865592Sthomasraoux continue; 7671a865592Sthomasraoux argMapping.push_back(std::make_pair( 7681a865592Sthomasraoux operand.index(), op.getNumIterOperands() + newOperands.size())); 7691a865592Sthomasraoux newOperands.push_back(it->second); 7701a865592Sthomasraoux } 7711a865592Sthomasraoux OpBuilder b(op); 7721a865592Sthomasraoux scf::ForOp newForOp = replaceForOpWithNewSignature(b, op, newOperands); 7731a865592Sthomasraoux Block &loopBody = *newForOp.getBody(); 7741a865592Sthomasraoux for (auto mapping : argMapping) { 7751a865592Sthomasraoux valueMapping[newForOp.getResult(mapping.first)] = 7761a865592Sthomasraoux newForOp.getResult(mapping.second); 7771a865592Sthomasraoux valueMapping[loopBody.getArgument(mapping.first + 7781a865592Sthomasraoux newForOp.getNumInductionVars())] = 7791a865592Sthomasraoux loopBody.getArgument(mapping.second + newForOp.getNumInductionVars()); 7801a865592Sthomasraoux } 7811a865592Sthomasraoux } 7821a865592Sthomasraoux 7831a865592Sthomasraoux static void convertYieldOp(scf::YieldOp op, 7841a865592Sthomasraoux llvm::DenseMap<Value, Value> &valueMapping) { 7851a865592Sthomasraoux OpBuilder b(op); 7861a865592Sthomasraoux auto loop = cast<scf::ForOp>(op->getParentOp()); 7871a865592Sthomasraoux auto yieldOperands = llvm::to_vector<4>(op.getOperands()); 788e4853be2SMehdi Amini for (const auto &operand : llvm::enumerate(op.getOperands())) { 7891a865592Sthomasraoux auto it = valueMapping.find(operand.value()); 7901a865592Sthomasraoux if (it == valueMapping.end()) 7911a865592Sthomasraoux continue; 7921a865592Sthomasraoux // Replace the yield of old value with the for op argument to make it easier 7931a865592Sthomasraoux // to remove the dead code. 7941a865592Sthomasraoux yieldOperands[operand.index()] = loop.getIterOperands()[operand.index()]; 7951a865592Sthomasraoux yieldOperands.push_back(it->second); 7961a865592Sthomasraoux } 7971a865592Sthomasraoux b.create<scf::YieldOp>(op.getLoc(), yieldOperands); 7981a865592Sthomasraoux op.erase(); 7991a865592Sthomasraoux } 8001a865592Sthomasraoux 8017fbb0678Sthomasraoux /// Convert an elementwise op to the equivalent elementwise op on MMA matrix. 8027fbb0678Sthomasraoux static void convertElementwiseOp(Operation *op, gpu::MMAElementwiseOp opType, 8037fbb0678Sthomasraoux llvm::DenseMap<Value, Value> &valueMapping) { 8047fbb0678Sthomasraoux OpBuilder b(op); 8057fbb0678Sthomasraoux SmallVector<Value> matrixOperands; 8067fbb0678Sthomasraoux for (Value operand : op->getOperands()) 8077fbb0678Sthomasraoux matrixOperands.push_back(valueMapping.find(operand)->second); 8087fbb0678Sthomasraoux Value newOp = b.create<gpu::SubgroupMmaElementwiseOp>( 8097fbb0678Sthomasraoux op->getLoc(), matrixOperands[0].getType(), matrixOperands, opType); 8107fbb0678Sthomasraoux valueMapping[op->getResult(0)] = newOp; 8117fbb0678Sthomasraoux } 8127fbb0678Sthomasraoux 813*1ca772edSChristopher Bate void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, 814*1ca772edSChristopher Bate bool useNvGpu) { 815*1ca772edSChristopher Bate if (!useNvGpu) { 816edd9515bSthomasraoux patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>( 817edd9515bSthomasraoux patterns.getContext()); 818*1ca772edSChristopher Bate return; 819*1ca772edSChristopher Bate } 820*1ca772edSChristopher Bate patterns 821*1ca772edSChristopher Bate .add<nvgpu::PrepareContractToGPUMMASync, CombineTransferReadOpTranspose>( 822*1ca772edSChristopher Bate patterns.getContext()); 823edd9515bSthomasraoux } 824edd9515bSthomasraoux 82547f175b0SRiver Riddle void mlir::convertVectorToMMAOps(Operation *rootOp) { 826*1ca772edSChristopher Bate SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/false); 827edd9515bSthomasraoux llvm::DenseMap<Value, Value> valueMapping; 828edd9515bSthomasraoux for (Operation *op : ops) { 829edd9515bSthomasraoux if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) { 830edd9515bSthomasraoux convertTransferReadOp(transferRead, valueMapping); 831edd9515bSthomasraoux } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) { 832edd9515bSthomasraoux convertTransferWriteOp(transferWrite, valueMapping); 833edd9515bSthomasraoux } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) { 834edd9515bSthomasraoux convertContractOp(contractOp, valueMapping); 835a54f4eaeSMogball } else if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) { 8366413226dSthomasraoux convertConstantOp(constantOp, valueMapping); 83743928419Sthomasraoux } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) { 83843928419Sthomasraoux convertBroadcastOp(broadcastOp, valueMapping); 8391a865592Sthomasraoux } else if (auto forOp = dyn_cast<scf::ForOp>(op)) { 8401a865592Sthomasraoux convertForOp(forOp, valueMapping); 8411a865592Sthomasraoux } else if (auto yiledOp = dyn_cast<scf::YieldOp>(op)) { 8421a865592Sthomasraoux convertYieldOp(yiledOp, valueMapping); 8437fbb0678Sthomasraoux } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) { 8447fbb0678Sthomasraoux convertElementwiseOp(op, *elementwiseType, valueMapping); 845edd9515bSthomasraoux } 846edd9515bSthomasraoux } 847edd9515bSthomasraoux } 848edd9515bSthomasraoux 849*1ca772edSChristopher Bate LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(Operation *rootOp) { 850*1ca772edSChristopher Bate SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/true); 851*1ca772edSChristopher Bate llvm::DenseMap<Value, Value> valueMapping; 852*1ca772edSChristopher Bate for (Operation *op : ops) { 853*1ca772edSChristopher Bate if (llvm::TypeSwitch<Operation *, LogicalResult>(op) 854*1ca772edSChristopher Bate .Case([&](vector::TransferReadOp transferReadOp) { 855*1ca772edSChristopher Bate return convertTransferReadToLoads(transferReadOp, valueMapping); 856*1ca772edSChristopher Bate }) 857*1ca772edSChristopher Bate .Case([&](vector::TransferWriteOp transferWriteOp) { 858*1ca772edSChristopher Bate return convertTransferWriteToStores(transferWriteOp, 859*1ca772edSChristopher Bate valueMapping); 860*1ca772edSChristopher Bate }) 861*1ca772edSChristopher Bate .Case([&](vector::ContractionOp contractionOp) { 862*1ca772edSChristopher Bate return convertContractOpToMmaSync(contractionOp, valueMapping); 863*1ca772edSChristopher Bate }) 864*1ca772edSChristopher Bate .Case([&](scf::ForOp forOp) { 865*1ca772edSChristopher Bate convertForOp(forOp, valueMapping); 866*1ca772edSChristopher Bate return success(); 867*1ca772edSChristopher Bate }) 868*1ca772edSChristopher Bate .Case([&](scf::YieldOp yieldOp) { 869*1ca772edSChristopher Bate convertYieldOp(yieldOp, valueMapping); 870*1ca772edSChristopher Bate return success(); 871*1ca772edSChristopher Bate }) 872*1ca772edSChristopher Bate .Case([&](arith::ConstantOp constOp) { 873*1ca772edSChristopher Bate return convertConstantOpMmaSync(constOp, valueMapping); 874*1ca772edSChristopher Bate }) 875*1ca772edSChristopher Bate .Default([&](Operation *op) { 876*1ca772edSChristopher Bate op->emitError() << "unhandled vector to mma type: " << *op; 877*1ca772edSChristopher Bate return failure(); 878*1ca772edSChristopher Bate }) 879*1ca772edSChristopher Bate .failed()) { 880*1ca772edSChristopher Bate op->emitError() << "Failed to convert op " << *op; 881*1ca772edSChristopher Bate return failure(); 882*1ca772edSChristopher Bate } 883*1ca772edSChristopher Bate } 884*1ca772edSChristopher Bate return success(); 885*1ca772edSChristopher Bate } 886*1ca772edSChristopher Bate 887edd9515bSthomasraoux namespace { 888edd9515bSthomasraoux 889edd9515bSthomasraoux struct ConvertVectorToGPUPass 890edd9515bSthomasraoux : public ConvertVectorToGPUBase<ConvertVectorToGPUPass> { 891*1ca772edSChristopher Bate 892*1ca772edSChristopher Bate explicit ConvertVectorToGPUPass(bool useNvGpu_) { 893*1ca772edSChristopher Bate useNvGpu.setValue(useNvGpu_); 894*1ca772edSChristopher Bate } 895*1ca772edSChristopher Bate 89641574554SRiver Riddle void runOnOperation() override { 89747f175b0SRiver Riddle RewritePatternSet patterns(&getContext()); 898*1ca772edSChristopher Bate populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue()); 899*1ca772edSChristopher Bate if (failed( 900*1ca772edSChristopher Bate applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) 901*1ca772edSChristopher Bate return signalPassFailure(); 902edd9515bSthomasraoux 903*1ca772edSChristopher Bate if (useNvGpu.getValue()) { 904*1ca772edSChristopher Bate if (failed(convertVectorToNVVMCompatibleMMASync(getOperation()))) 905*1ca772edSChristopher Bate return signalPassFailure(); 906*1ca772edSChristopher Bate } 907*1ca772edSChristopher Bate 908*1ca772edSChristopher Bate (void)convertVectorToMMAOps(getOperation()); 909edd9515bSthomasraoux } 910edd9515bSthomasraoux }; 911edd9515bSthomasraoux 912edd9515bSthomasraoux } // namespace 913edd9515bSthomasraoux 914*1ca772edSChristopher Bate std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) { 915*1ca772edSChristopher Bate return std::make_unique<ConvertVectorToGPUPass>(useNvGpu); 916edd9515bSthomasraoux } 917