//===- NVGPUDialect.cpp - MLIR NVGPU ops implementation -------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements the NVGPU dialect and its operations. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::nvgpu; #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.cpp.inc" void nvgpu::NVGPUDialect::initialize() { addTypes(); addOperations< #define GET_OP_LIST #include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc" >(); } Type NVGPUDialect::parseType(DialectAsmParser &parser) const { // Parse the main keyword for the type. StringRef keyword; if (parser.parseKeyword(&keyword)) return Type(); MLIRContext *context = getContext(); // Handle 'device async token' types. if (keyword == "device.async.token") return DeviceAsyncTokenType::get(context); parser.emitError(parser.getNameLoc(), "unknown nvgpu type: " + keyword); return Type(); } void NVGPUDialect::printType(Type type, DialectAsmPrinter &os) const { TypeSwitch(type) .Case([&](Type) { os << "device.async.token"; }) .Default([](Type) { llvm_unreachable("unexpected 'nvgpu' type kind"); }); } //===----------------------------------------------------------------------===// // NVGPU_DeviceAsyncCopyOp //===----------------------------------------------------------------------===// /// Return true if the last dimension of the MemRefType has unit stride. Also /// return true for memrefs with no strides. static bool isLastMemrefDimUnitStride(MemRefType type) { int64_t offset; SmallVector strides; if (failed(getStridesAndOffset(type, strides, offset))) { return false; } return strides.back() == 1; } LogicalResult DeviceAsyncCopyOp::verify() { auto srcMemref = getSrc().getType().cast(); auto dstMemref = getDst().getType().cast(); unsigned workgroupAddressSpace = gpu::GPUDialect::getWorkgroupAddressSpace(); if (!isLastMemrefDimUnitStride(srcMemref)) return emitError("source memref most minor dim must have unit stride"); if (!isLastMemrefDimUnitStride(dstMemref)) return emitError("destination memref most minor dim must have unit stride"); if (dstMemref.getMemorySpaceAsInt() != workgroupAddressSpace) return emitError("destination memref must have memory space ") << workgroupAddressSpace; if (dstMemref.getElementType() != srcMemref.getElementType()) return emitError("source and destination must have the same element type"); if (size_t(srcMemref.getRank()) != getSrcIndices().size()) return emitOpError() << "expected " << srcMemref.getRank() << " source indices, got " << getSrcIndices().size(); if (size_t(dstMemref.getRank()) != getDstIndices().size()) return emitOpError() << "expected " << dstMemref.getRank() << " destination indices, got " << getDstIndices().size(); return success(); } //===----------------------------------------------------------------------===// // NVGPU_MmaSyncOp //===----------------------------------------------------------------------===// LogicalResult MmaSyncOp::verify() { // Fundamental tensor core mma.sync op // For F32 (TF32), F16, S8, and S4 data types fundamental tensor core // operation is of shape: 8-by-8-by-128b. F64 is an exception. The // verification for mma.sync covering various shapes and data types is based // on the fundamental tensor core operionation. constexpr int kThreads = 32; // 32 threads per warp int64_t shapeM = 8; int64_t shapeN = 8; int64_t shapeK; // set based on data type (128b for all data types except F64) // Number of elements A, B, and C per thread per fundamental tensor core tile int64_t numElementA; // set based on data type (32b except F64) int64_t numElementB; // set based on data type (32b except F64) int64_t numElementC{2}; // two accumulator elements per fundamental tile // nvgpu.mma.sync vector operands (per thread) auto aVector = getMatrixA().getType().cast(); auto bVector = getMatrixB().getType().cast(); auto cVector = getMatrixC().getType().cast(); // vector shapes ArrayRef aShape = aVector.getShape(); ArrayRef bShape = bVector.getShape(); ArrayRef cShape = cVector.getShape(); // vector element type Type aType = aVector.getElementType(); // nvgpu.mma.sync shape (per 32 threads or per warp) int64_t m = getMmaShape()[0].cast().getInt(); int64_t n = getMmaShape()[1].cast().getInt(); int64_t k = getMmaShape()[2].cast().getInt(); if (aType.isF64()) { // exception to 8-by-8-128b fundamental tensor core tile size shapeK = 4; numElementA = 1; numElementB = 1; } else if (aType.isF32() || aType.isBF16() || aType.isF16() || aType.isInteger(8) || aType.isInteger(4)) { // 8-by-8-128b fundamental tensor core tile size int operandBitwidth = aType.getIntOrFloatBitWidth(); shapeK = 128 / operandBitwidth; // 128b wide shapeK numElementA = 32 / operandBitwidth; // 32b wide operand A numElementB = 32 / operandBitwidth; // 32b wide operand B } else { return emitError() << "expected input data type (i4,i8,f16,bf16,tf32,f64) " "supported by nvgpu.mma.sync"; } // // Basic verification // // verify warp-wide size for vector a if (aShape[0] * aShape[1] * kThreads != m * k) return emitOpError() << "expected " << m * k << " warp-wide matrix A elements"; // verify warp-wide size for vector b if (bShape[0] * bShape[1] * kThreads != k * n) return emitOpError() << "expected " << k * n << " warp-wide matrix B elements"; // verify warp-wide size for vector c if (cShape[0] * cShape[1] * kThreads != m * n) return emitOpError() << "expected " << m * n << " warp-wide matrix C elements"; // // Extended verification // // tiles of fundamental tensor core operations int64_t mTile = m / shapeM; int64_t nTile = n / shapeN; int64_t kTile = k / shapeK; // verify shape of aVector if (!((aShape[0] == mTile * kTile) && (aShape[1] == numElementA))) return emitOpError() << "expected matrix A to be shaped (" << mTile * kTile << " x " << numElementA << ")"; // verify shape of bVector if (!((bShape[0] == kTile * nTile) && (bShape[1] == numElementB))) return emitOpError() << "expected matrix B to be shaped (" << kTile * nTile << " x " << numElementB << ")"; // verify shape of cVector if (!((cShape[0] == mTile * nTile) && (cShape[1] == numElementC))) return emitOpError() << "expected matrix C to be shaped (" << mTile * nTile << " x " << numElementC << ")"; return success(); } //===----------------------------------------------------------------------===// // NVGPU_LdMatrixOp //===----------------------------------------------------------------------===// LogicalResult LdMatrixOp::verify() { // ldmatrix reads data from source in shared memory auto srcMemref = getSrcMemref().getType().cast(); // ldmatrix writes data to result/destination in vector registers auto resVector = getRes().getType().cast(); // vector register shape, element type, and bitwidth ArrayRef resShape = resVector.getShape(); Type resType = resVector.getElementType(); int64_t elementBitWidth = resType.getIntOrFloatBitWidth(); // ldmatrix loads 32 bits into vector registers per 8-by-8 tile per thread int64_t numElementsPer32b = 32 / elementBitWidth; // number of 8-by-8 tiles int64_t numTiles = getNumTiles(); // transpose elements in vector registers at 16b granularity when true bool isTranspose = getTranspose(); // address space id for shared memory unsigned smemAddressSpace = gpu::GPUDialect::getWorkgroupAddressSpace(); // // verification // if (!(srcMemref.getMemorySpaceAsInt() == smemAddressSpace)) return emitError() << "expected nvgpu.ldmatrix srcMemref must have memory space " << smemAddressSpace; if (elementBitWidth > 32) return emitError() << "nvgpu.ldmatrix works for 32b or lower"; if (isTranspose && !(elementBitWidth == 16)) return emitError() << "nvgpu.ldmatrix transpose works only at 16b granularity"; if (!(resShape[1] == numElementsPer32b)) return emitError() << "expected vector register shape[1] = " << numElementsPer32b; if (!(resShape[0] == numTiles)) return emitError() << "expected vector register shape[0] and numTiles to match"; return success(); } #define GET_OP_CLASSES #include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc"