1 //===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
10 #include "../PassDetail.h"
11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12 #include "mlir/Conversion/LLVMCommon/Pattern.h"
13 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
14 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
15 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
16 
17 using namespace mlir;
18 
19 /// Returns the type for the intrinsic given the vectorResultType of the
20 /// `gpu.mma.sync` operation.
inferIntrinsicResultType(Type vectorResultType)21 static Type inferIntrinsicResultType(Type vectorResultType) {
22   MLIRContext *ctx = vectorResultType.getContext();
23   auto a = vectorResultType.cast<LLVM::LLVMArrayType>();
24   auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2);
25   auto i32Ty = IntegerType::get(ctx, 32);
26   auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
27   Type f64Ty = Float64Type::get(ctx);
28   Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
29   Type f32Ty = Float32Type::get(ctx);
30   Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
31   if (a.getElementType() == f16x2Ty) {
32     return LLVM::LLVMStructType::getLiteral(
33         ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
34   }
35   if (a.getElementType() == i32x2Ty) {
36     return LLVM::LLVMStructType::getLiteral(
37         ctx,
38         SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty));
39   }
40   if (a.getElementType() == f64x2Ty) {
41     return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
42   }
43   if (a.getElementType() == f32x2Ty) {
44     return LLVM::LLVMStructType::getLiteral(
45         ctx,
46         SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty));
47   }
48   if (a.getElementType() == LLVM::getFixedVectorType(f32Ty, 1)) {
49     return LLVM::LLVMStructType::getLiteral(
50         ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty));
51   }
52   return vectorResultType;
53 }
54 
55 /// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is
56 /// always an LLVM struct) into a fragment that is compatible with the vector
57 /// type of this operation. This involves extracting elements from the struct
58 /// and inserting them into an LLVM array. These extra data-movement
59 /// operations should be canonicalized away by the LLVM backend.
convertIntrinsicResult(Location loc,Type intrinsicResultType,Type resultType,Value intrinsicResult,RewriterBase & rewriter)60 static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
61                                     Type resultType, Value intrinsicResult,
62                                     RewriterBase &rewriter) {
63   MLIRContext *ctx = rewriter.getContext();
64   auto structType = intrinsicResultType.dyn_cast<LLVM::LLVMStructType>();
65   auto arrayType = resultType.dyn_cast<LLVM::LLVMArrayType>();
66   Type i32Ty = rewriter.getI32Type();
67   Type f32Ty = rewriter.getF32Type();
68   Type f64Ty = rewriter.getF64Type();
69   Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2);
70   Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
71   Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
72   Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
73   Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
74 
75   auto makeConst = [&](int32_t index) -> Value {
76     return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
77                                              rewriter.getI32IntegerAttr(index));
78   };
79 
80   if (arrayType) {
81     SmallVector<Value, 4> elements;
82 
83     // The intrinsic returns 32-bit wide elements in a form which can be
84     // directly bitcasted and inserted into the result vector.
85     if (arrayType.getElementType() == f16x2Ty ||
86         arrayType.getElementType() == f32x1Ty) {
87       for (unsigned i = 0; i < structType.getBody().size(); i++) {
88         Value el = rewriter.create<LLVM::ExtractValueOp>(
89             loc, structType.getBody()[i], intrinsicResult,
90             rewriter.getI64ArrayAttr(i));
91         el = rewriter.createOrFold<LLVM::BitcastOp>(
92             loc, arrayType.getElementType(), el);
93         elements.push_back(el);
94       }
95     }
96 
97     // The intrinsic returns i32, f64, and f32 values as individual scalars,
98     // even when the result is notionally a 64-bit wide element (e.g. f32x2). We
99     // need to extract them from the struct and pack them into the 64-bit wide
100     // rows of the vector result.
101     if (arrayType.getElementType() == i32x2Ty ||
102         arrayType.getElementType() == f64x2Ty ||
103         arrayType.getElementType() == f32x2Ty) {
104 
105       for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
106         Value vec =
107             rewriter.create<LLVM::UndefOp>(loc, arrayType.getElementType());
108         Value x1 = rewriter.create<LLVM::ExtractValueOp>(
109             loc, structType.getBody()[i * 2], intrinsicResult,
110             rewriter.getI64ArrayAttr(i * 2));
111         Value x2 = rewriter.create<LLVM::ExtractValueOp>(
112             loc, structType.getBody()[i * 2 + 1], intrinsicResult,
113             rewriter.getI64ArrayAttr(i * 2 + 1));
114         vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
115                                                      x1, makeConst(0));
116         vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
117                                                      x2, makeConst(1));
118         elements.push_back(vec);
119       }
120     }
121 
122     // Create the final vectorized result.
123     Value result = rewriter.create<LLVM::UndefOp>(loc, arrayType);
124     for (const auto &el : llvm::enumerate(elements)) {
125       result = rewriter.create<LLVM::InsertValueOp>(
126           loc, arrayType, result, el.value(),
127           rewriter.getI64ArrayAttr(el.index()));
128     }
129     return result;
130   }
131 
132   return intrinsicResult;
133 }
134 
135 /// The `gpu.mma.sync` converter below expects matrix fragment operands to be
136 /// given as 2D `vectors` where the rows are 32b or 64b wide. The
137 /// `nvvm.mma.sync` op expects these argments to be a given in a long list of
138 /// scalars of certain types. This function helps unpack the `vector` arguments
139 /// and cast them to the types expected by `nvvm.mma.sync`.
unpackOperandVector(RewriterBase & rewriter,Location loc,Value operand,NVVM::MMATypes operandPtxType)140 static SmallVector<Value> unpackOperandVector(RewriterBase &rewriter,
141                                               Location loc, Value operand,
142                                               NVVM::MMATypes operandPtxType) {
143   SmallVector<Value> result;
144   Type i32Ty = rewriter.getI32Type();
145   Type f64Ty = rewriter.getF64Type();
146   Type f32Ty = rewriter.getF32Type();
147   Type i8Ty = rewriter.getI8Type();
148   Type i4Ty = rewriter.getIntegerType(4);
149   Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4);
150   Type i4x8Ty = LLVM::getFixedVectorType(i4Ty, 8);
151   Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
152   auto arrayTy = operand.getType().cast<LLVM::LLVMArrayType>();
153 
154   for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
155     Value toUse = rewriter.create<LLVM::ExtractValueOp>(
156         loc, arrayTy.getElementType(), operand, rewriter.getI64ArrayAttr(i));
157 
158     // For 4xi8 vectors, the intrinsic expects these to be provided as i32
159     // scalar types.
160     if (arrayTy.getElementType() == i8x4Ty ||
161         arrayTy.getElementType() == i4x8Ty ||
162         (arrayTy.getElementType() == f32x1Ty &&
163          operandPtxType == NVVM::MMATypes::tf32)) {
164       result.push_back(
165           rewriter.create<LLVM::BitcastOp>(loc, rewriter.getI32Type(), toUse));
166       continue;
167     }
168 
169     // For some element types (i32, f32, f64), we need to unpack the inner
170     // vector/array type as well because the intrinsic expects individual
171     // scalars to be provided.
172     VectorType innerArrayTy = arrayTy.getElementType().dyn_cast<VectorType>();
173     if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
174                          innerArrayTy.getElementType() == f64Ty ||
175                          innerArrayTy.getElementType() == f32Ty)) {
176       for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
177            idx < innerSize; idx++) {
178         result.push_back(rewriter.create<LLVM::ExtractElementOp>(
179             loc, toUse,
180             rewriter.create<LLVM::ConstantOp>(
181                 loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(idx))));
182       }
183       continue;
184     }
185     result.push_back(toUse);
186   }
187   return result;
188 }
189 
190 namespace {
191 
192 struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
193   using ConvertOpToLLVMPattern<nvgpu::LdMatrixOp>::ConvertOpToLLVMPattern;
194 
195   LogicalResult
matchAndRewrite__anonfb45cff10211::MmaLdMatrixOpToNVVM196   matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
197                   ConversionPatternRewriter &rewriter) const override {
198     MLIRContext *ctx = getContext();
199     Location loc = op->getLoc();
200 
201     // The result type of ldmatrix will always be a struct of 32bit integer
202     // registers if more than one 32bit value is returned. Otherwise, the result
203     // is a single i32. The result type of the GPU operation is always a vector
204     // of shape (NumRegisters, VectorRegister) where VectorRegister is the
205     // vector type of the result and always 32 bits long. We bitcast the result
206     // of the NVVM::LdMatrix to this vector type.
207     auto vectorResultType = op->getResultTypes()[0].dyn_cast<VectorType>();
208     if (!vectorResultType) {
209       return failure();
210     }
211     Type innerVectorType = LLVM::getFixedVectorType(
212         vectorResultType.getElementType(), vectorResultType.getDimSize(1));
213 
214     int64_t num32BitRegs = vectorResultType.getDimSize(0);
215 
216     Type ldMatrixResultType;
217     if (num32BitRegs > 1) {
218       ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
219           ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type()));
220     } else {
221       ldMatrixResultType = rewriter.getI32Type();
222     }
223 
224     auto srcMemrefType = op.getSrcMemref().getType().cast<MemRefType>();
225     Value srcPtr =
226         getStridedElementPtr(loc, srcMemrefType, adaptor.getSrcMemref(),
227                              adaptor.getIndices(), rewriter);
228     Value ldMatrixResult = rewriter.create<NVVM::LdMatrixOp>(
229         loc, ldMatrixResultType, srcPtr,
230         /*num=*/op.getNumTiles(),
231         /*layout=*/op.getTranspose() ? NVVM::MMALayout::col
232                                      : NVVM::MMALayout::row);
233 
234     // The ldmatrix operation returns either a single i32 value or a struct of
235     // i32 values. Here we unpack those values and cast them back to their
236     // actual vector type (still of width 32b) and repack them into a result
237     // struct.
238     Type finalResultType = typeConverter->convertType(vectorResultType);
239     Value result = rewriter.create<LLVM::UndefOp>(loc, finalResultType);
240     for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
241       Value i32Register = num32BitRegs > 1
242                               ? rewriter.create<LLVM::ExtractValueOp>(
243                                     loc, rewriter.getI32Type(), ldMatrixResult,
244                                     rewriter.getI64ArrayAttr(i))
245                               : ldMatrixResult;
246       Value casted =
247           rewriter.create<LLVM::BitcastOp>(loc, innerVectorType, i32Register);
248       result = rewriter.create<LLVM::InsertValueOp>(
249           loc, finalResultType, result, casted, rewriter.getI64ArrayAttr(i));
250     }
251 
252     rewriter.replaceOp(op, result);
253     return success();
254   }
255 };
256 
257 struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
258   using ConvertOpToLLVMPattern<nvgpu::MmaSyncOp>::ConvertOpToLLVMPattern;
259 
260   LogicalResult
matchAndRewrite__anonfb45cff10211::MmaSyncOptoNVVM261   matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
262                   ConversionPatternRewriter &rewriter) const override {
263     Location loc = op->getLoc();
264     // Get the shapes of the MMAMatrix type being used. The shapes will
265     // choose which intrinsic this op will be lowered to.
266     auto aType = op.getMatrixA().getType().cast<VectorType>();
267     auto cType = op.getMatrixC().getType().cast<VectorType>();
268 
269     int64_t m = op.getMmaShape()[0].cast<IntegerAttr>().getInt();
270     int64_t n = op.getMmaShape()[1].cast<IntegerAttr>().getInt();
271     int64_t k = op.getMmaShape()[2].cast<IntegerAttr>().getInt();
272     std::array<int64_t, 3> gemmShape{m, n, k};
273 
274     NVVM::MMATypes ptxTypeA;
275     NVVM::MMATypes ptxTypeB;
276     Optional<NVVM::MMATypes> ptxTypeC = NVVM::MmaOp::inferOperandMMAType(
277         cType.getElementType(), /*isAccumulator=*/true);
278     if (!ptxTypeC) {
279       return op->emitError(
280           "could not infer the PTX type for the accumulator/result");
281     }
282 
283     Optional<NVVM::MMAIntOverflow> overflow(llvm::None);
284     if (aType.getElementType().isInteger(8)) {
285       ptxTypeA = NVVM::MMATypes::s8;
286       ptxTypeB = NVVM::MMATypes::s8;
287       overflow = NVVM::MMAIntOverflow::satfinite;
288     } else if (aType.getElementType().isInteger(4)) {
289       ptxTypeA = NVVM::MMATypes::s4;
290       ptxTypeB = NVVM::MMATypes::s4;
291       overflow = NVVM::MMAIntOverflow::satfinite;
292     } else if (aType.getElementType().isF16()) {
293       ptxTypeA = NVVM::MMATypes::f16;
294       ptxTypeB = NVVM::MMATypes::f16;
295     } else if (aType.getElementType().isF64()) {
296       ptxTypeA = NVVM::MMATypes::f64;
297       ptxTypeB = NVVM::MMATypes::f64;
298     } else if (aType.getElementType().isF32()) {
299       ptxTypeA = NVVM::MMATypes::tf32;
300       ptxTypeB = NVVM::MMATypes::tf32;
301     } else {
302       return op->emitError("could not deduce operand PTX types");
303     }
304 
305     SmallVector<Value> matA =
306         unpackOperandVector(rewriter, loc, adaptor.getMatrixA(), ptxTypeA);
307     SmallVector<Value> matB =
308         unpackOperandVector(rewriter, loc, adaptor.getMatrixB(), ptxTypeB);
309     SmallVector<Value> matC =
310         unpackOperandVector(rewriter, loc, adaptor.getMatrixC(), *ptxTypeC);
311 
312     Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
313     Type intrinsicResTy = inferIntrinsicResultType(
314         typeConverter->convertType(op->getResultTypes()[0]));
315     Value intrinsicResult = rewriter.create<NVVM::MmaOp>(
316         op.getLoc(), intrinsicResTy, matA, matB, matC,
317         /*shape=*/gemmShape,
318         /*b1Op=*/llvm::None,
319         /*intOverflow=*/overflow,
320         /*multiplicandPtxTypes=*/
321         std::array<NVVM::MMATypes, 2>{ptxTypeA, ptxTypeB},
322         /*multiplicandLayouts=*/
323         std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row,
324                                        NVVM::MMALayout::col});
325     rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy,
326                                                   desiredRetTy, intrinsicResult,
327                                                   rewriter));
328     return success();
329   }
330 };
331 
332 struct ConvertNVGPUToNVVMPass
333     : public ConvertNVGPUToNVVMBase<ConvertNVGPUToNVVMPass> {
334   ConvertNVGPUToNVVMPass() = default;
335 
runOnOperation__anonfb45cff10211::ConvertNVGPUToNVVMPass336   void runOnOperation() override {
337     RewritePatternSet patterns(&getContext());
338     LLVMTypeConverter converter(&getContext());
339     /// device-side async tokens cannot be materialized in nvvm. We just convert
340     /// them to a dummy i32 type in order to easily drop them during conversion.
341     converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type {
342       return converter.convertType(IntegerType::get(type.getContext(), 32));
343     });
344     populateNVGPUToNVVMConversionPatterns(converter, patterns);
345     LLVMConversionTarget target(getContext());
346     target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
347     target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
348     if (failed(applyPartialConversion(getOperation(), target,
349                                       std::move(patterns))))
350       signalPassFailure();
351   }
352 };
353 
354 struct NVGPUAsyncCopyLowering
355     : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> {
356   using ConvertOpToLLVMPattern<
357       nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern;
358 
359   LogicalResult
matchAndRewrite__anonfb45cff10211::NVGPUAsyncCopyLowering360   matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
361                   ConversionPatternRewriter &rewriter) const override {
362     Location loc = op->getLoc();
363     auto dstMemrefType = op.getDst().getType().cast<MemRefType>();
364     Value dstPtr = getStridedElementPtr(loc, dstMemrefType, adaptor.getDst(),
365                                         adaptor.getDstIndices(), rewriter);
366     auto i8Ty = IntegerType::get(op.getContext(), 8);
367     auto dstPointerType =
368         LLVM::LLVMPointerType::get(i8Ty, dstMemrefType.getMemorySpaceAsInt());
369     dstPtr = rewriter.create<LLVM::BitcastOp>(loc, dstPointerType, dstPtr);
370 
371     auto srcMemrefType = op.getSrc().getType().cast<MemRefType>();
372 
373     Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(),
374                                         adaptor.getSrcIndices(), rewriter);
375     auto srcPointerType =
376         LLVM::LLVMPointerType::get(i8Ty, srcMemrefType.getMemorySpaceAsInt());
377     scrPtr = rewriter.create<LLVM::BitcastOp>(loc, srcPointerType, scrPtr);
378     // Intrinsics takes a global pointer so we need an address space cast.
379     auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
380         i8Ty, NVVM::NVVMMemorySpace::kGlobalMemorySpace);
381     scrPtr = rewriter.create<LLVM::AddrSpaceCastOp>(loc, srcPointerGlobalType,
382                                                     scrPtr);
383     int64_t numElements = adaptor.getNumElements().getZExtValue();
384     int64_t sizeInBytes =
385         (dstMemrefType.getElementTypeBitWidth() * numElements) / 8;
386     // bypass L1 is only supported for byte sizes of 16, we drop the hint
387     // otherwise.
388     UnitAttr bypassL1 =
389         sizeInBytes == 16 ? adaptor.getBypassL1Attr() : UnitAttr();
390     rewriter.create<NVVM::CpAsyncOp>(
391         loc, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes), bypassL1);
392 
393     // Drop the result token.
394     Value zero = rewriter.create<LLVM::ConstantOp>(
395         op->getLoc(), IntegerType::get(op.getContext(), 32),
396         rewriter.getI32IntegerAttr(0));
397     rewriter.replaceOp(op, zero);
398     return success();
399   }
400 };
401 
402 struct NVGPUAsyncCreateGroupLowering
403     : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCreateGroupOp> {
404   using ConvertOpToLLVMPattern<
405       nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern;
406 
407   LogicalResult
matchAndRewrite__anonfb45cff10211::NVGPUAsyncCreateGroupLowering408   matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
409                   ConversionPatternRewriter &rewriter) const override {
410     rewriter.create<NVVM::CpAsyncCommitGroupOp>(op.getLoc());
411     // Drop the result token.
412     Value zero = rewriter.create<LLVM::ConstantOp>(
413         op->getLoc(), IntegerType::get(op.getContext(), 32),
414         rewriter.getI32IntegerAttr(0));
415     rewriter.replaceOp(op, zero);
416     return success();
417   }
418 };
419 
420 struct NVGPUAsyncWaitLowering
421     : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncWaitOp> {
422   using ConvertOpToLLVMPattern<
423       nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern;
424 
425   LogicalResult
matchAndRewrite__anonfb45cff10211::NVGPUAsyncWaitLowering426   matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
427                   ConversionPatternRewriter &rewriter) const override {
428     // If numGroup is not present pick 0 as a conservative correct value.
429     int32_t numGroups = adaptor.getNumGroups().value_or(0);
430     rewriter.create<NVVM::CpAsyncWaitGroupOp>(op.getLoc(), numGroups);
431     rewriter.eraseOp(op);
432     return success();
433   }
434 };
435 
436 } // namespace
437 
populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)438 void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
439                                                  RewritePatternSet &patterns) {
440   patterns.add<MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
441                NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering>(
442       converter);
443 }
444 
createConvertNVGPUToNVVMPass()445 std::unique_ptr<Pass> mlir::createConvertNVGPUToNVVMPass() {
446   return std::make_unique<ConvertNVGPUToNVVMPass>();
447 }
448