1 //===-- MathToLibm.cpp - conversion from Math to libm calls ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/MathToLibm/MathToLibm.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13 #include "mlir/Dialect/Math/IR/Math.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Dialect/Utils/IndexingUtils.h"
16 #include "mlir/Dialect/Vector/IR/VectorOps.h"
17 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
18 #include "mlir/IR/BuiltinDialect.h"
19 #include "mlir/IR/PatternMatch.h"
20 
21 using namespace mlir;
22 
23 namespace {
24 // Pattern to convert vector operations to scalar operations. This is needed as
25 // libm calls require scalars.
26 template <typename Op>
27 struct VecOpToScalarOp : public OpRewritePattern<Op> {
28 public:
29   using OpRewritePattern<Op>::OpRewritePattern;
30 
31   LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
32 };
33 // Pattern to convert scalar math operations to calls to libm functions.
34 // Additionally the libm function signatures are declared.
35 template <typename Op>
36 struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
37 public:
38   using OpRewritePattern<Op>::OpRewritePattern;
39   ScalarOpToLibmCall<Op>(MLIRContext *context, StringRef floatFunc,
40                          StringRef doubleFunc, PatternBenefit benefit)
41       : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
42         doubleFunc(doubleFunc){};
43 
44   LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
45 
46 private:
47   std::string floatFunc, doubleFunc;
48 };
49 } // namespace
50 
51 template <typename Op>
52 LogicalResult
53 VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
54   auto opType = op.getType();
55   auto loc = op.getLoc();
56   auto vecType = opType.template dyn_cast<VectorType>();
57 
58   if (!vecType)
59     return failure();
60   if (!vecType.hasRank())
61     return failure();
62   auto shape = vecType.getShape();
63   int64_t numElements = vecType.getNumElements();
64 
65   Value result = rewriter.create<arith::ConstantOp>(
66       loc, DenseElementsAttr::get(
67                vecType, FloatAttr::get(vecType.getElementType(), 0.0)));
68   SmallVector<int64_t> ones(shape.size(), 1);
69   SmallVector<int64_t> strides = computeStrides(shape, ones);
70   for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) {
71     SmallVector<int64_t> positions = delinearize(strides, linearIndex);
72     SmallVector<Value> operands;
73     for (auto input : op->getOperands())
74       operands.push_back(
75           rewriter.create<vector::ExtractOp>(loc, input, positions));
76     Value scalarOp =
77         rewriter.create<Op>(loc, vecType.getElementType(), operands);
78     result =
79         rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions);
80   }
81   rewriter.replaceOp(op, {result});
82   return success();
83 }
84 
85 template <typename Op>
86 LogicalResult
87 ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
88                                         PatternRewriter &rewriter) const {
89   auto module = SymbolTable::getNearestSymbolTable(op);
90   auto type = op.getType();
91   // TODO: Support Float16 by upcasting to Float32
92   if (!type.template isa<Float32Type, Float64Type>())
93     return failure();
94 
95   auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc;
96   auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
97       SymbolTable::lookupSymbolIn(module, name));
98   // Forward declare function if it hasn't already been
99   if (!opFunc) {
100     OpBuilder::InsertionGuard guard(rewriter);
101     rewriter.setInsertionPointToStart(&module->getRegion(0).front());
102     auto opFunctionTy = FunctionType::get(
103         rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
104     opFunc =
105         rewriter.create<FuncOp>(rewriter.getUnknownLoc(), name, opFunctionTy);
106     opFunc.setPrivate();
107   }
108   assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name)));
109 
110   rewriter.replaceOpWithNewOp<CallOp>(op, name, op.getType(),
111                                       op->getOperands());
112 
113   return success();
114 }
115 
116 void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
117                                                 PatternBenefit benefit) {
118   patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>,
119                VecOpToScalarOp<math::TanhOp>>(patterns.getContext(), benefit);
120   patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(),
121                                                   "atan2f", "atan2", benefit);
122   patterns.add<ScalarOpToLibmCall<math::ErfOp>>(patterns.getContext(), "erff",
123                                                 "erf", benefit);
124   patterns.add<ScalarOpToLibmCall<math::ExpM1Op>>(patterns.getContext(),
125                                                   "expm1f", "expm1", benefit);
126   patterns.add<ScalarOpToLibmCall<math::TanhOp>>(patterns.getContext(), "tanhf",
127                                                  "tanh", benefit);
128 }
129 
130 namespace {
131 struct ConvertMathToLibmPass
132     : public ConvertMathToLibmBase<ConvertMathToLibmPass> {
133   void runOnOperation() override;
134 };
135 } // namespace
136 
137 void ConvertMathToLibmPass::runOnOperation() {
138   auto module = getOperation();
139 
140   RewritePatternSet patterns(&getContext());
141   populateMathToLibmConversionPatterns(patterns, /*benefit=*/1);
142 
143   ConversionTarget target(getContext());
144   target.addLegalDialect<arith::ArithmeticDialect, BuiltinDialect,
145                          StandardOpsDialect, vector::VectorDialect>();
146   target.addIllegalDialect<math::MathDialect>();
147   if (failed(applyPartialConversion(module, target, std::move(patterns))))
148     signalPassFailure();
149 }
150 
151 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertMathToLibmPass() {
152   return std::make_unique<ConvertMathToLibmPass>();
153 }
154