1 //===- NVGPUDialect.cpp - MLIR NVGPU ops implementation -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the NVGPU dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" 14 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 15 #include "mlir/IR/Builders.h" 16 #include "mlir/IR/DialectImplementation.h" 17 #include "mlir/IR/OpImplementation.h" 18 #include "mlir/IR/TypeUtilities.h" 19 #include "llvm/ADT/TypeSwitch.h" 20 21 using namespace mlir; 22 using namespace mlir::nvgpu; 23 24 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.cpp.inc" 25 26 void nvgpu::NVGPUDialect::initialize() { 27 addTypes<DeviceAsyncTokenType>(); 28 addOperations< 29 #define GET_OP_LIST 30 #include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc" 31 >(); 32 } 33 34 Type NVGPUDialect::parseType(DialectAsmParser &parser) const { 35 // Parse the main keyword for the type. 36 StringRef keyword; 37 if (parser.parseKeyword(&keyword)) 38 return Type(); 39 MLIRContext *context = getContext(); 40 // Handle 'device async token' types. 41 if (keyword == "device.async.token") 42 return DeviceAsyncTokenType::get(context); 43 44 parser.emitError(parser.getNameLoc(), "unknown nvgpu type: " + keyword); 45 return Type(); 46 } 47 48 void NVGPUDialect::printType(Type type, DialectAsmPrinter &os) const { 49 TypeSwitch<Type>(type) 50 .Case<DeviceAsyncTokenType>([&](Type) { os << "device.async.token"; }) 51 .Default([](Type) { llvm_unreachable("unexpected 'nvgpu' type kind"); }); 52 } 53 //===----------------------------------------------------------------------===// 54 // NVGPU_DeviceAsyncCopyOp 55 //===----------------------------------------------------------------------===// 56 57 /// Return true if the last dimension of the MemRefType has unit stride. Also 58 /// return true for memrefs with no strides. 59 static bool isLastMemrefDimUnitStride(MemRefType type) { 60 int64_t offset; 61 SmallVector<int64_t> strides; 62 if (failed(getStridesAndOffset(type, strides, offset))) { 63 return false; 64 } 65 return strides.back() == 1; 66 } 67 68 LogicalResult DeviceAsyncCopyOp::verify() { 69 auto srcMemref = getSrc().getType().cast<MemRefType>(); 70 auto dstMemref = getDst().getType().cast<MemRefType>(); 71 unsigned workgroupAddressSpace = gpu::GPUDialect::getWorkgroupAddressSpace(); 72 if (!isLastMemrefDimUnitStride(srcMemref)) 73 return emitError("source memref most minor dim must have unit stride"); 74 if (!isLastMemrefDimUnitStride(dstMemref)) 75 return emitError("destination memref most minor dim must have unit stride"); 76 if (dstMemref.getMemorySpaceAsInt() != workgroupAddressSpace) 77 return emitError("destination memref must have memory space ") 78 << workgroupAddressSpace; 79 if (dstMemref.getElementType() != srcMemref.getElementType()) 80 return emitError("source and destination must have the same element type"); 81 if (size_t(srcMemref.getRank()) != getSrcIndices().size()) 82 return emitOpError() << "expected " << srcMemref.getRank() 83 << " source indices, got " << getSrcIndices().size(); 84 if (size_t(dstMemref.getRank()) != getDstIndices().size()) 85 return emitOpError() << "expected " << dstMemref.getRank() 86 << " destination indices, got " 87 << getDstIndices().size(); 88 return success(); 89 } 90 91 //===----------------------------------------------------------------------===// 92 // NVGPU_MmaSyncOp 93 //===----------------------------------------------------------------------===// 94 95 LogicalResult MmaSyncOp::verify() { 96 97 // Fundamental tensor core mma.sync op 98 // For F32 (TF32), F16, S8, and S4 data types fundamental tensor core 99 // operation is of shape: 8-by-8-by-128b. F64 is an exception. The 100 // verification for mma.sync covering various shapes and data types is based 101 // on the fundamental tensor core operionation. 102 constexpr int kThreads = 32; // 32 threads per warp 103 int64_t shapeM = 8; 104 int64_t shapeN = 8; 105 int64_t shapeK; // set based on data type (128b for all data types except F64) 106 107 // Number of elements A, B, and C per thread per fundamental tensor core tile 108 int64_t numElementA; // set based on data type (32b except F64) 109 int64_t numElementB; // set based on data type (32b except F64) 110 int64_t numElementC{2}; // two accumulator elements per fundamental tile 111 112 // nvgpu.mma.sync vector operands (per thread) 113 auto aVector = getMatrixA().getType().cast<VectorType>(); 114 auto bVector = getMatrixB().getType().cast<VectorType>(); 115 auto cVector = getMatrixC().getType().cast<VectorType>(); 116 117 // vector shapes 118 ArrayRef<int64_t> aShape = aVector.getShape(); 119 ArrayRef<int64_t> bShape = bVector.getShape(); 120 ArrayRef<int64_t> cShape = cVector.getShape(); 121 122 // vector element type 123 Type aType = aVector.getElementType(); 124 125 // nvgpu.mma.sync shape (per 32 threads or per warp) 126 int64_t m = getMmaShape()[0].cast<IntegerAttr>().getInt(); 127 int64_t n = getMmaShape()[1].cast<IntegerAttr>().getInt(); 128 int64_t k = getMmaShape()[2].cast<IntegerAttr>().getInt(); 129 130 if (aType.isF64()) { 131 // exception to 8-by-8-128b fundamental tensor core tile size 132 shapeK = 4; 133 numElementA = 1; 134 numElementB = 1; 135 } else if (aType.isF32() || aType.isBF16() || aType.isF16() || 136 aType.isInteger(8) || aType.isInteger(4)) { 137 // 8-by-8-128b fundamental tensor core tile size 138 int operandBitwidth = aType.getIntOrFloatBitWidth(); 139 shapeK = 128 / operandBitwidth; // 128b wide shapeK 140 numElementA = 32 / operandBitwidth; // 32b wide operand A 141 numElementB = 32 / operandBitwidth; // 32b wide operand B 142 } else { 143 return emitError() << "expected input data type (i4,i8,f16,bf16,tf32,f64) " 144 "supported by nvgpu.mma.sync"; 145 } 146 147 // 148 // Basic verification 149 // 150 151 // verify warp-wide size for vector a 152 if (aShape[0] * aShape[1] * kThreads != m * k) 153 return emitOpError() << "expected " << m * k 154 << " warp-wide matrix A elements"; 155 156 // verify warp-wide size for vector b 157 if (bShape[0] * bShape[1] * kThreads != k * n) 158 return emitOpError() << "expected " << k * n 159 << " warp-wide matrix B elements"; 160 161 // verify warp-wide size for vector c 162 if (cShape[0] * cShape[1] * kThreads != m * n) 163 return emitOpError() << "expected " << m * n 164 << " warp-wide matrix C elements"; 165 166 // 167 // Extended verification 168 // 169 170 // tiles of fundamental tensor core operations 171 int64_t mTile = m / shapeM; 172 int64_t nTile = n / shapeN; 173 int64_t kTile = k / shapeK; 174 175 // verify shape of aVector 176 if (!((aShape[0] == mTile * kTile) && (aShape[1] == numElementA))) 177 return emitOpError() << "expected matrix A to be shaped (" << mTile * kTile 178 << " x " << numElementA << ")"; 179 180 // verify shape of bVector 181 if (!((bShape[0] == kTile * nTile) && (bShape[1] == numElementB))) 182 return emitOpError() << "expected matrix B to be shaped (" << kTile * nTile 183 << " x " << numElementB << ")"; 184 185 // verify shape of cVector 186 if (!((cShape[0] == mTile * nTile) && (cShape[1] == numElementC))) 187 return emitOpError() << "expected matrix C to be shaped (" << mTile * nTile 188 << " x " << numElementC << ")"; 189 190 return success(); 191 } 192 193 //===----------------------------------------------------------------------===// 194 // NVGPU_LdMatrixOp 195 //===----------------------------------------------------------------------===// 196 LogicalResult LdMatrixOp::verify() { 197 198 // ldmatrix reads data from source in shared memory 199 auto srcMemref = getSrcMemref().getType().cast<MemRefType>(); 200 201 // ldmatrix writes data to result/destination in vector registers 202 auto resVector = getRes().getType().cast<VectorType>(); 203 204 // vector register shape, element type, and bitwidth 205 ArrayRef<int64_t> resShape = resVector.getShape(); 206 Type resType = resVector.getElementType(); 207 int64_t elementBitWidth = resType.getIntOrFloatBitWidth(); 208 209 // ldmatrix loads 32 bits into vector registers per 8-by-8 tile per thread 210 int64_t numElementsPer32b = 32 / elementBitWidth; 211 212 // number of 8-by-8 tiles 213 int64_t numTiles = getNumTiles(); 214 215 // transpose elements in vector registers at 16b granularity when true 216 bool isTranspose = getTranspose(); 217 218 // address space id for shared memory 219 unsigned smemAddressSpace = gpu::GPUDialect::getWorkgroupAddressSpace(); 220 221 // 222 // verification 223 // 224 225 if (!(srcMemref.getMemorySpaceAsInt() == smemAddressSpace)) 226 return emitError() 227 << "expected nvgpu.ldmatrix srcMemref must have memory space " 228 << smemAddressSpace; 229 if (elementBitWidth > 32) 230 return emitError() << "nvgpu.ldmatrix works for 32b or lower"; 231 if (isTranspose && !(elementBitWidth == 16)) 232 return emitError() 233 << "nvgpu.ldmatrix transpose works only at 16b granularity"; 234 if (!(resShape[1] == numElementsPer32b)) 235 return emitError() << "expected vector register shape[1] = " 236 << numElementsPer32b; 237 if (!(resShape[0] == numTiles)) 238 return emitError() 239 << "expected vector register shape[0] and numTiles to match"; 240 241 return success(); 242 } 243 244 #define GET_OP_CLASSES 245 #include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc" 246