1 //===-- MathToLibm.cpp - conversion from Math to libm calls ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/MathToLibm/MathToLibm.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/Math/IR/Math.h"
13 #include "mlir/Dialect/StandardOps/IR/Ops.h"
14 #include "mlir/Dialect/Vector/VectorOps.h"
15 #include "mlir/IR/BuiltinDialect.h"
16 #include "mlir/IR/PatternMatch.h"
17 
18 using namespace mlir;
19 
20 namespace {
21 // Pattern to convert vector operations to scalar operations. This is needed as
22 // libm calls require scalars.
23 template <typename Op>
24 struct VecOpToScalarOp : public OpRewritePattern<Op> {
25 public:
26   using OpRewritePattern<Op>::OpRewritePattern;
27 
28   LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
29 };
30 // Pattern to convert scalar math operations to calls to libm functions.
31 // Additionally the libm function signatures are declared.
32 template <typename Op>
33 struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
34 public:
35   using OpRewritePattern<Op>::OpRewritePattern;
36   ScalarOpToLibmCall<Op>(MLIRContext *context, StringRef floatFunc,
37                          StringRef doubleFunc, PatternBenefit benefit)
38       : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
39         doubleFunc(doubleFunc){};
40 
41   LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
42 
43 private:
44   std::string floatFunc, doubleFunc;
45 };
46 } // namespace
47 
48 template <typename Op>
49 LogicalResult
50 VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
51   auto opType = op.getType();
52   auto loc = op.getLoc();
53   auto vecType = opType.template dyn_cast<VectorType>();
54 
55   if (!vecType)
56     return failure();
57   if (!vecType.hasRank())
58     return failure();
59   auto shape = vecType.getShape();
60   // TODO: support multidimensional vectors
61   if (shape.size() != 1)
62     return failure();
63 
64   Value result = rewriter.create<ConstantOp>(
65       loc, DenseElementsAttr::get(
66                vecType, FloatAttr::get(vecType.getElementType(), 0.0)));
67   for (auto i = 0; i < shape.front(); ++i) {
68     SmallVector<Value> operands;
69     for (auto input : op->getOperands())
70       operands.push_back(
71           rewriter.create<vector::ExtractElementOp>(loc, input, i));
72     Value scalarOp =
73         rewriter.create<Op>(loc, vecType.getElementType(), operands);
74     result = rewriter.create<vector::InsertElementOp>(loc, scalarOp, result, i);
75   }
76   rewriter.replaceOp(op, {result});
77   return success();
78 }
79 
80 template <typename Op>
81 LogicalResult
82 ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
83                                         PatternRewriter &rewriter) const {
84   auto module = op->template getParentOfType<ModuleOp>();
85   auto type = op.getType();
86   // TODO: Support Float16 by upcasting to Float32
87   if (!type.template isa<Float32Type, Float64Type>())
88     return failure();
89 
90   auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc;
91   auto opFunc = module.template lookupSymbol<FuncOp>(name);
92   // Forward declare function if it hasn't already been
93   if (!opFunc) {
94     OpBuilder::InsertionGuard guard(rewriter);
95     rewriter.setInsertionPointToStart(module.getBody());
96     auto opFunctionTy = FunctionType::get(
97         rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
98     opFunc =
99         rewriter.create<FuncOp>(rewriter.getUnknownLoc(), name, opFunctionTy);
100     opFunc.setPrivate();
101   }
102   assert(opFunc.getType().template cast<FunctionType>().getResults() ==
103          op->getResultTypes());
104   assert(opFunc.getType().template cast<FunctionType>().getInputs() ==
105          op->getOperandTypes());
106 
107   rewriter.replaceOpWithNewOp<CallOp>(op, opFunc, op->getOperands());
108 
109   return success();
110 }
111 
112 void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
113                                                 PatternBenefit benefit) {
114   patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>,
115                VecOpToScalarOp<math::TanhOp>>(patterns.getContext(), benefit);
116   patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(),
117                                                   "atan2f", "atan2", benefit);
118   patterns.add<ScalarOpToLibmCall<math::ExpM1Op>>(patterns.getContext(),
119                                                   "expm1f", "expm1", benefit);
120   patterns.add<ScalarOpToLibmCall<math::TanhOp>>(patterns.getContext(), "tanhf",
121                                                  "tanh", benefit);
122 }
123 
124 namespace {
125 struct ConvertMathToLibmPass
126     : public ConvertMathToLibmBase<ConvertMathToLibmPass> {
127   void runOnOperation() override;
128 };
129 } // namespace
130 
131 void ConvertMathToLibmPass::runOnOperation() {
132   auto module = getOperation();
133 
134   RewritePatternSet patterns(&getContext());
135   populateMathToLibmConversionPatterns(patterns, /*benefit=*/1);
136 
137   ConversionTarget target(getContext());
138   target.addLegalDialect<BuiltinDialect, StandardOpsDialect,
139                          vector::VectorDialect>();
140   target.addIllegalDialect<math::MathDialect>();
141   if (failed(applyPartialConversion(module, target, std::move(patterns))))
142     signalPassFailure();
143 }
144 
145 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertMathToLibmPass() {
146   return std::make_unique<ConvertMathToLibmPass>();
147 }
148