1 //===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
10 #include "../PassDetail.h"
11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12 #include "mlir/Conversion/LLVMCommon/Pattern.h"
13 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
14 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
15 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
16
17 using namespace mlir;
18
19 /// Returns the type for the intrinsic given the vectorResultType of the
20 /// `gpu.mma.sync` operation.
inferIntrinsicResultType(Type vectorResultType)21 static Type inferIntrinsicResultType(Type vectorResultType) {
22 MLIRContext *ctx = vectorResultType.getContext();
23 auto a = vectorResultType.cast<LLVM::LLVMArrayType>();
24 auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2);
25 auto i32Ty = IntegerType::get(ctx, 32);
26 auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
27 Type f64Ty = Float64Type::get(ctx);
28 Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
29 Type f32Ty = Float32Type::get(ctx);
30 Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
31 if (a.getElementType() == f16x2Ty) {
32 return LLVM::LLVMStructType::getLiteral(
33 ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
34 }
35 if (a.getElementType() == i32x2Ty) {
36 return LLVM::LLVMStructType::getLiteral(
37 ctx,
38 SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty));
39 }
40 if (a.getElementType() == f64x2Ty) {
41 return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
42 }
43 if (a.getElementType() == f32x2Ty) {
44 return LLVM::LLVMStructType::getLiteral(
45 ctx,
46 SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty));
47 }
48 if (a.getElementType() == LLVM::getFixedVectorType(f32Ty, 1)) {
49 return LLVM::LLVMStructType::getLiteral(
50 ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty));
51 }
52 return vectorResultType;
53 }
54
55 /// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is
56 /// always an LLVM struct) into a fragment that is compatible with the vector
57 /// type of this operation. This involves extracting elements from the struct
58 /// and inserting them into an LLVM array. These extra data-movement
59 /// operations should be canonicalized away by the LLVM backend.
convertIntrinsicResult(Location loc,Type intrinsicResultType,Type resultType,Value intrinsicResult,RewriterBase & rewriter)60 static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
61 Type resultType, Value intrinsicResult,
62 RewriterBase &rewriter) {
63 MLIRContext *ctx = rewriter.getContext();
64 auto structType = intrinsicResultType.dyn_cast<LLVM::LLVMStructType>();
65 auto arrayType = resultType.dyn_cast<LLVM::LLVMArrayType>();
66 Type i32Ty = rewriter.getI32Type();
67 Type f32Ty = rewriter.getF32Type();
68 Type f64Ty = rewriter.getF64Type();
69 Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2);
70 Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
71 Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
72 Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
73 Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
74
75 auto makeConst = [&](int32_t index) -> Value {
76 return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
77 rewriter.getI32IntegerAttr(index));
78 };
79
80 if (arrayType) {
81 SmallVector<Value, 4> elements;
82
83 // The intrinsic returns 32-bit wide elements in a form which can be
84 // directly bitcasted and inserted into the result vector.
85 if (arrayType.getElementType() == f16x2Ty ||
86 arrayType.getElementType() == f32x1Ty) {
87 for (unsigned i = 0; i < structType.getBody().size(); i++) {
88 Value el = rewriter.create<LLVM::ExtractValueOp>(
89 loc, structType.getBody()[i], intrinsicResult,
90 rewriter.getI64ArrayAttr(i));
91 el = rewriter.createOrFold<LLVM::BitcastOp>(
92 loc, arrayType.getElementType(), el);
93 elements.push_back(el);
94 }
95 }
96
97 // The intrinsic returns i32, f64, and f32 values as individual scalars,
98 // even when the result is notionally a 64-bit wide element (e.g. f32x2). We
99 // need to extract them from the struct and pack them into the 64-bit wide
100 // rows of the vector result.
101 if (arrayType.getElementType() == i32x2Ty ||
102 arrayType.getElementType() == f64x2Ty ||
103 arrayType.getElementType() == f32x2Ty) {
104
105 for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
106 Value vec =
107 rewriter.create<LLVM::UndefOp>(loc, arrayType.getElementType());
108 Value x1 = rewriter.create<LLVM::ExtractValueOp>(
109 loc, structType.getBody()[i * 2], intrinsicResult,
110 rewriter.getI64ArrayAttr(i * 2));
111 Value x2 = rewriter.create<LLVM::ExtractValueOp>(
112 loc, structType.getBody()[i * 2 + 1], intrinsicResult,
113 rewriter.getI64ArrayAttr(i * 2 + 1));
114 vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
115 x1, makeConst(0));
116 vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
117 x2, makeConst(1));
118 elements.push_back(vec);
119 }
120 }
121
122 // Create the final vectorized result.
123 Value result = rewriter.create<LLVM::UndefOp>(loc, arrayType);
124 for (const auto &el : llvm::enumerate(elements)) {
125 result = rewriter.create<LLVM::InsertValueOp>(
126 loc, arrayType, result, el.value(),
127 rewriter.getI64ArrayAttr(el.index()));
128 }
129 return result;
130 }
131
132 return intrinsicResult;
133 }
134
135 /// The `gpu.mma.sync` converter below expects matrix fragment operands to be
136 /// given as 2D `vectors` where the rows are 32b or 64b wide. The
137 /// `nvvm.mma.sync` op expects these argments to be a given in a long list of
138 /// scalars of certain types. This function helps unpack the `vector` arguments
139 /// and cast them to the types expected by `nvvm.mma.sync`.
unpackOperandVector(RewriterBase & rewriter,Location loc,Value operand,NVVM::MMATypes operandPtxType)140 static SmallVector<Value> unpackOperandVector(RewriterBase &rewriter,
141 Location loc, Value operand,
142 NVVM::MMATypes operandPtxType) {
143 SmallVector<Value> result;
144 Type i32Ty = rewriter.getI32Type();
145 Type f64Ty = rewriter.getF64Type();
146 Type f32Ty = rewriter.getF32Type();
147 Type i8Ty = rewriter.getI8Type();
148 Type i4Ty = rewriter.getIntegerType(4);
149 Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4);
150 Type i4x8Ty = LLVM::getFixedVectorType(i4Ty, 8);
151 Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
152 auto arrayTy = operand.getType().cast<LLVM::LLVMArrayType>();
153
154 for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
155 Value toUse = rewriter.create<LLVM::ExtractValueOp>(
156 loc, arrayTy.getElementType(), operand, rewriter.getI64ArrayAttr(i));
157
158 // For 4xi8 vectors, the intrinsic expects these to be provided as i32
159 // scalar types.
160 if (arrayTy.getElementType() == i8x4Ty ||
161 arrayTy.getElementType() == i4x8Ty ||
162 (arrayTy.getElementType() == f32x1Ty &&
163 operandPtxType == NVVM::MMATypes::tf32)) {
164 result.push_back(
165 rewriter.create<LLVM::BitcastOp>(loc, rewriter.getI32Type(), toUse));
166 continue;
167 }
168
169 // For some element types (i32, f32, f64), we need to unpack the inner
170 // vector/array type as well because the intrinsic expects individual
171 // scalars to be provided.
172 VectorType innerArrayTy = arrayTy.getElementType().dyn_cast<VectorType>();
173 if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
174 innerArrayTy.getElementType() == f64Ty ||
175 innerArrayTy.getElementType() == f32Ty)) {
176 for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
177 idx < innerSize; idx++) {
178 result.push_back(rewriter.create<LLVM::ExtractElementOp>(
179 loc, toUse,
180 rewriter.create<LLVM::ConstantOp>(
181 loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(idx))));
182 }
183 continue;
184 }
185 result.push_back(toUse);
186 }
187 return result;
188 }
189
190 namespace {
191
192 struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
193 using ConvertOpToLLVMPattern<nvgpu::LdMatrixOp>::ConvertOpToLLVMPattern;
194
195 LogicalResult
matchAndRewrite__anonfb45cff10211::MmaLdMatrixOpToNVVM196 matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
197 ConversionPatternRewriter &rewriter) const override {
198 MLIRContext *ctx = getContext();
199 Location loc = op->getLoc();
200
201 // The result type of ldmatrix will always be a struct of 32bit integer
202 // registers if more than one 32bit value is returned. Otherwise, the result
203 // is a single i32. The result type of the GPU operation is always a vector
204 // of shape (NumRegisters, VectorRegister) where VectorRegister is the
205 // vector type of the result and always 32 bits long. We bitcast the result
206 // of the NVVM::LdMatrix to this vector type.
207 auto vectorResultType = op->getResultTypes()[0].dyn_cast<VectorType>();
208 if (!vectorResultType) {
209 return failure();
210 }
211 Type innerVectorType = LLVM::getFixedVectorType(
212 vectorResultType.getElementType(), vectorResultType.getDimSize(1));
213
214 int64_t num32BitRegs = vectorResultType.getDimSize(0);
215
216 Type ldMatrixResultType;
217 if (num32BitRegs > 1) {
218 ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
219 ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type()));
220 } else {
221 ldMatrixResultType = rewriter.getI32Type();
222 }
223
224 auto srcMemrefType = op.getSrcMemref().getType().cast<MemRefType>();
225 Value srcPtr =
226 getStridedElementPtr(loc, srcMemrefType, adaptor.getSrcMemref(),
227 adaptor.getIndices(), rewriter);
228 Value ldMatrixResult = rewriter.create<NVVM::LdMatrixOp>(
229 loc, ldMatrixResultType, srcPtr,
230 /*num=*/op.getNumTiles(),
231 /*layout=*/op.getTranspose() ? NVVM::MMALayout::col
232 : NVVM::MMALayout::row);
233
234 // The ldmatrix operation returns either a single i32 value or a struct of
235 // i32 values. Here we unpack those values and cast them back to their
236 // actual vector type (still of width 32b) and repack them into a result
237 // struct.
238 Type finalResultType = typeConverter->convertType(vectorResultType);
239 Value result = rewriter.create<LLVM::UndefOp>(loc, finalResultType);
240 for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
241 Value i32Register = num32BitRegs > 1
242 ? rewriter.create<LLVM::ExtractValueOp>(
243 loc, rewriter.getI32Type(), ldMatrixResult,
244 rewriter.getI64ArrayAttr(i))
245 : ldMatrixResult;
246 Value casted =
247 rewriter.create<LLVM::BitcastOp>(loc, innerVectorType, i32Register);
248 result = rewriter.create<LLVM::InsertValueOp>(
249 loc, finalResultType, result, casted, rewriter.getI64ArrayAttr(i));
250 }
251
252 rewriter.replaceOp(op, result);
253 return success();
254 }
255 };
256
257 struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
258 using ConvertOpToLLVMPattern<nvgpu::MmaSyncOp>::ConvertOpToLLVMPattern;
259
260 LogicalResult
matchAndRewrite__anonfb45cff10211::MmaSyncOptoNVVM261 matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
262 ConversionPatternRewriter &rewriter) const override {
263 Location loc = op->getLoc();
264 // Get the shapes of the MMAMatrix type being used. The shapes will
265 // choose which intrinsic this op will be lowered to.
266 auto aType = op.getMatrixA().getType().cast<VectorType>();
267 auto cType = op.getMatrixC().getType().cast<VectorType>();
268
269 int64_t m = op.getMmaShape()[0].cast<IntegerAttr>().getInt();
270 int64_t n = op.getMmaShape()[1].cast<IntegerAttr>().getInt();
271 int64_t k = op.getMmaShape()[2].cast<IntegerAttr>().getInt();
272 std::array<int64_t, 3> gemmShape{m, n, k};
273
274 NVVM::MMATypes ptxTypeA;
275 NVVM::MMATypes ptxTypeB;
276 Optional<NVVM::MMATypes> ptxTypeC = NVVM::MmaOp::inferOperandMMAType(
277 cType.getElementType(), /*isAccumulator=*/true);
278 if (!ptxTypeC) {
279 return op->emitError(
280 "could not infer the PTX type for the accumulator/result");
281 }
282
283 Optional<NVVM::MMAIntOverflow> overflow(llvm::None);
284 if (aType.getElementType().isInteger(8)) {
285 ptxTypeA = NVVM::MMATypes::s8;
286 ptxTypeB = NVVM::MMATypes::s8;
287 overflow = NVVM::MMAIntOverflow::satfinite;
288 } else if (aType.getElementType().isInteger(4)) {
289 ptxTypeA = NVVM::MMATypes::s4;
290 ptxTypeB = NVVM::MMATypes::s4;
291 overflow = NVVM::MMAIntOverflow::satfinite;
292 } else if (aType.getElementType().isF16()) {
293 ptxTypeA = NVVM::MMATypes::f16;
294 ptxTypeB = NVVM::MMATypes::f16;
295 } else if (aType.getElementType().isF64()) {
296 ptxTypeA = NVVM::MMATypes::f64;
297 ptxTypeB = NVVM::MMATypes::f64;
298 } else if (aType.getElementType().isF32()) {
299 ptxTypeA = NVVM::MMATypes::tf32;
300 ptxTypeB = NVVM::MMATypes::tf32;
301 } else {
302 return op->emitError("could not deduce operand PTX types");
303 }
304
305 SmallVector<Value> matA =
306 unpackOperandVector(rewriter, loc, adaptor.getMatrixA(), ptxTypeA);
307 SmallVector<Value> matB =
308 unpackOperandVector(rewriter, loc, adaptor.getMatrixB(), ptxTypeB);
309 SmallVector<Value> matC =
310 unpackOperandVector(rewriter, loc, adaptor.getMatrixC(), *ptxTypeC);
311
312 Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
313 Type intrinsicResTy = inferIntrinsicResultType(
314 typeConverter->convertType(op->getResultTypes()[0]));
315 Value intrinsicResult = rewriter.create<NVVM::MmaOp>(
316 op.getLoc(), intrinsicResTy, matA, matB, matC,
317 /*shape=*/gemmShape,
318 /*b1Op=*/llvm::None,
319 /*intOverflow=*/overflow,
320 /*multiplicandPtxTypes=*/
321 std::array<NVVM::MMATypes, 2>{ptxTypeA, ptxTypeB},
322 /*multiplicandLayouts=*/
323 std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row,
324 NVVM::MMALayout::col});
325 rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy,
326 desiredRetTy, intrinsicResult,
327 rewriter));
328 return success();
329 }
330 };
331
332 struct ConvertNVGPUToNVVMPass
333 : public ConvertNVGPUToNVVMBase<ConvertNVGPUToNVVMPass> {
334 ConvertNVGPUToNVVMPass() = default;
335
runOnOperation__anonfb45cff10211::ConvertNVGPUToNVVMPass336 void runOnOperation() override {
337 RewritePatternSet patterns(&getContext());
338 LLVMTypeConverter converter(&getContext());
339 /// device-side async tokens cannot be materialized in nvvm. We just convert
340 /// them to a dummy i32 type in order to easily drop them during conversion.
341 converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type {
342 return converter.convertType(IntegerType::get(type.getContext(), 32));
343 });
344 populateNVGPUToNVVMConversionPatterns(converter, patterns);
345 LLVMConversionTarget target(getContext());
346 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
347 target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
348 if (failed(applyPartialConversion(getOperation(), target,
349 std::move(patterns))))
350 signalPassFailure();
351 }
352 };
353
354 struct NVGPUAsyncCopyLowering
355 : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> {
356 using ConvertOpToLLVMPattern<
357 nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern;
358
359 LogicalResult
matchAndRewrite__anonfb45cff10211::NVGPUAsyncCopyLowering360 matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
361 ConversionPatternRewriter &rewriter) const override {
362 Location loc = op->getLoc();
363 auto dstMemrefType = op.getDst().getType().cast<MemRefType>();
364 Value dstPtr = getStridedElementPtr(loc, dstMemrefType, adaptor.getDst(),
365 adaptor.getDstIndices(), rewriter);
366 auto i8Ty = IntegerType::get(op.getContext(), 8);
367 auto dstPointerType =
368 LLVM::LLVMPointerType::get(i8Ty, dstMemrefType.getMemorySpaceAsInt());
369 dstPtr = rewriter.create<LLVM::BitcastOp>(loc, dstPointerType, dstPtr);
370
371 auto srcMemrefType = op.getSrc().getType().cast<MemRefType>();
372
373 Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(),
374 adaptor.getSrcIndices(), rewriter);
375 auto srcPointerType =
376 LLVM::LLVMPointerType::get(i8Ty, srcMemrefType.getMemorySpaceAsInt());
377 scrPtr = rewriter.create<LLVM::BitcastOp>(loc, srcPointerType, scrPtr);
378 // Intrinsics takes a global pointer so we need an address space cast.
379 auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
380 i8Ty, NVVM::NVVMMemorySpace::kGlobalMemorySpace);
381 scrPtr = rewriter.create<LLVM::AddrSpaceCastOp>(loc, srcPointerGlobalType,
382 scrPtr);
383 int64_t numElements = adaptor.getNumElements().getZExtValue();
384 int64_t sizeInBytes =
385 (dstMemrefType.getElementTypeBitWidth() * numElements) / 8;
386 // bypass L1 is only supported for byte sizes of 16, we drop the hint
387 // otherwise.
388 UnitAttr bypassL1 =
389 sizeInBytes == 16 ? adaptor.getBypassL1Attr() : UnitAttr();
390 rewriter.create<NVVM::CpAsyncOp>(
391 loc, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes), bypassL1);
392
393 // Drop the result token.
394 Value zero = rewriter.create<LLVM::ConstantOp>(
395 op->getLoc(), IntegerType::get(op.getContext(), 32),
396 rewriter.getI32IntegerAttr(0));
397 rewriter.replaceOp(op, zero);
398 return success();
399 }
400 };
401
402 struct NVGPUAsyncCreateGroupLowering
403 : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCreateGroupOp> {
404 using ConvertOpToLLVMPattern<
405 nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern;
406
407 LogicalResult
matchAndRewrite__anonfb45cff10211::NVGPUAsyncCreateGroupLowering408 matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
409 ConversionPatternRewriter &rewriter) const override {
410 rewriter.create<NVVM::CpAsyncCommitGroupOp>(op.getLoc());
411 // Drop the result token.
412 Value zero = rewriter.create<LLVM::ConstantOp>(
413 op->getLoc(), IntegerType::get(op.getContext(), 32),
414 rewriter.getI32IntegerAttr(0));
415 rewriter.replaceOp(op, zero);
416 return success();
417 }
418 };
419
420 struct NVGPUAsyncWaitLowering
421 : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncWaitOp> {
422 using ConvertOpToLLVMPattern<
423 nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern;
424
425 LogicalResult
matchAndRewrite__anonfb45cff10211::NVGPUAsyncWaitLowering426 matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
427 ConversionPatternRewriter &rewriter) const override {
428 // If numGroup is not present pick 0 as a conservative correct value.
429 int32_t numGroups = adaptor.getNumGroups().value_or(0);
430 rewriter.create<NVVM::CpAsyncWaitGroupOp>(op.getLoc(), numGroups);
431 rewriter.eraseOp(op);
432 return success();
433 }
434 };
435
436 } // namespace
437
populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)438 void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
439 RewritePatternSet &patterns) {
440 patterns.add<MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
441 NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering>(
442 converter);
443 }
444
createConvertNVGPUToNVVMPass()445 std::unique_ptr<Pass> mlir::createConvertNVGPUToNVVMPass() {
446 return std::make_unique<ConvertNVGPUToNVVMPass>();
447 }
448