1edd9515bSthomasraoux //===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===// 2edd9515bSthomasraoux // 3edd9515bSthomasraoux // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4edd9515bSthomasraoux // See https://llvm.org/LICENSE.txt for license information. 5edd9515bSthomasraoux // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6edd9515bSthomasraoux // 7edd9515bSthomasraoux //===----------------------------------------------------------------------===// 8edd9515bSthomasraoux // 9edd9515bSthomasraoux // This file implements lowering of vector operations to GPU dialect ops. 10edd9515bSthomasraoux // 11edd9515bSthomasraoux //===----------------------------------------------------------------------===// 12edd9515bSthomasraoux 13edd9515bSthomasraoux #include <type_traits> 14edd9515bSthomasraoux 151ca772edSChristopher Bate #include "NvGpuSupport.h" 16edd9515bSthomasraoux #include "mlir/Conversion/VectorToGPU/VectorToGPU.h" 17edd9515bSthomasraoux 18edd9515bSthomasraoux #include "../PassDetail.h" 19edd9515bSthomasraoux #include "mlir/Analysis/SliceAnalysis.h" 20a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 21d7ef488bSMogball #include "mlir/Dialect/GPU/IR/GPUDialect.h" 2266f878ceSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 2351b925dfSChristopher Bate #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" 24*8b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h" 25edd9515bSthomasraoux #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 2699ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h" 2799ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 28edd9515bSthomasraoux #include "mlir/IR/Builders.h" 29edd9515bSthomasraoux #include "mlir/Pass/Pass.h" 30edd9515bSthomasraoux #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 31edd9515bSthomasraoux #include "mlir/Transforms/Passes.h" 321ca772edSChristopher Bate #include "llvm/ADT/TypeSwitch.h" 33edd9515bSthomasraoux 34edd9515bSthomasraoux using namespace mlir; 35edd9515bSthomasraoux 361ca772edSChristopher Bate /// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an 371ca772edSChristopher Bate /// AffineMap representing offsets to apply to indices, the function fills 381ca772edSChristopher Bate /// `indices` with the original indices plus the offsets. The offsets are 391ca772edSChristopher Bate /// applied by taking into account the permutation map of the transfer op. If 401ca772edSChristopher Bate /// the `offsetMap` has dimension placeholders, those should be provided in 411ca772edSChristopher Bate /// `dimValues`. 421ca772edSChristopher Bate template <typename TransferOpType> 431ca772edSChristopher Bate static void getXferIndices(OpBuilder &b, TransferOpType xferOp, 441ca772edSChristopher Bate AffineMap offsetMap, ArrayRef<Value> dimValues, 451ca772edSChristopher Bate SmallVector<Value, 4> &indices) { 461ca772edSChristopher Bate indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end()); 471ca772edSChristopher Bate Location loc = xferOp.getLoc(); 481ca772edSChristopher Bate unsigned offsetsIdx = 0; 491ca772edSChristopher Bate for (auto expr : xferOp.getPermutationMap().getResults()) { 501ca772edSChristopher Bate if (auto dim = expr.template dyn_cast<AffineDimExpr>()) { 511ca772edSChristopher Bate Value prevIdx = indices[dim.getPosition()]; 521ca772edSChristopher Bate SmallVector<Value, 3> dims(dimValues.begin(), dimValues.end()); 531ca772edSChristopher Bate dims.push_back(prevIdx); 541ca772edSChristopher Bate AffineExpr d0 = b.getAffineDimExpr(offsetMap.getNumDims()); 551ca772edSChristopher Bate indices[dim.getPosition()] = makeComposedAffineApply( 561ca772edSChristopher Bate b, loc, d0 + offsetMap.getResult(offsetsIdx++), dims); 571ca772edSChristopher Bate continue; 581ca772edSChristopher Bate } 591ca772edSChristopher Bate } 601ca772edSChristopher Bate } 611ca772edSChristopher Bate 62edd9515bSthomasraoux // Return true if the contract op can be convert to MMA matmul. 631ca772edSChristopher Bate static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, 641ca772edSChristopher Bate bool useNvGpu) { 657c38fd60SJacques Pienaar if (llvm::size(contract.getMasks()) != 0) 66edd9515bSthomasraoux return false; 67edd9515bSthomasraoux 68edd9515bSthomasraoux using MapList = ArrayRef<ArrayRef<AffineExpr>>; 69edd9515bSthomasraoux auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; 70edd9515bSthomasraoux AffineExpr m, n, k; 71edd9515bSthomasraoux bindDims(contract.getContext(), m, n, k); 727c38fd60SJacques Pienaar auto iteratorTypes = contract.getIteratorTypes().getValue(); 73edd9515bSthomasraoux if (!(isParallelIterator(iteratorTypes[0]) && 74edd9515bSthomasraoux isParallelIterator(iteratorTypes[1]) && 75edd9515bSthomasraoux isReductionIterator(iteratorTypes[2]))) 76edd9515bSthomasraoux return false; 77edd9515bSthomasraoux 78edd9515bSthomasraoux // The contract needs to represent a matmul to be able to convert to 79edd9515bSthomasraoux // MMAMatrix matmul. 801ca772edSChristopher Bate if (!useNvGpu && 811ca772edSChristopher Bate contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}})) 821ca772edSChristopher Bate return false; 831ca772edSChristopher Bate if (useNvGpu && contract.getIndexingMaps() != infer({{m, k}, {n, k}, {m, n}})) 84edd9515bSthomasraoux return false; 85edd9515bSthomasraoux 86edd9515bSthomasraoux return true; 87edd9515bSthomasraoux } 88edd9515bSthomasraoux 89edd9515bSthomasraoux // Return the stide for the dimension 0 of |type| if it is a memref and has a 90edd9515bSthomasraoux // constant stride. 91edd9515bSthomasraoux static llvm::Optional<int64_t> 92edd9515bSthomasraoux getMemrefConstantHorizontalStride(ShapedType type) { 93edd9515bSthomasraoux auto memrefType = type.dyn_cast<MemRefType>(); 94edd9515bSthomasraoux if (!memrefType) 95edd9515bSthomasraoux return false; 96a57ccad5SThomas Raoux // If the memref is 0 or 1D the horizontal stride is 0. 97a57ccad5SThomas Raoux if (memrefType.getRank() < 2) 98a57ccad5SThomas Raoux return 0; 99edd9515bSthomasraoux int64_t offset = 0; 100edd9515bSthomasraoux SmallVector<int64_t, 2> strides; 101d77f4836SThomas Raoux if (failed(getStridesAndOffset(memrefType, strides, offset)) || 102d77f4836SThomas Raoux strides.back() != 1) 103edd9515bSthomasraoux return llvm::None; 104a57ccad5SThomas Raoux int64_t stride = strides[strides.size() - 2]; 105a57ccad5SThomas Raoux if (stride == ShapedType::kDynamicStrideOrOffset) 106edd9515bSthomasraoux return llvm::None; 107a57ccad5SThomas Raoux return stride; 108edd9515bSthomasraoux } 109edd9515bSthomasraoux 110edd9515bSthomasraoux // Return true if the transfer op can be converted to a MMA matrix load. 1111ca772edSChristopher Bate static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp, 1121ca772edSChristopher Bate bool useNvGpu) { 1137c38fd60SJacques Pienaar if (readOp.getMask() || readOp.hasOutOfBoundsDim() || 114edd9515bSthomasraoux readOp.getVectorType().getRank() != 2) 115edd9515bSthomasraoux return false; 116edd9515bSthomasraoux if (!getMemrefConstantHorizontalStride(readOp.getShapedType())) 117edd9515bSthomasraoux return false; 1187c38fd60SJacques Pienaar AffineMap map = readOp.getPermutationMap(); 119e7969240SThomas Raoux OpBuilder b(readOp.getContext()); 120e7969240SThomas Raoux AffineExpr innerDim = b.getAffineDimExpr(map.getNumDims() - 1); 121e7969240SThomas Raoux AffineExpr zero = b.getAffineConstantExpr(0); 122e7969240SThomas Raoux auto broadcastInnerDim = AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, 123e7969240SThomas Raoux readOp.getContext()); 1241ca772edSChristopher Bate 1251ca772edSChristopher Bate if (!useNvGpu) { 126edd9515bSthomasraoux // TODO: Support transpose once it is added to GPU dialect ops. 127e7969240SThomas Raoux // For now we only support (d0, d1) -> (d0, d1) and (d0, d1) -> (0, d1). 1281ca772edSChristopher Bate return map.isMinorIdentity() || map == broadcastInnerDim; 1291ca772edSChristopher Bate } 1301ca772edSChristopher Bate 1311ca772edSChristopher Bate return true; 132edd9515bSthomasraoux } 133edd9515bSthomasraoux 134edd9515bSthomasraoux // Return true if the transfer op can be converted to a MMA matrix store. 135edd9515bSthomasraoux static bool 136edd9515bSthomasraoux transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) { 137c537a943SNicolas Vasilache // TODO: support 0-d corner case. 138c537a943SNicolas Vasilache if (writeOp.getTransferRank() == 0) 139c537a943SNicolas Vasilache return false; 140c537a943SNicolas Vasilache 1417c38fd60SJacques Pienaar if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() || 142edd9515bSthomasraoux writeOp.getVectorType().getRank() != 2) 143edd9515bSthomasraoux return false; 144edd9515bSthomasraoux if (!getMemrefConstantHorizontalStride(writeOp.getShapedType())) 145edd9515bSthomasraoux return false; 146edd9515bSthomasraoux // TODO: Support transpose once it is added to GPU dialect ops. 1477c38fd60SJacques Pienaar if (!writeOp.getPermutationMap().isMinorIdentity()) 148edd9515bSthomasraoux return false; 149edd9515bSthomasraoux return true; 150edd9515bSthomasraoux } 151edd9515bSthomasraoux 1526413226dSthomasraoux /// Return true if the constant is a splat to a 2D vector so that it can be 1536413226dSthomasraoux /// converted to a MMA constant matrix op. 154a54f4eaeSMogball static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) { 1556413226dSthomasraoux auto vecType = constantOp.getType().dyn_cast<VectorType>(); 1566413226dSthomasraoux if (!vecType || vecType.getRank() != 2) 1576413226dSthomasraoux return false; 158cfb72fd3SJacques Pienaar return constantOp.getValue().isa<SplatElementsAttr>(); 1596413226dSthomasraoux } 1606413226dSthomasraoux 16143928419Sthomasraoux /// Return true if this is a broadcast from scalar to a 2D vector. 16243928419Sthomasraoux static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) { 16343928419Sthomasraoux return broadcastOp.getVectorType().getRank() == 2 && 1647c38fd60SJacques Pienaar broadcastOp.getSource().getType().isa<FloatType>(); 16543928419Sthomasraoux } 16643928419Sthomasraoux 1677fbb0678Sthomasraoux /// Return the MMA elementwise enum associated with `op` if it is supported. 1687fbb0678Sthomasraoux /// Return `llvm::None` otherwise. 1697fbb0678Sthomasraoux static llvm::Optional<gpu::MMAElementwiseOp> 1707fbb0678Sthomasraoux convertElementwiseOpToMMA(Operation *op) { 1717fbb0678Sthomasraoux if (isa<arith::AddFOp>(op)) 1727fbb0678Sthomasraoux return gpu::MMAElementwiseOp::ADDF; 1737fbb0678Sthomasraoux if (isa<arith::MulFOp>(op)) 1747fbb0678Sthomasraoux return gpu::MMAElementwiseOp::MULF; 1759b1d90e8SAlexander Belyaev if (isa<arith::MaxFOp>(op)) 1767fbb0678Sthomasraoux return gpu::MMAElementwiseOp::MAXF; 1779b1d90e8SAlexander Belyaev if (isa<arith::MinFOp>(op)) 1787fbb0678Sthomasraoux return gpu::MMAElementwiseOp::MINF; 179e7969240SThomas Raoux if (isa<arith::DivFOp>(op)) 180e7969240SThomas Raoux return gpu::MMAElementwiseOp::DIVF; 1817fbb0678Sthomasraoux return llvm::None; 1827fbb0678Sthomasraoux } 1837fbb0678Sthomasraoux 1847fbb0678Sthomasraoux /// Return true if the op is supported as elementwise op on MMAMatrix type. 1857fbb0678Sthomasraoux static bool elementwiseSupportsMMAMatrixType(Operation *op) { 1867fbb0678Sthomasraoux return convertElementwiseOpToMMA(op).hasValue(); 1877fbb0678Sthomasraoux } 1887fbb0678Sthomasraoux 1891ca772edSChristopher Bate static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) { 1901a865592Sthomasraoux if (isa<scf::ForOp, scf::YieldOp>(op)) 1911a865592Sthomasraoux return true; 192edd9515bSthomasraoux if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) 1931ca772edSChristopher Bate return transferReadSupportsMMAMatrixType(transferRead, useNvGpu); 194edd9515bSthomasraoux if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) 195edd9515bSthomasraoux return transferWriteSupportsMMAMatrixType(transferWrite); 196edd9515bSthomasraoux if (auto contract = dyn_cast<vector::ContractionOp>(op)) 1971ca772edSChristopher Bate return contractSupportsMMAMatrixType(contract, useNvGpu); 198a54f4eaeSMogball if (auto constant = dyn_cast<arith::ConstantOp>(op)) 1996413226dSthomasraoux return constantSupportsMMAMatrixType(constant); 20043928419Sthomasraoux if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) 20143928419Sthomasraoux return broadcastSupportsMMAMatrixType(broadcast); 2027fbb0678Sthomasraoux return elementwiseSupportsMMAMatrixType(op); 203edd9515bSthomasraoux } 204edd9515bSthomasraoux 205e7969240SThomas Raoux /// Return an unsorted slice handling scf.for region differently than 206e7969240SThomas Raoux /// `getSlice`. In scf.for we only want to include as part of the slice elements 207e7969240SThomas Raoux /// that are part of the use/def chain. 208e7969240SThomas Raoux static SetVector<Operation *> getSliceContract(Operation *op, 209e7969240SThomas Raoux TransitiveFilter backwardFilter, 210e7969240SThomas Raoux TransitiveFilter forwardFilter) { 211e7969240SThomas Raoux SetVector<Operation *> slice; 212e7969240SThomas Raoux slice.insert(op); 213e7969240SThomas Raoux unsigned currentIndex = 0; 214e7969240SThomas Raoux SetVector<Operation *> backwardSlice; 215e7969240SThomas Raoux SetVector<Operation *> forwardSlice; 216e7969240SThomas Raoux while (currentIndex != slice.size()) { 217e7969240SThomas Raoux auto *currentOp = (slice)[currentIndex]; 218e7969240SThomas Raoux // Compute and insert the backwardSlice starting from currentOp. 219e7969240SThomas Raoux backwardSlice.clear(); 220e7969240SThomas Raoux getBackwardSlice(currentOp, &backwardSlice, backwardFilter); 221e7969240SThomas Raoux slice.insert(backwardSlice.begin(), backwardSlice.end()); 222e7969240SThomas Raoux 223e7969240SThomas Raoux // Compute and insert the forwardSlice starting from currentOp. 224e7969240SThomas Raoux forwardSlice.clear(); 225e7969240SThomas Raoux // Special case for ForOp, we don't want to include the whole region but 226e7969240SThomas Raoux // only the value using the region arguments. 227e7969240SThomas Raoux // TODO: We should refine this to only care about the region arguments being 228e7969240SThomas Raoux // converted to matrix type. 229e7969240SThomas Raoux if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) { 230e7969240SThomas Raoux for (Value forOpResult : forOp.getResults()) 231e7969240SThomas Raoux getForwardSlice(forOpResult, &forwardSlice, forwardFilter); 232e7969240SThomas Raoux for (BlockArgument &arg : forOp.getRegionIterArgs()) 233e7969240SThomas Raoux getForwardSlice(arg, &forwardSlice, forwardFilter); 234e7969240SThomas Raoux } else { 235e7969240SThomas Raoux getForwardSlice(currentOp, &forwardSlice, forwardFilter); 236e7969240SThomas Raoux } 237e7969240SThomas Raoux slice.insert(forwardSlice.begin(), forwardSlice.end()); 238e7969240SThomas Raoux ++currentIndex; 239e7969240SThomas Raoux } 240e7969240SThomas Raoux return slice; 241e7969240SThomas Raoux } 242e7969240SThomas Raoux 243edd9515bSthomasraoux // Analyze slice of operations based on convert op to figure out if the whole 244edd9515bSthomasraoux // slice can be converted to MMA operations. 2451ca772edSChristopher Bate static SetVector<Operation *> getOpToConvert(mlir::Operation *op, 2461ca772edSChristopher Bate bool useNvGpu) { 247edd9515bSthomasraoux auto hasVectorDest = [](Operation *op) { 24843928419Sthomasraoux return llvm::any_of(op->getResultTypes(), 24943928419Sthomasraoux [](Type t) { return t.isa<VectorType>(); }); 25043928419Sthomasraoux }; 25143928419Sthomasraoux auto hasVectorSrc = [](Operation *op) { 25243928419Sthomasraoux return llvm::any_of(op->getOperandTypes(), 253edd9515bSthomasraoux [](Type t) { return t.isa<VectorType>(); }); 254edd9515bSthomasraoux }; 255edd9515bSthomasraoux SetVector<Operation *> opToConvert; 256edd9515bSthomasraoux op->walk([&](vector::ContractionOp contract) { 257edd9515bSthomasraoux if (opToConvert.contains(contract.getOperation())) 258edd9515bSthomasraoux return; 259edd9515bSthomasraoux SetVector<Operation *> dependentOps = 260e7969240SThomas Raoux getSliceContract(contract, hasVectorDest, hasVectorSrc); 261edd9515bSthomasraoux // If any instruction cannot use MMA matrix type drop the whole 262e7969240SThomas Raoux // chain. MMA matrix are stored in an opaque type so they cannot be used 263edd9515bSthomasraoux // by all operations. 2641ca772edSChristopher Bate if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) { 2651ca772edSChristopher Bate return !supportsMMaMatrixType(op, useNvGpu); 2661ca772edSChristopher Bate })) 267edd9515bSthomasraoux return; 268edd9515bSthomasraoux opToConvert.insert(dependentOps.begin(), dependentOps.end()); 269edd9515bSthomasraoux }); 270e7969240SThomas Raoux // Sort the operations so that we can convert them in topological order. 271e7969240SThomas Raoux return topologicalSort(opToConvert); 272edd9515bSthomasraoux } 273edd9515bSthomasraoux 274edd9515bSthomasraoux namespace { 275edd9515bSthomasraoux // Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted 276edd9515bSthomasraoux // to MMA matmul. 277edd9515bSthomasraoux struct PrepareContractToGPUMMA 278edd9515bSthomasraoux : public OpRewritePattern<vector::ContractionOp> { 279edd9515bSthomasraoux using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 280edd9515bSthomasraoux 281edd9515bSthomasraoux LogicalResult matchAndRewrite(vector::ContractionOp op, 282edd9515bSthomasraoux PatternRewriter &rewriter) const override { 283edd9515bSthomasraoux Location loc = op.getLoc(); 2847c38fd60SJacques Pienaar Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc(); 285edd9515bSthomasraoux 286edd9515bSthomasraoux // Set up the parallel/reduction structure in right form. 287edd9515bSthomasraoux using MapList = ArrayRef<ArrayRef<AffineExpr>>; 288edd9515bSthomasraoux auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; 289edd9515bSthomasraoux AffineExpr m, n, k; 290edd9515bSthomasraoux bindDims(rewriter.getContext(), m, n, k); 291edd9515bSthomasraoux static constexpr std::array<int64_t, 2> perm = {1, 0}; 2927c38fd60SJacques Pienaar auto iteratorTypes = op.getIteratorTypes().getValue(); 293edd9515bSthomasraoux SmallVector<AffineMap, 4> maps = op.getIndexingMaps(); 294edd9515bSthomasraoux if (!(isParallelIterator(iteratorTypes[0]) && 295edd9515bSthomasraoux isParallelIterator(iteratorTypes[1]) && 296edd9515bSthomasraoux isReductionIterator(iteratorTypes[2]))) 297edd9515bSthomasraoux return failure(); 298edd9515bSthomasraoux // 299edd9515bSthomasraoux // Two outer parallel, one inner reduction (matmat flavor). 300edd9515bSthomasraoux // 301edd9515bSthomasraoux if (maps == infer({{m, k}, {k, n}, {m, n}})) { 302edd9515bSthomasraoux // This is the classical row-major matmul, nothing to do. 303edd9515bSthomasraoux return failure(); 304edd9515bSthomasraoux } 305edd9515bSthomasraoux if (maps == infer({{m, k}, {n, k}, {m, n}})) { 306edd9515bSthomasraoux rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 307edd9515bSthomasraoux } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { 308edd9515bSthomasraoux lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 309edd9515bSthomasraoux } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { 310edd9515bSthomasraoux rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 311edd9515bSthomasraoux lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 312edd9515bSthomasraoux } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { 313edd9515bSthomasraoux std::swap(rhs, lhs); 314edd9515bSthomasraoux rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 315edd9515bSthomasraoux lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 316edd9515bSthomasraoux } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { 317edd9515bSthomasraoux std::swap(rhs, lhs); 318edd9515bSthomasraoux rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 319edd9515bSthomasraoux } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { 320edd9515bSthomasraoux std::swap(lhs, rhs); 321edd9515bSthomasraoux lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 322edd9515bSthomasraoux } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { 323edd9515bSthomasraoux std::swap(lhs, rhs); 324edd9515bSthomasraoux } else { 325edd9515bSthomasraoux return failure(); 326edd9515bSthomasraoux } 327edd9515bSthomasraoux rewriter.replaceOpWithNewOp<vector::ContractionOp>( 328edd9515bSthomasraoux op, lhs, rhs, res, 329edd9515bSthomasraoux rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})), 3307c38fd60SJacques Pienaar op.getIteratorTypes()); 331edd9515bSthomasraoux return success(); 332edd9515bSthomasraoux } 333edd9515bSthomasraoux }; 334edd9515bSthomasraoux 335edd9515bSthomasraoux // Merge transpose op into the transfer read op. Transpose are not supported on 336edd9515bSthomasraoux // MMA types but MMA load can transpose the matrix when loading. 337edd9515bSthomasraoux struct CombineTransferReadOpTranspose final 338edd9515bSthomasraoux : public OpRewritePattern<vector::TransposeOp> { 339edd9515bSthomasraoux using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 340edd9515bSthomasraoux 341edd9515bSthomasraoux LogicalResult matchAndRewrite(vector::TransposeOp op, 342edd9515bSthomasraoux PatternRewriter &rewriter) const override { 3437c38fd60SJacques Pienaar auto transferReadOp = 3447c38fd60SJacques Pienaar op.getVector().getDefiningOp<vector::TransferReadOp>(); 345edd9515bSthomasraoux if (!transferReadOp) 346edd9515bSthomasraoux return failure(); 347c537a943SNicolas Vasilache 348c537a943SNicolas Vasilache // TODO: support 0-d corner case. 349c537a943SNicolas Vasilache if (transferReadOp.getTransferRank() == 0) 350c537a943SNicolas Vasilache return failure(); 351c537a943SNicolas Vasilache 3527c38fd60SJacques Pienaar if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim()) 353edd9515bSthomasraoux return failure(); 354edd9515bSthomasraoux SmallVector<int64_t, 2> perm; 355edd9515bSthomasraoux op.getTransp(perm); 356edd9515bSthomasraoux SmallVector<unsigned, 2> permU; 357edd9515bSthomasraoux for (int64_t o : perm) 358edd9515bSthomasraoux permU.push_back(unsigned(o)); 359edd9515bSthomasraoux AffineMap permutationMap = 360edd9515bSthomasraoux AffineMap::getPermutationMap(permU, op.getContext()); 3617c38fd60SJacques Pienaar AffineMap newMap = 3627c38fd60SJacques Pienaar permutationMap.compose(transferReadOp.getPermutationMap()); 363edd9515bSthomasraoux rewriter.replaceOpWithNewOp<vector::TransferReadOp>( 3647c38fd60SJacques Pienaar op, op.getType(), transferReadOp.getSource(), 3657c38fd60SJacques Pienaar transferReadOp.getIndices(), AffineMapAttr::get(newMap), 3667c38fd60SJacques Pienaar transferReadOp.getPadding(), transferReadOp.getMask(), 3677c38fd60SJacques Pienaar transferReadOp.getInBoundsAttr()); 368edd9515bSthomasraoux return success(); 369edd9515bSthomasraoux } 370edd9515bSthomasraoux }; 371edd9515bSthomasraoux 372edd9515bSthomasraoux } // namespace 373edd9515bSthomasraoux 374edd9515bSthomasraoux // MMA types have different layout based on how they are used in matmul ops. 3756413226dSthomasraoux // Figure the right layout to use by looking at op uses. 376edd9515bSthomasraoux // TODO: Change the GPU dialect to abstract the layout at the this level and 377edd9515bSthomasraoux // only care about it during lowering to NVVM. 3786413226dSthomasraoux template <typename OpTy> 3796413226dSthomasraoux static const char *inferFragType(OpTy op) { 380edd9515bSthomasraoux for (Operation *users : op->getUsers()) { 381edd9515bSthomasraoux auto contract = dyn_cast<vector::ContractionOp>(users); 382edd9515bSthomasraoux if (!contract) 383edd9515bSthomasraoux continue; 3847c38fd60SJacques Pienaar if (contract.getLhs() == op.getResult()) 385edd9515bSthomasraoux return "AOp"; 3867c38fd60SJacques Pienaar if (contract.getRhs() == op.getResult()) 387edd9515bSthomasraoux return "BOp"; 388edd9515bSthomasraoux } 389edd9515bSthomasraoux return "COp"; 390edd9515bSthomasraoux } 391edd9515bSthomasraoux 392edd9515bSthomasraoux static void convertTransferReadOp(vector::TransferReadOp op, 393edd9515bSthomasraoux llvm::DenseMap<Value, Value> &valueMapping) { 394c537a943SNicolas Vasilache assert(op.getTransferRank() > 0 && "unexpected 0-d transfer"); 3951ca772edSChristopher Bate assert(transferReadSupportsMMAMatrixType(op, /*useNvGpu=*/false)); 396edd9515bSthomasraoux Optional<int64_t> stride = 397edd9515bSthomasraoux getMemrefConstantHorizontalStride(op.getShapedType()); 3987c38fd60SJacques Pienaar AffineMap map = op.getPermutationMap(); 399e7969240SThomas Raoux // Handle broadcast by setting the stride to 0. 400e7969240SThomas Raoux if (map.getResult(0).isa<AffineConstantExpr>()) { 401e7969240SThomas Raoux assert(map.getResult(0).cast<AffineConstantExpr>().getValue() == 0); 402e7969240SThomas Raoux stride = 0; 403e7969240SThomas Raoux } 404edd9515bSthomasraoux assert(stride); 405edd9515bSthomasraoux const char *fragType = inferFragType(op); 406edd9515bSthomasraoux gpu::MMAMatrixType type = 407edd9515bSthomasraoux gpu::MMAMatrixType::get(op.getVectorType().getShape(), 408edd9515bSthomasraoux op.getVectorType().getElementType(), fragType); 409edd9515bSthomasraoux OpBuilder b(op); 410edd9515bSthomasraoux Value load = b.create<gpu::SubgroupMmaLoadMatrixOp>( 4117c38fd60SJacques Pienaar op.getLoc(), type, op.getSource(), op.getIndices(), 4127c38fd60SJacques Pienaar b.getIndexAttr(*stride)); 413edd9515bSthomasraoux valueMapping[op.getResult()] = load; 414edd9515bSthomasraoux } 415edd9515bSthomasraoux 416edd9515bSthomasraoux static void convertTransferWriteOp(vector::TransferWriteOp op, 417edd9515bSthomasraoux llvm::DenseMap<Value, Value> &valueMapping) { 418edd9515bSthomasraoux assert(transferWriteSupportsMMAMatrixType(op)); 419edd9515bSthomasraoux Optional<int64_t> stride = 420edd9515bSthomasraoux getMemrefConstantHorizontalStride(op.getShapedType()); 421edd9515bSthomasraoux assert(stride); 422edd9515bSthomasraoux OpBuilder b(op); 4237c38fd60SJacques Pienaar Value matrix = valueMapping.find(op.getVector())->second; 4247c38fd60SJacques Pienaar b.create<gpu::SubgroupMmaStoreMatrixOp>(op.getLoc(), matrix, op.getSource(), 4257c38fd60SJacques Pienaar op.getIndices(), 4267c38fd60SJacques Pienaar b.getIndexAttr(*stride)); 427edd9515bSthomasraoux op.erase(); 428edd9515bSthomasraoux } 429edd9515bSthomasraoux 4301ca772edSChristopher Bate /// Returns the vector type which represents a matrix fragment. 4311ca772edSChristopher Bate static VectorType 4321ca772edSChristopher Bate getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo ®Info) { 4331ca772edSChristopher Bate SmallVector<int64_t> shape{regInfo.numRegistersPerFragment, 4341ca772edSChristopher Bate regInfo.elementsPerRegister}; 4351ca772edSChristopher Bate Type elType = regInfo.registerLLVMType; 4361ca772edSChristopher Bate if (auto vecType = elType.dyn_cast<VectorType>()) 4371ca772edSChristopher Bate elType = vecType.getElementType(); 4381ca772edSChristopher Bate return VectorType::get(shape, elType); 4391ca772edSChristopher Bate } 4401ca772edSChristopher Bate 4411ca772edSChristopher Bate /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. 4421ca772edSChristopher Bate static LogicalResult 4431ca772edSChristopher Bate convertConstantOpMmaSync(arith::ConstantOp op, 4441ca772edSChristopher Bate llvm::DenseMap<Value, Value> &valueMapping) { 4451ca772edSChristopher Bate OpBuilder b(op); 4461ca772edSChristopher Bate FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 4471ca772edSChristopher Bate nvgpu::getWarpMatrixInfo(op); 4481ca772edSChristopher Bate if (failed(warpMatrixInfo)) 4491ca772edSChristopher Bate return failure(); 4501ca772edSChristopher Bate 4511ca772edSChristopher Bate FailureOr<nvgpu::FragmentElementInfo> regInfo = 4521ca772edSChristopher Bate nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 4531ca772edSChristopher Bate if (failed(regInfo)) 4541ca772edSChristopher Bate return failure(); 4551ca772edSChristopher Bate 4561ca772edSChristopher Bate VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 4571ca772edSChristopher Bate auto dense = op.getValue().dyn_cast<SplatElementsAttr>(); 4581ca772edSChristopher Bate if (!dense) 4591ca772edSChristopher Bate return failure(); 4601ca772edSChristopher Bate Value result = b.create<arith::ConstantOp>( 4611ca772edSChristopher Bate op.getLoc(), vectorType, 4621ca772edSChristopher Bate DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>())); 4631ca772edSChristopher Bate valueMapping[op.getResult()] = result; 4641ca772edSChristopher Bate return success(); 4651ca772edSChristopher Bate } 4661ca772edSChristopher Bate 4671ca772edSChristopher Bate static LogicalResult 4681ca772edSChristopher Bate creatLdMatrixCompatibleLoads(vector::TransferReadOp op, OpBuilder &builder, 4691ca772edSChristopher Bate llvm::DenseMap<Value, Value> &valueMapping) { 4701ca772edSChristopher Bate Location loc = op->getLoc(); 4711ca772edSChristopher Bate 4721ca772edSChristopher Bate FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 4731ca772edSChristopher Bate nvgpu::getWarpMatrixInfo(op); 4741ca772edSChristopher Bate if (failed(warpMatrixInfo)) 4751ca772edSChristopher Bate return failure(); 4761ca772edSChristopher Bate 4771ca772edSChristopher Bate FailureOr<nvgpu::FragmentElementInfo> regInfo = 4781ca772edSChristopher Bate nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 4791ca772edSChristopher Bate if (failed(regInfo)) 4801ca772edSChristopher Bate return failure(); 4811ca772edSChristopher Bate 4821ca772edSChristopher Bate FailureOr<nvgpu::LdMatrixParams> params = nvgpu::getLdMatrixParams( 4831ca772edSChristopher Bate *warpMatrixInfo, 4841ca772edSChristopher Bate /*transpose=*/!op.getPermutationMap().isMinorIdentity()); 4851ca772edSChristopher Bate if (failed(params)) { 4861ca772edSChristopher Bate return op->emitError() 4871ca772edSChristopher Bate << "failed to convert vector.transfer_read to ldmatrix; this op " 4881ca772edSChristopher Bate "likely " 4891ca772edSChristopher Bate "should not be converted to a nvgpu.ldmatrix call."; 4901ca772edSChristopher Bate } 4911ca772edSChristopher Bate 4921ca772edSChristopher Bate // Adjust the load offset. 4931ca772edSChristopher Bate auto laneId = builder.create<gpu::LaneIdOp>(loc); 4941ca772edSChristopher Bate FailureOr<AffineMap> offsets = 4951ca772edSChristopher Bate nvgpu::getLaneIdToLdMatrixMatrixCoord(loc, builder, *params); 4961ca772edSChristopher Bate if (failed(offsets)) 4971ca772edSChristopher Bate return failure(); 4981ca772edSChristopher Bate 4991ca772edSChristopher Bate VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 5001ca772edSChristopher Bate 5011ca772edSChristopher Bate SmallVector<Value, 4> indices; 5021ca772edSChristopher Bate getXferIndices<vector::TransferReadOp>(builder, op, *offsets, {laneId}, 5031ca772edSChristopher Bate indices); 5041ca772edSChristopher Bate nvgpu::LdMatrixOp newOp = builder.create<nvgpu::LdMatrixOp>( 5051ca772edSChristopher Bate loc, vectorType, op.getSource(), indices, 5061ca772edSChristopher Bate !op.getPermutationMap().isMinorIdentity(), params->numTiles); 5071ca772edSChristopher Bate valueMapping[op] = newOp->getResult(0); 5081ca772edSChristopher Bate return success(); 5091ca772edSChristopher Bate } 5101ca772edSChristopher Bate 5111ca772edSChristopher Bate static LogicalResult 5121ca772edSChristopher Bate createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder, 5131ca772edSChristopher Bate llvm::DenseMap<Value, Value> &valueMapping) { 5141ca772edSChristopher Bate Location loc = op.getLoc(); 5151ca772edSChristopher Bate FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 5161ca772edSChristopher Bate nvgpu::getWarpMatrixInfo(op); 5171ca772edSChristopher Bate if (failed(warpMatrixInfo)) 5181ca772edSChristopher Bate return failure(); 5191ca772edSChristopher Bate FailureOr<nvgpu::FragmentElementInfo> regInfo = 5201ca772edSChristopher Bate nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 5211ca772edSChristopher Bate if (failed(regInfo)) { 5221ca772edSChristopher Bate op->emitError() << "Failed to deduce register fragment type during " 5231ca772edSChristopher Bate "conversion to distributed non-ldmatrix compatible load"; 5241ca772edSChristopher Bate return failure(); 5251ca772edSChristopher Bate } 5261ca772edSChristopher Bate 5271ca772edSChristopher Bate NVVM::MMALayout targetLayout = 5281ca772edSChristopher Bate warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B 5291ca772edSChristopher Bate ? NVVM::MMALayout::col 5301ca772edSChristopher Bate : NVVM::MMALayout::row; 5311ca772edSChristopher Bate 5321ca772edSChristopher Bate Value laneId = builder.create<gpu::LaneIdOp>(loc); 5331ca772edSChristopher Bate SmallVector<Value, 4> elements; 5341ca772edSChristopher Bate 5351ca772edSChristopher Bate // This is the individual element type. 5361ca772edSChristopher Bate Type loadedElType = regInfo->registerLLVMType; 5371ca772edSChristopher Bate VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 5381ca772edSChristopher Bate 5391ca772edSChristopher Bate Value fill = builder.create<arith::ConstantOp>( 5401ca772edSChristopher Bate op.getLoc(), vectorType.getElementType(), 5411ca772edSChristopher Bate builder.getZeroAttr(vectorType.getElementType())); 5421ca772edSChristopher Bate Value result = builder.create<vector::SplatOp>(op.getLoc(), fill, vectorType); 5431ca772edSChristopher Bate 5441ca772edSChristopher Bate bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity(); 5451ca772edSChristopher Bate 5461ca772edSChristopher Bate // Vectorized loads. 5471ca772edSChristopher Bate if (!isTransposeLoad && targetLayout == NVVM::MMALayout::row) { 5481ca772edSChristopher Bate if (!loadedElType.isa<VectorType>()) { 5491ca772edSChristopher Bate loadedElType = VectorType::get({1}, loadedElType); 5501ca772edSChristopher Bate } 5511ca772edSChristopher Bate 5521ca772edSChristopher Bate for (int i = 0; i < vectorType.getShape()[0]; i++) { 5531ca772edSChristopher Bate FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord( 5541ca772edSChristopher Bate op.getLoc(), builder, *warpMatrixInfo); 5551ca772edSChristopher Bate if (failed(coords)) 5561ca772edSChristopher Bate return failure(); 5571ca772edSChristopher Bate Value logicalValueId = builder.create<arith::ConstantOp>( 5581ca772edSChristopher Bate loc, builder.getIndexType(), 5591ca772edSChristopher Bate builder.getIndexAttr(i * regInfo->elementsPerRegister)); 5601ca772edSChristopher Bate SmallVector<Value, 4> newIndices; 5611ca772edSChristopher Bate getXferIndices<vector::TransferReadOp>( 5621ca772edSChristopher Bate builder, op, *coords, {laneId, logicalValueId}, newIndices); 5631ca772edSChristopher Bate 5641ca772edSChristopher Bate Value el = builder.create<vector::LoadOp>(loc, loadedElType, 5651ca772edSChristopher Bate op.getSource(), newIndices); 5661ca772edSChristopher Bate result = builder.create<vector::InsertOp>(loc, el, result, 5671ca772edSChristopher Bate builder.getI64ArrayAttr(i)); 5681ca772edSChristopher Bate } 5691ca772edSChristopher Bate } else if (isTransposeLoad && targetLayout == NVVM::MMALayout::col) { 5701ca772edSChristopher Bate if (auto vecType = loadedElType.dyn_cast<VectorType>()) { 5711ca772edSChristopher Bate loadedElType = vecType.getElementType(); 5721ca772edSChristopher Bate } 5731ca772edSChristopher Bate // Load each element individually. 5741ca772edSChristopher Bate for (int i = 0; i < vectorType.getShape()[0]; i++) { 5751ca772edSChristopher Bate for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1]; 5761ca772edSChristopher Bate innerIdx++) { 5771ca772edSChristopher Bate 5781ca772edSChristopher Bate Value logicalValueId = builder.create<arith::ConstantOp>( 5791ca772edSChristopher Bate loc, builder.getIndexType(), 5801ca772edSChristopher Bate builder.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx)); 5811ca772edSChristopher Bate FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord( 5821ca772edSChristopher Bate op.getLoc(), builder, *warpMatrixInfo); 5831ca772edSChristopher Bate if (failed(coords)) 5841ca772edSChristopher Bate return failure(); 5851ca772edSChristopher Bate 5861ca772edSChristopher Bate SmallVector<Value, 4> newIndices; 5871ca772edSChristopher Bate getXferIndices<vector::TransferReadOp>( 5881ca772edSChristopher Bate builder, op, *coords, {laneId, logicalValueId}, newIndices); 5891ca772edSChristopher Bate Value el = builder.create<memref::LoadOp>(op.getLoc(), loadedElType, 5901ca772edSChristopher Bate op.getSource(), newIndices); 5911ca772edSChristopher Bate result = builder.create<vector::InsertOp>( 5921ca772edSChristopher Bate op.getLoc(), el, result, builder.getI64ArrayAttr({i, innerIdx})); 5931ca772edSChristopher Bate } 5941ca772edSChristopher Bate } 5951ca772edSChristopher Bate } else { 5961ca772edSChristopher Bate return failure(); 5971ca772edSChristopher Bate } 5981ca772edSChristopher Bate 5991ca772edSChristopher Bate valueMapping[op.getResult()] = result; 6001ca772edSChristopher Bate return success(); 6011ca772edSChristopher Bate } 6021ca772edSChristopher Bate 6031ca772edSChristopher Bate /// Converts a `vector.transfer_read` operation directly to either a 6041ca772edSChristopher Bate /// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be 6051ca772edSChristopher Bate /// used when converting to `nvgpu.mma.sync` operations. 6061ca772edSChristopher Bate static LogicalResult 6071ca772edSChristopher Bate convertTransferReadToLoads(vector::TransferReadOp op, 6081ca772edSChristopher Bate llvm::DenseMap<Value, Value> &valueMapping) { 6091ca772edSChristopher Bate OpBuilder b(op); 6101ca772edSChristopher Bate 6111ca772edSChristopher Bate FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 6121ca772edSChristopher Bate nvgpu::getWarpMatrixInfo(op); 6131ca772edSChristopher Bate if (failed(warpMatrixInfo)) 6141ca772edSChristopher Bate return failure(); 6151ca772edSChristopher Bate 6161ca772edSChristopher Bate bool isLdMatrixCompatible = 6171ca772edSChristopher Bate op.getSource().getType().cast<MemRefType>().getMemorySpaceAsInt() == 3 && 6181ca772edSChristopher Bate nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128; 6191ca772edSChristopher Bate 6201ca772edSChristopher Bate VectorType vecTy = op.getVectorType(); 6211ca772edSChristopher Bate int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth(); 6221ca772edSChristopher Bate 6231ca772edSChristopher Bate // When we are transposing the B operand, ldmatrix will only work if we have 6241ca772edSChristopher Bate // at least 8 rows to read and the width to read for the transpose is 128 6251ca772edSChristopher Bate // bits. 6261ca772edSChristopher Bate if (!op.getPermutationMap().isMinorIdentity() && 627271a48e0SThomas Raoux (bitWidth != 16 || vecTy.getDimSize(1) < 8 || 628271a48e0SThomas Raoux vecTy.getDimSize(0) * bitWidth < 128)) 6291ca772edSChristopher Bate isLdMatrixCompatible = false; 6301ca772edSChristopher Bate 6311ca772edSChristopher Bate if (!isLdMatrixCompatible) 6321ca772edSChristopher Bate return createNonLdMatrixLoads(op, b, valueMapping); 6331ca772edSChristopher Bate 6341ca772edSChristopher Bate return creatLdMatrixCompatibleLoads(op, b, valueMapping); 6351ca772edSChristopher Bate } 6361ca772edSChristopher Bate 6371ca772edSChristopher Bate static LogicalResult 6381ca772edSChristopher Bate convertTransferWriteToStores(vector::TransferWriteOp op, 6391ca772edSChristopher Bate llvm::DenseMap<Value, Value> &valueMapping) { 6401ca772edSChristopher Bate OpBuilder b(op); 6411ca772edSChristopher Bate Location loc = op->getLoc(); 6421ca772edSChristopher Bate Value matrix = valueMapping.find(op.getVector())->second; 6431ca772edSChristopher Bate 6441ca772edSChristopher Bate FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 6451ca772edSChristopher Bate nvgpu::getWarpMatrixInfo(op); 6461ca772edSChristopher Bate if (failed(warpMatrixInfo)) 6471ca772edSChristopher Bate return failure(); 6481ca772edSChristopher Bate FailureOr<nvgpu::FragmentElementInfo> regInfo = 6491ca772edSChristopher Bate nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 6501ca772edSChristopher Bate if (failed(regInfo)) 6511ca772edSChristopher Bate return failure(); 6521ca772edSChristopher Bate 6531ca772edSChristopher Bate VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 6541ca772edSChristopher Bate Value laneId = b.create<gpu::LaneIdOp>(loc); 6551ca772edSChristopher Bate 6561ca772edSChristopher Bate for (unsigned i = 0; i < vectorType.getShape()[0]; i++) { 6571ca772edSChristopher Bate Value logicalValueId = b.create<arith::ConstantOp>( 6581ca772edSChristopher Bate loc, b.getIndexType(), 6591ca772edSChristopher Bate b.getIndexAttr(i * regInfo->elementsPerRegister)); 6601ca772edSChristopher Bate FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord( 6611ca772edSChristopher Bate op.getLoc(), b, *warpMatrixInfo); 6621ca772edSChristopher Bate if (failed(coords)) 6631ca772edSChristopher Bate return failure(); 6641ca772edSChristopher Bate 6651ca772edSChristopher Bate Value el = b.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i}); 6661ca772edSChristopher Bate SmallVector<Value, 4> newIndices; 6671ca772edSChristopher Bate getXferIndices<vector::TransferWriteOp>( 6681ca772edSChristopher Bate b, op, *coords, {laneId, logicalValueId}, newIndices); 6691ca772edSChristopher Bate b.create<vector::StoreOp>(loc, el, op.getSource(), newIndices); 6701ca772edSChristopher Bate } 6711ca772edSChristopher Bate op->erase(); 6721ca772edSChristopher Bate return success(); 6731ca772edSChristopher Bate } 6741ca772edSChristopher Bate 675edd9515bSthomasraoux static void convertContractOp(vector::ContractionOp op, 676edd9515bSthomasraoux llvm::DenseMap<Value, Value> &valueMapping) { 677edd9515bSthomasraoux OpBuilder b(op); 6787c38fd60SJacques Pienaar Value opA = valueMapping.find(op.getLhs())->second; 6797c38fd60SJacques Pienaar Value opB = valueMapping.find(op.getRhs())->second; 6807c38fd60SJacques Pienaar Value opC = valueMapping.find(op.getAcc())->second; 681edd9515bSthomasraoux Value matmul = b.create<gpu::SubgroupMmaComputeOp>(op.getLoc(), opC.getType(), 682edd9515bSthomasraoux opA, opB, opC); 683edd9515bSthomasraoux valueMapping[op.getResult()] = matmul; 684edd9515bSthomasraoux } 685edd9515bSthomasraoux 6861ca772edSChristopher Bate static LogicalResult 6871ca772edSChristopher Bate convertContractOpToMmaSync(vector::ContractionOp op, 6881ca772edSChristopher Bate llvm::DenseMap<Value, Value> &valueMapping) { 6891ca772edSChristopher Bate OpBuilder b(op); 6901ca772edSChristopher Bate Value opA = valueMapping.find(op.getLhs())->second; 6911ca772edSChristopher Bate Value opB = valueMapping.find(op.getRhs())->second; 6921ca772edSChristopher Bate Value opC = valueMapping.find(op.getAcc())->second; 6931ca772edSChristopher Bate int64_t m = op.getLhs().getType().cast<VectorType>().getShape()[0]; 6941ca772edSChristopher Bate int64_t n = op.getRhs().getType().cast<VectorType>().getShape()[0]; 6951ca772edSChristopher Bate int64_t k = op.getLhs().getType().cast<VectorType>().getShape()[1]; 6961ca772edSChristopher Bate Value matmul = b.create<nvgpu::MmaSyncOp>( 6971ca772edSChristopher Bate op.getLoc(), opC.getType(), opA, opB, opC, b.getI64ArrayAttr({m, n, k})); 6981ca772edSChristopher Bate valueMapping[op.getResult()] = matmul; 6991ca772edSChristopher Bate return success(); 7001ca772edSChristopher Bate } 7011ca772edSChristopher Bate 7026413226dSthomasraoux /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. 703a54f4eaeSMogball static void convertConstantOp(arith::ConstantOp op, 7046413226dSthomasraoux llvm::DenseMap<Value, Value> &valueMapping) { 7056413226dSthomasraoux assert(constantSupportsMMAMatrixType(op)); 7066413226dSthomasraoux OpBuilder b(op); 707937e40a8SRiver Riddle Attribute splat = 708937e40a8SRiver Riddle op.getValue().cast<SplatElementsAttr>().getSplatValue<Attribute>(); 7096413226dSthomasraoux auto scalarConstant = 710a54f4eaeSMogball b.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat); 7116413226dSthomasraoux const char *fragType = inferFragType(op); 7126413226dSthomasraoux auto vecType = op.getType().cast<VectorType>(); 7136413226dSthomasraoux gpu::MMAMatrixType type = gpu::MMAMatrixType::get( 7146413226dSthomasraoux vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); 7156413226dSthomasraoux auto matrix = b.create<gpu::SubgroupMmaConstantMatrixOp>(op.getLoc(), type, 7166413226dSthomasraoux scalarConstant); 7176413226dSthomasraoux valueMapping[op.getResult()] = matrix; 7186413226dSthomasraoux } 7196413226dSthomasraoux 72043928419Sthomasraoux /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op. 72143928419Sthomasraoux static void convertBroadcastOp(vector::BroadcastOp op, 72243928419Sthomasraoux llvm::DenseMap<Value, Value> &valueMapping) { 72343928419Sthomasraoux assert(broadcastSupportsMMAMatrixType(op)); 72443928419Sthomasraoux OpBuilder b(op); 72543928419Sthomasraoux const char *fragType = inferFragType(op); 72643928419Sthomasraoux auto vecType = op.getVectorType(); 72743928419Sthomasraoux gpu::MMAMatrixType type = gpu::MMAMatrixType::get( 72843928419Sthomasraoux vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); 72943928419Sthomasraoux auto matrix = b.create<gpu::SubgroupMmaConstantMatrixOp>(op.getLoc(), type, 7307c38fd60SJacques Pienaar op.getSource()); 73143928419Sthomasraoux valueMapping[op.getResult()] = matrix; 73243928419Sthomasraoux } 73343928419Sthomasraoux 7341a865592Sthomasraoux // Replace ForOp with a new ForOp with extra operands. The YieldOp is not 7351a865592Sthomasraoux // updated and needs to be updated separatly for the loop to be correct. 7361a865592Sthomasraoux static scf::ForOp replaceForOpWithNewSignature(OpBuilder &b, scf::ForOp loop, 7371a865592Sthomasraoux ValueRange newIterOperands) { 7381a865592Sthomasraoux // Create a new loop before the existing one, with the extra operands. 7391a865592Sthomasraoux OpBuilder::InsertionGuard g(b); 7401a865592Sthomasraoux b.setInsertionPoint(loop); 7411a865592Sthomasraoux auto operands = llvm::to_vector<4>(loop.getIterOperands()); 7421a865592Sthomasraoux operands.append(newIterOperands.begin(), newIterOperands.end()); 7431a865592Sthomasraoux scf::ForOp newLoop = 744c0342a2dSJacques Pienaar b.create<scf::ForOp>(loop.getLoc(), loop.getLowerBound(), 745c0342a2dSJacques Pienaar loop.getUpperBound(), loop.getStep(), operands); 7461a865592Sthomasraoux newLoop.getBody()->erase(); 7471a865592Sthomasraoux newLoop.getLoopBody().getBlocks().splice( 7481a865592Sthomasraoux newLoop.getLoopBody().getBlocks().begin(), 7491a865592Sthomasraoux loop.getLoopBody().getBlocks()); 750e084679fSRiver Riddle for (Value operand : newIterOperands) 751e084679fSRiver Riddle newLoop.getBody()->addArgument(operand.getType(), operand.getLoc()); 7521a865592Sthomasraoux 7531a865592Sthomasraoux for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front( 7541a865592Sthomasraoux loop.getNumResults()))) 7551a865592Sthomasraoux std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); 7561a865592Sthomasraoux loop.erase(); 7571a865592Sthomasraoux return newLoop; 7581a865592Sthomasraoux } 7591a865592Sthomasraoux 7601a865592Sthomasraoux static void convertForOp(scf::ForOp op, 7611a865592Sthomasraoux llvm::DenseMap<Value, Value> &valueMapping) { 7621a865592Sthomasraoux SmallVector<Value> newOperands; 7631a865592Sthomasraoux SmallVector<std::pair<size_t, size_t>> argMapping; 764e4853be2SMehdi Amini for (const auto &operand : llvm::enumerate(op.getIterOperands())) { 7651a865592Sthomasraoux auto it = valueMapping.find(operand.value()); 7661a865592Sthomasraoux if (it == valueMapping.end()) 7671a865592Sthomasraoux continue; 7681a865592Sthomasraoux argMapping.push_back(std::make_pair( 7691a865592Sthomasraoux operand.index(), op.getNumIterOperands() + newOperands.size())); 7701a865592Sthomasraoux newOperands.push_back(it->second); 7711a865592Sthomasraoux } 7721a865592Sthomasraoux OpBuilder b(op); 7731a865592Sthomasraoux scf::ForOp newForOp = replaceForOpWithNewSignature(b, op, newOperands); 7741a865592Sthomasraoux Block &loopBody = *newForOp.getBody(); 7751a865592Sthomasraoux for (auto mapping : argMapping) { 7761a865592Sthomasraoux valueMapping[newForOp.getResult(mapping.first)] = 7771a865592Sthomasraoux newForOp.getResult(mapping.second); 7781a865592Sthomasraoux valueMapping[loopBody.getArgument(mapping.first + 7791a865592Sthomasraoux newForOp.getNumInductionVars())] = 7801a865592Sthomasraoux loopBody.getArgument(mapping.second + newForOp.getNumInductionVars()); 7811a865592Sthomasraoux } 7821a865592Sthomasraoux } 7831a865592Sthomasraoux 7841a865592Sthomasraoux static void convertYieldOp(scf::YieldOp op, 7851a865592Sthomasraoux llvm::DenseMap<Value, Value> &valueMapping) { 7861a865592Sthomasraoux OpBuilder b(op); 7871a865592Sthomasraoux auto loop = cast<scf::ForOp>(op->getParentOp()); 7881a865592Sthomasraoux auto yieldOperands = llvm::to_vector<4>(op.getOperands()); 789e4853be2SMehdi Amini for (const auto &operand : llvm::enumerate(op.getOperands())) { 7901a865592Sthomasraoux auto it = valueMapping.find(operand.value()); 7911a865592Sthomasraoux if (it == valueMapping.end()) 7921a865592Sthomasraoux continue; 7931a865592Sthomasraoux // Replace the yield of old value with the for op argument to make it easier 7941a865592Sthomasraoux // to remove the dead code. 7951a865592Sthomasraoux yieldOperands[operand.index()] = loop.getIterOperands()[operand.index()]; 7961a865592Sthomasraoux yieldOperands.push_back(it->second); 7971a865592Sthomasraoux } 7981a865592Sthomasraoux b.create<scf::YieldOp>(op.getLoc(), yieldOperands); 7991a865592Sthomasraoux op.erase(); 8001a865592Sthomasraoux } 8011a865592Sthomasraoux 8027fbb0678Sthomasraoux /// Convert an elementwise op to the equivalent elementwise op on MMA matrix. 8037fbb0678Sthomasraoux static void convertElementwiseOp(Operation *op, gpu::MMAElementwiseOp opType, 8047fbb0678Sthomasraoux llvm::DenseMap<Value, Value> &valueMapping) { 8057fbb0678Sthomasraoux OpBuilder b(op); 8067fbb0678Sthomasraoux SmallVector<Value> matrixOperands; 8077fbb0678Sthomasraoux for (Value operand : op->getOperands()) 8087fbb0678Sthomasraoux matrixOperands.push_back(valueMapping.find(operand)->second); 8097fbb0678Sthomasraoux Value newOp = b.create<gpu::SubgroupMmaElementwiseOp>( 8107fbb0678Sthomasraoux op->getLoc(), matrixOperands[0].getType(), matrixOperands, opType); 8117fbb0678Sthomasraoux valueMapping[op->getResult(0)] = newOp; 8127fbb0678Sthomasraoux } 8137fbb0678Sthomasraoux 8141ca772edSChristopher Bate void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, 8151ca772edSChristopher Bate bool useNvGpu) { 8161ca772edSChristopher Bate if (!useNvGpu) { 817edd9515bSthomasraoux patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>( 818edd9515bSthomasraoux patterns.getContext()); 8191ca772edSChristopher Bate return; 8201ca772edSChristopher Bate } 8211ca772edSChristopher Bate patterns 8221ca772edSChristopher Bate .add<nvgpu::PrepareContractToGPUMMASync, CombineTransferReadOpTranspose>( 8231ca772edSChristopher Bate patterns.getContext()); 824edd9515bSthomasraoux } 825edd9515bSthomasraoux 82647f175b0SRiver Riddle void mlir::convertVectorToMMAOps(Operation *rootOp) { 8271ca772edSChristopher Bate SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/false); 828edd9515bSthomasraoux llvm::DenseMap<Value, Value> valueMapping; 829edd9515bSthomasraoux for (Operation *op : ops) { 830edd9515bSthomasraoux if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) { 831edd9515bSthomasraoux convertTransferReadOp(transferRead, valueMapping); 832edd9515bSthomasraoux } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) { 833edd9515bSthomasraoux convertTransferWriteOp(transferWrite, valueMapping); 834edd9515bSthomasraoux } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) { 835edd9515bSthomasraoux convertContractOp(contractOp, valueMapping); 836a54f4eaeSMogball } else if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) { 8376413226dSthomasraoux convertConstantOp(constantOp, valueMapping); 83843928419Sthomasraoux } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) { 83943928419Sthomasraoux convertBroadcastOp(broadcastOp, valueMapping); 8401a865592Sthomasraoux } else if (auto forOp = dyn_cast<scf::ForOp>(op)) { 8411a865592Sthomasraoux convertForOp(forOp, valueMapping); 8421a865592Sthomasraoux } else if (auto yiledOp = dyn_cast<scf::YieldOp>(op)) { 8431a865592Sthomasraoux convertYieldOp(yiledOp, valueMapping); 8447fbb0678Sthomasraoux } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) { 8457fbb0678Sthomasraoux convertElementwiseOp(op, *elementwiseType, valueMapping); 846edd9515bSthomasraoux } 847edd9515bSthomasraoux } 848edd9515bSthomasraoux } 849edd9515bSthomasraoux 8501ca772edSChristopher Bate LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(Operation *rootOp) { 8511ca772edSChristopher Bate SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/true); 8521ca772edSChristopher Bate llvm::DenseMap<Value, Value> valueMapping; 8531ca772edSChristopher Bate for (Operation *op : ops) { 8541ca772edSChristopher Bate if (llvm::TypeSwitch<Operation *, LogicalResult>(op) 8551ca772edSChristopher Bate .Case([&](vector::TransferReadOp transferReadOp) { 8561ca772edSChristopher Bate return convertTransferReadToLoads(transferReadOp, valueMapping); 8571ca772edSChristopher Bate }) 8581ca772edSChristopher Bate .Case([&](vector::TransferWriteOp transferWriteOp) { 8591ca772edSChristopher Bate return convertTransferWriteToStores(transferWriteOp, 8601ca772edSChristopher Bate valueMapping); 8611ca772edSChristopher Bate }) 8621ca772edSChristopher Bate .Case([&](vector::ContractionOp contractionOp) { 8631ca772edSChristopher Bate return convertContractOpToMmaSync(contractionOp, valueMapping); 8641ca772edSChristopher Bate }) 8651ca772edSChristopher Bate .Case([&](scf::ForOp forOp) { 8661ca772edSChristopher Bate convertForOp(forOp, valueMapping); 8671ca772edSChristopher Bate return success(); 8681ca772edSChristopher Bate }) 8691ca772edSChristopher Bate .Case([&](scf::YieldOp yieldOp) { 8701ca772edSChristopher Bate convertYieldOp(yieldOp, valueMapping); 8711ca772edSChristopher Bate return success(); 8721ca772edSChristopher Bate }) 8731ca772edSChristopher Bate .Case([&](arith::ConstantOp constOp) { 8741ca772edSChristopher Bate return convertConstantOpMmaSync(constOp, valueMapping); 8751ca772edSChristopher Bate }) 8761ca772edSChristopher Bate .Default([&](Operation *op) { 8771ca772edSChristopher Bate op->emitError() << "unhandled vector to mma type: " << *op; 8781ca772edSChristopher Bate return failure(); 8791ca772edSChristopher Bate }) 8801ca772edSChristopher Bate .failed()) { 8811ca772edSChristopher Bate op->emitError() << "Failed to convert op " << *op; 8821ca772edSChristopher Bate return failure(); 8831ca772edSChristopher Bate } 8841ca772edSChristopher Bate } 8851ca772edSChristopher Bate return success(); 8861ca772edSChristopher Bate } 8871ca772edSChristopher Bate 888edd9515bSthomasraoux namespace { 889edd9515bSthomasraoux 890edd9515bSthomasraoux struct ConvertVectorToGPUPass 891edd9515bSthomasraoux : public ConvertVectorToGPUBase<ConvertVectorToGPUPass> { 8921ca772edSChristopher Bate 8931ca772edSChristopher Bate explicit ConvertVectorToGPUPass(bool useNvGpu_) { 8941ca772edSChristopher Bate useNvGpu.setValue(useNvGpu_); 8951ca772edSChristopher Bate } 8961ca772edSChristopher Bate 89741574554SRiver Riddle void runOnOperation() override { 89847f175b0SRiver Riddle RewritePatternSet patterns(&getContext()); 8991ca772edSChristopher Bate populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue()); 9001ca772edSChristopher Bate if (failed( 9011ca772edSChristopher Bate applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) 9021ca772edSChristopher Bate return signalPassFailure(); 903edd9515bSthomasraoux 9041ca772edSChristopher Bate if (useNvGpu.getValue()) { 9051ca772edSChristopher Bate if (failed(convertVectorToNVVMCompatibleMMASync(getOperation()))) 9061ca772edSChristopher Bate return signalPassFailure(); 9071ca772edSChristopher Bate } 9081ca772edSChristopher Bate 9091ca772edSChristopher Bate (void)convertVectorToMMAOps(getOperation()); 910edd9515bSthomasraoux } 911edd9515bSthomasraoux }; 912edd9515bSthomasraoux 913edd9515bSthomasraoux } // namespace 914edd9515bSthomasraoux 9151ca772edSChristopher Bate std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) { 9161ca772edSChristopher Bate return std::make_unique<ConvertVectorToGPUPass>(useNvGpu); 917edd9515bSthomasraoux } 918