1 //===- NvvmMMASupport.h - MLIR Vector to GPU lowering support --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides utilities to assist in the lowering of Vector operations
10 // to GPU dialect MMA operations.
11 //
12 //===----------------------------------------------------------------------===//
13 #ifndef MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H
14 #define MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H
15 
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
18 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
19 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
20 #include "mlir/Dialect/Vector/IR/VectorOps.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/Types.h"
23 
24 namespace mlir {
25 namespace nvgpu {
26 
27 enum class MatMulOperandRole : int32_t { A = 0, B, C };
28 
29 /// Collects information about a warp-level matrix operand represented by a
30 /// VectorType.
31 struct WarpMatrixInfo {
32   VectorType vectorType;
33   MatMulOperandRole operandRole;
34 };
35 
36 /// Given an op that operates on a VectorType representing a warp-level matrix
37 /// operand, the function returns a struct containing relevant type information.
38 FailureOr<WarpMatrixInfo> getWarpMatrixInfo(Operation *op);
39 
40 /// Returns the number of bits in a single tile row. It is either 128, 256, or
41 /// 512 bits depending on the data type and` whether the operand is an
42 /// accumulator/result operand
43 int64_t inferTileWidthInBits(const WarpMatrixInfo &type);
44 
45 /// Specifies information about the registers which compose a matrix fragment
46 /// according to the PTX documentation.
47 struct FragmentElementInfo {
48   Type registerLLVMType;
49   int64_t elementsPerRegister;
50   int64_t registerWidthBits;
51   int64_t numRegistersPerFragment;
52 };
53 
54 /// Returns a FragmentElementInfo struct describing the register types for the
55 /// given matrix fragment type.
56 FailureOr<FragmentElementInfo>
57 getMmaSyncRegisterType(const WarpMatrixInfo &type);
58 
59 /// Returns an AffineMap which maps a two dimensions representing (laneId,
60 /// logicalValueId) and returns two results representing offsets within a
61 /// matrix operand. The offsets point to the values the thread is responsible
62 /// for (AKA the matrix fragment values) during a warp-collective matrix
63 /// operation. For a visual reference of this LaneId -> (row, col) mapping,
64 /// please see NVIDIA's PTX documentation:
65 /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma
66 FailureOr<AffineMap>
67 getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder,
68                                   const WarpMatrixInfo &fragmentType);
69 
70 struct LdMatrixParams {
71   VectorType fragmentType;
72   bool isAccum;
73   int64_t numTiles;
74   IteratorType contiguousDimType;
75   NVVM::MMALayout targetLayout;
76 };
77 
78 FailureOr<LdMatrixParams> getLdMatrixParams(const WarpMatrixInfo &type,
79                                             bool transpose);
80 /// Returns an AffineMap which maps a single dimension representing the laneId
81 /// to two results representing offsets within the matrix operand that should
82 /// be the pointer locations a thread should pass to the ldmatrix instruction.
83 FailureOr<AffineMap>
84 getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder,
85                                const LdMatrixParams &params);
86 
87 // Transform contract into (m, k)x(n, k)x(m, n) form so that it can be converted
88 // to MMA matmul.
89 struct PrepareContractToGPUMMASync
90     : public OpRewritePattern<vector::ContractionOp> {
91   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
92 
93   LogicalResult matchAndRewrite(vector::ContractionOp op,
94                                 PatternRewriter &rewriter) const override;
95 };
96 
97 } // namespace nvgpu
98 } // namespace mlir
99 
100 #endif // MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H
101