1 //===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
10 #include "../PassDetail.h"
11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12 #include "mlir/Conversion/LLVMCommon/Pattern.h"
13 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
14 #include "mlir/Dialect/NVGPU/NVGPUDialect.h"
15 
16 using namespace mlir;
17 
18 /// Returns the type for the intrinsic given the vectorResultType of the
19 /// `gpu.mma.sync` operation.
20 static Type inferIntrinsicResultType(Type vectorResultType) {
21   MLIRContext *ctx = vectorResultType.getContext();
22   auto a = vectorResultType.cast<LLVM::LLVMArrayType>();
23   auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2);
24   auto i32Ty = IntegerType::get(ctx, 32);
25   auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
26   Type f64Ty = Float64Type::get(ctx);
27   Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
28   if (a.getElementType() == f16x2Ty) {
29     return LLVM::LLVMStructType::getLiteral(
30         ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
31   }
32   if (a.getElementType() == i32x2Ty) {
33     return LLVM::LLVMStructType::getLiteral(
34         ctx,
35         SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty));
36   }
37   if (a.getElementType() == f64x2Ty) {
38     return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
39   }
40   return vectorResultType;
41 }
42 
43 /// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is
44 /// always an LLVM struct) into a fragment that is compatible with the vector
45 /// type of this operation. This involves extracting elements from the struct
46 /// and inserting them into an LLVM array. These extra data-movement
47 /// operations should be canonicalized away by the LLVM backend.
48 static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
49                                     Type resultType, Value intrinsicResult,
50                                     RewriterBase &rewriter) {
51   MLIRContext *ctx = rewriter.getContext();
52   auto structType = intrinsicResultType.dyn_cast<LLVM::LLVMStructType>();
53   auto arrayType = resultType.dyn_cast<LLVM::LLVMArrayType>();
54   Type i32Ty = rewriter.getI32Type();
55   Type f64Ty = rewriter.getF64Type();
56   Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2);
57   Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
58   Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
59 
60   auto makeConst = [&](int32_t index) -> Value {
61     return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
62                                              rewriter.getI32IntegerAttr(index));
63   };
64 
65   if (arrayType) {
66     SmallVector<Value, 4> elements;
67 
68     if (arrayType.getElementType() == f16x2Ty) {
69       for (unsigned i = 0; i < structType.getBody().size(); i++) {
70         elements.push_back(rewriter.create<LLVM::ExtractValueOp>(
71             loc, structType.getBody()[i], intrinsicResult,
72             rewriter.getI64ArrayAttr(i)));
73       }
74     }
75 
76     // The intrinsic returns i32 and f64 values as individual scalars. We need
77     // to extract them from the struct and pack them into vectors.
78     if (arrayType.getElementType() == i32x2Ty ||
79         arrayType.getElementType() == f64x2Ty) {
80       Value vec =
81           rewriter.create<LLVM::UndefOp>(loc, arrayType.getElementType());
82       for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
83         Value x1 = rewriter.create<LLVM::ExtractValueOp>(
84             loc, structType.getBody()[i * 2], intrinsicResult,
85             rewriter.getI64ArrayAttr(i * 2));
86         Value x2 = rewriter.create<LLVM::ExtractValueOp>(
87             loc, structType.getBody()[i * 2 + 1], intrinsicResult,
88             rewriter.getI64ArrayAttr(i * 2 + 1));
89         vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
90                                                      x1, makeConst(0));
91         vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
92                                                      x2, makeConst(1));
93       }
94       elements.push_back(vec);
95     }
96 
97     // Create the final vectorized result.
98     Value result = rewriter.create<LLVM::UndefOp>(loc, arrayType);
99     for (const auto &el : llvm::enumerate(elements)) {
100       result = rewriter.create<LLVM::InsertValueOp>(
101           loc, arrayType, result, el.value(),
102           rewriter.getI64ArrayAttr(el.index()));
103     }
104     return result;
105   }
106 
107   return intrinsicResult;
108 }
109 
110 /// The `gpu.mma.sync` converter below expects matrix fragment operands to be
111 /// given as 2D `vectors` where the rows are 32b or 64b wide. The
112 /// `nvvm.mma.sync` op expects these argments to be a given in a long list of
113 /// scalars of certain types. This function helps unpack the `vector` arguments
114 /// and cast them to the types expected by `nvvm.mma.sync`.
115 static SmallVector<Value> unpackOperandVector(RewriterBase &rewriter,
116                                               Location loc, Value operand) {
117   SmallVector<Value> result;
118   Type i32Ty = rewriter.getI32Type();
119   Type f64Ty = rewriter.getF64Type();
120   Type i8Ty = rewriter.getI8Type();
121   Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4);
122   auto arrayTy = operand.getType().cast<LLVM::LLVMArrayType>();
123 
124   for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
125     Value toUse = rewriter.create<LLVM::ExtractValueOp>(
126         loc, arrayTy.getElementType(), operand, rewriter.getI64ArrayAttr(i));
127 
128     // For 4xi8 vectors, the intrinsic expects these to be provided as i32
129     // scalar types.
130     if (arrayTy.getElementType() == i8x4Ty) {
131       result.push_back(
132           rewriter.create<LLVM::BitcastOp>(loc, rewriter.getI32Type(), toUse));
133       continue;
134     }
135 
136     // For some element types (i32, f64), we need to unpack the inner
137     // vector/array type as well because the intrinsic expects individual
138     // scalars to be provided.
139     VectorType innerArrayTy = arrayTy.getElementType().dyn_cast<VectorType>();
140     if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
141                          innerArrayTy.getElementType() == f64Ty)) {
142       for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
143            idx < innerSize; idx++) {
144         result.push_back(rewriter.create<LLVM::ExtractElementOp>(
145             loc, toUse,
146             rewriter.create<LLVM::ConstantOp>(
147                 loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(idx))));
148       }
149       continue;
150     }
151     result.push_back(toUse);
152   }
153   return result;
154 }
155 
156 namespace {
157 
158 struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
159   using ConvertOpToLLVMPattern<nvgpu::LdMatrixOp>::ConvertOpToLLVMPattern;
160 
161   LogicalResult
162   matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
163                   ConversionPatternRewriter &rewriter) const override {
164     MLIRContext *ctx = getContext();
165     Location loc = op->getLoc();
166 
167     // The result type of ldmatrix will always be a struct of 32bit integer
168     // registers if more than one 32bit value is returned. Otherwise, the result
169     // is a single i32. The result type of the GPU operation is always a vector
170     // of shape (NumRegisters, VectorRegister) where VectorRegister is the
171     // vector type of the result and always 32 bits long. We bitcast the result
172     // of the NVVM::LdMatrix to this vector type.
173     auto vectorResultType = op->getResultTypes()[0].dyn_cast<VectorType>();
174     if (!vectorResultType) {
175       return failure();
176     }
177     Type innerVectorType = LLVM::getFixedVectorType(
178         vectorResultType.getElementType(), vectorResultType.getDimSize(1));
179 
180     int64_t num32BitRegs = vectorResultType.getDimSize(0);
181 
182     Type ldMatrixResultType;
183     if (num32BitRegs > 1) {
184       ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
185           ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type()));
186     } else {
187       ldMatrixResultType = rewriter.getI32Type();
188     }
189 
190     auto srcMemrefType = op.srcMemref().getType().cast<MemRefType>();
191     Value srcPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.srcMemref(),
192                                         adaptor.indices(), rewriter);
193     Value ldMatrixResult = rewriter.create<NVVM::LdMatrixOp>(
194         loc, ldMatrixResultType, srcPtr,
195         /*num=*/op.numTiles(),
196         /*layout=*/op.transpose() ? NVVM::MMALayout::col
197                                   : NVVM::MMALayout::row);
198 
199     // The ldmatrix operation returns either a single i32 value or a struct of
200     // i32 values. Here we unpack those values and cast them back to their
201     // actual vector type (still of width 32b) and repack them into a result
202     // struct.
203     Type finalResultType = typeConverter->convertType(vectorResultType);
204     Value result = rewriter.create<LLVM::UndefOp>(loc, finalResultType);
205     for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
206       Value i32Register = num32BitRegs > 1
207                               ? rewriter.create<LLVM::ExtractValueOp>(
208                                     loc, rewriter.getI32Type(), ldMatrixResult,
209                                     rewriter.getI64ArrayAttr(i))
210                               : ldMatrixResult;
211       Value casted =
212           rewriter.create<LLVM::BitcastOp>(loc, innerVectorType, i32Register);
213       result = rewriter.create<LLVM::InsertValueOp>(
214           loc, finalResultType, result, casted, rewriter.getI64ArrayAttr(i));
215     }
216 
217     rewriter.replaceOp(op, result);
218     return success();
219   }
220 };
221 
222 struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
223   using ConvertOpToLLVMPattern<nvgpu::MmaSyncOp>::ConvertOpToLLVMPattern;
224 
225   LogicalResult
226   matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
227                   ConversionPatternRewriter &rewriter) const override {
228     Location loc = op->getLoc();
229     // Get the shapes of the MMAMatrix type being used. The shapes will
230     // choose which intrinsic this op will be lowered to.
231     auto aType = op.matrixA().getType().cast<VectorType>();
232 
233     int64_t m = op.mmaShape()[0].cast<IntegerAttr>().getInt();
234     int64_t n = op.mmaShape()[1].cast<IntegerAttr>().getInt();
235     int64_t k = op.mmaShape()[2].cast<IntegerAttr>().getInt();
236     std::array<int64_t, 3> gemmShape{m, n, k};
237 
238     SmallVector<Value> matA =
239         unpackOperandVector(rewriter, loc, adaptor.matrixA());
240     SmallVector<Value> matB =
241         unpackOperandVector(rewriter, loc, adaptor.matrixB());
242     SmallVector<Value> matC =
243         unpackOperandVector(rewriter, loc, adaptor.matrixC());
244 
245     NVVM::MMATypes ptxTypeA;
246     NVVM::MMATypes ptxTypeB;
247     Optional<NVVM::MMAIntOverflow> overflow(llvm::None);
248     if (aType.getElementType().isInteger(8)) {
249       ptxTypeA = NVVM::MMATypes::s8;
250       ptxTypeB = NVVM::MMATypes::s8;
251       overflow = NVVM::MMAIntOverflow::satfinite;
252 
253     } else if (aType.getElementType().isF16()) {
254       ptxTypeA = NVVM::MMATypes::f16;
255       ptxTypeB = NVVM::MMATypes::f16;
256     } else if (aType.getElementType().isF64()) {
257       ptxTypeA = NVVM::MMATypes::f64;
258       ptxTypeB = NVVM::MMATypes::f64;
259     } else {
260       return op->emitError("could not deduce operand PTX types");
261     }
262 
263     Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
264     Type intrinsicResTy = inferIntrinsicResultType(
265         typeConverter->convertType(op->getResultTypes()[0]));
266     Value intrinsicResult = rewriter.create<NVVM::MmaOp>(
267         op.getLoc(), intrinsicResTy, matA, matB, matC,
268         /*shape=*/gemmShape,
269         /*b1Op=*/llvm::None,
270         /*intOverflow=*/overflow,
271         /*multiplicandPtxTypes=*/
272         std::array<NVVM::MMATypes, 2>{ptxTypeA, ptxTypeB},
273         /*multiplicandLayouts=*/
274         std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row,
275                                        NVVM::MMALayout::col});
276     rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy,
277                                                   desiredRetTy, intrinsicResult,
278                                                   rewriter));
279     return success();
280   }
281 };
282 
283 struct ConvertNVGPUToNVVMPass
284     : public ConvertNVGPUToNVVMBase<ConvertNVGPUToNVVMPass> {
285   ConvertNVGPUToNVVMPass() = default;
286 
287   void runOnOperation() override {
288     RewritePatternSet patterns(&getContext());
289     LLVMTypeConverter converter(&getContext());
290     populateNVGPUToNVVMConversionPatterns(converter, patterns);
291     LLVMConversionTarget target(getContext());
292     target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
293     target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
294     if (failed(applyPartialConversion(getOperation(), target,
295                                       std::move(patterns))))
296       signalPassFailure();
297   }
298 };
299 
300 } // namespace
301 void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
302                                                  RewritePatternSet &patterns) {
303   patterns.add<MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM>(converter);
304 }
305 
306 std::unique_ptr<Pass> mlir::createConvertNVGPUToNVVMPass() {
307   return std::make_unique<ConvertNVGPUToNVVMPass>();
308 }
309