1 //===- NVVMToLLVMIRTranslation.cpp - Translate NVVM to LLVM IR ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between the MLIR NVVM dialect and
10 // LLVM IR.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
15 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
16 #include "mlir/IR/Operation.h"
17 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
18 
19 #include "llvm/IR/IRBuilder.h"
20 #include "llvm/IR/IntrinsicsNVPTX.h"
21 
22 using namespace mlir;
23 using namespace mlir::LLVM;
24 using mlir::LLVM::detail::createIntrinsicCall;
25 
getShflIntrinsicId(llvm::Type * resultType,NVVM::ShflKind kind,bool withPredicate)26 static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType,
27                                               NVVM::ShflKind kind,
28                                               bool withPredicate) {
29 
30   if (withPredicate) {
31     resultType = cast<llvm::StructType>(resultType)->getElementType(0);
32     switch (kind) {
33     case NVVM::ShflKind::bfly:
34       return resultType->isFloatTy()
35                  ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
36                  : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p;
37     case NVVM::ShflKind::up:
38       return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32p
39                                      : llvm::Intrinsic::nvvm_shfl_sync_up_i32p;
40     case NVVM::ShflKind::down:
41       return resultType->isFloatTy()
42                  ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p
43                  : llvm::Intrinsic::nvvm_shfl_sync_down_i32p;
44     case NVVM::ShflKind::idx:
45       return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32p
46                                      : llvm::Intrinsic::nvvm_shfl_sync_idx_i32p;
47     }
48   } else {
49     switch (kind) {
50     case NVVM::ShflKind::bfly:
51       return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
52                                      : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
53     case NVVM::ShflKind::up:
54       return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32
55                                      : llvm::Intrinsic::nvvm_shfl_sync_up_i32;
56     case NVVM::ShflKind::down:
57       return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32
58                                      : llvm::Intrinsic::nvvm_shfl_sync_down_i32;
59     case NVVM::ShflKind::idx:
60       return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32
61                                      : llvm::Intrinsic::nvvm_shfl_sync_idx_i32;
62     }
63   }
64   llvm_unreachable("unknown shuffle kind");
65 }
66 
67 /// Return the intrinsic ID associated with ldmatrix for the given paramters.
getLdMatrixIntrinsicId(NVVM::MMALayout layout,int32_t num)68 static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
69                                                   int32_t num) {
70   if (layout == NVVM::MMALayout::row) {
71     switch (num) {
72     case 1:
73       return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16;
74     case 2:
75       return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16;
76     case 4:
77       return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16;
78     default:
79       llvm_unreachable("unsupported number of matrix");
80     }
81 
82   } else {
83     switch (num) {
84     case 1:
85       return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
86     case 2:
87       return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
88     case 4:
89       return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
90     default:
91       llvm_unreachable("unsupported number of matrix");
92     }
93   }
94 }
95 
96 namespace {
97 /// Implementation of the dialect interface that converts operations belonging
98 /// to the NVVM dialect to LLVM IR.
99 class NVVMDialectLLVMIRTranslationInterface
100     : public LLVMTranslationDialectInterface {
101 public:
102   using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
103 
104   /// Translates the given operation to LLVM IR using the provided IR builder
105   /// and saving the state in `moduleTranslation`.
106   LogicalResult
convertOperation(Operation * op,llvm::IRBuilderBase & builder,LLVM::ModuleTranslation & moduleTranslation) const107   convertOperation(Operation *op, llvm::IRBuilderBase &builder,
108                    LLVM::ModuleTranslation &moduleTranslation) const final {
109     Operation &opInst = *op;
110 #include "mlir/Dialect/LLVMIR/NVVMConversions.inc"
111 
112     return failure();
113   }
114 
115   /// Attaches module-level metadata for functions marked as kernels.
116   LogicalResult
amendOperation(Operation * op,NamedAttribute attribute,LLVM::ModuleTranslation & moduleTranslation) const117   amendOperation(Operation *op, NamedAttribute attribute,
118                  LLVM::ModuleTranslation &moduleTranslation) const final {
119     if (attribute.getName() == NVVM::NVVMDialect::getKernelFuncAttrName()) {
120       auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
121       if (!func)
122         return failure();
123 
124       llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
125       llvm::Function *llvmFunc =
126           moduleTranslation.lookupFunction(func.getName());
127       llvm::Metadata *llvmMetadata[] = {
128           llvm::ValueAsMetadata::get(llvmFunc),
129           llvm::MDString::get(llvmContext, "kernel"),
130           llvm::ValueAsMetadata::get(
131               llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), 1))};
132       llvm::MDNode *llvmMetadataNode =
133           llvm::MDNode::get(llvmContext, llvmMetadata);
134       moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations")
135           ->addOperand(llvmMetadataNode);
136     }
137     return success();
138   }
139 };
140 } // namespace
141 
registerNVVMDialectTranslation(DialectRegistry & registry)142 void mlir::registerNVVMDialectTranslation(DialectRegistry &registry) {
143   registry.insert<NVVM::NVVMDialect>();
144   registry.addExtension(+[](MLIRContext *ctx, NVVM::NVVMDialect *dialect) {
145     dialect->addInterfaces<NVVMDialectLLVMIRTranslationInterface>();
146   });
147 }
148 
registerNVVMDialectTranslation(MLIRContext & context)149 void mlir::registerNVVMDialectTranslation(MLIRContext &context) {
150   DialectRegistry registry;
151   registerNVVMDialectTranslation(registry);
152   context.appendDialectRegistry(registry);
153 }
154