1 //===- NVGPUDialect.cpp - MLIR NVGPU ops implementation -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the NVGPU dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" 14 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 15 #include "mlir/IR/Builders.h" 16 #include "mlir/IR/DialectImplementation.h" 17 #include "mlir/IR/OpImplementation.h" 18 #include "mlir/IR/TypeUtilities.h" 19 #include "llvm/ADT/TypeSwitch.h" 20 21 using namespace mlir; 22 using namespace mlir::nvgpu; 23 24 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.cpp.inc" 25 26 void nvgpu::NVGPUDialect::initialize() { 27 addTypes<DeviceAsyncTokenType>(); 28 addOperations< 29 #define GET_OP_LIST 30 #include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc" 31 >(); 32 } 33 34 Type NVGPUDialect::parseType(DialectAsmParser &parser) const { 35 // Parse the main keyword for the type. 36 StringRef keyword; 37 if (parser.parseKeyword(&keyword)) 38 return Type(); 39 MLIRContext *context = getContext(); 40 // Handle 'device async token' types. 41 if (keyword == "device.async.token") 42 return DeviceAsyncTokenType::get(context); 43 44 parser.emitError(parser.getNameLoc(), "unknown nvgpu type: " + keyword); 45 return Type(); 46 } 47 48 void NVGPUDialect::printType(Type type, DialectAsmPrinter &os) const { 49 TypeSwitch<Type>(type) 50 .Case<DeviceAsyncTokenType>([&](Type) { os << "device.async.token"; }) 51 .Default([](Type) { llvm_unreachable("unexpected 'nvgpu' type kind"); }); 52 } 53 //===----------------------------------------------------------------------===// 54 // NVGPU_DeviceAsyncCopyOp 55 //===----------------------------------------------------------------------===// 56 57 /// Return true if the last dimension of the MemRefType has unit stride. Also 58 /// return true for memrefs with no strides. 59 static bool isLastMemrefDimUnitStride(MemRefType type) { 60 int64_t offset; 61 SmallVector<int64_t> strides; 62 if (failed(getStridesAndOffset(type, strides, offset))) { 63 return false; 64 } 65 return strides.back() == 1; 66 } 67 68 LogicalResult DeviceAsyncCopyOp::verify() { 69 auto srcMemref = getSrc().getType().cast<MemRefType>(); 70 auto dstMemref = getDst().getType().cast<MemRefType>(); 71 unsigned workgroupAddressSpace = gpu::GPUDialect::getWorkgroupAddressSpace(); 72 if (!isLastMemrefDimUnitStride(srcMemref)) 73 return emitError("source memref most minor dim must have unit stride"); 74 if (!isLastMemrefDimUnitStride(dstMemref)) 75 return emitError("destination memref most minor dim must have unit stride"); 76 if (dstMemref.getMemorySpaceAsInt() != workgroupAddressSpace) 77 return emitError("destination memref must have memory space ") 78 << workgroupAddressSpace; 79 if (dstMemref.getElementType() != srcMemref.getElementType()) 80 return emitError("source and destination must have the same element type"); 81 if (size_t(srcMemref.getRank()) != getSrcIndices().size()) 82 return emitOpError() << "expected " << srcMemref.getRank() 83 << " source indices, got " << getSrcIndices().size(); 84 if (size_t(dstMemref.getRank()) != getDstIndices().size()) 85 return emitOpError() << "expected " << dstMemref.getRank() 86 << " destination indices, got " 87 << getDstIndices().size(); 88 return success(); 89 } 90 91 LogicalResult MmaSyncOp::verify() { 92 93 // Fundamental tensor core mma.sync op 94 // For F32 (TF32), F16, S8, and S4 data types fundamental tensor core 95 // operation is of shape: 8-by-8-by-128b. F64 is an exception. The 96 // verification for mma.sync covering various shapes and data types is based 97 // on the fundamental tensor core operionation. 98 constexpr int kThreads = 32; // 32 threads per warp 99 int64_t shapeM = 8; 100 int64_t shapeN = 8; 101 int64_t shapeK; // set based on data type (128b for all data types except F64) 102 103 // Number of elements A, B, and C per thread per fundamental tensor core tile 104 int64_t numElementA; // set based on data type (32b except F64) 105 int64_t numElementB; // set based on data type (32b except F64) 106 int64_t numElementC{2}; // two accumulator elements per fundamental tile 107 108 // nvgpu.mma.sync vector operands (per thread) 109 auto aVector = getMatrixA().getType().cast<VectorType>(); 110 auto bVector = getMatrixB().getType().cast<VectorType>(); 111 auto cVector = getMatrixC().getType().cast<VectorType>(); 112 113 // vector shapes 114 ArrayRef<int64_t> aShape = aVector.getShape(); 115 ArrayRef<int64_t> bShape = bVector.getShape(); 116 ArrayRef<int64_t> cShape = cVector.getShape(); 117 118 // vector element type 119 Type aType = aVector.getElementType(); 120 121 // nvgpu.mma.sync shape (per 32 threads or per warp) 122 int64_t m = getMmaShape()[0].cast<IntegerAttr>().getInt(); 123 int64_t n = getMmaShape()[1].cast<IntegerAttr>().getInt(); 124 int64_t k = getMmaShape()[2].cast<IntegerAttr>().getInt(); 125 126 if (aType.isF64()) { 127 // exception to 8-by-8-128b fundamental tensor core tile size 128 shapeK = 4; 129 numElementA = 1; 130 numElementB = 1; 131 } else if (aType.isF32() || aType.isBF16() || aType.isF16() || 132 aType.isInteger(8) || aType.isInteger(4)) { 133 // 8-by-8-128b fundamental tensor core tile size 134 int operandBitwidth = aType.getIntOrFloatBitWidth(); 135 shapeK = 128 / operandBitwidth; // 128b wide shapeK 136 numElementA = 32 / operandBitwidth; // 32b wide operand A 137 numElementB = 32 / operandBitwidth; // 32b wide operand B 138 } else { 139 return emitError() << "expected input data type (i4,i8,f16,bf16,tf32,f64) " 140 "supported by nvgpu.mma.sync"; 141 } 142 143 // 144 // Basic verification 145 // 146 147 // verify warp-wide size for vector a 148 if (aShape[0] * aShape[1] * kThreads != m * k) 149 return emitOpError() << "expected " << m * k 150 << " warp-wide matrix A elements"; 151 152 // verify warp-wide size for vector b 153 if (bShape[0] * bShape[1] * kThreads != k * n) 154 return emitOpError() << "expected " << k * n 155 << " warp-wide matrix B elements"; 156 157 // verify warp-wide size for vector c 158 if (cShape[0] * cShape[1] * kThreads != m * n) 159 return emitOpError() << "expected " << m * n 160 << " warp-wide matrix C elements"; 161 162 // 163 // Extended verification 164 // 165 166 // tiles of fundamental tensor core operations 167 int64_t mTile = m / shapeM; 168 int64_t nTile = n / shapeN; 169 int64_t kTile = k / shapeK; 170 171 // verify shape of aVector 172 if (!((aShape[0] == mTile * kTile) && (aShape[1] == numElementA))) 173 return emitOpError() << "expected matrix A to be shaped (" << mTile * kTile 174 << " x " << numElementA << ")"; 175 176 // verify shape of bVector 177 if (!((bShape[0] == kTile * nTile) && (bShape[1] == numElementB))) 178 return emitOpError() << "expected matrix B to be shaped (" << kTile * nTile 179 << " x " << numElementB << ")"; 180 181 // verify shape of cVector 182 if (!((cShape[0] == mTile * nTile) && (cShape[1] == numElementC))) 183 return emitOpError() << "expected matrix C to be shaped (" << mTile * nTile 184 << " x " << numElementC << ")"; 185 186 return success(); 187 } 188 189 #define GET_OP_CLASSES 190 #include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc" 191