1 //===-- MathToLibm.cpp - conversion from Math to libm calls ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/MathToLibm/MathToLibm.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13 #include "mlir/Dialect/Math/IR/Math.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Dialect/Vector/VectorOps.h"
16 #include "mlir/Dialect/Vector/VectorUtils.h"
17 #include "mlir/IR/BuiltinDialect.h"
18 #include "mlir/IR/PatternMatch.h"
19 
20 using namespace mlir;
21 
22 namespace {
23 // Pattern to convert vector operations to scalar operations. This is needed as
24 // libm calls require scalars.
25 template <typename Op>
26 struct VecOpToScalarOp : public OpRewritePattern<Op> {
27 public:
28   using OpRewritePattern<Op>::OpRewritePattern;
29 
30   LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
31 };
32 // Pattern to convert scalar math operations to calls to libm functions.
33 // Additionally the libm function signatures are declared.
34 template <typename Op>
35 struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
36 public:
37   using OpRewritePattern<Op>::OpRewritePattern;
38   ScalarOpToLibmCall<Op>(MLIRContext *context, StringRef floatFunc,
39                          StringRef doubleFunc, PatternBenefit benefit)
40       : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
41         doubleFunc(doubleFunc){};
42 
43   LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
44 
45 private:
46   std::string floatFunc, doubleFunc;
47 };
48 } // namespace
49 
50 template <typename Op>
51 LogicalResult
52 VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
53   auto opType = op.getType();
54   auto loc = op.getLoc();
55   auto vecType = opType.template dyn_cast<VectorType>();
56 
57   if (!vecType)
58     return failure();
59   if (!vecType.hasRank())
60     return failure();
61   auto shape = vecType.getShape();
62   int64_t numElements = vecType.getNumElements();
63 
64   Value result = rewriter.create<arith::ConstantOp>(
65       loc, DenseElementsAttr::get(
66                vecType, FloatAttr::get(vecType.getElementType(), 0.0)));
67   SmallVector<int64_t> ones(shape.size(), 1);
68   SmallVector<int64_t> strides = computeStrides(shape, ones);
69   for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) {
70     SmallVector<int64_t> positions = delinearize(strides, linearIndex);
71     SmallVector<Value> operands;
72     for (auto input : op->getOperands())
73       operands.push_back(
74           rewriter.create<vector::ExtractOp>(loc, input, positions));
75     Value scalarOp =
76         rewriter.create<Op>(loc, vecType.getElementType(), operands);
77     result =
78         rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions);
79   }
80   rewriter.replaceOp(op, {result});
81   return success();
82 }
83 
84 template <typename Op>
85 LogicalResult
86 ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
87                                         PatternRewriter &rewriter) const {
88   auto module = SymbolTable::getNearestSymbolTable(op);
89   auto type = op.getType();
90   // TODO: Support Float16 by upcasting to Float32
91   if (!type.template isa<Float32Type, Float64Type>())
92     return failure();
93 
94   auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc;
95   auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
96       SymbolTable::lookupSymbolIn(module, name));
97   // Forward declare function if it hasn't already been
98   if (!opFunc) {
99     OpBuilder::InsertionGuard guard(rewriter);
100     rewriter.setInsertionPointToStart(&module->getRegion(0).front());
101     auto opFunctionTy = FunctionType::get(
102         rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
103     opFunc =
104         rewriter.create<FuncOp>(rewriter.getUnknownLoc(), name, opFunctionTy);
105     opFunc.setPrivate();
106   }
107   assert(SymbolTable::lookupSymbolIn(module, name)
108              ->template hasTrait<mlir::OpTrait::FunctionLike>());
109 
110   rewriter.replaceOpWithNewOp<CallOp>(op, name, op.getType(),
111                                       op->getOperands());
112 
113   return success();
114 }
115 
116 void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
117                                                 PatternBenefit benefit) {
118   patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>,
119                VecOpToScalarOp<math::TanhOp>>(patterns.getContext(), benefit);
120   patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(),
121                                                   "atan2f", "atan2", benefit);
122   patterns.add<ScalarOpToLibmCall<math::ErfOp>>(patterns.getContext(), "erff",
123                                                 "erf", benefit);
124   patterns.add<ScalarOpToLibmCall<math::ExpM1Op>>(patterns.getContext(),
125                                                   "expm1f", "expm1", benefit);
126   patterns.add<ScalarOpToLibmCall<math::TanhOp>>(patterns.getContext(), "tanhf",
127                                                  "tanh", benefit);
128 }
129 
130 namespace {
131 struct ConvertMathToLibmPass
132     : public ConvertMathToLibmBase<ConvertMathToLibmPass> {
133   void runOnOperation() override;
134 };
135 } // namespace
136 
137 void ConvertMathToLibmPass::runOnOperation() {
138   auto module = getOperation();
139 
140   RewritePatternSet patterns(&getContext());
141   populateMathToLibmConversionPatterns(patterns, /*benefit=*/1);
142 
143   ConversionTarget target(getContext());
144   target.addLegalDialect<arith::ArithmeticDialect, BuiltinDialect,
145                          StandardOpsDialect, vector::VectorDialect>();
146   target.addIllegalDialect<math::MathDialect>();
147   if (failed(applyPartialConversion(module, target, std::move(patterns))))
148     signalPassFailure();
149 }
150 
151 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertMathToLibmPass() {
152   return std::make_unique<ConvertMathToLibmPass>();
153 }
154