1 //===- NVVMToLLVMIRTranslation.cpp - Translate NVVM to LLVM IR ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a translation between the MLIR NVVM dialect and 10 // LLVM IR. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" 15 #include "mlir/Dialect/LLVMIR/NVVMDialect.h" 16 #include "mlir/IR/Operation.h" 17 #include "mlir/Target/LLVMIR/ModuleTranslation.h" 18 19 #include "llvm/IR/IRBuilder.h" 20 #include "llvm/IR/IntrinsicsNVPTX.h" 21 22 using namespace mlir; 23 using namespace mlir::LLVM; 24 using mlir::LLVM::detail::createIntrinsicCall; 25 26 static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType, 27 NVVM::ShflKind kind, 28 bool withPredicate) { 29 30 if (withPredicate) { 31 resultType = cast<llvm::StructType>(resultType)->getElementType(0); 32 switch (kind) { 33 case NVVM::ShflKind::bfly: 34 return resultType->isFloatTy() 35 ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p 36 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p; 37 case NVVM::ShflKind::up: 38 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32p 39 : llvm::Intrinsic::nvvm_shfl_sync_up_i32p; 40 case NVVM::ShflKind::down: 41 return resultType->isFloatTy() 42 ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p 43 : llvm::Intrinsic::nvvm_shfl_sync_down_i32p; 44 case NVVM::ShflKind::idx: 45 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32p 46 : llvm::Intrinsic::nvvm_shfl_sync_idx_i32p; 47 } 48 } else { 49 switch (kind) { 50 case NVVM::ShflKind::bfly: 51 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32 52 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32; 53 case NVVM::ShflKind::up: 54 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32 55 : llvm::Intrinsic::nvvm_shfl_sync_up_i32; 56 case NVVM::ShflKind::down: 57 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32 58 : llvm::Intrinsic::nvvm_shfl_sync_down_i32; 59 case NVVM::ShflKind::idx: 60 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32 61 : llvm::Intrinsic::nvvm_shfl_sync_idx_i32; 62 } 63 } 64 llvm_unreachable("unknown shuffle kind"); 65 } 66 67 /// Return the intrinsic ID associated with ldmatrix for the given paramters. 68 static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout, 69 int32_t num) { 70 if (layout == NVVM::MMALayout::row) { 71 switch (num) { 72 case 1: 73 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16; 74 case 2: 75 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16; 76 case 4: 77 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16; 78 default: 79 llvm_unreachable("unsupported number of matrix"); 80 } 81 82 } else { 83 switch (num) { 84 case 1: 85 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16; 86 case 2: 87 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16; 88 case 4: 89 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16; 90 default: 91 llvm_unreachable("unsupported number of matrix"); 92 } 93 } 94 } 95 96 namespace { 97 /// Implementation of the dialect interface that converts operations belonging 98 /// to the NVVM dialect to LLVM IR. 99 class NVVMDialectLLVMIRTranslationInterface 100 : public LLVMTranslationDialectInterface { 101 public: 102 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; 103 104 /// Translates the given operation to LLVM IR using the provided IR builder 105 /// and saving the state in `moduleTranslation`. 106 LogicalResult 107 convertOperation(Operation *op, llvm::IRBuilderBase &builder, 108 LLVM::ModuleTranslation &moduleTranslation) const final { 109 Operation &opInst = *op; 110 #include "mlir/Dialect/LLVMIR/NVVMConversions.inc" 111 112 return failure(); 113 } 114 115 /// Attaches module-level metadata for functions marked as kernels. 116 LogicalResult 117 amendOperation(Operation *op, NamedAttribute attribute, 118 LLVM::ModuleTranslation &moduleTranslation) const final { 119 if (attribute.getName() == NVVM::NVVMDialect::getKernelFuncAttrName()) { 120 auto func = dyn_cast<LLVM::LLVMFuncOp>(op); 121 if (!func) 122 return failure(); 123 124 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext(); 125 llvm::Function *llvmFunc = 126 moduleTranslation.lookupFunction(func.getName()); 127 llvm::Metadata *llvmMetadata[] = { 128 llvm::ValueAsMetadata::get(llvmFunc), 129 llvm::MDString::get(llvmContext, "kernel"), 130 llvm::ValueAsMetadata::get( 131 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), 1))}; 132 llvm::MDNode *llvmMetadataNode = 133 llvm::MDNode::get(llvmContext, llvmMetadata); 134 moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations") 135 ->addOperand(llvmMetadataNode); 136 } 137 return success(); 138 } 139 }; 140 } // namespace 141 142 void mlir::registerNVVMDialectTranslation(DialectRegistry ®istry) { 143 registry.insert<NVVM::NVVMDialect>(); 144 registry.addExtension(+[](MLIRContext *ctx, NVVM::NVVMDialect *dialect) { 145 dialect->addInterfaces<NVVMDialectLLVMIRTranslationInterface>(); 146 }); 147 } 148 149 void mlir::registerNVVMDialectTranslation(MLIRContext &context) { 150 DialectRegistry registry; 151 registerNVVMDialectTranslation(registry); 152 context.appendDialectRegistry(registry); 153 } 154