1 //===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h" 10 #include "../PassDetail.h" 11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 12 #include "mlir/Conversion/LLVMCommon/Pattern.h" 13 #include "mlir/Dialect/LLVMIR/NVVMDialect.h" 14 #include "mlir/Dialect/NVGPU/NVGPUDialect.h" 15 16 using namespace mlir; 17 18 /// Returns the type for the intrinsic given the vectorResultType of the 19 /// `gpu.mma.sync` operation. 20 static Type inferIntrinsicResultType(Type vectorResultType) { 21 MLIRContext *ctx = vectorResultType.getContext(); 22 auto a = vectorResultType.cast<LLVM::LLVMArrayType>(); 23 auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2); 24 auto i32Ty = IntegerType::get(ctx, 32); 25 auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2); 26 Type f64Ty = Float64Type::get(ctx); 27 Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2); 28 if (a.getElementType() == f16x2Ty) { 29 return LLVM::LLVMStructType::getLiteral( 30 ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty)); 31 } 32 if (a.getElementType() == i32x2Ty) { 33 return LLVM::LLVMStructType::getLiteral( 34 ctx, 35 SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty)); 36 } 37 if (a.getElementType() == f64x2Ty) { 38 return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty}); 39 } 40 return vectorResultType; 41 } 42 43 /// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is 44 /// always an LLVM struct) into a fragment that is compatible with the vector 45 /// type of this operation. This involves extracting elements from the struct 46 /// and inserting them into an LLVM array. These extra data-movement 47 /// operations should be canonicalized away by the LLVM backend. 48 static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, 49 Type resultType, Value intrinsicResult, 50 RewriterBase &rewriter) { 51 MLIRContext *ctx = rewriter.getContext(); 52 auto structType = intrinsicResultType.dyn_cast<LLVM::LLVMStructType>(); 53 auto arrayType = resultType.dyn_cast<LLVM::LLVMArrayType>(); 54 Type i32Ty = rewriter.getI32Type(); 55 Type f64Ty = rewriter.getF64Type(); 56 Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2); 57 Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2); 58 Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2); 59 60 auto makeConst = [&](int32_t index) -> Value { 61 return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32), 62 rewriter.getI32IntegerAttr(index)); 63 }; 64 65 if (arrayType) { 66 SmallVector<Value, 4> elements; 67 68 if (arrayType.getElementType() == f16x2Ty) { 69 for (unsigned i = 0; i < structType.getBody().size(); i++) { 70 elements.push_back(rewriter.create<LLVM::ExtractValueOp>( 71 loc, structType.getBody()[i], intrinsicResult, 72 rewriter.getI64ArrayAttr(i))); 73 } 74 } 75 76 // The intrinsic returns i32 and f64 values as individual scalars. We need 77 // to extract them from the struct and pack them into vectors. 78 if (arrayType.getElementType() == i32x2Ty || 79 arrayType.getElementType() == f64x2Ty) { 80 Value vec = 81 rewriter.create<LLVM::UndefOp>(loc, arrayType.getElementType()); 82 for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) { 83 Value x1 = rewriter.create<LLVM::ExtractValueOp>( 84 loc, structType.getBody()[i * 2], intrinsicResult, 85 rewriter.getI64ArrayAttr(i * 2)); 86 Value x2 = rewriter.create<LLVM::ExtractValueOp>( 87 loc, structType.getBody()[i * 2 + 1], intrinsicResult, 88 rewriter.getI64ArrayAttr(i * 2 + 1)); 89 vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec, 90 x1, makeConst(0)); 91 vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec, 92 x2, makeConst(1)); 93 } 94 elements.push_back(vec); 95 } 96 97 // Create the final vectorized result. 98 Value result = rewriter.create<LLVM::UndefOp>(loc, arrayType); 99 for (const auto &el : llvm::enumerate(elements)) { 100 result = rewriter.create<LLVM::InsertValueOp>( 101 loc, arrayType, result, el.value(), 102 rewriter.getI64ArrayAttr(el.index())); 103 } 104 return result; 105 } 106 107 return intrinsicResult; 108 } 109 110 /// The `gpu.mma.sync` converter below expects matrix fragment operands to be 111 /// given as 2D `vectors` where the rows are 32b or 64b wide. The 112 /// `nvvm.mma.sync` op expects these argments to be a given in a long list of 113 /// scalars of certain types. This function helps unpack the `vector` arguments 114 /// and cast them to the types expected by `nvvm.mma.sync`. 115 static SmallVector<Value> unpackOperandVector(RewriterBase &rewriter, 116 Location loc, Value operand) { 117 SmallVector<Value> result; 118 Type i32Ty = rewriter.getI32Type(); 119 Type f64Ty = rewriter.getF64Type(); 120 Type i8Ty = rewriter.getI8Type(); 121 Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4); 122 auto arrayTy = operand.getType().cast<LLVM::LLVMArrayType>(); 123 124 for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) { 125 Value toUse = rewriter.create<LLVM::ExtractValueOp>( 126 loc, arrayTy.getElementType(), operand, rewriter.getI64ArrayAttr(i)); 127 128 // For 4xi8 vectors, the intrinsic expects these to be provided as i32 129 // scalar types. 130 if (arrayTy.getElementType() == i8x4Ty) { 131 result.push_back( 132 rewriter.create<LLVM::BitcastOp>(loc, rewriter.getI32Type(), toUse)); 133 continue; 134 } 135 136 // For some element types (i32, f64), we need to unpack the inner 137 // vector/array type as well because the intrinsic expects individual 138 // scalars to be provided. 139 VectorType innerArrayTy = arrayTy.getElementType().dyn_cast<VectorType>(); 140 if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty || 141 innerArrayTy.getElementType() == f64Ty)) { 142 for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements(); 143 idx < innerSize; idx++) { 144 result.push_back(rewriter.create<LLVM::ExtractElementOp>( 145 loc, toUse, 146 rewriter.create<LLVM::ConstantOp>( 147 loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(idx)))); 148 } 149 continue; 150 } 151 result.push_back(toUse); 152 } 153 return result; 154 } 155 156 namespace { 157 158 struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> { 159 using ConvertOpToLLVMPattern<nvgpu::LdMatrixOp>::ConvertOpToLLVMPattern; 160 161 LogicalResult 162 matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor, 163 ConversionPatternRewriter &rewriter) const override { 164 MLIRContext *ctx = getContext(); 165 Location loc = op->getLoc(); 166 167 // The result type of ldmatrix will always be a struct of 32bit integer 168 // registers if more than one 32bit value is returned. Otherwise, the result 169 // is a single i32. The result type of the GPU operation is always a vector 170 // of shape (NumRegisters, VectorRegister) where VectorRegister is the 171 // vector type of the result and always 32 bits long. We bitcast the result 172 // of the NVVM::LdMatrix to this vector type. 173 auto vectorResultType = op->getResultTypes()[0].dyn_cast<VectorType>(); 174 if (!vectorResultType) { 175 return failure(); 176 } 177 Type innerVectorType = LLVM::getFixedVectorType( 178 vectorResultType.getElementType(), vectorResultType.getDimSize(1)); 179 180 int64_t num32BitRegs = vectorResultType.getDimSize(0); 181 182 Type ldMatrixResultType; 183 if (num32BitRegs > 1) { 184 ldMatrixResultType = LLVM::LLVMStructType::getLiteral( 185 ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type())); 186 } else { 187 ldMatrixResultType = rewriter.getI32Type(); 188 } 189 190 auto srcMemrefType = op.srcMemref().getType().cast<MemRefType>(); 191 Value srcPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.srcMemref(), 192 adaptor.indices(), rewriter); 193 Value ldMatrixResult = rewriter.create<NVVM::LdMatrixOp>( 194 loc, ldMatrixResultType, srcPtr, 195 /*num=*/op.numTiles(), 196 /*layout=*/op.transpose() ? NVVM::MMALayout::col 197 : NVVM::MMALayout::row); 198 199 // The ldmatrix operation returns either a single i32 value or a struct of 200 // i32 values. Here we unpack those values and cast them back to their 201 // actual vector type (still of width 32b) and repack them into a result 202 // struct. 203 Type finalResultType = typeConverter->convertType(vectorResultType); 204 Value result = rewriter.create<LLVM::UndefOp>(loc, finalResultType); 205 for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) { 206 Value i32Register = num32BitRegs > 1 207 ? rewriter.create<LLVM::ExtractValueOp>( 208 loc, rewriter.getI32Type(), ldMatrixResult, 209 rewriter.getI64ArrayAttr(i)) 210 : ldMatrixResult; 211 Value casted = 212 rewriter.create<LLVM::BitcastOp>(loc, innerVectorType, i32Register); 213 result = rewriter.create<LLVM::InsertValueOp>( 214 loc, finalResultType, result, casted, rewriter.getI64ArrayAttr(i)); 215 } 216 217 rewriter.replaceOp(op, result); 218 return success(); 219 } 220 }; 221 222 struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> { 223 using ConvertOpToLLVMPattern<nvgpu::MmaSyncOp>::ConvertOpToLLVMPattern; 224 225 LogicalResult 226 matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor, 227 ConversionPatternRewriter &rewriter) const override { 228 Location loc = op->getLoc(); 229 // Get the shapes of the MMAMatrix type being used. The shapes will 230 // choose which intrinsic this op will be lowered to. 231 auto aType = op.matrixA().getType().cast<VectorType>(); 232 233 int64_t m = op.mmaShape()[0].cast<IntegerAttr>().getInt(); 234 int64_t n = op.mmaShape()[1].cast<IntegerAttr>().getInt(); 235 int64_t k = op.mmaShape()[2].cast<IntegerAttr>().getInt(); 236 std::array<int64_t, 3> gemmShape{m, n, k}; 237 238 SmallVector<Value> matA = 239 unpackOperandVector(rewriter, loc, adaptor.matrixA()); 240 SmallVector<Value> matB = 241 unpackOperandVector(rewriter, loc, adaptor.matrixB()); 242 SmallVector<Value> matC = 243 unpackOperandVector(rewriter, loc, adaptor.matrixC()); 244 245 NVVM::MMATypes ptxTypeA; 246 NVVM::MMATypes ptxTypeB; 247 Optional<NVVM::MMAIntOverflow> overflow(llvm::None); 248 if (aType.getElementType().isInteger(8)) { 249 ptxTypeA = NVVM::MMATypes::s8; 250 ptxTypeB = NVVM::MMATypes::s8; 251 overflow = NVVM::MMAIntOverflow::satfinite; 252 253 } else if (aType.getElementType().isF16()) { 254 ptxTypeA = NVVM::MMATypes::f16; 255 ptxTypeB = NVVM::MMATypes::f16; 256 } else if (aType.getElementType().isF64()) { 257 ptxTypeA = NVVM::MMATypes::f64; 258 ptxTypeB = NVVM::MMATypes::f64; 259 } else { 260 return op->emitError("could not deduce operand PTX types"); 261 } 262 263 Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]); 264 Type intrinsicResTy = inferIntrinsicResultType( 265 typeConverter->convertType(op->getResultTypes()[0])); 266 Value intrinsicResult = rewriter.create<NVVM::MmaOp>( 267 op.getLoc(), intrinsicResTy, matA, matB, matC, 268 /*shape=*/gemmShape, 269 /*b1Op=*/llvm::None, 270 /*intOverflow=*/overflow, 271 /*multiplicandPtxTypes=*/ 272 std::array<NVVM::MMATypes, 2>{ptxTypeA, ptxTypeB}, 273 /*multiplicandLayouts=*/ 274 std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row, 275 NVVM::MMALayout::col}); 276 rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, 277 desiredRetTy, intrinsicResult, 278 rewriter)); 279 return success(); 280 } 281 }; 282 283 struct ConvertNVGPUToNVVMPass 284 : public ConvertNVGPUToNVVMBase<ConvertNVGPUToNVVMPass> { 285 ConvertNVGPUToNVVMPass() = default; 286 287 void runOnOperation() override { 288 RewritePatternSet patterns(&getContext()); 289 LLVMTypeConverter converter(&getContext()); 290 populateNVGPUToNVVMConversionPatterns(converter, patterns); 291 LLVMConversionTarget target(getContext()); 292 target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); 293 target.addLegalDialect<::mlir::NVVM::NVVMDialect>(); 294 if (failed(applyPartialConversion(getOperation(), target, 295 std::move(patterns)))) 296 signalPassFailure(); 297 } 298 }; 299 300 } // namespace 301 void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, 302 RewritePatternSet &patterns) { 303 patterns.add<MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM>(converter); 304 } 305 306 std::unique_ptr<Pass> mlir::createConvertNVGPUToNVVMPass() { 307 return std::make_unique<ConvertNVGPUToNVVMPass>(); 308 } 309