1 //===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
10 #include "../PassDetail.h"
11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12 #include "mlir/Conversion/LLVMCommon/Pattern.h"
13 #include "mlir/Dialect/GPU/GPUDialect.h"
14 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
15 #include "mlir/Dialect/NVGPU/NVGPUDialect.h"
16 
17 using namespace mlir;
18 
19 /// Returns the type for the intrinsic given the vectorResultType of the
20 /// `gpu.mma.sync` operation.
21 static Type inferIntrinsicResultType(Type vectorResultType) {
22   MLIRContext *ctx = vectorResultType.getContext();
23   auto a = vectorResultType.cast<LLVM::LLVMArrayType>();
24   auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2);
25   auto i32Ty = IntegerType::get(ctx, 32);
26   auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
27   Type f64Ty = Float64Type::get(ctx);
28   Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
29   Type f32Ty = Float32Type::get(ctx);
30   Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
31   if (a.getElementType() == f16x2Ty) {
32     return LLVM::LLVMStructType::getLiteral(
33         ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
34   }
35   if (a.getElementType() == i32x2Ty) {
36     return LLVM::LLVMStructType::getLiteral(
37         ctx,
38         SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty));
39   }
40   if (a.getElementType() == f64x2Ty) {
41     return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
42   }
43   if (a.getElementType() == f32x2Ty) {
44     return LLVM::LLVMStructType::getLiteral(
45         ctx,
46         SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty));
47   }
48   if (a.getElementType() == LLVM::getFixedVectorType(f32Ty, 1)) {
49     return LLVM::LLVMStructType::getLiteral(
50         ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty));
51   }
52   return vectorResultType;
53 }
54 
55 /// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is
56 /// always an LLVM struct) into a fragment that is compatible with the vector
57 /// type of this operation. This involves extracting elements from the struct
58 /// and inserting them into an LLVM array. These extra data-movement
59 /// operations should be canonicalized away by the LLVM backend.
60 static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
61                                     Type resultType, Value intrinsicResult,
62                                     RewriterBase &rewriter) {
63   MLIRContext *ctx = rewriter.getContext();
64   auto structType = intrinsicResultType.dyn_cast<LLVM::LLVMStructType>();
65   auto arrayType = resultType.dyn_cast<LLVM::LLVMArrayType>();
66   Type i32Ty = rewriter.getI32Type();
67   Type f32Ty = rewriter.getF32Type();
68   Type f64Ty = rewriter.getF64Type();
69   Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2);
70   Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
71   Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
72   Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
73   Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
74 
75   auto makeConst = [&](int32_t index) -> Value {
76     return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
77                                              rewriter.getI32IntegerAttr(index));
78   };
79 
80   if (arrayType) {
81     SmallVector<Value, 4> elements;
82 
83     // The intrinsic returns 32-bit wide elements in a form which can be
84     // directly bitcasted and inserted into the result vector.
85     if (arrayType.getElementType() == f16x2Ty ||
86         arrayType.getElementType() == f32x1Ty) {
87       for (unsigned i = 0; i < structType.getBody().size(); i++) {
88         Value el = rewriter.create<LLVM::ExtractValueOp>(
89             loc, structType.getBody()[i], intrinsicResult,
90             rewriter.getI64ArrayAttr(i));
91         el = rewriter.createOrFold<LLVM::BitcastOp>(
92             loc, arrayType.getElementType(), el);
93         elements.push_back(el);
94       }
95     }
96 
97     // The intrinsic returns i32, f64, and f32 values as individual scalars,
98     // even when the result is notionally a 64-bit wide element (e.g. f32x2). We
99     // need to extract them from the struct and pack them into the 64-bit wide
100     // rows of the vector result.
101     if (arrayType.getElementType() == i32x2Ty ||
102         arrayType.getElementType() == f64x2Ty ||
103         arrayType.getElementType() == f32x2Ty) {
104 
105       for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
106         Value vec =
107             rewriter.create<LLVM::UndefOp>(loc, arrayType.getElementType());
108         Value x1 = rewriter.create<LLVM::ExtractValueOp>(
109             loc, structType.getBody()[i * 2], intrinsicResult,
110             rewriter.getI64ArrayAttr(i * 2));
111         Value x2 = rewriter.create<LLVM::ExtractValueOp>(
112             loc, structType.getBody()[i * 2 + 1], intrinsicResult,
113             rewriter.getI64ArrayAttr(i * 2 + 1));
114         vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
115                                                      x1, makeConst(0));
116         vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
117                                                      x2, makeConst(1));
118         elements.push_back(vec);
119       }
120     }
121 
122     // Create the final vectorized result.
123     Value result = rewriter.create<LLVM::UndefOp>(loc, arrayType);
124     for (const auto &el : llvm::enumerate(elements)) {
125       result = rewriter.create<LLVM::InsertValueOp>(
126           loc, arrayType, result, el.value(),
127           rewriter.getI64ArrayAttr(el.index()));
128     }
129     return result;
130   }
131 
132   return intrinsicResult;
133 }
134 
135 /// The `gpu.mma.sync` converter below expects matrix fragment operands to be
136 /// given as 2D `vectors` where the rows are 32b or 64b wide. The
137 /// `nvvm.mma.sync` op expects these argments to be a given in a long list of
138 /// scalars of certain types. This function helps unpack the `vector` arguments
139 /// and cast them to the types expected by `nvvm.mma.sync`.
140 static SmallVector<Value> unpackOperandVector(RewriterBase &rewriter,
141                                               Location loc, Value operand,
142                                               NVVM::MMATypes operandPtxType) {
143   SmallVector<Value> result;
144   Type i32Ty = rewriter.getI32Type();
145   Type f64Ty = rewriter.getF64Type();
146   Type f32Ty = rewriter.getF32Type();
147   Type i8Ty = rewriter.getI8Type();
148   Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4);
149   Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
150   auto arrayTy = operand.getType().cast<LLVM::LLVMArrayType>();
151 
152   for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
153     Value toUse = rewriter.create<LLVM::ExtractValueOp>(
154         loc, arrayTy.getElementType(), operand, rewriter.getI64ArrayAttr(i));
155 
156     // For 4xi8 vectors, the intrinsic expects these to be provided as i32
157     // scalar types.
158     if (arrayTy.getElementType() == i8x4Ty ||
159         (arrayTy.getElementType() == f32x1Ty &&
160          operandPtxType == NVVM::MMATypes::tf32)) {
161       result.push_back(
162           rewriter.create<LLVM::BitcastOp>(loc, rewriter.getI32Type(), toUse));
163       continue;
164     }
165 
166     // For some element types (i32, f32, f64), we need to unpack the inner
167     // vector/array type as well because the intrinsic expects individual
168     // scalars to be provided.
169     VectorType innerArrayTy = arrayTy.getElementType().dyn_cast<VectorType>();
170     if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
171                          innerArrayTy.getElementType() == f64Ty ||
172                          innerArrayTy.getElementType() == f32Ty)) {
173       for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
174            idx < innerSize; idx++) {
175         result.push_back(rewriter.create<LLVM::ExtractElementOp>(
176             loc, toUse,
177             rewriter.create<LLVM::ConstantOp>(
178                 loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(idx))));
179       }
180       continue;
181     }
182     result.push_back(toUse);
183   }
184   return result;
185 }
186 
187 namespace {
188 
189 struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
190   using ConvertOpToLLVMPattern<nvgpu::LdMatrixOp>::ConvertOpToLLVMPattern;
191 
192   LogicalResult
193   matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
194                   ConversionPatternRewriter &rewriter) const override {
195     MLIRContext *ctx = getContext();
196     Location loc = op->getLoc();
197 
198     // The result type of ldmatrix will always be a struct of 32bit integer
199     // registers if more than one 32bit value is returned. Otherwise, the result
200     // is a single i32. The result type of the GPU operation is always a vector
201     // of shape (NumRegisters, VectorRegister) where VectorRegister is the
202     // vector type of the result and always 32 bits long. We bitcast the result
203     // of the NVVM::LdMatrix to this vector type.
204     auto vectorResultType = op->getResultTypes()[0].dyn_cast<VectorType>();
205     if (!vectorResultType) {
206       return failure();
207     }
208     Type innerVectorType = LLVM::getFixedVectorType(
209         vectorResultType.getElementType(), vectorResultType.getDimSize(1));
210 
211     int64_t num32BitRegs = vectorResultType.getDimSize(0);
212 
213     Type ldMatrixResultType;
214     if (num32BitRegs > 1) {
215       ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
216           ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type()));
217     } else {
218       ldMatrixResultType = rewriter.getI32Type();
219     }
220 
221     auto srcMemrefType = op.srcMemref().getType().cast<MemRefType>();
222     Value srcPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.srcMemref(),
223                                         adaptor.indices(), rewriter);
224     Value ldMatrixResult = rewriter.create<NVVM::LdMatrixOp>(
225         loc, ldMatrixResultType, srcPtr,
226         /*num=*/op.numTiles(),
227         /*layout=*/op.transpose() ? NVVM::MMALayout::col
228                                   : NVVM::MMALayout::row);
229 
230     // The ldmatrix operation returns either a single i32 value or a struct of
231     // i32 values. Here we unpack those values and cast them back to their
232     // actual vector type (still of width 32b) and repack them into a result
233     // struct.
234     Type finalResultType = typeConverter->convertType(vectorResultType);
235     Value result = rewriter.create<LLVM::UndefOp>(loc, finalResultType);
236     for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
237       Value i32Register = num32BitRegs > 1
238                               ? rewriter.create<LLVM::ExtractValueOp>(
239                                     loc, rewriter.getI32Type(), ldMatrixResult,
240                                     rewriter.getI64ArrayAttr(i))
241                               : ldMatrixResult;
242       Value casted =
243           rewriter.create<LLVM::BitcastOp>(loc, innerVectorType, i32Register);
244       result = rewriter.create<LLVM::InsertValueOp>(
245           loc, finalResultType, result, casted, rewriter.getI64ArrayAttr(i));
246     }
247 
248     rewriter.replaceOp(op, result);
249     return success();
250   }
251 };
252 
253 struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
254   using ConvertOpToLLVMPattern<nvgpu::MmaSyncOp>::ConvertOpToLLVMPattern;
255 
256   LogicalResult
257   matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
258                   ConversionPatternRewriter &rewriter) const override {
259     Location loc = op->getLoc();
260     // Get the shapes of the MMAMatrix type being used. The shapes will
261     // choose which intrinsic this op will be lowered to.
262     auto aType = op.matrixA().getType().cast<VectorType>();
263     auto cType = op.matrixC().getType().cast<VectorType>();
264 
265     int64_t m = op.mmaShape()[0].cast<IntegerAttr>().getInt();
266     int64_t n = op.mmaShape()[1].cast<IntegerAttr>().getInt();
267     int64_t k = op.mmaShape()[2].cast<IntegerAttr>().getInt();
268     std::array<int64_t, 3> gemmShape{m, n, k};
269 
270     NVVM::MMATypes ptxTypeA;
271     NVVM::MMATypes ptxTypeB;
272     Optional<NVVM::MMATypes> ptxTypeC = NVVM::MmaOp::inferOperandMMAType(
273         cType.getElementType(), /*isAccumulator=*/true);
274     if (!ptxTypeC) {
275       return op->emitError(
276           "could not infer the PTX type for the accumulator/result");
277     }
278 
279     Optional<NVVM::MMAIntOverflow> overflow(llvm::None);
280     if (aType.getElementType().isInteger(8)) {
281       ptxTypeA = NVVM::MMATypes::s8;
282       ptxTypeB = NVVM::MMATypes::s8;
283       overflow = NVVM::MMAIntOverflow::satfinite;
284     } else if (aType.getElementType().isF16()) {
285       ptxTypeA = NVVM::MMATypes::f16;
286       ptxTypeB = NVVM::MMATypes::f16;
287     } else if (aType.getElementType().isF64()) {
288       ptxTypeA = NVVM::MMATypes::f64;
289       ptxTypeB = NVVM::MMATypes::f64;
290     } else if (aType.getElementType().isF32()) {
291       ptxTypeA = NVVM::MMATypes::tf32;
292       ptxTypeB = NVVM::MMATypes::tf32;
293     } else {
294       return op->emitError("could not deduce operand PTX types");
295     }
296 
297     SmallVector<Value> matA =
298         unpackOperandVector(rewriter, loc, adaptor.matrixA(), ptxTypeA);
299     SmallVector<Value> matB =
300         unpackOperandVector(rewriter, loc, adaptor.matrixB(), ptxTypeB);
301     SmallVector<Value> matC =
302         unpackOperandVector(rewriter, loc, adaptor.matrixC(), *ptxTypeC);
303 
304     Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
305     Type intrinsicResTy = inferIntrinsicResultType(
306         typeConverter->convertType(op->getResultTypes()[0]));
307     Value intrinsicResult = rewriter.create<NVVM::MmaOp>(
308         op.getLoc(), intrinsicResTy, matA, matB, matC,
309         /*shape=*/gemmShape,
310         /*b1Op=*/llvm::None,
311         /*intOverflow=*/overflow,
312         /*multiplicandPtxTypes=*/
313         std::array<NVVM::MMATypes, 2>{ptxTypeA, ptxTypeB},
314         /*multiplicandLayouts=*/
315         std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row,
316                                        NVVM::MMALayout::col});
317     rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy,
318                                                   desiredRetTy, intrinsicResult,
319                                                   rewriter));
320     return success();
321   }
322 };
323 
324 struct ConvertNVGPUToNVVMPass
325     : public ConvertNVGPUToNVVMBase<ConvertNVGPUToNVVMPass> {
326   ConvertNVGPUToNVVMPass() = default;
327 
328   void runOnOperation() override {
329     RewritePatternSet patterns(&getContext());
330     LLVMTypeConverter converter(&getContext());
331     /// device-side async tokens cannot be materialized in nvvm. We just convert
332     /// them to a dummy i32 type in order to easily drop them during conversion.
333     converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type {
334       return converter.convertType(IntegerType::get(type.getContext(), 32));
335     });
336     populateNVGPUToNVVMConversionPatterns(converter, patterns);
337     LLVMConversionTarget target(getContext());
338     target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
339     target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
340     if (failed(applyPartialConversion(getOperation(), target,
341                                       std::move(patterns))))
342       signalPassFailure();
343   }
344 };
345 
346 struct NVGPUAsyncCopyLowering
347     : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> {
348   using ConvertOpToLLVMPattern<
349       nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern;
350 
351   LogicalResult
352   matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
353                   ConversionPatternRewriter &rewriter) const override {
354     Location loc = op->getLoc();
355     auto dstMemrefType = op.dst().getType().cast<MemRefType>();
356     Value dstPtr = getStridedElementPtr(loc, dstMemrefType, adaptor.dst(),
357                                         adaptor.dstIndices(), rewriter);
358     auto i8Ty = IntegerType::get(op.getContext(), 8);
359     auto dstPointerType =
360         LLVM::LLVMPointerType::get(i8Ty, dstMemrefType.getMemorySpaceAsInt());
361     dstPtr = rewriter.create<LLVM::BitcastOp>(loc, dstPointerType, dstPtr);
362 
363     auto srcMemrefType = op.src().getType().cast<MemRefType>();
364 
365     Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.src(),
366                                         adaptor.srcIndices(), rewriter);
367     auto srcPointerType =
368         LLVM::LLVMPointerType::get(i8Ty, srcMemrefType.getMemorySpaceAsInt());
369     scrPtr = rewriter.create<LLVM::BitcastOp>(loc, srcPointerType, scrPtr);
370     // Intrinsics takes a global pointer so we need an address space cast.
371     auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
372         i8Ty, NVVM::NVVMMemorySpace::kGlobalMemorySpace);
373     scrPtr = rewriter.create<LLVM::AddrSpaceCastOp>(loc, srcPointerGlobalType,
374                                                     scrPtr);
375     int64_t numElements = adaptor.numElements().getZExtValue();
376     int64_t sizeInBytes =
377         (dstMemrefType.getElementTypeBitWidth() / 8) * numElements;
378     // bypass L1 is only supported for byte sizes of 16, we drop the hint
379     // otherwise.
380     UnitAttr bypassL1 = sizeInBytes == 16 ? adaptor.bypassL1Attr() : UnitAttr();
381     rewriter.create<NVVM::CpAsyncOp>(
382         loc, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes), bypassL1);
383 
384     // Drop the result token.
385     Value zero = rewriter.create<LLVM::ConstantOp>(
386         op->getLoc(), IntegerType::get(op.getContext(), 32),
387         rewriter.getI32IntegerAttr(0));
388     rewriter.replaceOp(op, zero);
389     return success();
390   }
391 };
392 
393 struct NVGPUAsyncCreateGroupLowering
394     : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCreateGroupOp> {
395   using ConvertOpToLLVMPattern<
396       nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern;
397 
398   LogicalResult
399   matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
400                   ConversionPatternRewriter &rewriter) const override {
401     rewriter.create<NVVM::CpAsyncCommitGroupOp>(op.getLoc());
402     // Drop the result token.
403     Value zero = rewriter.create<LLVM::ConstantOp>(
404         op->getLoc(), IntegerType::get(op.getContext(), 32),
405         rewriter.getI32IntegerAttr(0));
406     rewriter.replaceOp(op, zero);
407     return success();
408   }
409 };
410 
411 struct NVGPUAsyncWaitLowering
412     : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncWaitOp> {
413   using ConvertOpToLLVMPattern<
414       nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern;
415 
416   LogicalResult
417   matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
418                   ConversionPatternRewriter &rewriter) const override {
419     // If numGroup is not present pick 0 as a conservative correct value.
420     int32_t numGroups = adaptor.numGroups() ? *adaptor.numGroups() : 0;
421     rewriter.create<NVVM::CpAsyncWaitGroupOp>(op.getLoc(), numGroups);
422     rewriter.eraseOp(op);
423     return success();
424   }
425 };
426 
427 } // namespace
428 
429 void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
430                                                  RewritePatternSet &patterns) {
431   patterns.add<MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
432                NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering>(
433       converter);
434 }
435 
436 std::unique_ptr<Pass> mlir::createConvertNVGPUToNVVMPass() {
437   return std::make_unique<ConvertNVGPUToNVVMPass>();
438 }
439