1 //===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h" 10 #include "../PassDetail.h" 11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 12 #include "mlir/Conversion/LLVMCommon/Pattern.h" 13 #include "mlir/Dialect/GPU/GPUDialect.h" 14 #include "mlir/Dialect/LLVMIR/NVVMDialect.h" 15 #include "mlir/Dialect/NVGPU/NVGPUDialect.h" 16 17 using namespace mlir; 18 19 /// Returns the type for the intrinsic given the vectorResultType of the 20 /// `gpu.mma.sync` operation. 21 static Type inferIntrinsicResultType(Type vectorResultType) { 22 MLIRContext *ctx = vectorResultType.getContext(); 23 auto a = vectorResultType.cast<LLVM::LLVMArrayType>(); 24 auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2); 25 auto i32Ty = IntegerType::get(ctx, 32); 26 auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2); 27 Type f64Ty = Float64Type::get(ctx); 28 Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2); 29 Type f32Ty = Float32Type::get(ctx); 30 Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2); 31 if (a.getElementType() == f16x2Ty) { 32 return LLVM::LLVMStructType::getLiteral( 33 ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty)); 34 } 35 if (a.getElementType() == i32x2Ty) { 36 return LLVM::LLVMStructType::getLiteral( 37 ctx, 38 SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty)); 39 } 40 if (a.getElementType() == f64x2Ty) { 41 return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty}); 42 } 43 if (a.getElementType() == f32x2Ty) { 44 return LLVM::LLVMStructType::getLiteral( 45 ctx, 46 SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty)); 47 } 48 if (a.getElementType() == LLVM::getFixedVectorType(f32Ty, 1)) { 49 return LLVM::LLVMStructType::getLiteral( 50 ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty)); 51 } 52 return vectorResultType; 53 } 54 55 /// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is 56 /// always an LLVM struct) into a fragment that is compatible with the vector 57 /// type of this operation. This involves extracting elements from the struct 58 /// and inserting them into an LLVM array. These extra data-movement 59 /// operations should be canonicalized away by the LLVM backend. 60 static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, 61 Type resultType, Value intrinsicResult, 62 RewriterBase &rewriter) { 63 MLIRContext *ctx = rewriter.getContext(); 64 auto structType = intrinsicResultType.dyn_cast<LLVM::LLVMStructType>(); 65 auto arrayType = resultType.dyn_cast<LLVM::LLVMArrayType>(); 66 Type i32Ty = rewriter.getI32Type(); 67 Type f32Ty = rewriter.getF32Type(); 68 Type f64Ty = rewriter.getF64Type(); 69 Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2); 70 Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2); 71 Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2); 72 Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2); 73 Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1); 74 75 auto makeConst = [&](int32_t index) -> Value { 76 return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32), 77 rewriter.getI32IntegerAttr(index)); 78 }; 79 80 if (arrayType) { 81 SmallVector<Value, 4> elements; 82 83 // The intrinsic returns 32-bit wide elements in a form which can be 84 // directly bitcasted and inserted into the result vector. 85 if (arrayType.getElementType() == f16x2Ty || 86 arrayType.getElementType() == f32x1Ty) { 87 for (unsigned i = 0; i < structType.getBody().size(); i++) { 88 Value el = rewriter.create<LLVM::ExtractValueOp>( 89 loc, structType.getBody()[i], intrinsicResult, 90 rewriter.getI64ArrayAttr(i)); 91 el = rewriter.createOrFold<LLVM::BitcastOp>( 92 loc, arrayType.getElementType(), el); 93 elements.push_back(el); 94 } 95 } 96 97 // The intrinsic returns i32, f64, and f32 values as individual scalars, 98 // even when the result is notionally a 64-bit wide element (e.g. f32x2). We 99 // need to extract them from the struct and pack them into the 64-bit wide 100 // rows of the vector result. 101 if (arrayType.getElementType() == i32x2Ty || 102 arrayType.getElementType() == f64x2Ty || 103 arrayType.getElementType() == f32x2Ty) { 104 105 for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) { 106 Value vec = 107 rewriter.create<LLVM::UndefOp>(loc, arrayType.getElementType()); 108 Value x1 = rewriter.create<LLVM::ExtractValueOp>( 109 loc, structType.getBody()[i * 2], intrinsicResult, 110 rewriter.getI64ArrayAttr(i * 2)); 111 Value x2 = rewriter.create<LLVM::ExtractValueOp>( 112 loc, structType.getBody()[i * 2 + 1], intrinsicResult, 113 rewriter.getI64ArrayAttr(i * 2 + 1)); 114 vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec, 115 x1, makeConst(0)); 116 vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec, 117 x2, makeConst(1)); 118 elements.push_back(vec); 119 } 120 } 121 122 // Create the final vectorized result. 123 Value result = rewriter.create<LLVM::UndefOp>(loc, arrayType); 124 for (const auto &el : llvm::enumerate(elements)) { 125 result = rewriter.create<LLVM::InsertValueOp>( 126 loc, arrayType, result, el.value(), 127 rewriter.getI64ArrayAttr(el.index())); 128 } 129 return result; 130 } 131 132 return intrinsicResult; 133 } 134 135 /// The `gpu.mma.sync` converter below expects matrix fragment operands to be 136 /// given as 2D `vectors` where the rows are 32b or 64b wide. The 137 /// `nvvm.mma.sync` op expects these argments to be a given in a long list of 138 /// scalars of certain types. This function helps unpack the `vector` arguments 139 /// and cast them to the types expected by `nvvm.mma.sync`. 140 static SmallVector<Value> unpackOperandVector(RewriterBase &rewriter, 141 Location loc, Value operand, 142 NVVM::MMATypes operandPtxType) { 143 SmallVector<Value> result; 144 Type i32Ty = rewriter.getI32Type(); 145 Type f64Ty = rewriter.getF64Type(); 146 Type f32Ty = rewriter.getF32Type(); 147 Type i8Ty = rewriter.getI8Type(); 148 Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4); 149 Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1); 150 auto arrayTy = operand.getType().cast<LLVM::LLVMArrayType>(); 151 152 for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) { 153 Value toUse = rewriter.create<LLVM::ExtractValueOp>( 154 loc, arrayTy.getElementType(), operand, rewriter.getI64ArrayAttr(i)); 155 156 // For 4xi8 vectors, the intrinsic expects these to be provided as i32 157 // scalar types. 158 if (arrayTy.getElementType() == i8x4Ty || 159 (arrayTy.getElementType() == f32x1Ty && 160 operandPtxType == NVVM::MMATypes::tf32)) { 161 result.push_back( 162 rewriter.create<LLVM::BitcastOp>(loc, rewriter.getI32Type(), toUse)); 163 continue; 164 } 165 166 // For some element types (i32, f32, f64), we need to unpack the inner 167 // vector/array type as well because the intrinsic expects individual 168 // scalars to be provided. 169 VectorType innerArrayTy = arrayTy.getElementType().dyn_cast<VectorType>(); 170 if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty || 171 innerArrayTy.getElementType() == f64Ty || 172 innerArrayTy.getElementType() == f32Ty)) { 173 for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements(); 174 idx < innerSize; idx++) { 175 result.push_back(rewriter.create<LLVM::ExtractElementOp>( 176 loc, toUse, 177 rewriter.create<LLVM::ConstantOp>( 178 loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(idx)))); 179 } 180 continue; 181 } 182 result.push_back(toUse); 183 } 184 return result; 185 } 186 187 namespace { 188 189 struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> { 190 using ConvertOpToLLVMPattern<nvgpu::LdMatrixOp>::ConvertOpToLLVMPattern; 191 192 LogicalResult 193 matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor, 194 ConversionPatternRewriter &rewriter) const override { 195 MLIRContext *ctx = getContext(); 196 Location loc = op->getLoc(); 197 198 // The result type of ldmatrix will always be a struct of 32bit integer 199 // registers if more than one 32bit value is returned. Otherwise, the result 200 // is a single i32. The result type of the GPU operation is always a vector 201 // of shape (NumRegisters, VectorRegister) where VectorRegister is the 202 // vector type of the result and always 32 bits long. We bitcast the result 203 // of the NVVM::LdMatrix to this vector type. 204 auto vectorResultType = op->getResultTypes()[0].dyn_cast<VectorType>(); 205 if (!vectorResultType) { 206 return failure(); 207 } 208 Type innerVectorType = LLVM::getFixedVectorType( 209 vectorResultType.getElementType(), vectorResultType.getDimSize(1)); 210 211 int64_t num32BitRegs = vectorResultType.getDimSize(0); 212 213 Type ldMatrixResultType; 214 if (num32BitRegs > 1) { 215 ldMatrixResultType = LLVM::LLVMStructType::getLiteral( 216 ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type())); 217 } else { 218 ldMatrixResultType = rewriter.getI32Type(); 219 } 220 221 auto srcMemrefType = op.srcMemref().getType().cast<MemRefType>(); 222 Value srcPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.srcMemref(), 223 adaptor.indices(), rewriter); 224 Value ldMatrixResult = rewriter.create<NVVM::LdMatrixOp>( 225 loc, ldMatrixResultType, srcPtr, 226 /*num=*/op.numTiles(), 227 /*layout=*/op.transpose() ? NVVM::MMALayout::col 228 : NVVM::MMALayout::row); 229 230 // The ldmatrix operation returns either a single i32 value or a struct of 231 // i32 values. Here we unpack those values and cast them back to their 232 // actual vector type (still of width 32b) and repack them into a result 233 // struct. 234 Type finalResultType = typeConverter->convertType(vectorResultType); 235 Value result = rewriter.create<LLVM::UndefOp>(loc, finalResultType); 236 for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) { 237 Value i32Register = num32BitRegs > 1 238 ? rewriter.create<LLVM::ExtractValueOp>( 239 loc, rewriter.getI32Type(), ldMatrixResult, 240 rewriter.getI64ArrayAttr(i)) 241 : ldMatrixResult; 242 Value casted = 243 rewriter.create<LLVM::BitcastOp>(loc, innerVectorType, i32Register); 244 result = rewriter.create<LLVM::InsertValueOp>( 245 loc, finalResultType, result, casted, rewriter.getI64ArrayAttr(i)); 246 } 247 248 rewriter.replaceOp(op, result); 249 return success(); 250 } 251 }; 252 253 struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> { 254 using ConvertOpToLLVMPattern<nvgpu::MmaSyncOp>::ConvertOpToLLVMPattern; 255 256 LogicalResult 257 matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor, 258 ConversionPatternRewriter &rewriter) const override { 259 Location loc = op->getLoc(); 260 // Get the shapes of the MMAMatrix type being used. The shapes will 261 // choose which intrinsic this op will be lowered to. 262 auto aType = op.matrixA().getType().cast<VectorType>(); 263 auto cType = op.matrixC().getType().cast<VectorType>(); 264 265 int64_t m = op.mmaShape()[0].cast<IntegerAttr>().getInt(); 266 int64_t n = op.mmaShape()[1].cast<IntegerAttr>().getInt(); 267 int64_t k = op.mmaShape()[2].cast<IntegerAttr>().getInt(); 268 std::array<int64_t, 3> gemmShape{m, n, k}; 269 270 NVVM::MMATypes ptxTypeA; 271 NVVM::MMATypes ptxTypeB; 272 Optional<NVVM::MMATypes> ptxTypeC = NVVM::MmaOp::inferOperandMMAType( 273 cType.getElementType(), /*isAccumulator=*/true); 274 if (!ptxTypeC) { 275 return op->emitError( 276 "could not infer the PTX type for the accumulator/result"); 277 } 278 279 Optional<NVVM::MMAIntOverflow> overflow(llvm::None); 280 if (aType.getElementType().isInteger(8)) { 281 ptxTypeA = NVVM::MMATypes::s8; 282 ptxTypeB = NVVM::MMATypes::s8; 283 overflow = NVVM::MMAIntOverflow::satfinite; 284 } else if (aType.getElementType().isF16()) { 285 ptxTypeA = NVVM::MMATypes::f16; 286 ptxTypeB = NVVM::MMATypes::f16; 287 } else if (aType.getElementType().isF64()) { 288 ptxTypeA = NVVM::MMATypes::f64; 289 ptxTypeB = NVVM::MMATypes::f64; 290 } else if (aType.getElementType().isF32()) { 291 ptxTypeA = NVVM::MMATypes::tf32; 292 ptxTypeB = NVVM::MMATypes::tf32; 293 } else { 294 return op->emitError("could not deduce operand PTX types"); 295 } 296 297 SmallVector<Value> matA = 298 unpackOperandVector(rewriter, loc, adaptor.matrixA(), ptxTypeA); 299 SmallVector<Value> matB = 300 unpackOperandVector(rewriter, loc, adaptor.matrixB(), ptxTypeB); 301 SmallVector<Value> matC = 302 unpackOperandVector(rewriter, loc, adaptor.matrixC(), *ptxTypeC); 303 304 Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]); 305 Type intrinsicResTy = inferIntrinsicResultType( 306 typeConverter->convertType(op->getResultTypes()[0])); 307 Value intrinsicResult = rewriter.create<NVVM::MmaOp>( 308 op.getLoc(), intrinsicResTy, matA, matB, matC, 309 /*shape=*/gemmShape, 310 /*b1Op=*/llvm::None, 311 /*intOverflow=*/overflow, 312 /*multiplicandPtxTypes=*/ 313 std::array<NVVM::MMATypes, 2>{ptxTypeA, ptxTypeB}, 314 /*multiplicandLayouts=*/ 315 std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row, 316 NVVM::MMALayout::col}); 317 rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, 318 desiredRetTy, intrinsicResult, 319 rewriter)); 320 return success(); 321 } 322 }; 323 324 struct ConvertNVGPUToNVVMPass 325 : public ConvertNVGPUToNVVMBase<ConvertNVGPUToNVVMPass> { 326 ConvertNVGPUToNVVMPass() = default; 327 328 void runOnOperation() override { 329 RewritePatternSet patterns(&getContext()); 330 LLVMTypeConverter converter(&getContext()); 331 /// device-side async tokens cannot be materialized in nvvm. We just convert 332 /// them to a dummy i32 type in order to easily drop them during conversion. 333 converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type { 334 return converter.convertType(IntegerType::get(type.getContext(), 32)); 335 }); 336 populateNVGPUToNVVMConversionPatterns(converter, patterns); 337 LLVMConversionTarget target(getContext()); 338 target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); 339 target.addLegalDialect<::mlir::NVVM::NVVMDialect>(); 340 if (failed(applyPartialConversion(getOperation(), target, 341 std::move(patterns)))) 342 signalPassFailure(); 343 } 344 }; 345 346 struct NVGPUAsyncCopyLowering 347 : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> { 348 using ConvertOpToLLVMPattern< 349 nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern; 350 351 LogicalResult 352 matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor, 353 ConversionPatternRewriter &rewriter) const override { 354 Location loc = op->getLoc(); 355 auto dstMemrefType = op.dst().getType().cast<MemRefType>(); 356 Value dstPtr = getStridedElementPtr(loc, dstMemrefType, adaptor.dst(), 357 adaptor.dstIndices(), rewriter); 358 auto i8Ty = IntegerType::get(op.getContext(), 8); 359 auto dstPointerType = 360 LLVM::LLVMPointerType::get(i8Ty, dstMemrefType.getMemorySpaceAsInt()); 361 dstPtr = rewriter.create<LLVM::BitcastOp>(loc, dstPointerType, dstPtr); 362 363 auto srcMemrefType = op.src().getType().cast<MemRefType>(); 364 365 Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.src(), 366 adaptor.srcIndices(), rewriter); 367 auto srcPointerType = 368 LLVM::LLVMPointerType::get(i8Ty, srcMemrefType.getMemorySpaceAsInt()); 369 scrPtr = rewriter.create<LLVM::BitcastOp>(loc, srcPointerType, scrPtr); 370 // Intrinsics takes a global pointer so we need an address space cast. 371 auto srcPointerGlobalType = LLVM::LLVMPointerType::get( 372 i8Ty, NVVM::NVVMMemorySpace::kGlobalMemorySpace); 373 scrPtr = rewriter.create<LLVM::AddrSpaceCastOp>(loc, srcPointerGlobalType, 374 scrPtr); 375 int64_t numElements = adaptor.numElements().getZExtValue(); 376 int64_t sizeInBytes = 377 (dstMemrefType.getElementTypeBitWidth() / 8) * numElements; 378 // bypass L1 is only supported for byte sizes of 16, we drop the hint 379 // otherwise. 380 UnitAttr bypassL1 = sizeInBytes == 16 ? adaptor.bypassL1Attr() : UnitAttr(); 381 rewriter.create<NVVM::CpAsyncOp>( 382 loc, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes), bypassL1); 383 384 // Drop the result token. 385 Value zero = rewriter.create<LLVM::ConstantOp>( 386 op->getLoc(), IntegerType::get(op.getContext(), 32), 387 rewriter.getI32IntegerAttr(0)); 388 rewriter.replaceOp(op, zero); 389 return success(); 390 } 391 }; 392 393 struct NVGPUAsyncCreateGroupLowering 394 : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCreateGroupOp> { 395 using ConvertOpToLLVMPattern< 396 nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern; 397 398 LogicalResult 399 matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor, 400 ConversionPatternRewriter &rewriter) const override { 401 rewriter.create<NVVM::CpAsyncCommitGroupOp>(op.getLoc()); 402 // Drop the result token. 403 Value zero = rewriter.create<LLVM::ConstantOp>( 404 op->getLoc(), IntegerType::get(op.getContext(), 32), 405 rewriter.getI32IntegerAttr(0)); 406 rewriter.replaceOp(op, zero); 407 return success(); 408 } 409 }; 410 411 struct NVGPUAsyncWaitLowering 412 : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncWaitOp> { 413 using ConvertOpToLLVMPattern< 414 nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern; 415 416 LogicalResult 417 matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor, 418 ConversionPatternRewriter &rewriter) const override { 419 // If numGroup is not present pick 0 as a conservative correct value. 420 int32_t numGroups = adaptor.numGroups() ? *adaptor.numGroups() : 0; 421 rewriter.create<NVVM::CpAsyncWaitGroupOp>(op.getLoc(), numGroups); 422 rewriter.eraseOp(op); 423 return success(); 424 } 425 }; 426 427 } // namespace 428 429 void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, 430 RewritePatternSet &patterns) { 431 patterns.add<MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering, 432 NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering>( 433 converter); 434 } 435 436 std::unique_ptr<Pass> mlir::createConvertNVGPUToNVVMPass() { 437 return std::make_unique<ConvertNVGPUToNVVMPass>(); 438 } 439