1 //===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
10 #include "../PassDetail.h"
11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12 #include "mlir/Conversion/LLVMCommon/Pattern.h"
13 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
14 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
15 #include "mlir/Dialect/NVGPU/NVGPUDialect.h"
16 
17 using namespace mlir;
18 
19 /// Returns the type for the intrinsic given the vectorResultType of the
20 /// `gpu.mma.sync` operation.
21 static Type inferIntrinsicResultType(Type vectorResultType) {
22   MLIRContext *ctx = vectorResultType.getContext();
23   auto a = vectorResultType.cast<LLVM::LLVMArrayType>();
24   auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2);
25   auto i32Ty = IntegerType::get(ctx, 32);
26   auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
27   Type f64Ty = Float64Type::get(ctx);
28   Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
29   Type f32Ty = Float32Type::get(ctx);
30   Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
31   if (a.getElementType() == f16x2Ty) {
32     return LLVM::LLVMStructType::getLiteral(
33         ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
34   }
35   if (a.getElementType() == i32x2Ty) {
36     return LLVM::LLVMStructType::getLiteral(
37         ctx,
38         SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty));
39   }
40   if (a.getElementType() == f64x2Ty) {
41     return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
42   }
43   if (a.getElementType() == f32x2Ty) {
44     return LLVM::LLVMStructType::getLiteral(
45         ctx,
46         SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty));
47   }
48   if (a.getElementType() == LLVM::getFixedVectorType(f32Ty, 1)) {
49     return LLVM::LLVMStructType::getLiteral(
50         ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty));
51   }
52   return vectorResultType;
53 }
54 
55 /// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is
56 /// always an LLVM struct) into a fragment that is compatible with the vector
57 /// type of this operation. This involves extracting elements from the struct
58 /// and inserting them into an LLVM array. These extra data-movement
59 /// operations should be canonicalized away by the LLVM backend.
60 static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
61                                     Type resultType, Value intrinsicResult,
62                                     RewriterBase &rewriter) {
63   MLIRContext *ctx = rewriter.getContext();
64   auto structType = intrinsicResultType.dyn_cast<LLVM::LLVMStructType>();
65   auto arrayType = resultType.dyn_cast<LLVM::LLVMArrayType>();
66   Type i32Ty = rewriter.getI32Type();
67   Type f32Ty = rewriter.getF32Type();
68   Type f64Ty = rewriter.getF64Type();
69   Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2);
70   Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
71   Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
72   Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
73   Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
74 
75   auto makeConst = [&](int32_t index) -> Value {
76     return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
77                                              rewriter.getI32IntegerAttr(index));
78   };
79 
80   if (arrayType) {
81     SmallVector<Value, 4> elements;
82 
83     // The intrinsic returns 32-bit wide elements in a form which can be
84     // directly bitcasted and inserted into the result vector.
85     if (arrayType.getElementType() == f16x2Ty ||
86         arrayType.getElementType() == f32x1Ty) {
87       for (unsigned i = 0; i < structType.getBody().size(); i++) {
88         Value el = rewriter.create<LLVM::ExtractValueOp>(
89             loc, structType.getBody()[i], intrinsicResult,
90             rewriter.getI64ArrayAttr(i));
91         el = rewriter.createOrFold<LLVM::BitcastOp>(
92             loc, arrayType.getElementType(), el);
93         elements.push_back(el);
94       }
95     }
96 
97     // The intrinsic returns i32, f64, and f32 values as individual scalars,
98     // even when the result is notionally a 64-bit wide element (e.g. f32x2). We
99     // need to extract them from the struct and pack them into the 64-bit wide
100     // rows of the vector result.
101     if (arrayType.getElementType() == i32x2Ty ||
102         arrayType.getElementType() == f64x2Ty ||
103         arrayType.getElementType() == f32x2Ty) {
104 
105       for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
106         Value vec =
107             rewriter.create<LLVM::UndefOp>(loc, arrayType.getElementType());
108         Value x1 = rewriter.create<LLVM::ExtractValueOp>(
109             loc, structType.getBody()[i * 2], intrinsicResult,
110             rewriter.getI64ArrayAttr(i * 2));
111         Value x2 = rewriter.create<LLVM::ExtractValueOp>(
112             loc, structType.getBody()[i * 2 + 1], intrinsicResult,
113             rewriter.getI64ArrayAttr(i * 2 + 1));
114         vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
115                                                      x1, makeConst(0));
116         vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
117                                                      x2, makeConst(1));
118         elements.push_back(vec);
119       }
120     }
121 
122     // Create the final vectorized result.
123     Value result = rewriter.create<LLVM::UndefOp>(loc, arrayType);
124     for (const auto &el : llvm::enumerate(elements)) {
125       result = rewriter.create<LLVM::InsertValueOp>(
126           loc, arrayType, result, el.value(),
127           rewriter.getI64ArrayAttr(el.index()));
128     }
129     return result;
130   }
131 
132   return intrinsicResult;
133 }
134 
135 /// The `gpu.mma.sync` converter below expects matrix fragment operands to be
136 /// given as 2D `vectors` where the rows are 32b or 64b wide. The
137 /// `nvvm.mma.sync` op expects these argments to be a given in a long list of
138 /// scalars of certain types. This function helps unpack the `vector` arguments
139 /// and cast them to the types expected by `nvvm.mma.sync`.
140 static SmallVector<Value> unpackOperandVector(RewriterBase &rewriter,
141                                               Location loc, Value operand,
142                                               NVVM::MMATypes operandPtxType) {
143   SmallVector<Value> result;
144   Type i32Ty = rewriter.getI32Type();
145   Type f64Ty = rewriter.getF64Type();
146   Type f32Ty = rewriter.getF32Type();
147   Type i8Ty = rewriter.getI8Type();
148   Type i4Ty = rewriter.getIntegerType(4);
149   Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4);
150   Type i4x8Ty = LLVM::getFixedVectorType(i4Ty, 8);
151   Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
152   auto arrayTy = operand.getType().cast<LLVM::LLVMArrayType>();
153 
154   for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
155     Value toUse = rewriter.create<LLVM::ExtractValueOp>(
156         loc, arrayTy.getElementType(), operand, rewriter.getI64ArrayAttr(i));
157 
158     // For 4xi8 vectors, the intrinsic expects these to be provided as i32
159     // scalar types.
160     if (arrayTy.getElementType() == i8x4Ty ||
161         arrayTy.getElementType() == i4x8Ty ||
162         (arrayTy.getElementType() == f32x1Ty &&
163          operandPtxType == NVVM::MMATypes::tf32)) {
164       result.push_back(
165           rewriter.create<LLVM::BitcastOp>(loc, rewriter.getI32Type(), toUse));
166       continue;
167     }
168 
169     // For some element types (i32, f32, f64), we need to unpack the inner
170     // vector/array type as well because the intrinsic expects individual
171     // scalars to be provided.
172     VectorType innerArrayTy = arrayTy.getElementType().dyn_cast<VectorType>();
173     if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
174                          innerArrayTy.getElementType() == f64Ty ||
175                          innerArrayTy.getElementType() == f32Ty)) {
176       for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
177            idx < innerSize; idx++) {
178         result.push_back(rewriter.create<LLVM::ExtractElementOp>(
179             loc, toUse,
180             rewriter.create<LLVM::ConstantOp>(
181                 loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(idx))));
182       }
183       continue;
184     }
185     result.push_back(toUse);
186   }
187   return result;
188 }
189 
190 namespace {
191 
192 struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
193   using ConvertOpToLLVMPattern<nvgpu::LdMatrixOp>::ConvertOpToLLVMPattern;
194 
195   LogicalResult
196   matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
197                   ConversionPatternRewriter &rewriter) const override {
198     MLIRContext *ctx = getContext();
199     Location loc = op->getLoc();
200 
201     // The result type of ldmatrix will always be a struct of 32bit integer
202     // registers if more than one 32bit value is returned. Otherwise, the result
203     // is a single i32. The result type of the GPU operation is always a vector
204     // of shape (NumRegisters, VectorRegister) where VectorRegister is the
205     // vector type of the result and always 32 bits long. We bitcast the result
206     // of the NVVM::LdMatrix to this vector type.
207     auto vectorResultType = op->getResultTypes()[0].dyn_cast<VectorType>();
208     if (!vectorResultType) {
209       return failure();
210     }
211     Type innerVectorType = LLVM::getFixedVectorType(
212         vectorResultType.getElementType(), vectorResultType.getDimSize(1));
213 
214     int64_t num32BitRegs = vectorResultType.getDimSize(0);
215 
216     Type ldMatrixResultType;
217     if (num32BitRegs > 1) {
218       ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
219           ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type()));
220     } else {
221       ldMatrixResultType = rewriter.getI32Type();
222     }
223 
224     auto srcMemrefType = op.srcMemref().getType().cast<MemRefType>();
225     Value srcPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.srcMemref(),
226                                         adaptor.indices(), rewriter);
227     Value ldMatrixResult = rewriter.create<NVVM::LdMatrixOp>(
228         loc, ldMatrixResultType, srcPtr,
229         /*num=*/op.numTiles(),
230         /*layout=*/op.transpose() ? NVVM::MMALayout::col
231                                   : NVVM::MMALayout::row);
232 
233     // The ldmatrix operation returns either a single i32 value or a struct of
234     // i32 values. Here we unpack those values and cast them back to their
235     // actual vector type (still of width 32b) and repack them into a result
236     // struct.
237     Type finalResultType = typeConverter->convertType(vectorResultType);
238     Value result = rewriter.create<LLVM::UndefOp>(loc, finalResultType);
239     for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
240       Value i32Register = num32BitRegs > 1
241                               ? rewriter.create<LLVM::ExtractValueOp>(
242                                     loc, rewriter.getI32Type(), ldMatrixResult,
243                                     rewriter.getI64ArrayAttr(i))
244                               : ldMatrixResult;
245       Value casted =
246           rewriter.create<LLVM::BitcastOp>(loc, innerVectorType, i32Register);
247       result = rewriter.create<LLVM::InsertValueOp>(
248           loc, finalResultType, result, casted, rewriter.getI64ArrayAttr(i));
249     }
250 
251     rewriter.replaceOp(op, result);
252     return success();
253   }
254 };
255 
256 struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
257   using ConvertOpToLLVMPattern<nvgpu::MmaSyncOp>::ConvertOpToLLVMPattern;
258 
259   LogicalResult
260   matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
261                   ConversionPatternRewriter &rewriter) const override {
262     Location loc = op->getLoc();
263     // Get the shapes of the MMAMatrix type being used. The shapes will
264     // choose which intrinsic this op will be lowered to.
265     auto aType = op.matrixA().getType().cast<VectorType>();
266     auto cType = op.matrixC().getType().cast<VectorType>();
267 
268     int64_t m = op.mmaShape()[0].cast<IntegerAttr>().getInt();
269     int64_t n = op.mmaShape()[1].cast<IntegerAttr>().getInt();
270     int64_t k = op.mmaShape()[2].cast<IntegerAttr>().getInt();
271     std::array<int64_t, 3> gemmShape{m, n, k};
272 
273     NVVM::MMATypes ptxTypeA;
274     NVVM::MMATypes ptxTypeB;
275     Optional<NVVM::MMATypes> ptxTypeC = NVVM::MmaOp::inferOperandMMAType(
276         cType.getElementType(), /*isAccumulator=*/true);
277     if (!ptxTypeC) {
278       return op->emitError(
279           "could not infer the PTX type for the accumulator/result");
280     }
281 
282     Optional<NVVM::MMAIntOverflow> overflow(llvm::None);
283     if (aType.getElementType().isInteger(8)) {
284       ptxTypeA = NVVM::MMATypes::s8;
285       ptxTypeB = NVVM::MMATypes::s8;
286       overflow = NVVM::MMAIntOverflow::satfinite;
287     } else if (aType.getElementType().isInteger(4)) {
288       ptxTypeA = NVVM::MMATypes::s4;
289       ptxTypeB = NVVM::MMATypes::s4;
290       overflow = NVVM::MMAIntOverflow::satfinite;
291     } else if (aType.getElementType().isF16()) {
292       ptxTypeA = NVVM::MMATypes::f16;
293       ptxTypeB = NVVM::MMATypes::f16;
294     } else if (aType.getElementType().isF64()) {
295       ptxTypeA = NVVM::MMATypes::f64;
296       ptxTypeB = NVVM::MMATypes::f64;
297     } else if (aType.getElementType().isF32()) {
298       ptxTypeA = NVVM::MMATypes::tf32;
299       ptxTypeB = NVVM::MMATypes::tf32;
300     } else {
301       return op->emitError("could not deduce operand PTX types");
302     }
303 
304     SmallVector<Value> matA =
305         unpackOperandVector(rewriter, loc, adaptor.matrixA(), ptxTypeA);
306     SmallVector<Value> matB =
307         unpackOperandVector(rewriter, loc, adaptor.matrixB(), ptxTypeB);
308     SmallVector<Value> matC =
309         unpackOperandVector(rewriter, loc, adaptor.matrixC(), *ptxTypeC);
310 
311     Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
312     Type intrinsicResTy = inferIntrinsicResultType(
313         typeConverter->convertType(op->getResultTypes()[0]));
314     Value intrinsicResult = rewriter.create<NVVM::MmaOp>(
315         op.getLoc(), intrinsicResTy, matA, matB, matC,
316         /*shape=*/gemmShape,
317         /*b1Op=*/llvm::None,
318         /*intOverflow=*/overflow,
319         /*multiplicandPtxTypes=*/
320         std::array<NVVM::MMATypes, 2>{ptxTypeA, ptxTypeB},
321         /*multiplicandLayouts=*/
322         std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row,
323                                        NVVM::MMALayout::col});
324     rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy,
325                                                   desiredRetTy, intrinsicResult,
326                                                   rewriter));
327     return success();
328   }
329 };
330 
331 struct ConvertNVGPUToNVVMPass
332     : public ConvertNVGPUToNVVMBase<ConvertNVGPUToNVVMPass> {
333   ConvertNVGPUToNVVMPass() = default;
334 
335   void runOnOperation() override {
336     RewritePatternSet patterns(&getContext());
337     LLVMTypeConverter converter(&getContext());
338     /// device-side async tokens cannot be materialized in nvvm. We just convert
339     /// them to a dummy i32 type in order to easily drop them during conversion.
340     converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type {
341       return converter.convertType(IntegerType::get(type.getContext(), 32));
342     });
343     populateNVGPUToNVVMConversionPatterns(converter, patterns);
344     LLVMConversionTarget target(getContext());
345     target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
346     target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
347     if (failed(applyPartialConversion(getOperation(), target,
348                                       std::move(patterns))))
349       signalPassFailure();
350   }
351 };
352 
353 struct NVGPUAsyncCopyLowering
354     : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> {
355   using ConvertOpToLLVMPattern<
356       nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern;
357 
358   LogicalResult
359   matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
360                   ConversionPatternRewriter &rewriter) const override {
361     Location loc = op->getLoc();
362     auto dstMemrefType = op.dst().getType().cast<MemRefType>();
363     Value dstPtr = getStridedElementPtr(loc, dstMemrefType, adaptor.dst(),
364                                         adaptor.dstIndices(), rewriter);
365     auto i8Ty = IntegerType::get(op.getContext(), 8);
366     auto dstPointerType =
367         LLVM::LLVMPointerType::get(i8Ty, dstMemrefType.getMemorySpaceAsInt());
368     dstPtr = rewriter.create<LLVM::BitcastOp>(loc, dstPointerType, dstPtr);
369 
370     auto srcMemrefType = op.src().getType().cast<MemRefType>();
371 
372     Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.src(),
373                                         adaptor.srcIndices(), rewriter);
374     auto srcPointerType =
375         LLVM::LLVMPointerType::get(i8Ty, srcMemrefType.getMemorySpaceAsInt());
376     scrPtr = rewriter.create<LLVM::BitcastOp>(loc, srcPointerType, scrPtr);
377     // Intrinsics takes a global pointer so we need an address space cast.
378     auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
379         i8Ty, NVVM::NVVMMemorySpace::kGlobalMemorySpace);
380     scrPtr = rewriter.create<LLVM::AddrSpaceCastOp>(loc, srcPointerGlobalType,
381                                                     scrPtr);
382     int64_t numElements = adaptor.numElements().getZExtValue();
383     int64_t sizeInBytes =
384         (dstMemrefType.getElementTypeBitWidth() * numElements) / 8;
385     // bypass L1 is only supported for byte sizes of 16, we drop the hint
386     // otherwise.
387     UnitAttr bypassL1 = sizeInBytes == 16 ? adaptor.bypassL1Attr() : UnitAttr();
388     rewriter.create<NVVM::CpAsyncOp>(
389         loc, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes), bypassL1);
390 
391     // Drop the result token.
392     Value zero = rewriter.create<LLVM::ConstantOp>(
393         op->getLoc(), IntegerType::get(op.getContext(), 32),
394         rewriter.getI32IntegerAttr(0));
395     rewriter.replaceOp(op, zero);
396     return success();
397   }
398 };
399 
400 struct NVGPUAsyncCreateGroupLowering
401     : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCreateGroupOp> {
402   using ConvertOpToLLVMPattern<
403       nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern;
404 
405   LogicalResult
406   matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
407                   ConversionPatternRewriter &rewriter) const override {
408     rewriter.create<NVVM::CpAsyncCommitGroupOp>(op.getLoc());
409     // Drop the result token.
410     Value zero = rewriter.create<LLVM::ConstantOp>(
411         op->getLoc(), IntegerType::get(op.getContext(), 32),
412         rewriter.getI32IntegerAttr(0));
413     rewriter.replaceOp(op, zero);
414     return success();
415   }
416 };
417 
418 struct NVGPUAsyncWaitLowering
419     : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncWaitOp> {
420   using ConvertOpToLLVMPattern<
421       nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern;
422 
423   LogicalResult
424   matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
425                   ConversionPatternRewriter &rewriter) const override {
426     // If numGroup is not present pick 0 as a conservative correct value.
427     int32_t numGroups = adaptor.numGroups() ? *adaptor.numGroups() : 0;
428     rewriter.create<NVVM::CpAsyncWaitGroupOp>(op.getLoc(), numGroups);
429     rewriter.eraseOp(op);
430     return success();
431   }
432 };
433 
434 } // namespace
435 
436 void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
437                                                  RewritePatternSet &patterns) {
438   patterns.add<MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
439                NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering>(
440       converter);
441 }
442 
443 std::unique_ptr<Pass> mlir::createConvertNVGPUToNVVMPass() {
444   return std::make_unique<ConvertNVGPUToNVVMPass>();
445 }
446