//===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h" #include "../PassDetail.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" using namespace mlir; /// Returns the type for the intrinsic given the vectorResultType of the /// `gpu.mma.sync` operation. static Type inferIntrinsicResultType(Type vectorResultType) { MLIRContext *ctx = vectorResultType.getContext(); auto a = vectorResultType.cast(); auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2); auto i32Ty = IntegerType::get(ctx, 32); auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2); Type f64Ty = Float64Type::get(ctx); Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2); Type f32Ty = Float32Type::get(ctx); Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2); if (a.getElementType() == f16x2Ty) { return LLVM::LLVMStructType::getLiteral( ctx, SmallVector(a.getNumElements(), f16x2Ty)); } if (a.getElementType() == i32x2Ty) { return LLVM::LLVMStructType::getLiteral( ctx, SmallVector(static_cast(a.getNumElements()) * 2, i32Ty)); } if (a.getElementType() == f64x2Ty) { return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty}); } if (a.getElementType() == f32x2Ty) { return LLVM::LLVMStructType::getLiteral( ctx, SmallVector(static_cast(a.getNumElements()) * 2, f32Ty)); } if (a.getElementType() == LLVM::getFixedVectorType(f32Ty, 1)) { return LLVM::LLVMStructType::getLiteral( ctx, SmallVector(static_cast(a.getNumElements()), f32Ty)); } return vectorResultType; } /// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is /// always an LLVM struct) into a fragment that is compatible with the vector /// type of this operation. This involves extracting elements from the struct /// and inserting them into an LLVM array. These extra data-movement /// operations should be canonicalized away by the LLVM backend. static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, Type resultType, Value intrinsicResult, RewriterBase &rewriter) { MLIRContext *ctx = rewriter.getContext(); auto structType = intrinsicResultType.dyn_cast(); auto arrayType = resultType.dyn_cast(); Type i32Ty = rewriter.getI32Type(); Type f32Ty = rewriter.getF32Type(); Type f64Ty = rewriter.getF64Type(); Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2); Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2); Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2); Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2); Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1); auto makeConst = [&](int32_t index) -> Value { return rewriter.create(loc, IntegerType::get(ctx, 32), rewriter.getI32IntegerAttr(index)); }; if (arrayType) { SmallVector elements; // The intrinsic returns 32-bit wide elements in a form which can be // directly bitcasted and inserted into the result vector. if (arrayType.getElementType() == f16x2Ty || arrayType.getElementType() == f32x1Ty) { for (unsigned i = 0; i < structType.getBody().size(); i++) { Value el = rewriter.create( loc, structType.getBody()[i], intrinsicResult, rewriter.getI64ArrayAttr(i)); el = rewriter.createOrFold( loc, arrayType.getElementType(), el); elements.push_back(el); } } // The intrinsic returns i32, f64, and f32 values as individual scalars, // even when the result is notionally a 64-bit wide element (e.g. f32x2). We // need to extract them from the struct and pack them into the 64-bit wide // rows of the vector result. if (arrayType.getElementType() == i32x2Ty || arrayType.getElementType() == f64x2Ty || arrayType.getElementType() == f32x2Ty) { for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) { Value vec = rewriter.create(loc, arrayType.getElementType()); Value x1 = rewriter.create( loc, structType.getBody()[i * 2], intrinsicResult, rewriter.getI64ArrayAttr(i * 2)); Value x2 = rewriter.create( loc, structType.getBody()[i * 2 + 1], intrinsicResult, rewriter.getI64ArrayAttr(i * 2 + 1)); vec = rewriter.create(loc, vec.getType(), vec, x1, makeConst(0)); vec = rewriter.create(loc, vec.getType(), vec, x2, makeConst(1)); elements.push_back(vec); } } // Create the final vectorized result. Value result = rewriter.create(loc, arrayType); for (const auto &el : llvm::enumerate(elements)) { result = rewriter.create( loc, arrayType, result, el.value(), rewriter.getI64ArrayAttr(el.index())); } return result; } return intrinsicResult; } /// The `gpu.mma.sync` converter below expects matrix fragment operands to be /// given as 2D `vectors` where the rows are 32b or 64b wide. The /// `nvvm.mma.sync` op expects these argments to be a given in a long list of /// scalars of certain types. This function helps unpack the `vector` arguments /// and cast them to the types expected by `nvvm.mma.sync`. static SmallVector unpackOperandVector(RewriterBase &rewriter, Location loc, Value operand, NVVM::MMATypes operandPtxType) { SmallVector result; Type i32Ty = rewriter.getI32Type(); Type f64Ty = rewriter.getF64Type(); Type f32Ty = rewriter.getF32Type(); Type i8Ty = rewriter.getI8Type(); Type i4Ty = rewriter.getIntegerType(4); Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4); Type i4x8Ty = LLVM::getFixedVectorType(i4Ty, 8); Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1); auto arrayTy = operand.getType().cast(); for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) { Value toUse = rewriter.create( loc, arrayTy.getElementType(), operand, rewriter.getI64ArrayAttr(i)); // For 4xi8 vectors, the intrinsic expects these to be provided as i32 // scalar types. if (arrayTy.getElementType() == i8x4Ty || arrayTy.getElementType() == i4x8Ty || (arrayTy.getElementType() == f32x1Ty && operandPtxType == NVVM::MMATypes::tf32)) { result.push_back( rewriter.create(loc, rewriter.getI32Type(), toUse)); continue; } // For some element types (i32, f32, f64), we need to unpack the inner // vector/array type as well because the intrinsic expects individual // scalars to be provided. VectorType innerArrayTy = arrayTy.getElementType().dyn_cast(); if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty || innerArrayTy.getElementType() == f64Ty || innerArrayTy.getElementType() == f32Ty)) { for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements(); idx < innerSize; idx++) { result.push_back(rewriter.create( loc, toUse, rewriter.create( loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(idx)))); } continue; } result.push_back(toUse); } return result; } namespace { struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { MLIRContext *ctx = getContext(); Location loc = op->getLoc(); // The result type of ldmatrix will always be a struct of 32bit integer // registers if more than one 32bit value is returned. Otherwise, the result // is a single i32. The result type of the GPU operation is always a vector // of shape (NumRegisters, VectorRegister) where VectorRegister is the // vector type of the result and always 32 bits long. We bitcast the result // of the NVVM::LdMatrix to this vector type. auto vectorResultType = op->getResultTypes()[0].dyn_cast(); if (!vectorResultType) { return failure(); } Type innerVectorType = LLVM::getFixedVectorType( vectorResultType.getElementType(), vectorResultType.getDimSize(1)); int64_t num32BitRegs = vectorResultType.getDimSize(0); Type ldMatrixResultType; if (num32BitRegs > 1) { ldMatrixResultType = LLVM::LLVMStructType::getLiteral( ctx, SmallVector(num32BitRegs, rewriter.getI32Type())); } else { ldMatrixResultType = rewriter.getI32Type(); } auto srcMemrefType = op.getSrcMemref().getType().cast(); Value srcPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrcMemref(), adaptor.getIndices(), rewriter); Value ldMatrixResult = rewriter.create( loc, ldMatrixResultType, srcPtr, /*num=*/op.getNumTiles(), /*layout=*/op.getTranspose() ? NVVM::MMALayout::col : NVVM::MMALayout::row); // The ldmatrix operation returns either a single i32 value or a struct of // i32 values. Here we unpack those values and cast them back to their // actual vector type (still of width 32b) and repack them into a result // struct. Type finalResultType = typeConverter->convertType(vectorResultType); Value result = rewriter.create(loc, finalResultType); for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) { Value i32Register = num32BitRegs > 1 ? rewriter.create( loc, rewriter.getI32Type(), ldMatrixResult, rewriter.getI64ArrayAttr(i)) : ldMatrixResult; Value casted = rewriter.create(loc, innerVectorType, i32Register); result = rewriter.create( loc, finalResultType, result, casted, rewriter.getI64ArrayAttr(i)); } rewriter.replaceOp(op, result); return success(); } }; struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); // Get the shapes of the MMAMatrix type being used. The shapes will // choose which intrinsic this op will be lowered to. auto aType = op.getMatrixA().getType().cast(); auto cType = op.getMatrixC().getType().cast(); int64_t m = op.getMmaShape()[0].cast().getInt(); int64_t n = op.getMmaShape()[1].cast().getInt(); int64_t k = op.getMmaShape()[2].cast().getInt(); std::array gemmShape{m, n, k}; NVVM::MMATypes ptxTypeA; NVVM::MMATypes ptxTypeB; Optional ptxTypeC = NVVM::MmaOp::inferOperandMMAType( cType.getElementType(), /*isAccumulator=*/true); if (!ptxTypeC) { return op->emitError( "could not infer the PTX type for the accumulator/result"); } Optional overflow(llvm::None); if (aType.getElementType().isInteger(8)) { ptxTypeA = NVVM::MMATypes::s8; ptxTypeB = NVVM::MMATypes::s8; overflow = NVVM::MMAIntOverflow::satfinite; } else if (aType.getElementType().isInteger(4)) { ptxTypeA = NVVM::MMATypes::s4; ptxTypeB = NVVM::MMATypes::s4; overflow = NVVM::MMAIntOverflow::satfinite; } else if (aType.getElementType().isF16()) { ptxTypeA = NVVM::MMATypes::f16; ptxTypeB = NVVM::MMATypes::f16; } else if (aType.getElementType().isF64()) { ptxTypeA = NVVM::MMATypes::f64; ptxTypeB = NVVM::MMATypes::f64; } else if (aType.getElementType().isF32()) { ptxTypeA = NVVM::MMATypes::tf32; ptxTypeB = NVVM::MMATypes::tf32; } else { return op->emitError("could not deduce operand PTX types"); } SmallVector matA = unpackOperandVector(rewriter, loc, adaptor.getMatrixA(), ptxTypeA); SmallVector matB = unpackOperandVector(rewriter, loc, adaptor.getMatrixB(), ptxTypeB); SmallVector matC = unpackOperandVector(rewriter, loc, adaptor.getMatrixC(), *ptxTypeC); Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]); Type intrinsicResTy = inferIntrinsicResultType( typeConverter->convertType(op->getResultTypes()[0])); Value intrinsicResult = rewriter.create( op.getLoc(), intrinsicResTy, matA, matB, matC, /*shape=*/gemmShape, /*b1Op=*/llvm::None, /*intOverflow=*/overflow, /*multiplicandPtxTypes=*/ std::array{ptxTypeA, ptxTypeB}, /*multiplicandLayouts=*/ std::array{NVVM::MMALayout::row, NVVM::MMALayout::col}); rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, desiredRetTy, intrinsicResult, rewriter)); return success(); } }; struct ConvertNVGPUToNVVMPass : public ConvertNVGPUToNVVMBase { ConvertNVGPUToNVVMPass() = default; void runOnOperation() override { RewritePatternSet patterns(&getContext()); LLVMTypeConverter converter(&getContext()); /// device-side async tokens cannot be materialized in nvvm. We just convert /// them to a dummy i32 type in order to easily drop them during conversion. converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type { return converter.convertType(IntegerType::get(type.getContext(), 32)); }); populateNVGPUToNVVMConversionPatterns(converter, patterns); LLVMConversionTarget target(getContext()); target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); target.addLegalDialect<::mlir::NVVM::NVVMDialect>(); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } }; struct NVGPUAsyncCopyLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); auto dstMemrefType = op.getDst().getType().cast(); Value dstPtr = getStridedElementPtr(loc, dstMemrefType, adaptor.getDst(), adaptor.getDstIndices(), rewriter); auto i8Ty = IntegerType::get(op.getContext(), 8); auto dstPointerType = LLVM::LLVMPointerType::get(i8Ty, dstMemrefType.getMemorySpaceAsInt()); dstPtr = rewriter.create(loc, dstPointerType, dstPtr); auto srcMemrefType = op.getSrc().getType().cast(); Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(), adaptor.getSrcIndices(), rewriter); auto srcPointerType = LLVM::LLVMPointerType::get(i8Ty, srcMemrefType.getMemorySpaceAsInt()); scrPtr = rewriter.create(loc, srcPointerType, scrPtr); // Intrinsics takes a global pointer so we need an address space cast. auto srcPointerGlobalType = LLVM::LLVMPointerType::get( i8Ty, NVVM::NVVMMemorySpace::kGlobalMemorySpace); scrPtr = rewriter.create(loc, srcPointerGlobalType, scrPtr); int64_t numElements = adaptor.getNumElements().getZExtValue(); int64_t sizeInBytes = (dstMemrefType.getElementTypeBitWidth() * numElements) / 8; // bypass L1 is only supported for byte sizes of 16, we drop the hint // otherwise. UnitAttr bypassL1 = sizeInBytes == 16 ? adaptor.getBypassL1Attr() : UnitAttr(); rewriter.create( loc, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes), bypassL1); // Drop the result token. Value zero = rewriter.create( op->getLoc(), IntegerType::get(op.getContext(), 32), rewriter.getI32IntegerAttr(0)); rewriter.replaceOp(op, zero); return success(); } }; struct NVGPUAsyncCreateGroupLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.create(op.getLoc()); // Drop the result token. Value zero = rewriter.create( op->getLoc(), IntegerType::get(op.getContext(), 32), rewriter.getI32IntegerAttr(0)); rewriter.replaceOp(op, zero); return success(); } }; struct NVGPUAsyncWaitLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // If numGroup is not present pick 0 as a conservative correct value. int32_t numGroups = adaptor.getNumGroups().value_or(0); rewriter.create(op.getLoc(), numGroups); rewriter.eraseOp(op); return success(); } }; } // namespace void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns) { patterns.add( converter); } std::unique_ptr mlir::createConvertNVGPUToNVVMPass() { return std::make_unique(); }