1 //===- ROCDLToLLVMIRTranslation.cpp - Translate ROCDL to LLVM IR ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a translation between the MLIR ROCDL dialect and 10 // LLVM IR. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" 15 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" 16 #include "mlir/IR/Operation.h" 17 #include "mlir/Target/LLVMIR/ModuleTranslation.h" 18 19 #include "llvm/IR/IRBuilder.h" 20 #include "llvm/IR/IntrinsicsAMDGPU.h" 21 #include "llvm/Support/raw_ostream.h" 22 23 using namespace mlir; 24 using namespace mlir::LLVM; 25 using mlir::LLVM::detail::createIntrinsicCall; 26 27 // Create a call to ROCm-Device-Library function 28 // Currently this routine will work only for calling ROCDL functions that 29 // take a single int32 argument. It is likely that the interface of this 30 // function will change to make it more generic. 31 static llvm::Value *createDeviceFunctionCall(llvm::IRBuilderBase &builder, 32 StringRef fnName, int parameter) { 33 llvm::Module *module = builder.GetInsertBlock()->getModule(); 34 llvm::FunctionType *functionType = llvm::FunctionType::get( 35 llvm::Type::getInt64Ty(module->getContext()), // return type. 36 llvm::Type::getInt32Ty(module->getContext()), // parameter type. 37 false); // no variadic arguments. 38 llvm::Function *fn = dyn_cast<llvm::Function>( 39 module->getOrInsertFunction(fnName, functionType).getCallee()); 40 llvm::Value *fnOp0 = llvm::ConstantInt::get( 41 llvm::Type::getInt32Ty(module->getContext()), parameter); 42 return builder.CreateCall(fn, ArrayRef<llvm::Value *>(fnOp0)); 43 } 44 45 namespace { 46 /// Implementation of the dialect interface that converts operations belonging 47 /// to the ROCDL dialect to LLVM IR. 48 class ROCDLDialectLLVMIRTranslationInterface 49 : public LLVMTranslationDialectInterface { 50 public: 51 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; 52 53 /// Translates the given operation to LLVM IR using the provided IR builder 54 /// and saving the state in `moduleTranslation`. 55 LogicalResult 56 convertOperation(Operation *op, llvm::IRBuilderBase &builder, 57 LLVM::ModuleTranslation &moduleTranslation) const final { 58 Operation &opInst = *op; 59 #include "mlir/Dialect/LLVMIR/ROCDLConversions.inc" 60 61 return failure(); 62 } 63 64 /// Attaches module-level metadata for functions marked as kernels. 65 LogicalResult 66 amendOperation(Operation *op, NamedAttribute attribute, 67 LLVM::ModuleTranslation &moduleTranslation) const final { 68 if (attribute.getName() == ROCDL::ROCDLDialect::getKernelFuncAttrName()) { 69 auto func = dyn_cast<LLVM::LLVMFuncOp>(op); 70 if (!func) 71 return failure(); 72 73 // For GPU kernels, 74 // 1. Insert AMDGPU_KERNEL calling convention. 75 // 2. Insert amdgpu-flat-work-group-size(1, 256) attribute unless the user 76 // has overriden this value - 256 is the default in clang 77 // 3. Insert amdgpu-implicitarg-num-bytes=56 (which must be set on OpenCL 78 // and HIP kernels per Clang) 79 llvm::Function *llvmFunc = 80 moduleTranslation.lookupFunction(func.getName()); 81 llvmFunc->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL); 82 if (!llvmFunc->hasFnAttribute("amdgpu-flat-work-group-size")) { 83 llvmFunc->addFnAttr("amdgpu-flat-work-group-size", "1, 256"); 84 } 85 llvmFunc->addFnAttr("amdgpu-implicitarg-num-bytes", "56"); 86 } 87 // Override flat-work-group-size 88 if ("rocdl.max_flat_work_group_size" == attribute.getName()) { 89 auto func = dyn_cast<LLVM::LLVMFuncOp>(op); 90 if (!func) 91 return failure(); 92 auto value = attribute.getValue().dyn_cast<IntegerAttr>(); 93 if (!value) 94 return failure(); 95 96 llvm::Function *llvmFunc = 97 moduleTranslation.lookupFunction(func.getName()); 98 llvm::SmallString<8> llvmAttrValue; 99 llvm::raw_svector_ostream attrValueStream(llvmAttrValue); 100 attrValueStream << "1, " << value.getInt(); 101 llvmFunc->addFnAttr("amdgpu-flat-work-group-size", llvmAttrValue); 102 } 103 return success(); 104 } 105 }; 106 } // namespace 107 108 void mlir::registerROCDLDialectTranslation(DialectRegistry ®istry) { 109 registry.insert<ROCDL::ROCDLDialect>(); 110 registry.addDialectInterface<ROCDL::ROCDLDialect, 111 ROCDLDialectLLVMIRTranslationInterface>(); 112 } 113 114 void mlir::registerROCDLDialectTranslation(MLIRContext &context) { 115 DialectRegistry registry; 116 registerROCDLDialectTranslation(registry); 117 context.appendDialectRegistry(registry); 118 } 119