1 //===------ WmmaOpsToNVVM.cpp - WMMA LD/ST/Compute to NVVM lowering -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains definitions of patterns to lower GPU Subgroup MMA ops to 10 // NVVM Dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" 15 #include "mlir/Conversion/LLVMCommon/Pattern.h" 16 #include "mlir/Dialect/GPU/GPUDialect.h" 17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 18 #include "mlir/Dialect/LLVMIR/NVVMDialect.h" 19 #include "mlir/IR/TypeUtilities.h" 20 21 using namespace mlir; 22 23 namespace { 24 25 /// Checks if all the operands of the op being lowered are of LLVM Types. The 26 /// types are expected to be converted by the `LLVMTypeConverter` before the op 27 /// is actually lowered. If the type of an operands is not already converted it 28 /// hints a missing typeConversion and failure is returned in that case. 29 static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, 30 ConversionPatternRewriter &rewriter) { 31 if (!llvm::all_of(operands, [](Value value) { 32 return LLVM::isCompatibleType(value.getType()); 33 })) { 34 return rewriter.notifyMatchFailure( 35 op, "cannot convert if operands aren't of LLVM type."); 36 } 37 38 return success(); 39 } 40 41 /// Error string to emit when an unimplemented WMMA variant is encountered. 42 static constexpr StringRef kInvalidCaseStr = "Unsupported WMMA variant."; 43 44 static NVVM::MMAFrag convertOperand(StringRef operandName) { 45 if (operandName.equals("AOp")) 46 return NVVM::MMAFrag::a; 47 if (operandName.equals("BOp")) 48 return NVVM::MMAFrag::b; 49 if (operandName.equals("COp")) 50 return NVVM::MMAFrag::c; 51 llvm_unreachable("Unknown operand name"); 52 } 53 54 static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) { 55 if (type.getElementType().isF16()) 56 return NVVM::MMATypes::f16; 57 if (type.getElementType().isF32()) 58 return type.getOperand().equals("COp") ? NVVM::MMATypes::f32 59 : NVVM::MMATypes::tf32; 60 llvm_unreachable("Unsupported type"); 61 } 62 63 /// This class implements the conversion of GPU MMA loadOp to wmma.load op 64 /// in the NVVM dialect. The conversion not only emits the NVVM op but also 65 /// emits code that is necessary to store the data in the destination memref 66 /// after it has been loaded. 67 struct WmmaLoadOpToNVVMLowering 68 : public ConvertOpToLLVMPattern<gpu::SubgroupMmaLoadMatrixOp> { 69 using ConvertOpToLLVMPattern< 70 gpu::SubgroupMmaLoadMatrixOp>::ConvertOpToLLVMPattern; 71 72 LogicalResult 73 matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp, 74 OpAdaptor adaptor, 75 ConversionPatternRewriter &rewriter) const override { 76 Operation *op = subgroupMmaLoadMatrixOp.getOperation(); 77 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) 78 return failure(); 79 80 // Get the shape of the MMAMatrix type being returned. The shape will 81 // choose which intrinsic this op will be lowered to. 82 gpu::MMAMatrixType retType = 83 subgroupMmaLoadMatrixOp.res().getType().cast<gpu::MMAMatrixType>(); 84 ArrayRef<int64_t> retTypeShape = retType.getShape(); 85 int64_t m = 0; 86 int64_t n = 0; 87 int64_t k = 0; 88 NVVM::MMATypes eltype = getElementType(retType); 89 // NVVM intrinsics require to give mxnxk dimensions, infer the missing 90 // dimension based on the valid intrinsics available. 91 if (retType.getOperand().equals("AOp")) { 92 m = retTypeShape[0]; 93 k = retTypeShape[1]; 94 n = NVVM::WMMALoadOp::inferNDimension(m, k, eltype); 95 } else if (retType.getOperand().equals("BOp")) { 96 k = retTypeShape[0]; 97 n = retTypeShape[1]; 98 m = NVVM::WMMALoadOp::inferMDimension(k, n, eltype); 99 } else if (retType.getOperand().equals("COp")) { 100 m = retTypeShape[0]; 101 n = retTypeShape[1]; 102 k = NVVM::WMMALoadOp::inferKDimension(m, n, eltype); 103 } 104 NVVM::MMALayout layout = NVVM::MMALayout::row; 105 NVVM::MMAFrag frag = convertOperand(retType.getOperand()); 106 // Check that there is an exisiting instruction for the combination we need. 107 if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layout, eltype, frag) == 0) 108 return rewriter.notifyMatchFailure(op, kInvalidCaseStr); 109 110 Type resType = convertMMAToLLVMType(retType); 111 Location loc = op->getLoc(); 112 113 // Create nvvm.mma_load op according to the operand types. 114 Value dataPtr = getStridedElementPtr( 115 loc, subgroupMmaLoadMatrixOp.srcMemref().getType().cast<MemRefType>(), 116 adaptor.srcMemref(), adaptor.indices(), rewriter); 117 118 Value leadingDim = rewriter.create<LLVM::ConstantOp>( 119 loc, rewriter.getI32Type(), 120 subgroupMmaLoadMatrixOp.leadDimensionAttr()); 121 rewriter.replaceOpWithNewOp<NVVM::WMMALoadOp>( 122 op, resType, dataPtr, leadingDim, m, n, k, layout, eltype, frag); 123 return success(); 124 } 125 }; 126 127 /// This class implements the conversion of GPU MMA storeOp to wmma.store op 128 /// in the NVVM dialect. The conversion not only emits the NVVM op but also 129 /// emits code that is necessary to unpack the data in the source and 130 /// convert the data in the format that is needed by the NVVM op. 131 struct WmmaStoreOpToNVVMLowering 132 : public ConvertOpToLLVMPattern<gpu::SubgroupMmaStoreMatrixOp> { 133 using ConvertOpToLLVMPattern< 134 gpu::SubgroupMmaStoreMatrixOp>::ConvertOpToLLVMPattern; 135 136 LogicalResult 137 matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp, 138 OpAdaptor adaptor, 139 ConversionPatternRewriter &rewriter) const override { 140 Operation *op = subgroupMmaStoreMatrixOp.getOperation(); 141 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) 142 return failure(); 143 144 Location loc = op->getLoc(); 145 146 SmallVector<Value, 4> storeOpOperands; 147 // Get the shape of the MMAMatrix type being stored. The shape will 148 // choose which intrinsic this op will be lowered to. 149 gpu::MMAMatrixType srcType = 150 subgroupMmaStoreMatrixOp.src().getType().cast<gpu::MMAMatrixType>(); 151 ArrayRef<int64_t> srcTypeShape = srcType.getShape(); 152 NVVM::MMALayout layout = NVVM::MMALayout::row; 153 NVVM::MMATypes eltype = getElementType(srcType); 154 int64_t m = srcTypeShape[0]; 155 int64_t n = srcTypeShape[1]; 156 int64_t k = NVVM::WMMAStoreOp::inferKDimension(m, n, eltype); 157 if (NVVM::WMMAStoreOp::getIntrinsicID(m, n, k, layout, eltype) == 0) 158 return rewriter.notifyMatchFailure(op, kInvalidCaseStr); 159 160 auto matrixType = adaptor.src().getType().cast<LLVM::LLVMStructType>(); 161 for (unsigned i = 0, e = matrixType.getBody().size(); i < e; ++i) { 162 Value toUse = rewriter.create<LLVM::ExtractValueOp>( 163 loc, matrixType.getBody()[i], adaptor.src(), 164 rewriter.getI32ArrayAttr(i)); 165 storeOpOperands.push_back(toUse); 166 } 167 168 Value dataPtr = getStridedElementPtr( 169 loc, subgroupMmaStoreMatrixOp.dstMemref().getType().cast<MemRefType>(), 170 adaptor.dstMemref(), adaptor.indices(), rewriter); 171 Value leadingDim = rewriter.create<LLVM::ConstantOp>( 172 loc, rewriter.getI32Type(), 173 subgroupMmaStoreMatrixOp.leadDimensionAttr()); 174 rewriter.replaceOpWithNewOp<NVVM::WMMAStoreOp>( 175 op, dataPtr, m, n, k, layout, eltype, storeOpOperands, leadingDim); 176 return success(); 177 } 178 }; 179 180 /// This class implements the conversion of GPU MMA computeOp to wmma.mma op 181 /// in the NVVM dialect. 182 struct WmmaMmaOpToNVVMLowering 183 : public ConvertOpToLLVMPattern<gpu::SubgroupMmaComputeOp> { 184 using ConvertOpToLLVMPattern< 185 gpu::SubgroupMmaComputeOp>::ConvertOpToLLVMPattern; 186 187 LogicalResult 188 matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp, 189 OpAdaptor adaptor, 190 ConversionPatternRewriter &rewriter) const override { 191 Operation *op = subgroupMmaComputeOp.getOperation(); 192 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) 193 return failure(); 194 195 Location loc = op->getLoc(); 196 197 // The wmma.mma intrinsic in llvm requires the operands as individual 198 // values. So individual elements from the memrefs need to be extracted and 199 // then passed on to the intrinsic call. Emit llvm ops to extract individual 200 // values form lowered memrefs. 201 SmallVector<Value> unpackedOps; 202 203 auto unpackOp = [&](Value operand) { 204 auto structType = operand.getType().cast<LLVM::LLVMStructType>(); 205 for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) { 206 Value toUse = rewriter.create<LLVM::ExtractValueOp>( 207 loc, structType.getBody()[i], operand, rewriter.getI32ArrayAttr(i)); 208 unpackedOps.push_back(toUse); 209 } 210 }; 211 212 // Get the shapes of the MMAMatrix type being used. The shapes will 213 // choose which intrinsic this op will be lowered to. 214 gpu::MMAMatrixType aType = 215 subgroupMmaComputeOp.opA().getType().cast<gpu::MMAMatrixType>(); 216 ArrayRef<int64_t> aTypeShape = aType.getShape(); 217 gpu::MMAMatrixType cType = 218 subgroupMmaComputeOp.opC().getType().cast<gpu::MMAMatrixType>(); 219 ArrayRef<int64_t> cTypeShape = cType.getShape(); 220 int64_t m = cTypeShape[0]; 221 int64_t n = cTypeShape[1]; 222 int64_t k = aTypeShape[1]; 223 NVVM::MMALayout layout = NVVM::MMALayout::row; 224 NVVM::MMATypes sourceType = getElementType(aType); 225 NVVM::MMATypes destType = getElementType(cType); 226 if (NVVM::WMMAMmaOp::getIntrinsicID(m, n, k, layout, layout, sourceType, 227 destType) == 0) 228 return rewriter.notifyMatchFailure(op, kInvalidCaseStr); 229 230 unpackOp(adaptor.opA()); 231 unpackOp(adaptor.opB()); 232 unpackOp(adaptor.opC()); 233 234 rewriter.replaceOpWithNewOp<NVVM::WMMAMmaOp>( 235 op, adaptor.opC().getType(), m, n, k, layout, layout, sourceType, 236 destType, unpackedOps); 237 return success(); 238 } 239 }; 240 241 /// Convert GPU MMA ConstantMatrixOp to a chain of InsertValueOp. 242 struct WmmaConstantOpToNVVMLowering 243 : public ConvertOpToLLVMPattern<gpu::SubgroupMmaConstantMatrixOp> { 244 using ConvertOpToLLVMPattern< 245 gpu::SubgroupMmaConstantMatrixOp>::ConvertOpToLLVMPattern; 246 247 LogicalResult 248 matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantOp, 249 OpAdaptor adaptor, 250 ConversionPatternRewriter &rewriter) const override { 251 if (failed(areAllLLVMTypes(subgroupMmaConstantOp.getOperation(), 252 adaptor.getOperands(), rewriter))) 253 return failure(); 254 Location loc = subgroupMmaConstantOp.getLoc(); 255 Value cst = adaptor.getOperands()[0]; 256 LLVM::LLVMStructType type = convertMMAToLLVMType( 257 subgroupMmaConstantOp.getType().cast<gpu::MMAMatrixType>()); 258 // If the element type is a vector create a vector from the operand. 259 if (auto vecType = type.getBody()[0].dyn_cast<VectorType>()) { 260 Value vecCst = rewriter.create<LLVM::UndefOp>(loc, vecType); 261 for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) { 262 Value idx = rewriter.create<LLVM::ConstantOp>( 263 loc, typeConverter->convertType(rewriter.getIntegerType(32)), 264 rewriter.getI32IntegerAttr(vecEl)); 265 vecCst = rewriter.create<LLVM::InsertElementOp>(loc, vecType, vecCst, 266 cst, idx); 267 } 268 cst = vecCst; 269 } 270 Value matrixStruct = rewriter.create<LLVM::UndefOp>(loc, type); 271 for (size_t i : llvm::seq(size_t(0), type.getBody().size())) { 272 matrixStruct = rewriter.create<LLVM::InsertValueOp>( 273 loc, matrixStruct, cst, rewriter.getI32ArrayAttr(i)); 274 } 275 rewriter.replaceOp(subgroupMmaConstantOp, matrixStruct); 276 return success(); 277 } 278 }; 279 280 static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs, 281 Value rhs, bool isMin) { 282 auto floatType = getElementTypeOrSelf(lhs.getType()).cast<FloatType>(); 283 Type i1Type = builder.getI1Type(); 284 if (auto vecType = lhs.getType().dyn_cast<VectorType>()) 285 i1Type = VectorType::get(vecType.getShape(), i1Type); 286 Value cmp = builder.create<LLVM::FCmpOp>( 287 loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt, 288 lhs, rhs); 289 Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs); 290 Value isNan = builder.create<LLVM::FCmpOp>( 291 loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs); 292 Value nan = builder.create<LLVM::ConstantOp>( 293 loc, lhs.getType(), 294 builder.getFloatAttr(floatType, 295 APFloat::getQNaN(floatType.getFloatSemantics()))); 296 return builder.create<LLVM::SelectOp>(loc, isNan, sel, nan); 297 } 298 299 static Value createScalarOp(OpBuilder &builder, Location loc, 300 gpu::MMAElementwiseOp op, 301 ArrayRef<Value> operands) { 302 switch (op) { 303 case gpu::MMAElementwiseOp::ADDF: 304 return builder.create<LLVM::FAddOp>(loc, operands[0].getType(), operands); 305 case gpu::MMAElementwiseOp::MULF: 306 return builder.create<LLVM::FMulOp>(loc, operands[0].getType(), operands); 307 case gpu::MMAElementwiseOp::DIVF: 308 return builder.create<LLVM::FDivOp>(loc, operands[0].getType(), operands); 309 case gpu::MMAElementwiseOp::MAXF: 310 return createMinMaxF(builder, loc, operands[0], operands[1], 311 /*isMin=*/false); 312 case gpu::MMAElementwiseOp::MINF: 313 return createMinMaxF(builder, loc, operands[0], operands[1], 314 /*isMin=*/true); 315 } 316 llvm_unreachable("unknown op"); 317 } 318 319 /// Convert GPU MMA elementwise ops to extract + op + insert. 320 struct WmmaElementwiseOpToNVVMLowering 321 : public ConvertOpToLLVMPattern<gpu::SubgroupMmaElementwiseOp> { 322 using ConvertOpToLLVMPattern< 323 gpu::SubgroupMmaElementwiseOp>::ConvertOpToLLVMPattern; 324 325 LogicalResult 326 matchAndRewrite(gpu::SubgroupMmaElementwiseOp subgroupMmaElementwiseOp, 327 OpAdaptor adaptor, 328 ConversionPatternRewriter &rewriter) const override { 329 if (failed(areAllLLVMTypes(subgroupMmaElementwiseOp.getOperation(), 330 adaptor.getOperands(), rewriter))) 331 return failure(); 332 Location loc = subgroupMmaElementwiseOp.getLoc(); 333 size_t numOperands = adaptor.getOperands().size(); 334 LLVM::LLVMStructType destType = convertMMAToLLVMType( 335 subgroupMmaElementwiseOp.getType().cast<gpu::MMAMatrixType>()); 336 Value matrixStruct = rewriter.create<LLVM::UndefOp>(loc, destType); 337 for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) { 338 SmallVector<Value> extractedOperands; 339 for (size_t opIdx = 0; opIdx < numOperands; opIdx++) { 340 Type elementType = adaptor.getOperands()[opIdx] 341 .getType() 342 .cast<LLVM::LLVMStructType>() 343 .getBody()[i]; 344 extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>( 345 loc, elementType, adaptor.getOperands()[opIdx], 346 rewriter.getI32ArrayAttr(i))); 347 } 348 Value element = 349 createScalarOp(rewriter, loc, subgroupMmaElementwiseOp.operation(), 350 extractedOperands); 351 matrixStruct = rewriter.create<LLVM::InsertValueOp>( 352 loc, matrixStruct, element, rewriter.getI32ArrayAttr(i)); 353 } 354 rewriter.replaceOp(subgroupMmaElementwiseOp, matrixStruct); 355 return success(); 356 } 357 }; 358 359 } // namespace 360 361 namespace mlir { 362 363 /// Return the LLVMStructureType corresponding to the MMAMatrixType `type`. 364 LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type) { 365 NVVM::MMAFrag frag = convertOperand(type.getOperand()); 366 NVVM::MMATypes eltType = getElementType(type); 367 std::pair<Type, unsigned> typeInfo = 368 inferMMAType(eltType, frag, type.getContext()); 369 return LLVM::LLVMStructType::getLiteral( 370 type.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first)); 371 } 372 373 void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter, 374 RewritePatternSet &patterns) { 375 patterns.insert<WmmaLoadOpToNVVMLowering, WmmaMmaOpToNVVMLowering, 376 WmmaStoreOpToNVVMLowering, WmmaConstantOpToNVVMLowering, 377 WmmaElementwiseOpToNVVMLowering>(converter); 378 } 379 } // namespace mlir 380