1 //===- OpToFuncCallLowering.h - GPU ops lowering to custom calls *- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 #ifndef MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
9 #define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
10 
11 #include "mlir/Conversion/LLVMCommon/Pattern.h"
12 #include "mlir/Dialect/GPU/GPUDialect.h"
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/IR/Builders.h"
16 
17 namespace mlir {
18 
19 /// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func`
20 /// depending on the element type that Op operates upon. The function
21 /// declaration is added in case it was not added before.
22 ///
23 /// If the input values are of f16 type, the value is first casted to f32, the
24 /// function called and then the result casted back.
25 ///
26 /// Example with NVVM:
27 ///   %exp_f32 = std.exp %arg_f32 : f32
28 ///
29 /// will be transformed into
30 ///   llvm.call @__nv_expf(%arg_f32) : (f32) -> f32
31 template <typename SourceOp>
32 struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
33 public:
34   explicit OpToFuncCallLowering(LLVMTypeConverter &lowering, StringRef f32Func,
35                                 StringRef f64Func)
36       : ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
37         f64Func(f64Func) {}
38 
39   LogicalResult
40   matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
41                   ConversionPatternRewriter &rewriter) const override {
42     using LLVM::LLVMFuncOp;
43 
44     static_assert(
45         std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
46         "expected single result op");
47 
48     static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
49                                   SourceOp>::value,
50                   "expected op with same operand and result types");
51 
52     SmallVector<Value, 1> castedOperands;
53     for (Value operand : adaptor.getOperands())
54       castedOperands.push_back(maybeCast(operand, rewriter));
55 
56     Type resultType = castedOperands.front().getType();
57     Type funcType = getFunctionType(resultType, castedOperands);
58     StringRef funcName = getFunctionName(
59         funcType.cast<LLVM::LLVMFunctionType>().getReturnType());
60     if (funcName.empty())
61       return failure();
62 
63     LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
64     auto callOp = rewriter.create<LLVM::CallOp>(
65         op->getLoc(), resultType, SymbolRefAttr::get(funcOp), castedOperands);
66 
67     if (resultType == adaptor.getOperands().front().getType()) {
68       rewriter.replaceOp(op, {callOp.getResult(0)});
69       return success();
70     }
71 
72     Value truncated = rewriter.create<LLVM::FPTruncOp>(
73         op->getLoc(), adaptor.getOperands().front().getType(),
74         callOp.getResult(0));
75     rewriter.replaceOp(op, {truncated});
76     return success();
77   }
78 
79 private:
80   Value maybeCast(Value operand, PatternRewriter &rewriter) const {
81     Type type = operand.getType();
82     if (!type.isa<Float16Type>())
83       return operand;
84 
85     return rewriter.create<LLVM::FPExtOp>(
86         operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
87   }
88 
89   Type getFunctionType(Type resultType, ValueRange operands) const {
90     SmallVector<Type> operandTypes(operands.getTypes());
91     return LLVM::LLVMFunctionType::get(resultType, operandTypes);
92   }
93 
94   StringRef getFunctionName(Type type) const {
95     if (type.isa<Float32Type>())
96       return f32Func;
97     if (type.isa<Float64Type>())
98       return f64Func;
99     return "";
100   }
101 
102   LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType,
103                                      Operation *op) const {
104     using LLVM::LLVMFuncOp;
105 
106     auto funcAttr = StringAttr::get(op->getContext(), funcName);
107     Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr);
108     if (funcOp)
109       return cast<LLVMFuncOp>(*funcOp);
110 
111     mlir::OpBuilder b(op->getParentOfType<LLVMFuncOp>());
112     return b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType);
113   }
114 
115   const std::string f32Func;
116   const std::string f64Func;
117 };
118 
119 } // namespace mlir
120 
121 #endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
122