1 //===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h" 10 #include "../PassDetail.h" 11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 12 #include "mlir/Conversion/LLVMCommon/Pattern.h" 13 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 14 #include "mlir/Dialect/LLVMIR/NVVMDialect.h" 15 #include "mlir/Dialect/NVGPU/NVGPUDialect.h" 16 17 using namespace mlir; 18 19 /// Returns the type for the intrinsic given the vectorResultType of the 20 /// `gpu.mma.sync` operation. 21 static Type inferIntrinsicResultType(Type vectorResultType) { 22 MLIRContext *ctx = vectorResultType.getContext(); 23 auto a = vectorResultType.cast<LLVM::LLVMArrayType>(); 24 auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2); 25 auto i32Ty = IntegerType::get(ctx, 32); 26 auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2); 27 Type f64Ty = Float64Type::get(ctx); 28 Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2); 29 Type f32Ty = Float32Type::get(ctx); 30 Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2); 31 if (a.getElementType() == f16x2Ty) { 32 return LLVM::LLVMStructType::getLiteral( 33 ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty)); 34 } 35 if (a.getElementType() == i32x2Ty) { 36 return LLVM::LLVMStructType::getLiteral( 37 ctx, 38 SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty)); 39 } 40 if (a.getElementType() == f64x2Ty) { 41 return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty}); 42 } 43 if (a.getElementType() == f32x2Ty) { 44 return LLVM::LLVMStructType::getLiteral( 45 ctx, 46 SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty)); 47 } 48 if (a.getElementType() == LLVM::getFixedVectorType(f32Ty, 1)) { 49 return LLVM::LLVMStructType::getLiteral( 50 ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty)); 51 } 52 return vectorResultType; 53 } 54 55 /// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is 56 /// always an LLVM struct) into a fragment that is compatible with the vector 57 /// type of this operation. This involves extracting elements from the struct 58 /// and inserting them into an LLVM array. These extra data-movement 59 /// operations should be canonicalized away by the LLVM backend. 60 static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, 61 Type resultType, Value intrinsicResult, 62 RewriterBase &rewriter) { 63 MLIRContext *ctx = rewriter.getContext(); 64 auto structType = intrinsicResultType.dyn_cast<LLVM::LLVMStructType>(); 65 auto arrayType = resultType.dyn_cast<LLVM::LLVMArrayType>(); 66 Type i32Ty = rewriter.getI32Type(); 67 Type f32Ty = rewriter.getF32Type(); 68 Type f64Ty = rewriter.getF64Type(); 69 Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2); 70 Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2); 71 Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2); 72 Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2); 73 Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1); 74 75 auto makeConst = [&](int32_t index) -> Value { 76 return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32), 77 rewriter.getI32IntegerAttr(index)); 78 }; 79 80 if (arrayType) { 81 SmallVector<Value, 4> elements; 82 83 // The intrinsic returns 32-bit wide elements in a form which can be 84 // directly bitcasted and inserted into the result vector. 85 if (arrayType.getElementType() == f16x2Ty || 86 arrayType.getElementType() == f32x1Ty) { 87 for (unsigned i = 0; i < structType.getBody().size(); i++) { 88 Value el = rewriter.create<LLVM::ExtractValueOp>( 89 loc, structType.getBody()[i], intrinsicResult, 90 rewriter.getI64ArrayAttr(i)); 91 el = rewriter.createOrFold<LLVM::BitcastOp>( 92 loc, arrayType.getElementType(), el); 93 elements.push_back(el); 94 } 95 } 96 97 // The intrinsic returns i32, f64, and f32 values as individual scalars, 98 // even when the result is notionally a 64-bit wide element (e.g. f32x2). We 99 // need to extract them from the struct and pack them into the 64-bit wide 100 // rows of the vector result. 101 if (arrayType.getElementType() == i32x2Ty || 102 arrayType.getElementType() == f64x2Ty || 103 arrayType.getElementType() == f32x2Ty) { 104 105 for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) { 106 Value vec = 107 rewriter.create<LLVM::UndefOp>(loc, arrayType.getElementType()); 108 Value x1 = rewriter.create<LLVM::ExtractValueOp>( 109 loc, structType.getBody()[i * 2], intrinsicResult, 110 rewriter.getI64ArrayAttr(i * 2)); 111 Value x2 = rewriter.create<LLVM::ExtractValueOp>( 112 loc, structType.getBody()[i * 2 + 1], intrinsicResult, 113 rewriter.getI64ArrayAttr(i * 2 + 1)); 114 vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec, 115 x1, makeConst(0)); 116 vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec, 117 x2, makeConst(1)); 118 elements.push_back(vec); 119 } 120 } 121 122 // Create the final vectorized result. 123 Value result = rewriter.create<LLVM::UndefOp>(loc, arrayType); 124 for (const auto &el : llvm::enumerate(elements)) { 125 result = rewriter.create<LLVM::InsertValueOp>( 126 loc, arrayType, result, el.value(), 127 rewriter.getI64ArrayAttr(el.index())); 128 } 129 return result; 130 } 131 132 return intrinsicResult; 133 } 134 135 /// The `gpu.mma.sync` converter below expects matrix fragment operands to be 136 /// given as 2D `vectors` where the rows are 32b or 64b wide. The 137 /// `nvvm.mma.sync` op expects these argments to be a given in a long list of 138 /// scalars of certain types. This function helps unpack the `vector` arguments 139 /// and cast them to the types expected by `nvvm.mma.sync`. 140 static SmallVector<Value> unpackOperandVector(RewriterBase &rewriter, 141 Location loc, Value operand, 142 NVVM::MMATypes operandPtxType) { 143 SmallVector<Value> result; 144 Type i32Ty = rewriter.getI32Type(); 145 Type f64Ty = rewriter.getF64Type(); 146 Type f32Ty = rewriter.getF32Type(); 147 Type i8Ty = rewriter.getI8Type(); 148 Type i4Ty = rewriter.getIntegerType(4); 149 Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4); 150 Type i4x8Ty = LLVM::getFixedVectorType(i4Ty, 8); 151 Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1); 152 auto arrayTy = operand.getType().cast<LLVM::LLVMArrayType>(); 153 154 for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) { 155 Value toUse = rewriter.create<LLVM::ExtractValueOp>( 156 loc, arrayTy.getElementType(), operand, rewriter.getI64ArrayAttr(i)); 157 158 // For 4xi8 vectors, the intrinsic expects these to be provided as i32 159 // scalar types. 160 if (arrayTy.getElementType() == i8x4Ty || 161 arrayTy.getElementType() == i4x8Ty || 162 (arrayTy.getElementType() == f32x1Ty && 163 operandPtxType == NVVM::MMATypes::tf32)) { 164 result.push_back( 165 rewriter.create<LLVM::BitcastOp>(loc, rewriter.getI32Type(), toUse)); 166 continue; 167 } 168 169 // For some element types (i32, f32, f64), we need to unpack the inner 170 // vector/array type as well because the intrinsic expects individual 171 // scalars to be provided. 172 VectorType innerArrayTy = arrayTy.getElementType().dyn_cast<VectorType>(); 173 if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty || 174 innerArrayTy.getElementType() == f64Ty || 175 innerArrayTy.getElementType() == f32Ty)) { 176 for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements(); 177 idx < innerSize; idx++) { 178 result.push_back(rewriter.create<LLVM::ExtractElementOp>( 179 loc, toUse, 180 rewriter.create<LLVM::ConstantOp>( 181 loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(idx)))); 182 } 183 continue; 184 } 185 result.push_back(toUse); 186 } 187 return result; 188 } 189 190 namespace { 191 192 struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> { 193 using ConvertOpToLLVMPattern<nvgpu::LdMatrixOp>::ConvertOpToLLVMPattern; 194 195 LogicalResult 196 matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor, 197 ConversionPatternRewriter &rewriter) const override { 198 MLIRContext *ctx = getContext(); 199 Location loc = op->getLoc(); 200 201 // The result type of ldmatrix will always be a struct of 32bit integer 202 // registers if more than one 32bit value is returned. Otherwise, the result 203 // is a single i32. The result type of the GPU operation is always a vector 204 // of shape (NumRegisters, VectorRegister) where VectorRegister is the 205 // vector type of the result and always 32 bits long. We bitcast the result 206 // of the NVVM::LdMatrix to this vector type. 207 auto vectorResultType = op->getResultTypes()[0].dyn_cast<VectorType>(); 208 if (!vectorResultType) { 209 return failure(); 210 } 211 Type innerVectorType = LLVM::getFixedVectorType( 212 vectorResultType.getElementType(), vectorResultType.getDimSize(1)); 213 214 int64_t num32BitRegs = vectorResultType.getDimSize(0); 215 216 Type ldMatrixResultType; 217 if (num32BitRegs > 1) { 218 ldMatrixResultType = LLVM::LLVMStructType::getLiteral( 219 ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type())); 220 } else { 221 ldMatrixResultType = rewriter.getI32Type(); 222 } 223 224 auto srcMemrefType = op.srcMemref().getType().cast<MemRefType>(); 225 Value srcPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.srcMemref(), 226 adaptor.indices(), rewriter); 227 Value ldMatrixResult = rewriter.create<NVVM::LdMatrixOp>( 228 loc, ldMatrixResultType, srcPtr, 229 /*num=*/op.numTiles(), 230 /*layout=*/op.transpose() ? NVVM::MMALayout::col 231 : NVVM::MMALayout::row); 232 233 // The ldmatrix operation returns either a single i32 value or a struct of 234 // i32 values. Here we unpack those values and cast them back to their 235 // actual vector type (still of width 32b) and repack them into a result 236 // struct. 237 Type finalResultType = typeConverter->convertType(vectorResultType); 238 Value result = rewriter.create<LLVM::UndefOp>(loc, finalResultType); 239 for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) { 240 Value i32Register = num32BitRegs > 1 241 ? rewriter.create<LLVM::ExtractValueOp>( 242 loc, rewriter.getI32Type(), ldMatrixResult, 243 rewriter.getI64ArrayAttr(i)) 244 : ldMatrixResult; 245 Value casted = 246 rewriter.create<LLVM::BitcastOp>(loc, innerVectorType, i32Register); 247 result = rewriter.create<LLVM::InsertValueOp>( 248 loc, finalResultType, result, casted, rewriter.getI64ArrayAttr(i)); 249 } 250 251 rewriter.replaceOp(op, result); 252 return success(); 253 } 254 }; 255 256 struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> { 257 using ConvertOpToLLVMPattern<nvgpu::MmaSyncOp>::ConvertOpToLLVMPattern; 258 259 LogicalResult 260 matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor, 261 ConversionPatternRewriter &rewriter) const override { 262 Location loc = op->getLoc(); 263 // Get the shapes of the MMAMatrix type being used. The shapes will 264 // choose which intrinsic this op will be lowered to. 265 auto aType = op.matrixA().getType().cast<VectorType>(); 266 auto cType = op.matrixC().getType().cast<VectorType>(); 267 268 int64_t m = op.mmaShape()[0].cast<IntegerAttr>().getInt(); 269 int64_t n = op.mmaShape()[1].cast<IntegerAttr>().getInt(); 270 int64_t k = op.mmaShape()[2].cast<IntegerAttr>().getInt(); 271 std::array<int64_t, 3> gemmShape{m, n, k}; 272 273 NVVM::MMATypes ptxTypeA; 274 NVVM::MMATypes ptxTypeB; 275 Optional<NVVM::MMATypes> ptxTypeC = NVVM::MmaOp::inferOperandMMAType( 276 cType.getElementType(), /*isAccumulator=*/true); 277 if (!ptxTypeC) { 278 return op->emitError( 279 "could not infer the PTX type for the accumulator/result"); 280 } 281 282 Optional<NVVM::MMAIntOverflow> overflow(llvm::None); 283 if (aType.getElementType().isInteger(8)) { 284 ptxTypeA = NVVM::MMATypes::s8; 285 ptxTypeB = NVVM::MMATypes::s8; 286 overflow = NVVM::MMAIntOverflow::satfinite; 287 } else if (aType.getElementType().isInteger(4)) { 288 ptxTypeA = NVVM::MMATypes::s4; 289 ptxTypeB = NVVM::MMATypes::s4; 290 overflow = NVVM::MMAIntOverflow::satfinite; 291 } else if (aType.getElementType().isF16()) { 292 ptxTypeA = NVVM::MMATypes::f16; 293 ptxTypeB = NVVM::MMATypes::f16; 294 } else if (aType.getElementType().isF64()) { 295 ptxTypeA = NVVM::MMATypes::f64; 296 ptxTypeB = NVVM::MMATypes::f64; 297 } else if (aType.getElementType().isF32()) { 298 ptxTypeA = NVVM::MMATypes::tf32; 299 ptxTypeB = NVVM::MMATypes::tf32; 300 } else { 301 return op->emitError("could not deduce operand PTX types"); 302 } 303 304 SmallVector<Value> matA = 305 unpackOperandVector(rewriter, loc, adaptor.matrixA(), ptxTypeA); 306 SmallVector<Value> matB = 307 unpackOperandVector(rewriter, loc, adaptor.matrixB(), ptxTypeB); 308 SmallVector<Value> matC = 309 unpackOperandVector(rewriter, loc, adaptor.matrixC(), *ptxTypeC); 310 311 Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]); 312 Type intrinsicResTy = inferIntrinsicResultType( 313 typeConverter->convertType(op->getResultTypes()[0])); 314 Value intrinsicResult = rewriter.create<NVVM::MmaOp>( 315 op.getLoc(), intrinsicResTy, matA, matB, matC, 316 /*shape=*/gemmShape, 317 /*b1Op=*/llvm::None, 318 /*intOverflow=*/overflow, 319 /*multiplicandPtxTypes=*/ 320 std::array<NVVM::MMATypes, 2>{ptxTypeA, ptxTypeB}, 321 /*multiplicandLayouts=*/ 322 std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row, 323 NVVM::MMALayout::col}); 324 rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, 325 desiredRetTy, intrinsicResult, 326 rewriter)); 327 return success(); 328 } 329 }; 330 331 struct ConvertNVGPUToNVVMPass 332 : public ConvertNVGPUToNVVMBase<ConvertNVGPUToNVVMPass> { 333 ConvertNVGPUToNVVMPass() = default; 334 335 void runOnOperation() override { 336 RewritePatternSet patterns(&getContext()); 337 LLVMTypeConverter converter(&getContext()); 338 /// device-side async tokens cannot be materialized in nvvm. We just convert 339 /// them to a dummy i32 type in order to easily drop them during conversion. 340 converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type { 341 return converter.convertType(IntegerType::get(type.getContext(), 32)); 342 }); 343 populateNVGPUToNVVMConversionPatterns(converter, patterns); 344 LLVMConversionTarget target(getContext()); 345 target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); 346 target.addLegalDialect<::mlir::NVVM::NVVMDialect>(); 347 if (failed(applyPartialConversion(getOperation(), target, 348 std::move(patterns)))) 349 signalPassFailure(); 350 } 351 }; 352 353 struct NVGPUAsyncCopyLowering 354 : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> { 355 using ConvertOpToLLVMPattern< 356 nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern; 357 358 LogicalResult 359 matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor, 360 ConversionPatternRewriter &rewriter) const override { 361 Location loc = op->getLoc(); 362 auto dstMemrefType = op.dst().getType().cast<MemRefType>(); 363 Value dstPtr = getStridedElementPtr(loc, dstMemrefType, adaptor.dst(), 364 adaptor.dstIndices(), rewriter); 365 auto i8Ty = IntegerType::get(op.getContext(), 8); 366 auto dstPointerType = 367 LLVM::LLVMPointerType::get(i8Ty, dstMemrefType.getMemorySpaceAsInt()); 368 dstPtr = rewriter.create<LLVM::BitcastOp>(loc, dstPointerType, dstPtr); 369 370 auto srcMemrefType = op.src().getType().cast<MemRefType>(); 371 372 Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.src(), 373 adaptor.srcIndices(), rewriter); 374 auto srcPointerType = 375 LLVM::LLVMPointerType::get(i8Ty, srcMemrefType.getMemorySpaceAsInt()); 376 scrPtr = rewriter.create<LLVM::BitcastOp>(loc, srcPointerType, scrPtr); 377 // Intrinsics takes a global pointer so we need an address space cast. 378 auto srcPointerGlobalType = LLVM::LLVMPointerType::get( 379 i8Ty, NVVM::NVVMMemorySpace::kGlobalMemorySpace); 380 scrPtr = rewriter.create<LLVM::AddrSpaceCastOp>(loc, srcPointerGlobalType, 381 scrPtr); 382 int64_t numElements = adaptor.numElements().getZExtValue(); 383 int64_t sizeInBytes = 384 (dstMemrefType.getElementTypeBitWidth() * numElements) / 8; 385 // bypass L1 is only supported for byte sizes of 16, we drop the hint 386 // otherwise. 387 UnitAttr bypassL1 = sizeInBytes == 16 ? adaptor.bypassL1Attr() : UnitAttr(); 388 rewriter.create<NVVM::CpAsyncOp>( 389 loc, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes), bypassL1); 390 391 // Drop the result token. 392 Value zero = rewriter.create<LLVM::ConstantOp>( 393 op->getLoc(), IntegerType::get(op.getContext(), 32), 394 rewriter.getI32IntegerAttr(0)); 395 rewriter.replaceOp(op, zero); 396 return success(); 397 } 398 }; 399 400 struct NVGPUAsyncCreateGroupLowering 401 : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCreateGroupOp> { 402 using ConvertOpToLLVMPattern< 403 nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern; 404 405 LogicalResult 406 matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor, 407 ConversionPatternRewriter &rewriter) const override { 408 rewriter.create<NVVM::CpAsyncCommitGroupOp>(op.getLoc()); 409 // Drop the result token. 410 Value zero = rewriter.create<LLVM::ConstantOp>( 411 op->getLoc(), IntegerType::get(op.getContext(), 32), 412 rewriter.getI32IntegerAttr(0)); 413 rewriter.replaceOp(op, zero); 414 return success(); 415 } 416 }; 417 418 struct NVGPUAsyncWaitLowering 419 : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncWaitOp> { 420 using ConvertOpToLLVMPattern< 421 nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern; 422 423 LogicalResult 424 matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor, 425 ConversionPatternRewriter &rewriter) const override { 426 // If numGroup is not present pick 0 as a conservative correct value. 427 int32_t numGroups = adaptor.numGroups() ? *adaptor.numGroups() : 0; 428 rewriter.create<NVVM::CpAsyncWaitGroupOp>(op.getLoc(), numGroups); 429 rewriter.eraseOp(op); 430 return success(); 431 } 432 }; 433 434 } // namespace 435 436 void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, 437 RewritePatternSet &patterns) { 438 patterns.add<MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering, 439 NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering>( 440 converter); 441 } 442 443 std::unique_ptr<Pass> mlir::createConvertNVGPUToNVVMPass() { 444 return std::make_unique<ConvertNVGPUToNVVMPass>(); 445 } 446