1 //===------ WmmaOpsToNVVM.cpp - WMMA LD/ST/Compute to NVVM lowering -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains definitions of patterns to lower GPU Subgroup MMA ops to
10 // NVVM Dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
15 #include "mlir/Conversion/LLVMCommon/Pattern.h"
16 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
19 #include "mlir/IR/TypeUtilities.h"
20 
21 using namespace mlir;
22 
23 namespace {
24 
25 /// Checks if all the operands of the op being lowered are of LLVM Types. The
26 /// types are expected to be converted by the `LLVMTypeConverter` before the op
27 /// is actually lowered. If the type of an operands is not already converted it
28 /// hints a missing typeConversion and failure is returned in that case.
areAllLLVMTypes(Operation * op,ValueRange operands,ConversionPatternRewriter & rewriter)29 static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
30                                      ConversionPatternRewriter &rewriter) {
31   if (!llvm::all_of(operands, [](Value value) {
32         return LLVM::isCompatibleType(value.getType());
33       })) {
34     return rewriter.notifyMatchFailure(
35         op, "cannot convert if operands aren't of LLVM type.");
36   }
37 
38   return success();
39 }
40 
41 /// Error string to emit when an unimplemented WMMA variant is encountered.
42 static constexpr StringRef kInvalidCaseStr = "Unsupported WMMA variant.";
43 
convertOperand(StringRef operandName)44 static NVVM::MMAFrag convertOperand(StringRef operandName) {
45   if (operandName.equals("AOp"))
46     return NVVM::MMAFrag::a;
47   if (operandName.equals("BOp"))
48     return NVVM::MMAFrag::b;
49   if (operandName.equals("COp"))
50     return NVVM::MMAFrag::c;
51   llvm_unreachable("Unknown operand name");
52 }
53 
getElementType(gpu::MMAMatrixType type)54 static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) {
55   if (type.getElementType().isF16())
56     return NVVM::MMATypes::f16;
57   if (type.getElementType().isF32())
58     return type.getOperand().equals("COp") ? NVVM::MMATypes::f32
59                                            : NVVM::MMATypes::tf32;
60   llvm_unreachable("Unsupported type");
61 }
62 
63 /// This class implements the conversion of GPU MMA loadOp to wmma.load op
64 /// in the NVVM dialect. The conversion not only emits the NVVM op but also
65 /// emits code that is necessary to store the data in the destination memref
66 /// after it has been loaded.
67 struct WmmaLoadOpToNVVMLowering
68     : public ConvertOpToLLVMPattern<gpu::SubgroupMmaLoadMatrixOp> {
69   using ConvertOpToLLVMPattern<
70       gpu::SubgroupMmaLoadMatrixOp>::ConvertOpToLLVMPattern;
71 
72   LogicalResult
matchAndRewrite__anonc5aca7e10111::WmmaLoadOpToNVVMLowering73   matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp,
74                   OpAdaptor adaptor,
75                   ConversionPatternRewriter &rewriter) const override {
76     Operation *op = subgroupMmaLoadMatrixOp.getOperation();
77     if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
78       return failure();
79 
80     // Get the shape of the MMAMatrix type being returned. The shape will
81     // choose which intrinsic this op will be lowered to.
82     gpu::MMAMatrixType retType =
83         subgroupMmaLoadMatrixOp.res().getType().cast<gpu::MMAMatrixType>();
84     ArrayRef<int64_t> retTypeShape = retType.getShape();
85     int64_t m = 0;
86     int64_t n = 0;
87     int64_t k = 0;
88     NVVM::MMATypes eltype = getElementType(retType);
89     // NVVM intrinsics require to give mxnxk dimensions, infer the missing
90     // dimension based on the valid intrinsics available.
91     if (retType.getOperand().equals("AOp")) {
92       m = retTypeShape[0];
93       k = retTypeShape[1];
94       n = NVVM::WMMALoadOp::inferNDimension(m, k, eltype);
95     } else if (retType.getOperand().equals("BOp")) {
96       k = retTypeShape[0];
97       n = retTypeShape[1];
98       m = NVVM::WMMALoadOp::inferMDimension(k, n, eltype);
99     } else if (retType.getOperand().equals("COp")) {
100       m = retTypeShape[0];
101       n = retTypeShape[1];
102       k = NVVM::WMMALoadOp::inferKDimension(m, n, eltype);
103     }
104     NVVM::MMALayout layout = NVVM::MMALayout::row;
105     NVVM::MMAFrag frag = convertOperand(retType.getOperand());
106     // Check that there is an exisiting instruction for the combination we need.
107     if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layout, eltype, frag) == 0)
108       return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
109 
110     Type resType = convertMMAToLLVMType(retType);
111     Location loc = op->getLoc();
112 
113     // Create nvvm.mma_load op according to the operand types.
114     Value dataPtr = getStridedElementPtr(
115         loc, subgroupMmaLoadMatrixOp.srcMemref().getType().cast<MemRefType>(),
116         adaptor.srcMemref(), adaptor.indices(), rewriter);
117 
118     Value leadingDim = rewriter.create<LLVM::ConstantOp>(
119         loc, rewriter.getI32Type(),
120         subgroupMmaLoadMatrixOp.leadDimensionAttr());
121     rewriter.replaceOpWithNewOp<NVVM::WMMALoadOp>(
122         op, resType, dataPtr, leadingDim, m, n, k, layout, eltype, frag);
123     return success();
124   }
125 };
126 
127 /// This class implements the conversion of GPU MMA storeOp to wmma.store op
128 /// in the NVVM dialect. The conversion not only emits the NVVM op but also
129 /// emits code that is necessary to unpack the data in the source and
130 /// convert the data in the format that is needed by the NVVM op.
131 struct WmmaStoreOpToNVVMLowering
132     : public ConvertOpToLLVMPattern<gpu::SubgroupMmaStoreMatrixOp> {
133   using ConvertOpToLLVMPattern<
134       gpu::SubgroupMmaStoreMatrixOp>::ConvertOpToLLVMPattern;
135 
136   LogicalResult
matchAndRewrite__anonc5aca7e10111::WmmaStoreOpToNVVMLowering137   matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp,
138                   OpAdaptor adaptor,
139                   ConversionPatternRewriter &rewriter) const override {
140     Operation *op = subgroupMmaStoreMatrixOp.getOperation();
141     if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
142       return failure();
143 
144     Location loc = op->getLoc();
145 
146     SmallVector<Value, 4> storeOpOperands;
147     // Get the shape of the MMAMatrix type being stored. The shape will
148     // choose which intrinsic this op will be lowered to.
149     gpu::MMAMatrixType srcType =
150         subgroupMmaStoreMatrixOp.src().getType().cast<gpu::MMAMatrixType>();
151     ArrayRef<int64_t> srcTypeShape = srcType.getShape();
152     NVVM::MMALayout layout = NVVM::MMALayout::row;
153     NVVM::MMATypes eltype = getElementType(srcType);
154     int64_t m = srcTypeShape[0];
155     int64_t n = srcTypeShape[1];
156     int64_t k = NVVM::WMMAStoreOp::inferKDimension(m, n, eltype);
157     if (NVVM::WMMAStoreOp::getIntrinsicID(m, n, k, layout, eltype) == 0)
158       return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
159 
160     auto matrixType = adaptor.src().getType().cast<LLVM::LLVMStructType>();
161     for (unsigned i = 0, e = matrixType.getBody().size(); i < e; ++i) {
162       Value toUse = rewriter.create<LLVM::ExtractValueOp>(
163           loc, matrixType.getBody()[i], adaptor.src(),
164           rewriter.getI32ArrayAttr(i));
165       storeOpOperands.push_back(toUse);
166     }
167 
168     Value dataPtr = getStridedElementPtr(
169         loc, subgroupMmaStoreMatrixOp.dstMemref().getType().cast<MemRefType>(),
170         adaptor.dstMemref(), adaptor.indices(), rewriter);
171     Value leadingDim = rewriter.create<LLVM::ConstantOp>(
172         loc, rewriter.getI32Type(),
173         subgroupMmaStoreMatrixOp.leadDimensionAttr());
174     rewriter.replaceOpWithNewOp<NVVM::WMMAStoreOp>(
175         op, dataPtr, m, n, k, layout, eltype, storeOpOperands, leadingDim);
176     return success();
177   }
178 };
179 
180 /// This class implements the conversion of GPU MMA computeOp to wmma.mma op
181 /// in the NVVM dialect.
182 struct WmmaMmaOpToNVVMLowering
183     : public ConvertOpToLLVMPattern<gpu::SubgroupMmaComputeOp> {
184   using ConvertOpToLLVMPattern<
185       gpu::SubgroupMmaComputeOp>::ConvertOpToLLVMPattern;
186 
187   LogicalResult
matchAndRewrite__anonc5aca7e10111::WmmaMmaOpToNVVMLowering188   matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
189                   OpAdaptor adaptor,
190                   ConversionPatternRewriter &rewriter) const override {
191     Operation *op = subgroupMmaComputeOp.getOperation();
192     if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
193       return failure();
194 
195     Location loc = op->getLoc();
196 
197     // The wmma.mma intrinsic in llvm requires the operands as individual
198     // values. So individual elements from the memrefs need to be extracted and
199     // then passed on to the intrinsic call. Emit llvm ops to extract individual
200     // values form lowered memrefs.
201     SmallVector<Value> unpackedOps;
202 
203     auto unpackOp = [&](Value operand) {
204       auto structType = operand.getType().cast<LLVM::LLVMStructType>();
205       for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) {
206         Value toUse = rewriter.create<LLVM::ExtractValueOp>(
207             loc, structType.getBody()[i], operand, rewriter.getI32ArrayAttr(i));
208         unpackedOps.push_back(toUse);
209       }
210     };
211 
212     // Get the shapes of the MMAMatrix type being used. The shapes will
213     // choose which intrinsic this op will be lowered to.
214     gpu::MMAMatrixType aType =
215         subgroupMmaComputeOp.opA().getType().cast<gpu::MMAMatrixType>();
216     ArrayRef<int64_t> aTypeShape = aType.getShape();
217     gpu::MMAMatrixType cType =
218         subgroupMmaComputeOp.opC().getType().cast<gpu::MMAMatrixType>();
219     ArrayRef<int64_t> cTypeShape = cType.getShape();
220     int64_t m = cTypeShape[0];
221     int64_t n = cTypeShape[1];
222     int64_t k = aTypeShape[1];
223     NVVM::MMALayout layout = NVVM::MMALayout::row;
224     NVVM::MMATypes sourceType = getElementType(aType);
225     NVVM::MMATypes destType = getElementType(cType);
226     if (NVVM::WMMAMmaOp::getIntrinsicID(m, n, k, layout, layout, sourceType,
227                                         destType) == 0)
228       return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
229 
230     unpackOp(adaptor.opA());
231     unpackOp(adaptor.opB());
232     unpackOp(adaptor.opC());
233 
234     rewriter.replaceOpWithNewOp<NVVM::WMMAMmaOp>(
235         op, adaptor.opC().getType(), m, n, k, layout, layout, sourceType,
236         destType, unpackedOps);
237     return success();
238   }
239 };
240 
241 /// Convert GPU MMA ConstantMatrixOp to a chain of InsertValueOp.
242 struct WmmaConstantOpToNVVMLowering
243     : public ConvertOpToLLVMPattern<gpu::SubgroupMmaConstantMatrixOp> {
244   using ConvertOpToLLVMPattern<
245       gpu::SubgroupMmaConstantMatrixOp>::ConvertOpToLLVMPattern;
246 
247   LogicalResult
matchAndRewrite__anonc5aca7e10111::WmmaConstantOpToNVVMLowering248   matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantOp,
249                   OpAdaptor adaptor,
250                   ConversionPatternRewriter &rewriter) const override {
251     if (failed(areAllLLVMTypes(subgroupMmaConstantOp.getOperation(),
252                                adaptor.getOperands(), rewriter)))
253       return failure();
254     Location loc = subgroupMmaConstantOp.getLoc();
255     Value cst = adaptor.getOperands()[0];
256     LLVM::LLVMStructType type = convertMMAToLLVMType(
257         subgroupMmaConstantOp.getType().cast<gpu::MMAMatrixType>());
258     // If the element type is a vector create a vector from the operand.
259     if (auto vecType = type.getBody()[0].dyn_cast<VectorType>()) {
260       Value vecCst = rewriter.create<LLVM::UndefOp>(loc, vecType);
261       for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) {
262         Value idx = rewriter.create<LLVM::ConstantOp>(
263             loc, typeConverter->convertType(rewriter.getIntegerType(32)),
264             rewriter.getI32IntegerAttr(vecEl));
265         vecCst = rewriter.create<LLVM::InsertElementOp>(loc, vecType, vecCst,
266                                                         cst, idx);
267       }
268       cst = vecCst;
269     }
270     Value matrixStruct = rewriter.create<LLVM::UndefOp>(loc, type);
271     for (size_t i : llvm::seq(size_t(0), type.getBody().size())) {
272       matrixStruct = rewriter.create<LLVM::InsertValueOp>(
273           loc, matrixStruct, cst, rewriter.getI32ArrayAttr(i));
274     }
275     rewriter.replaceOp(subgroupMmaConstantOp, matrixStruct);
276     return success();
277   }
278 };
279 
createMinMaxF(OpBuilder & builder,Location loc,Value lhs,Value rhs,bool isMin)280 static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
281                            Value rhs, bool isMin) {
282   auto floatType = getElementTypeOrSelf(lhs.getType()).cast<FloatType>();
283   Type i1Type = builder.getI1Type();
284   if (auto vecType = lhs.getType().dyn_cast<VectorType>())
285     i1Type = VectorType::get(vecType.getShape(), i1Type);
286   Value cmp = builder.create<LLVM::FCmpOp>(
287       loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
288       lhs, rhs);
289   Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
290   Value isNan = builder.create<LLVM::FCmpOp>(
291       loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs);
292   Value nan = builder.create<LLVM::ConstantOp>(
293       loc, lhs.getType(),
294       builder.getFloatAttr(floatType,
295                            APFloat::getQNaN(floatType.getFloatSemantics())));
296   return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel);
297 }
298 
createScalarOp(OpBuilder & builder,Location loc,gpu::MMAElementwiseOp op,ArrayRef<Value> operands)299 static Value createScalarOp(OpBuilder &builder, Location loc,
300                             gpu::MMAElementwiseOp op,
301                             ArrayRef<Value> operands) {
302   switch (op) {
303   case gpu::MMAElementwiseOp::ADDF:
304     return builder.create<LLVM::FAddOp>(loc, operands[0].getType(), operands);
305   case gpu::MMAElementwiseOp::MULF:
306     return builder.create<LLVM::FMulOp>(loc, operands[0].getType(), operands);
307   case gpu::MMAElementwiseOp::DIVF:
308     return builder.create<LLVM::FDivOp>(loc, operands[0].getType(), operands);
309   case gpu::MMAElementwiseOp::MAXF:
310     return createMinMaxF(builder, loc, operands[0], operands[1],
311                          /*isMin=*/false);
312   case gpu::MMAElementwiseOp::MINF:
313     return createMinMaxF(builder, loc, operands[0], operands[1],
314                          /*isMin=*/true);
315   }
316   llvm_unreachable("unknown op");
317 }
318 
319 /// Convert GPU MMA elementwise ops to extract + op + insert.
320 struct WmmaElementwiseOpToNVVMLowering
321     : public ConvertOpToLLVMPattern<gpu::SubgroupMmaElementwiseOp> {
322   using ConvertOpToLLVMPattern<
323       gpu::SubgroupMmaElementwiseOp>::ConvertOpToLLVMPattern;
324 
325   LogicalResult
matchAndRewrite__anonc5aca7e10111::WmmaElementwiseOpToNVVMLowering326   matchAndRewrite(gpu::SubgroupMmaElementwiseOp subgroupMmaElementwiseOp,
327                   OpAdaptor adaptor,
328                   ConversionPatternRewriter &rewriter) const override {
329     if (failed(areAllLLVMTypes(subgroupMmaElementwiseOp.getOperation(),
330                                adaptor.getOperands(), rewriter)))
331       return failure();
332     Location loc = subgroupMmaElementwiseOp.getLoc();
333     size_t numOperands = adaptor.getOperands().size();
334     LLVM::LLVMStructType destType = convertMMAToLLVMType(
335         subgroupMmaElementwiseOp.getType().cast<gpu::MMAMatrixType>());
336     Value matrixStruct = rewriter.create<LLVM::UndefOp>(loc, destType);
337     for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) {
338       SmallVector<Value> extractedOperands;
339       for (size_t opIdx = 0; opIdx < numOperands; opIdx++) {
340         Type elementType = adaptor.getOperands()[opIdx]
341                                .getType()
342                                .cast<LLVM::LLVMStructType>()
343                                .getBody()[i];
344         extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
345             loc, elementType, adaptor.getOperands()[opIdx],
346             rewriter.getI32ArrayAttr(i)));
347       }
348       Value element =
349           createScalarOp(rewriter, loc, subgroupMmaElementwiseOp.operation(),
350                          extractedOperands);
351       matrixStruct = rewriter.create<LLVM::InsertValueOp>(
352           loc, matrixStruct, element, rewriter.getI32ArrayAttr(i));
353     }
354     rewriter.replaceOp(subgroupMmaElementwiseOp, matrixStruct);
355     return success();
356   }
357 };
358 
359 } // namespace
360 
361 /// Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
convertMMAToLLVMType(gpu::MMAMatrixType type)362 LLVM::LLVMStructType mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) {
363   NVVM::MMAFrag frag = convertOperand(type.getOperand());
364   NVVM::MMATypes eltType = getElementType(type);
365   std::pair<Type, unsigned> typeInfo =
366       NVVM::inferMMAType(eltType, frag, type.getContext());
367   return LLVM::LLVMStructType::getLiteral(
368       type.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
369 }
370 
populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)371 void mlir::populateGpuWMMAToNVVMConversionPatterns(
372     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
373   patterns.add<WmmaLoadOpToNVVMLowering, WmmaMmaOpToNVVMLowering,
374                WmmaStoreOpToNVVMLowering, WmmaConstantOpToNVVMLowering,
375                WmmaElementwiseOpToNVVMLowering>(converter);
376 }
377