1 //===- NVVMToLLVMIRTranslation.cpp - Translate NVVM to LLVM IR ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between the MLIR NVVM dialect and
10 // LLVM IR.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
15 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
16 #include "mlir/IR/Operation.h"
17 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
18 
19 #include "llvm/IR/IRBuilder.h"
20 #include "llvm/IR/IntrinsicsNVPTX.h"
21 
22 using namespace mlir;
23 using namespace mlir::LLVM;
24 using mlir::LLVM::detail::createIntrinsicCall;
25 
26 static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType,
27                                               NVVM::ShflKind kind,
28                                               bool withPredicate) {
29 
30   if (withPredicate) {
31     resultType = cast<llvm::StructType>(resultType)->getElementType(0);
32     switch (kind) {
33     case NVVM::ShflKind::bfly:
34       return resultType->isFloatTy()
35                  ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
36                  : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p;
37     case NVVM::ShflKind::up:
38       return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32p
39                                      : llvm::Intrinsic::nvvm_shfl_sync_up_i32p;
40     case NVVM::ShflKind::down:
41       return resultType->isFloatTy()
42                  ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p
43                  : llvm::Intrinsic::nvvm_shfl_sync_down_i32p;
44     case NVVM::ShflKind::idx:
45       return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32p
46                                      : llvm::Intrinsic::nvvm_shfl_sync_idx_i32p;
47     }
48   } else {
49     switch (kind) {
50     case NVVM::ShflKind::bfly:
51       return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
52                                      : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
53     case NVVM::ShflKind::up:
54       return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32
55                                      : llvm::Intrinsic::nvvm_shfl_sync_up_i32;
56     case NVVM::ShflKind::down:
57       return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32
58                                      : llvm::Intrinsic::nvvm_shfl_sync_down_i32;
59     case NVVM::ShflKind::idx:
60       return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32
61                                      : llvm::Intrinsic::nvvm_shfl_sync_idx_i32;
62     }
63   }
64   llvm_unreachable("unknown shuffle kind");
65 }
66 
67 namespace {
68 /// Implementation of the dialect interface that converts operations belonging
69 /// to the NVVM dialect to LLVM IR.
70 class NVVMDialectLLVMIRTranslationInterface
71     : public LLVMTranslationDialectInterface {
72 public:
73   using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
74 
75   /// Translates the given operation to LLVM IR using the provided IR builder
76   /// and saving the state in `moduleTranslation`.
77   LogicalResult
78   convertOperation(Operation *op, llvm::IRBuilderBase &builder,
79                    LLVM::ModuleTranslation &moduleTranslation) const final {
80     Operation &opInst = *op;
81 #include "mlir/Dialect/LLVMIR/NVVMConversions.inc"
82 
83     return failure();
84   }
85 
86   /// Attaches module-level metadata for functions marked as kernels.
87   LogicalResult
88   amendOperation(Operation *op, NamedAttribute attribute,
89                  LLVM::ModuleTranslation &moduleTranslation) const final {
90     if (attribute.getName() == NVVM::NVVMDialect::getKernelFuncAttrName()) {
91       auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
92       if (!func)
93         return failure();
94 
95       llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
96       llvm::Function *llvmFunc =
97           moduleTranslation.lookupFunction(func.getName());
98       llvm::Metadata *llvmMetadata[] = {
99           llvm::ValueAsMetadata::get(llvmFunc),
100           llvm::MDString::get(llvmContext, "kernel"),
101           llvm::ValueAsMetadata::get(
102               llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), 1))};
103       llvm::MDNode *llvmMetadataNode =
104           llvm::MDNode::get(llvmContext, llvmMetadata);
105       moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations")
106           ->addOperand(llvmMetadataNode);
107     }
108     return success();
109   }
110 };
111 } // end namespace
112 
113 void mlir::registerNVVMDialectTranslation(DialectRegistry &registry) {
114   registry.insert<NVVM::NVVMDialect>();
115   registry.addDialectInterface<NVVM::NVVMDialect,
116                                NVVMDialectLLVMIRTranslationInterface>();
117 }
118 
119 void mlir::registerNVVMDialectTranslation(MLIRContext &context) {
120   DialectRegistry registry;
121   registerNVVMDialectTranslation(registry);
122   context.appendDialectRegistry(registry);
123 }
124