1 //===-- MathToLibm.cpp - conversion from Math to libm calls ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/MathToLibm/MathToLibm.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/Math/IR/Math.h"
15 #include "mlir/Dialect/Utils/IndexingUtils.h"
16 #include "mlir/Dialect/Vector/IR/VectorOps.h"
17 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
18 #include "mlir/IR/BuiltinDialect.h"
19 #include "mlir/IR/PatternMatch.h"
20 
21 using namespace mlir;
22 
23 namespace {
24 // Pattern to convert vector operations to scalar operations. This is needed as
25 // libm calls require scalars.
26 template <typename Op>
27 struct VecOpToScalarOp : public OpRewritePattern<Op> {
28 public:
29   using OpRewritePattern<Op>::OpRewritePattern;
30 
31   LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
32 };
33 // Pattern to promote an op of a smaller floating point type to F32.
34 template <typename Op>
35 struct PromoteOpToF32 : public OpRewritePattern<Op> {
36 public:
37   using OpRewritePattern<Op>::OpRewritePattern;
38 
39   LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
40 };
41 // Pattern to convert scalar math operations to calls to libm functions.
42 // Additionally the libm function signatures are declared.
43 template <typename Op>
44 struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
45 public:
46   using OpRewritePattern<Op>::OpRewritePattern;
47   ScalarOpToLibmCall<Op>(MLIRContext *context, StringRef floatFunc,
48                          StringRef doubleFunc, PatternBenefit benefit)
49       : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
50         doubleFunc(doubleFunc){};
51 
52   LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
53 
54 private:
55   std::string floatFunc, doubleFunc;
56 };
57 } // namespace
58 
59 template <typename Op>
60 LogicalResult
61 VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
62   auto opType = op.getType();
63   auto loc = op.getLoc();
64   auto vecType = opType.template dyn_cast<VectorType>();
65 
66   if (!vecType)
67     return failure();
68   if (!vecType.hasRank())
69     return failure();
70   auto shape = vecType.getShape();
71   int64_t numElements = vecType.getNumElements();
72 
73   Value result = rewriter.create<arith::ConstantOp>(
74       loc, DenseElementsAttr::get(
75                vecType, FloatAttr::get(vecType.getElementType(), 0.0)));
76   SmallVector<int64_t> ones(shape.size(), 1);
77   SmallVector<int64_t> strides = computeStrides(shape, ones);
78   for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) {
79     SmallVector<int64_t> positions = delinearize(strides, linearIndex);
80     SmallVector<Value> operands;
81     for (auto input : op->getOperands())
82       operands.push_back(
83           rewriter.create<vector::ExtractOp>(loc, input, positions));
84     Value scalarOp =
85         rewriter.create<Op>(loc, vecType.getElementType(), operands);
86     result =
87         rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions);
88   }
89   rewriter.replaceOp(op, {result});
90   return success();
91 }
92 
93 template <typename Op>
94 LogicalResult
95 PromoteOpToF32<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
96   auto opType = op.getType();
97   if (!opType.template isa<Float16Type, BFloat16Type>())
98     return failure();
99 
100   auto loc = op.getLoc();
101   auto f32 = rewriter.getF32Type();
102   auto extendedOperands = llvm::to_vector(
103       llvm::map_range(op->getOperands(), [&](Value operand) -> Value {
104         return rewriter.create<arith::ExtFOp>(loc, f32, operand);
105       }));
106   auto newOp = rewriter.create<Op>(loc, f32, extendedOperands);
107   rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, opType, newOp);
108   return success();
109 }
110 
111 template <typename Op>
112 LogicalResult
113 ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
114                                         PatternRewriter &rewriter) const {
115   auto module = SymbolTable::getNearestSymbolTable(op);
116   auto type = op.getType();
117   if (!type.template isa<Float32Type, Float64Type>())
118     return failure();
119 
120   auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc;
121   auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
122       SymbolTable::lookupSymbolIn(module, name));
123   // Forward declare function if it hasn't already been
124   if (!opFunc) {
125     OpBuilder::InsertionGuard guard(rewriter);
126     rewriter.setInsertionPointToStart(&module->getRegion(0).front());
127     auto opFunctionTy = FunctionType::get(
128         rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
129     opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name,
130                                            opFunctionTy);
131     opFunc.setPrivate();
132   }
133   assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name)));
134 
135   rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(),
136                                             op->getOperands());
137 
138   return success();
139 }
140 
141 void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
142                                                 PatternBenefit benefit) {
143   patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>,
144                VecOpToScalarOp<math::TanhOp>, VecOpToScalarOp<math::CosOp>,
145                VecOpToScalarOp<math::SinOp>, VecOpToScalarOp<math::ErfOp>,
146                VecOpToScalarOp<math::RoundOp>, VecOpToScalarOp<math::AtanOp>,
147                VecOpToScalarOp<math::TanOp>>(patterns.getContext(), benefit);
148   patterns.add<PromoteOpToF32<math::Atan2Op>, PromoteOpToF32<math::ExpM1Op>,
149                PromoteOpToF32<math::TanhOp>, PromoteOpToF32<math::CosOp>,
150                PromoteOpToF32<math::SinOp>, PromoteOpToF32<math::ErfOp>,
151                PromoteOpToF32<math::RoundOp>, PromoteOpToF32<math::AtanOp>,
152                PromoteOpToF32<math::TanOp>>(patterns.getContext(), benefit);
153   patterns.add<ScalarOpToLibmCall<math::AtanOp>>(patterns.getContext(), "atanf",
154                                                  "atan", benefit);
155   patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(),
156                                                   "atan2f", "atan2", benefit);
157   patterns.add<ScalarOpToLibmCall<math::ErfOp>>(patterns.getContext(), "erff",
158                                                 "erf", benefit);
159   patterns.add<ScalarOpToLibmCall<math::ExpM1Op>>(patterns.getContext(),
160                                                   "expm1f", "expm1", benefit);
161   patterns.add<ScalarOpToLibmCall<math::TanOp>>(patterns.getContext(), "tanf",
162                                                 "tan", benefit);
163   patterns.add<ScalarOpToLibmCall<math::TanhOp>>(patterns.getContext(), "tanhf",
164                                                  "tanh", benefit);
165   patterns.add<ScalarOpToLibmCall<math::RoundOp>>(patterns.getContext(),
166                                                   "roundf", "round", benefit);
167   patterns.add<ScalarOpToLibmCall<math::CosOp>>(patterns.getContext(), "cosf",
168                                                 "cos", benefit);
169   patterns.add<ScalarOpToLibmCall<math::SinOp>>(patterns.getContext(), "sinf",
170                                                 "sin", benefit);
171 }
172 
173 namespace {
174 struct ConvertMathToLibmPass
175     : public ConvertMathToLibmBase<ConvertMathToLibmPass> {
176   void runOnOperation() override;
177 };
178 } // namespace
179 
180 void ConvertMathToLibmPass::runOnOperation() {
181   auto module = getOperation();
182 
183   RewritePatternSet patterns(&getContext());
184   populateMathToLibmConversionPatterns(patterns, /*benefit=*/1);
185 
186   ConversionTarget target(getContext());
187   target.addLegalDialect<arith::ArithmeticDialect, BuiltinDialect,
188                          func::FuncDialect, vector::VectorDialect>();
189   target.addIllegalDialect<math::MathDialect>();
190   if (failed(applyPartialConversion(module, target, std::move(patterns))))
191     signalPassFailure();
192 }
193 
194 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertMathToLibmPass() {
195   return std::make_unique<ConvertMathToLibmPass>();
196 }
197