1 //===- OpToFuncCallLowering.h - GPU ops lowering to custom calls *- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 #ifndef MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
9 #define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
10 
11 #include "mlir/Conversion/LLVMCommon/Pattern.h"
12 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/IR/Builders.h"
15 
16 namespace mlir {
17 
18 /// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func`
19 /// depending on the element type that Op operates upon. The function
20 /// declaration is added in case it was not added before.
21 ///
22 /// If the input values are of f16 type, the value is first casted to f32, the
23 /// function called and then the result casted back.
24 ///
25 /// Example with NVVM:
26 ///   %exp_f32 = math.exp %arg_f32 : f32
27 ///
28 /// will be transformed into
29 ///   llvm.call @__nv_expf(%arg_f32) : (f32) -> f32
30 template <typename SourceOp>
31 struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
32 public:
OpToFuncCallLoweringOpToFuncCallLowering33   explicit OpToFuncCallLowering(LLVMTypeConverter &lowering, StringRef f32Func,
34                                 StringRef f64Func)
35       : ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
36         f64Func(f64Func) {}
37 
38   LogicalResult
matchAndRewriteOpToFuncCallLowering39   matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
40                   ConversionPatternRewriter &rewriter) const override {
41     using LLVM::LLVMFuncOp;
42 
43     static_assert(
44         std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
45         "expected single result op");
46 
47     static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
48                                   SourceOp>::value,
49                   "expected op with same operand and result types");
50 
51     SmallVector<Value, 1> castedOperands;
52     for (Value operand : adaptor.getOperands())
53       castedOperands.push_back(maybeCast(operand, rewriter));
54 
55     Type resultType = castedOperands.front().getType();
56     Type funcType = getFunctionType(resultType, castedOperands);
57     StringRef funcName = getFunctionName(
58         funcType.cast<LLVM::LLVMFunctionType>().getReturnType());
59     if (funcName.empty())
60       return failure();
61 
62     LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
63     auto callOp = rewriter.create<LLVM::CallOp>(
64         op->getLoc(), resultType, SymbolRefAttr::get(funcOp), castedOperands);
65 
66     if (resultType == adaptor.getOperands().front().getType()) {
67       rewriter.replaceOp(op, {callOp.getResult(0)});
68       return success();
69     }
70 
71     Value truncated = rewriter.create<LLVM::FPTruncOp>(
72         op->getLoc(), adaptor.getOperands().front().getType(),
73         callOp.getResult(0));
74     rewriter.replaceOp(op, {truncated});
75     return success();
76   }
77 
78 private:
maybeCastOpToFuncCallLowering79   Value maybeCast(Value operand, PatternRewriter &rewriter) const {
80     Type type = operand.getType();
81     if (!type.isa<Float16Type>())
82       return operand;
83 
84     return rewriter.create<LLVM::FPExtOp>(
85         operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
86   }
87 
getFunctionTypeOpToFuncCallLowering88   Type getFunctionType(Type resultType, ValueRange operands) const {
89     SmallVector<Type> operandTypes(operands.getTypes());
90     return LLVM::LLVMFunctionType::get(resultType, operandTypes);
91   }
92 
getFunctionNameOpToFuncCallLowering93   StringRef getFunctionName(Type type) const {
94     if (type.isa<Float32Type>())
95       return f32Func;
96     if (type.isa<Float64Type>())
97       return f64Func;
98     return "";
99   }
100 
appendOrGetFuncOpOpToFuncCallLowering101   LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType,
102                                      Operation *op) const {
103     using LLVM::LLVMFuncOp;
104 
105     auto funcAttr = StringAttr::get(op->getContext(), funcName);
106     Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr);
107     if (funcOp)
108       return cast<LLVMFuncOp>(*funcOp);
109 
110     mlir::OpBuilder b(op->getParentOfType<LLVMFuncOp>());
111     return b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType);
112   }
113 
114   const std::string f32Func;
115   const std::string f64Func;
116 };
117 
118 } // namespace mlir
119 
120 #endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
121