1 //===- NvvmMMASupport.h - MLIR Vector to GPU lowering support --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file provides utilities to assist in the lowering of Vector operations 10 // to GPU dialect MMA operations. 11 // 12 //===----------------------------------------------------------------------===// 13 #ifndef MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H 14 #define MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H 15 16 #include "mlir/Dialect/Affine/IR/AffineOps.h" 17 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 18 #include "mlir/Dialect/LLVMIR/NVVMDialect.h" 19 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 20 #include "mlir/Dialect/Vector/IR/VectorOps.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "mlir/IR/Types.h" 23 24 namespace mlir { 25 namespace nvgpu { 26 27 enum class MatMulOperandRole : int32_t { A = 0, B, C }; 28 29 /// Collects information about a warp-level matrix operand represented by a 30 /// VectorType. 31 struct WarpMatrixInfo { 32 VectorType vectorType; 33 MatMulOperandRole operandRole; 34 }; 35 36 /// Given an op that operates on a VectorType representing a warp-level matrix 37 /// operand, the function returns a struct containing relevant type information. 38 FailureOr<WarpMatrixInfo> getWarpMatrixInfo(Operation *op); 39 40 /// Returns the number of bits in a single tile row. It is either 128, 256, or 41 /// 512 bits depending on the data type and` whether the operand is an 42 /// accumulator/result operand 43 int64_t inferTileWidthInBits(const WarpMatrixInfo &type); 44 45 /// Specifies information about the registers which compose a matrix fragment 46 /// according to the PTX documentation. 47 struct FragmentElementInfo { 48 Type registerLLVMType; 49 int64_t elementsPerRegister; 50 int64_t registerWidthBits; 51 int64_t numRegistersPerFragment; 52 }; 53 54 /// Returns a FragmentElementInfo struct describing the register types for the 55 /// given matrix fragment type. 56 FailureOr<FragmentElementInfo> 57 getMmaSyncRegisterType(const WarpMatrixInfo &type); 58 59 /// Returns an AffineMap which maps a two dimensions representing (laneId, 60 /// logicalValueId) and returns two results representing offsets within a 61 /// matrix operand. The offsets point to the values the thread is responsible 62 /// for (AKA the matrix fragment values) during a warp-collective matrix 63 /// operation. For a visual reference of this LaneId -> (row, col) mapping, 64 /// please see NVIDIA's PTX documentation: 65 /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma 66 FailureOr<AffineMap> 67 getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder, 68 const WarpMatrixInfo &fragmentType); 69 70 struct LdMatrixParams { 71 VectorType fragmentType; 72 bool isAccum; 73 int64_t numTiles; 74 IteratorType contiguousDimType; 75 NVVM::MMALayout targetLayout; 76 }; 77 78 FailureOr<LdMatrixParams> getLdMatrixParams(const WarpMatrixInfo &type, 79 bool transpose); 80 /// Returns an AffineMap which maps a single dimension representing the laneId 81 /// to two results representing offsets within the matrix operand that should 82 /// be the pointer locations a thread should pass to the ldmatrix instruction. 83 FailureOr<AffineMap> 84 getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder, 85 const LdMatrixParams ¶ms); 86 87 // Transform contract into (m, k)x(n, k)x(m, n) form so that it can be converted 88 // to MMA matmul. 89 struct PrepareContractToGPUMMASync 90 : public OpRewritePattern<vector::ContractionOp> { 91 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 92 93 LogicalResult matchAndRewrite(vector::ContractionOp op, 94 PatternRewriter &rewriter) const override; 95 }; 96 97 } // namespace nvgpu 98 } // namespace mlir 99 100 #endif // MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H 101