1 //===- NvGpuSupport.cpp - MLIR Vector to GPU lowering support --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file provides utilities to assist in the lowering of Vector operations 10 // to NvGPU dialect MMA operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "NvGpuSupport.h" 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/NVGPU/NVGPUDialect.h" 17 #include "mlir/Dialect/Vector/IR/VectorOps.h" 18 19 namespace mlir { 20 namespace nvgpu { 21 namespace { 22 23 /// There are always 4 threads per [128|256|512] bit row. 24 constexpr int64_t kThreadsPerRow = 4; 25 26 constexpr int64_t kNumRowsPerTile = 8; 27 28 bool isAccumulatorOrResult(MatMulOperandRole operandType) { 29 return operandType == MatMulOperandRole::C; 30 } 31 32 /// Returns the number of registers which compose a matrix fragment held by a 33 /// single thread. 34 int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type) { 35 int64_t lineSize = inferTileWidthInBits(type); 36 auto shape = type.vectorType.getShape(); 37 return (shape[0] / kNumRowsPerTile) * 38 (shape[1] * type.vectorType.getElementType().getIntOrFloatBitWidth()) / 39 lineSize; 40 } 41 42 /// Returns the number of 8 x [128|256|512] bit tiles that compose the given 43 /// operand shape. 44 std::array<int64_t, 2> getTileShape(ArrayRef<int64_t> operandShape, 45 Type elementType, int64_t lineSizeBits) { 46 // For each 8x128bit square, a thread is responsible for one 32bit register. 47 return {operandShape[0] / kNumRowsPerTile, 48 (operandShape[1] * elementType.getIntOrFloatBitWidth()) / 49 lineSizeBits}; 50 } 51 52 } // namespace 53 54 FailureOr<WarpMatrixInfo> getWarpMatrixInfo(Operation *op) { 55 WarpMatrixInfo info; 56 57 // Determine the vector type. 58 if (vector::TransferWriteOp writeOp = dyn_cast<vector::TransferWriteOp>(op)) { 59 info.vectorType = writeOp.getVectorType(); 60 } else if (isa<vector::TransferReadOp, vector::ContractionOp, 61 arith::ConstantOp>(op)) { 62 info.vectorType = op->getResult(0).getType().cast<VectorType>(); 63 } else { 64 return op->emitError() 65 << "unhandled operation type in nvgpu.mma.sync conversion path"; 66 } 67 68 // Determine the operand role. We assume it is an accumulator/result unless it 69 // is directly consumed by a `vector.contract` op. 70 info.operandRole = MatMulOperandRole::C; 71 for (Operation *user : op->getUsers()) { 72 auto contract = dyn_cast<vector::ContractionOp>(user); 73 if (!contract) 74 continue; 75 if (contract.getLhs() == op->getResult(0)) { 76 info.operandRole = MatMulOperandRole::A; 77 break; 78 } 79 if (contract.getRhs() == op->getResult(0)) { 80 info.operandRole = MatMulOperandRole::B; 81 break; 82 } 83 } 84 return info; 85 } 86 87 int64_t inferTileWidthInBits(const WarpMatrixInfo &type) { 88 bool isAcc = isAccumulatorOrResult(type.operandRole); 89 Type elType = type.vectorType.getElementType(); 90 if (isAcc && elType.getIntOrFloatBitWidth() == 32) { 91 return 256; 92 } 93 if (elType.getIntOrFloatBitWidth() == 64) { 94 return isAcc ? 512 : 256; 95 } 96 return 128; 97 } 98 99 FailureOr<FragmentElementInfo> 100 getMmaSyncRegisterType(const WarpMatrixInfo &type) { 101 MLIRContext *ctx = type.vectorType.getContext(); 102 const bool isAccum = isAccumulatorOrResult(type.operandRole); 103 104 Type elType = type.vectorType.getElementType(); 105 if (elType.isF16()) { 106 return FragmentElementInfo{ 107 LLVM::getFixedVectorType(Float16Type::get(ctx), 2), 2, 32, 108 inferNumRegistersPerMatrixFragment(type)}; 109 } 110 111 // f64 operand 112 Type f64Ty = Float64Type::get(ctx); 113 if (elType.isF64()) { 114 return isAccum 115 ? FragmentElementInfo{LLVM::getFixedVectorType(f64Ty, 2), 2, 128, 116 inferNumRegistersPerMatrixFragment(type)} 117 : FragmentElementInfo{f64Ty, 1, 64, 118 inferNumRegistersPerMatrixFragment(type)}; 119 } 120 121 // int8 operand 122 if (elType.isInteger(8)) { 123 return FragmentElementInfo{ 124 LLVM::getFixedVectorType(IntegerType::get(ctx, 8), 4), 4, 32, 125 inferNumRegistersPerMatrixFragment(type)}; 126 } 127 // Integer 32bit acc operands 128 if (elType.isInteger(32)) { 129 return FragmentElementInfo{ 130 LLVM::getFixedVectorType(IntegerType::get(ctx, 32), 2), 2, 64, 131 inferNumRegistersPerMatrixFragment(type)}; 132 } 133 134 // Floating point 32bit operands 135 if (elType.isF32()) { 136 Type f32Ty = Float32Type::get(ctx); 137 return isAccum 138 ? FragmentElementInfo{LLVM::getFixedVectorType(f32Ty, 2), 2, 64, 139 inferNumRegistersPerMatrixFragment(type)} 140 : FragmentElementInfo{f32Ty, 1, 32, 141 inferNumRegistersPerMatrixFragment(type)}; 142 } 143 return failure(); 144 } 145 146 static AffineMap getRegisterIndexToTileOffsetMap(int64_t lineSize, 147 Type elementType, 148 ArrayRef<int64_t> operandShape, 149 bool isAccumulator, 150 int64_t elementsPerRegister, 151 AffineExpr logicalValueId) { 152 const int64_t elementsPerLine = 153 lineSize / elementType.getIntOrFloatBitWidth(); 154 const std::array<int64_t, 2> num8x128bTiles = 155 getTileShape(operandShape, elementType, lineSize); 156 AffineExpr registerIdx = logicalValueId.floorDiv(elementsPerRegister); 157 return AffineMap::get( 158 2, 0, 159 {(registerIdx % num8x128bTiles[0]) * 8, 160 (registerIdx.floorDiv(num8x128bTiles[0])) * elementsPerLine}, 161 elementType.getContext()); 162 } 163 164 FailureOr<AffineMap> 165 getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder, 166 const WarpMatrixInfo &fragmentType) { 167 Type elementType = fragmentType.vectorType.getElementType(); 168 ArrayRef<int64_t> operandShape = fragmentType.vectorType.getShape(); 169 FailureOr<nvgpu::FragmentElementInfo> regInfo = 170 getMmaSyncRegisterType(fragmentType); 171 if (failed(regInfo)) 172 return failure(); 173 174 const int64_t elementBitWidth = elementType.getIntOrFloatBitWidth(); 175 const int64_t elementsPerRegister = 176 regInfo->registerWidthBits / elementBitWidth; 177 const int64_t lineSize = inferTileWidthInBits(fragmentType); 178 179 AffineExpr laneId, logicalValueIdDim; 180 bindDims(builder.getContext(), laneId, logicalValueIdDim); 181 182 // Determine what register logicalValueId corresponds to. Use that as a 183 // linear index into the coordinate mapping `index -> (tile row, tile col)`. 184 AffineMap registerIndexToTileCoord = getRegisterIndexToTileOffsetMap( 185 lineSize, elementType, operandShape, 186 isAccumulatorOrResult(fragmentType.operandRole), elementsPerRegister, 187 logicalValueIdDim); 188 189 auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap { 190 return AffineMap::get(2, 0, dimExprs, builder.getContext()); 191 }; 192 193 auto tileRow = registerIndexToTileCoord.getResult(0); 194 auto tileCol = registerIndexToTileCoord.getResult(1); 195 return makeMap({tileRow + laneId.floorDiv(kThreadsPerRow), 196 tileCol + (laneId % kThreadsPerRow) * elementsPerRegister + 197 (logicalValueIdDim % elementsPerRegister)}); 198 } 199 200 FailureOr<nvgpu::LdMatrixParams> getLdMatrixParams(const WarpMatrixInfo &type, 201 bool transpose) { 202 LdMatrixParams params; 203 Type elType = type.vectorType.getElementType(); 204 params.fragmentType = type.vectorType; 205 if (type.operandRole == MatMulOperandRole::A || 206 type.operandRole == MatMulOperandRole::C) { 207 params.targetLayout = NVVM::MMALayout::row; 208 } else { 209 params.targetLayout = NVVM::MMALayout::col; 210 } 211 ArrayRef<int64_t> shape = type.vectorType.getShape(); 212 params.contiguousDimType = 213 transpose ? IteratorType::Parallel : IteratorType::Reduction; 214 215 if (params.targetLayout == NVVM::MMALayout::row) { 216 params.numTiles = (shape[0] / kNumRowsPerTile) * 217 ((shape[1] * elType.getIntOrFloatBitWidth()) / 128); 218 } else { 219 params.numTiles = (shape[1] / kNumRowsPerTile) * 220 ((shape[0] * elType.getIntOrFloatBitWidth()) / 128); 221 } 222 223 if (params.numTiles == 0) 224 return failure(); 225 226 return params; 227 } 228 229 FailureOr<AffineMap> 230 getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder, 231 const LdMatrixParams ¶ms) { 232 // One thread per 128b row. 233 const int64_t kNumThreadsPerTile = kNumRowsPerTile; 234 const int bitsPerElement = static_cast<int>( 235 params.fragmentType.getElementType().getIntOrFloatBitWidth()); 236 const int kElementsPer128b = (128 / bitsPerElement); 237 ArrayRef<int64_t> operandShape = params.fragmentType.getShape(); 238 AffineExpr d0 = getAffineDimExpr(0, builder.getContext()); 239 240 auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap { 241 return AffineMap::get(1, 0, dimExprs, builder.getContext()); 242 }; 243 244 // This case corresponds to row-major A|C or col-major B operands. 245 if (params.contiguousDimType == IteratorType::Reduction) { 246 AffineExpr row = d0 % (operandShape[0]); 247 AffineExpr col = d0.floorDiv(operandShape[0]) * (kElementsPer128b); 248 return makeMap({row, col}); 249 } 250 251 // This case Corresponds to col-major A|C or row-major B operands. The 252 // operandShape given is already pre-transposed (e.g. 8x16 = KxN). 253 if (params.contiguousDimType == IteratorType::Parallel) { 254 const int64_t num8x128bCols = (operandShape[0] * bitsPerElement) / 128; 255 // Threads are assigned in groups of 8 first across columns, then to 256 // rows. This is transpose of what `ldmatrix` expects, but when 257 // `ldmatrix` gets the `.trans` qualifier, final the effect will be to 258 // transpose just the blocks. 259 auto groupIdx = d0.floorDiv(kNumThreadsPerTile); 260 auto tileCol = (groupIdx % num8x128bCols); 261 auto tileRow = groupIdx.floorDiv(num8x128bCols); 262 return makeMap({tileCol * kElementsPer128b, 263 tileRow * kNumRowsPerTile + (d0 % kNumRowsPerTile)}); 264 } 265 return failure(); 266 } 267 268 LogicalResult 269 PrepareContractToGPUMMASync::matchAndRewrite(vector::ContractionOp op, 270 PatternRewriter &rewriter) const { 271 Location loc = op.getLoc(); 272 Value lhs = op.getLhs(); 273 Value rhs = op.getRhs(); 274 Value res = op.getAcc(); 275 276 // Set up the parallel/reduction structure in right form. 277 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 278 auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; 279 AffineExpr m; 280 AffineExpr n; 281 AffineExpr k; 282 bindDims(rewriter.getContext(), m, n, k); 283 static constexpr std::array<int64_t, 2> perm = {1, 0}; 284 auto iteratorTypes = op.getIteratorTypes().getValue(); 285 SmallVector<AffineMap, 4> maps = op.getIndexingMaps(); 286 if (iteratorTypes.size() != 3) 287 return failure(); 288 if (!(isParallelIterator(iteratorTypes[0]) && 289 isParallelIterator(iteratorTypes[1]) && 290 isReductionIterator(iteratorTypes[2]))) 291 return failure(); 292 293 // The canonical form is "TNT" = A row-major, B col-major, C row-major. 294 const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}}); 295 if (maps == canonicalForm) { 296 return failure(); 297 } 298 if (maps == infer({{m, k}, {k, n}, {m, n}})) { 299 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 300 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { 301 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 302 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { 303 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 304 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 305 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { 306 std::swap(rhs, lhs); 307 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 308 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 309 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { 310 std::swap(rhs, lhs); 311 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 312 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { 313 std::swap(lhs, rhs); 314 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 315 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { 316 std::swap(lhs, rhs); 317 } else { 318 return failure(); 319 } 320 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 321 op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm), 322 op.getIteratorTypes()); 323 return success(); 324 } 325 326 } // namespace nvgpu 327 } // namespace mlir