1 //===-- MathToLibm.cpp - conversion from Math to libm calls ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/MathToLibm/MathToLibm.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13 #include "mlir/Dialect/Math/IR/Math.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Dialect/Vector/VectorOps.h"
16 #include "mlir/IR/BuiltinDialect.h"
17 #include "mlir/IR/PatternMatch.h"
18 
19 using namespace mlir;
20 
21 namespace {
22 // Pattern to convert vector operations to scalar operations. This is needed as
23 // libm calls require scalars.
24 template <typename Op>
25 struct VecOpToScalarOp : public OpRewritePattern<Op> {
26 public:
27   using OpRewritePattern<Op>::OpRewritePattern;
28 
29   LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
30 };
31 // Pattern to convert scalar math operations to calls to libm functions.
32 // Additionally the libm function signatures are declared.
33 template <typename Op>
34 struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
35 public:
36   using OpRewritePattern<Op>::OpRewritePattern;
37   ScalarOpToLibmCall<Op>(MLIRContext *context, StringRef floatFunc,
38                          StringRef doubleFunc, PatternBenefit benefit)
39       : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
40         doubleFunc(doubleFunc){};
41 
42   LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
43 
44 private:
45   std::string floatFunc, doubleFunc;
46 };
47 } // namespace
48 
49 template <typename Op>
50 LogicalResult
51 VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
52   auto opType = op.getType();
53   auto loc = op.getLoc();
54   auto vecType = opType.template dyn_cast<VectorType>();
55 
56   if (!vecType)
57     return failure();
58   if (!vecType.hasRank())
59     return failure();
60   auto shape = vecType.getShape();
61   // TODO: support multidimensional vectors
62   if (shape.size() != 1)
63     return failure();
64 
65   Value result = rewriter.create<arith::ConstantOp>(
66       loc, DenseElementsAttr::get(
67                vecType, FloatAttr::get(vecType.getElementType(), 0.0)));
68   for (auto i = 0; i < shape.front(); ++i) {
69     SmallVector<Value> operands;
70     for (auto input : op->getOperands())
71       operands.push_back(
72           rewriter.create<vector::ExtractElementOp>(loc, input, i));
73     Value scalarOp =
74         rewriter.create<Op>(loc, vecType.getElementType(), operands);
75     result = rewriter.create<vector::InsertElementOp>(loc, scalarOp, result, i);
76   }
77   rewriter.replaceOp(op, {result});
78   return success();
79 }
80 
81 template <typename Op>
82 LogicalResult
83 ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
84                                         PatternRewriter &rewriter) const {
85   auto module = SymbolTable::getNearestSymbolTable(op);
86   auto type = op.getType();
87   // TODO: Support Float16 by upcasting to Float32
88   if (!type.template isa<Float32Type, Float64Type>())
89     return failure();
90 
91   auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc;
92   auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
93       SymbolTable::lookupSymbolIn(module, name));
94   // Forward declare function if it hasn't already been
95   if (!opFunc) {
96     OpBuilder::InsertionGuard guard(rewriter);
97     rewriter.setInsertionPointToStart(&module->getRegion(0).front());
98     auto opFunctionTy = FunctionType::get(
99         rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
100     opFunc =
101         rewriter.create<FuncOp>(rewriter.getUnknownLoc(), name, opFunctionTy);
102     opFunc.setPrivate();
103   }
104   assert(SymbolTable::lookupSymbolIn(module, name)
105              ->template hasTrait<mlir::OpTrait::FunctionLike>());
106 
107   rewriter.replaceOpWithNewOp<CallOp>(op, name, op.getType(),
108                                       op->getOperands());
109 
110   return success();
111 }
112 
113 void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
114                                                 PatternBenefit benefit) {
115   patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>,
116                VecOpToScalarOp<math::TanhOp>>(patterns.getContext(), benefit);
117   patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(),
118                                                   "atan2f", "atan2", benefit);
119   patterns.add<ScalarOpToLibmCall<math::ExpM1Op>>(patterns.getContext(),
120                                                   "expm1f", "expm1", benefit);
121   patterns.add<ScalarOpToLibmCall<math::TanhOp>>(patterns.getContext(), "tanhf",
122                                                  "tanh", benefit);
123 }
124 
125 namespace {
126 struct ConvertMathToLibmPass
127     : public ConvertMathToLibmBase<ConvertMathToLibmPass> {
128   void runOnOperation() override;
129 };
130 } // namespace
131 
132 void ConvertMathToLibmPass::runOnOperation() {
133   auto module = getOperation();
134 
135   RewritePatternSet patterns(&getContext());
136   populateMathToLibmConversionPatterns(patterns, /*benefit=*/1);
137 
138   ConversionTarget target(getContext());
139   target.addLegalDialect<arith::ArithmeticDialect, BuiltinDialect,
140                          StandardOpsDialect, vector::VectorDialect>();
141   target.addIllegalDialect<math::MathDialect>();
142   if (failed(applyPartialConversion(module, target, std::move(patterns))))
143     signalPassFailure();
144 }
145 
146 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertMathToLibmPass() {
147   return std::make_unique<ConvertMathToLibmPass>();
148 }
149