1 //===- NvGpuSupport.cpp - MLIR Vector to GPU lowering support --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file provides utilities to assist in the lowering of Vector operations 10 // to NvGPU dialect MMA operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "NvGpuSupport.h" 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" 17 #include "mlir/Dialect/Vector/IR/VectorOps.h" 18 19 namespace mlir { 20 namespace nvgpu { 21 namespace { 22 23 /// There are always 4 threads per [128|256|512] bit row. 24 constexpr int64_t kThreadsPerRow = 4; 25 26 constexpr int64_t kNumRowsPerTile = 8; 27 28 bool isAccumulatorOrResult(MatMulOperandRole operandType) { 29 return operandType == MatMulOperandRole::C; 30 } 31 32 /// Returns the number of registers which compose a matrix fragment held by a 33 /// single thread. 34 int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type) { 35 int64_t lineSize = inferTileWidthInBits(type); 36 auto shape = type.vectorType.getShape(); 37 return (shape[0] / kNumRowsPerTile) * 38 (shape[1] * type.vectorType.getElementType().getIntOrFloatBitWidth()) / 39 lineSize; 40 } 41 42 /// Returns the number of 8 x [128|256|512] bit tiles that compose the given 43 /// operand shape. 44 std::array<int64_t, 2> getTileShape(ArrayRef<int64_t> operandShape, 45 Type elementType, int64_t lineSizeBits) { 46 // For each 8x128bit square, a thread is responsible for one 32bit register. 47 return {operandShape[0] / kNumRowsPerTile, 48 (operandShape[1] * elementType.getIntOrFloatBitWidth()) / 49 lineSizeBits}; 50 } 51 52 } // namespace 53 54 FailureOr<WarpMatrixInfo> getWarpMatrixInfo(Operation *op) { 55 WarpMatrixInfo info; 56 57 // Determine the vector type. 58 if (vector::TransferWriteOp writeOp = dyn_cast<vector::TransferWriteOp>(op)) { 59 info.vectorType = writeOp.getVectorType(); 60 } else if (isa<vector::TransferReadOp, vector::ContractionOp, 61 arith::ConstantOp>(op)) { 62 info.vectorType = op->getResult(0).getType().cast<VectorType>(); 63 } else { 64 return op->emitError() 65 << "unhandled operation type in nvgpu.mma.sync conversion path"; 66 } 67 68 // Determine the operand role. We assume it is an accumulator/result unless it 69 // is directly consumed by a `vector.contract` op. 70 info.operandRole = MatMulOperandRole::C; 71 for (Operation *user : op->getUsers()) { 72 auto contract = dyn_cast<vector::ContractionOp>(user); 73 if (!contract) 74 continue; 75 if (contract.getLhs() == op->getResult(0)) { 76 info.operandRole = MatMulOperandRole::A; 77 break; 78 } 79 if (contract.getRhs() == op->getResult(0)) { 80 info.operandRole = MatMulOperandRole::B; 81 break; 82 } 83 } 84 return info; 85 } 86 87 int64_t inferTileWidthInBits(const WarpMatrixInfo &type) { 88 bool isAcc = isAccumulatorOrResult(type.operandRole); 89 Type elType = type.vectorType.getElementType(); 90 if (isAcc && elType.getIntOrFloatBitWidth() == 32) { 91 return 256; 92 } 93 if (elType.getIntOrFloatBitWidth() == 64) { 94 return isAcc ? 512 : 256; 95 } 96 return 128; 97 } 98 99 FailureOr<FragmentElementInfo> 100 getMmaSyncRegisterType(const WarpMatrixInfo &type) { 101 MLIRContext *ctx = type.vectorType.getContext(); 102 const bool isAccum = isAccumulatorOrResult(type.operandRole); 103 104 Type elType = type.vectorType.getElementType(); 105 if (elType.isF16()) { 106 return FragmentElementInfo{ 107 LLVM::getFixedVectorType(Float16Type::get(ctx), 2), 2, 32, 108 inferNumRegistersPerMatrixFragment(type)}; 109 } 110 111 // f64 operand 112 Type f64Ty = Float64Type::get(ctx); 113 if (elType.isF64()) { 114 return isAccum 115 ? FragmentElementInfo{LLVM::getFixedVectorType(f64Ty, 2), 2, 128, 116 inferNumRegistersPerMatrixFragment(type)} 117 : FragmentElementInfo{f64Ty, 1, 64, 118 inferNumRegistersPerMatrixFragment(type)}; 119 } 120 121 // int8 operand 122 if (elType.isInteger(8)) { 123 return FragmentElementInfo{ 124 LLVM::getFixedVectorType(IntegerType::get(ctx, 8), 4), 4, 32, 125 inferNumRegistersPerMatrixFragment(type)}; 126 } 127 128 // int4 operand 129 if (elType.isInteger(4)) { 130 return FragmentElementInfo{ 131 LLVM::getFixedVectorType(IntegerType::get(ctx, 4), 8), 8, 32, 132 inferNumRegistersPerMatrixFragment(type)}; 133 } 134 135 // Integer 32bit acc operands 136 if (elType.isInteger(32)) { 137 return FragmentElementInfo{ 138 LLVM::getFixedVectorType(IntegerType::get(ctx, 32), 2), 2, 64, 139 inferNumRegistersPerMatrixFragment(type)}; 140 } 141 142 // Floating point 32bit operands 143 if (elType.isF32()) { 144 Type f32Ty = Float32Type::get(ctx); 145 return isAccum 146 ? FragmentElementInfo{LLVM::getFixedVectorType(f32Ty, 2), 2, 64, 147 inferNumRegistersPerMatrixFragment(type)} 148 : FragmentElementInfo{f32Ty, 1, 32, 149 inferNumRegistersPerMatrixFragment(type)}; 150 } 151 return failure(); 152 } 153 154 static AffineMap getRegisterIndexToTileOffsetMap(int64_t lineSize, 155 Type elementType, 156 ArrayRef<int64_t> operandShape, 157 bool isAccumulator, 158 int64_t elementsPerRegister, 159 AffineExpr logicalValueId) { 160 const int64_t elementsPerLine = 161 lineSize / elementType.getIntOrFloatBitWidth(); 162 const std::array<int64_t, 2> num8x128bTiles = 163 getTileShape(operandShape, elementType, lineSize); 164 AffineExpr registerIdx = logicalValueId.floorDiv(elementsPerRegister); 165 return AffineMap::get( 166 2, 0, 167 {(registerIdx % num8x128bTiles[0]) * 8, 168 (registerIdx.floorDiv(num8x128bTiles[0])) * elementsPerLine}, 169 elementType.getContext()); 170 } 171 172 FailureOr<AffineMap> 173 getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder, 174 const WarpMatrixInfo &fragmentType) { 175 Type elementType = fragmentType.vectorType.getElementType(); 176 ArrayRef<int64_t> operandShape = fragmentType.vectorType.getShape(); 177 FailureOr<nvgpu::FragmentElementInfo> regInfo = 178 getMmaSyncRegisterType(fragmentType); 179 if (failed(regInfo)) 180 return failure(); 181 182 const int64_t elementBitWidth = elementType.getIntOrFloatBitWidth(); 183 const int64_t elementsPerRegister = 184 regInfo->registerWidthBits / elementBitWidth; 185 const int64_t lineSize = inferTileWidthInBits(fragmentType); 186 187 AffineExpr laneId, logicalValueIdDim; 188 bindDims(builder.getContext(), laneId, logicalValueIdDim); 189 190 // Determine what register logicalValueId corresponds to. Use that as a 191 // linear index into the coordinate mapping `index -> (tile row, tile col)`. 192 AffineMap registerIndexToTileCoord = getRegisterIndexToTileOffsetMap( 193 lineSize, elementType, operandShape, 194 isAccumulatorOrResult(fragmentType.operandRole), elementsPerRegister, 195 logicalValueIdDim); 196 197 auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap { 198 return AffineMap::get(2, 0, dimExprs, builder.getContext()); 199 }; 200 201 auto tileRow = registerIndexToTileCoord.getResult(0); 202 auto tileCol = registerIndexToTileCoord.getResult(1); 203 return makeMap({tileRow + laneId.floorDiv(kThreadsPerRow), 204 tileCol + (laneId % kThreadsPerRow) * elementsPerRegister + 205 (logicalValueIdDim % elementsPerRegister)}); 206 } 207 208 FailureOr<nvgpu::LdMatrixParams> getLdMatrixParams(const WarpMatrixInfo &type, 209 bool transpose) { 210 LdMatrixParams params; 211 Type elType = type.vectorType.getElementType(); 212 params.fragmentType = type.vectorType; 213 if (type.operandRole == MatMulOperandRole::A || 214 type.operandRole == MatMulOperandRole::C) { 215 params.targetLayout = NVVM::MMALayout::row; 216 } else { 217 params.targetLayout = NVVM::MMALayout::col; 218 } 219 ArrayRef<int64_t> shape = type.vectorType.getShape(); 220 params.contiguousDimType = 221 transpose ? IteratorType::Parallel : IteratorType::Reduction; 222 223 if (params.contiguousDimType == IteratorType::Reduction) { 224 params.numTiles = (shape[0] / kNumRowsPerTile) * 225 ((shape[1] * elType.getIntOrFloatBitWidth()) / 128); 226 } else { 227 params.numTiles = (shape[1] / kNumRowsPerTile) * 228 ((shape[0] * elType.getIntOrFloatBitWidth()) / 128); 229 } 230 231 if (params.numTiles == 0) 232 return failure(); 233 234 return params; 235 } 236 237 FailureOr<AffineMap> 238 getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder, 239 const LdMatrixParams ¶ms) { 240 // One thread per 128b row. 241 const int64_t kNumThreadsPerTile = kNumRowsPerTile; 242 const int bitsPerElement = static_cast<int>( 243 params.fragmentType.getElementType().getIntOrFloatBitWidth()); 244 const int kElementsPer128b = (128 / bitsPerElement); 245 ArrayRef<int64_t> operandShape = params.fragmentType.getShape(); 246 AffineExpr d0 = getAffineDimExpr(0, builder.getContext()); 247 248 auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap { 249 return AffineMap::get(1, 0, dimExprs, builder.getContext()); 250 }; 251 252 // This case corresponds to row-major A|C or col-major B operands. 253 if (params.contiguousDimType == IteratorType::Reduction) { 254 AffineExpr row = d0 % (operandShape[0]); 255 AffineExpr col = d0.floorDiv(operandShape[0]) * (kElementsPer128b); 256 return makeMap({row, col}); 257 } 258 259 // This case Corresponds to col-major A|C or row-major B operands. The 260 // operandShape given is already pre-transposed (e.g. 8x16 = KxN). 261 if (params.contiguousDimType == IteratorType::Parallel) { 262 const int64_t num8x128bCols = (operandShape[0] * bitsPerElement) / 128; 263 // Threads are assigned in groups of 8 first across columns, then to 264 // rows. This is transpose of what `ldmatrix` expects, but when 265 // `ldmatrix` gets the `.trans` qualifier, final the effect will be to 266 // transpose just the blocks. 267 auto groupIdx = d0.floorDiv(kNumThreadsPerTile); 268 auto tileCol = (groupIdx % num8x128bCols); 269 auto tileRow = groupIdx.floorDiv(num8x128bCols); 270 return makeMap({tileCol * kElementsPer128b, 271 tileRow * kNumRowsPerTile + (d0 % kNumRowsPerTile)}); 272 } 273 return failure(); 274 } 275 276 LogicalResult 277 PrepareContractToGPUMMASync::matchAndRewrite(vector::ContractionOp op, 278 PatternRewriter &rewriter) const { 279 Location loc = op.getLoc(); 280 Value lhs = op.getLhs(); 281 Value rhs = op.getRhs(); 282 Value res = op.getAcc(); 283 284 // Set up the parallel/reduction structure in right form. 285 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 286 auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; 287 AffineExpr m; 288 AffineExpr n; 289 AffineExpr k; 290 bindDims(rewriter.getContext(), m, n, k); 291 static constexpr std::array<int64_t, 2> perm = {1, 0}; 292 auto iteratorTypes = op.getIteratorTypes().getValue(); 293 SmallVector<AffineMap, 4> maps = op.getIndexingMaps(); 294 if (iteratorTypes.size() != 3) 295 return failure(); 296 if (!(isParallelIterator(iteratorTypes[0]) && 297 isParallelIterator(iteratorTypes[1]) && 298 isReductionIterator(iteratorTypes[2]))) 299 return failure(); 300 301 // The canonical form is "TNT" = A row-major, B col-major, C row-major. 302 const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}}); 303 if (maps == canonicalForm) { 304 return failure(); 305 } 306 if (maps == infer({{m, k}, {k, n}, {m, n}})) { 307 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 308 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { 309 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 310 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { 311 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 312 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 313 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { 314 std::swap(rhs, lhs); 315 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 316 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 317 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { 318 std::swap(rhs, lhs); 319 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 320 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { 321 std::swap(lhs, rhs); 322 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 323 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { 324 std::swap(lhs, rhs); 325 } else { 326 return failure(); 327 } 328 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 329 op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm), 330 op.getIteratorTypes()); 331 return success(); 332 } 333 334 } // namespace nvgpu 335 } // namespace mlir