1 //===- ROCDLToLLVMIRTranslation.cpp - Translate ROCDL to LLVM IR ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between the MLIR ROCDL dialect and
10 // LLVM IR.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
15 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
16 #include "mlir/IR/Operation.h"
17 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
18 
19 #include "llvm/IR/IRBuilder.h"
20 #include "llvm/IR/IntrinsicsAMDGPU.h"
21 #include "llvm/Support/raw_ostream.h"
22 
23 using namespace mlir;
24 using namespace mlir::LLVM;
25 using mlir::LLVM::detail::createIntrinsicCall;
26 
27 // Create a call to ROCm-Device-Library function
28 // Currently this routine will work only for calling ROCDL functions that
29 // take a single int32 argument. It is likely that the interface of this
30 // function will change to make it more generic.
createDeviceFunctionCall(llvm::IRBuilderBase & builder,StringRef fnName,int parameter)31 static llvm::Value *createDeviceFunctionCall(llvm::IRBuilderBase &builder,
32                                              StringRef fnName, int parameter) {
33   llvm::Module *module = builder.GetInsertBlock()->getModule();
34   llvm::FunctionType *functionType = llvm::FunctionType::get(
35       llvm::Type::getInt64Ty(module->getContext()), // return type.
36       llvm::Type::getInt32Ty(module->getContext()), // parameter type.
37       false);                                       // no variadic arguments.
38   llvm::Function *fn = dyn_cast<llvm::Function>(
39       module->getOrInsertFunction(fnName, functionType).getCallee());
40   llvm::Value *fnOp0 = llvm::ConstantInt::get(
41       llvm::Type::getInt32Ty(module->getContext()), parameter);
42   return builder.CreateCall(fn, ArrayRef<llvm::Value *>(fnOp0));
43 }
44 
45 namespace {
46 /// Implementation of the dialect interface that converts operations belonging
47 /// to the ROCDL dialect to LLVM IR.
48 class ROCDLDialectLLVMIRTranslationInterface
49     : public LLVMTranslationDialectInterface {
50 public:
51   using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
52 
53   /// Translates the given operation to LLVM IR using the provided IR builder
54   /// and saving the state in `moduleTranslation`.
55   LogicalResult
convertOperation(Operation * op,llvm::IRBuilderBase & builder,LLVM::ModuleTranslation & moduleTranslation) const56   convertOperation(Operation *op, llvm::IRBuilderBase &builder,
57                    LLVM::ModuleTranslation &moduleTranslation) const final {
58     Operation &opInst = *op;
59 #include "mlir/Dialect/LLVMIR/ROCDLConversions.inc"
60 
61     return failure();
62   }
63 
64   /// Attaches module-level metadata for functions marked as kernels.
65   LogicalResult
amendOperation(Operation * op,NamedAttribute attribute,LLVM::ModuleTranslation & moduleTranslation) const66   amendOperation(Operation *op, NamedAttribute attribute,
67                  LLVM::ModuleTranslation &moduleTranslation) const final {
68     if (attribute.getName() == ROCDL::ROCDLDialect::getKernelFuncAttrName()) {
69       auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
70       if (!func)
71         return failure();
72 
73       // For GPU kernels,
74       // 1. Insert AMDGPU_KERNEL calling convention.
75       // 2. Insert amdgpu-flat-work-group-size(1, 256) attribute unless the user
76       // has overriden this value - 256 is the default in clang
77       // 3. Insert amdgpu-implicitarg-num-bytes=56 (which must be set on OpenCL
78       // and HIP kernels per Clang)
79       llvm::Function *llvmFunc =
80           moduleTranslation.lookupFunction(func.getName());
81       llvmFunc->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
82       if (!llvmFunc->hasFnAttribute("amdgpu-flat-work-group-size")) {
83         llvmFunc->addFnAttr("amdgpu-flat-work-group-size", "1, 256");
84       }
85       llvmFunc->addFnAttr("amdgpu-implicitarg-num-bytes", "56");
86     }
87     // Override flat-work-group-size
88     if ("rocdl.max_flat_work_group_size" == attribute.getName()) {
89       auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
90       if (!func)
91         return failure();
92       auto value = attribute.getValue().dyn_cast<IntegerAttr>();
93       if (!value)
94         return failure();
95 
96       llvm::Function *llvmFunc =
97           moduleTranslation.lookupFunction(func.getName());
98       llvm::SmallString<8> llvmAttrValue;
99       llvm::raw_svector_ostream attrValueStream(llvmAttrValue);
100       attrValueStream << "1, " << value.getInt();
101       llvmFunc->addFnAttr("amdgpu-flat-work-group-size", llvmAttrValue);
102     }
103     return success();
104   }
105 };
106 } // namespace
107 
registerROCDLDialectTranslation(DialectRegistry & registry)108 void mlir::registerROCDLDialectTranslation(DialectRegistry &registry) {
109   registry.insert<ROCDL::ROCDLDialect>();
110   registry.addExtension(+[](MLIRContext *ctx, ROCDL::ROCDLDialect *dialect) {
111     dialect->addInterfaces<ROCDLDialectLLVMIRTranslationInterface>();
112   });
113 }
114 
registerROCDLDialectTranslation(MLIRContext & context)115 void mlir::registerROCDLDialectTranslation(MLIRContext &context) {
116   DialectRegistry registry;
117   registerROCDLDialectTranslation(registry);
118   context.appendDialectRegistry(registry);
119 }
120