134810e1bSTres Popp //===-- MathToLibm.cpp - conversion from Math to libm calls ---------------===//
234810e1bSTres Popp //
334810e1bSTres Popp // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
434810e1bSTres Popp // See https://llvm.org/LICENSE.txt for license information.
534810e1bSTres Popp // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
634810e1bSTres Popp //
734810e1bSTres Popp //===----------------------------------------------------------------------===//
834810e1bSTres Popp 
934810e1bSTres Popp #include "mlir/Conversion/MathToLibm/MathToLibm.h"
1034810e1bSTres Popp 
1134810e1bSTres Popp #include "../PassDetail.h"
12a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1323aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
1434810e1bSTres Popp #include "mlir/Dialect/Math/IR/Math.h"
1599ef9eebSMatthias Springer #include "mlir/Dialect/Utils/IndexingUtils.h"
1699ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
1799ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
1834810e1bSTres Popp #include "mlir/IR/BuiltinDialect.h"
1934810e1bSTres Popp #include "mlir/IR/PatternMatch.h"
2034810e1bSTres Popp 
2134810e1bSTres Popp using namespace mlir;
2234810e1bSTres Popp 
2334810e1bSTres Popp namespace {
2434810e1bSTres Popp // Pattern to convert vector operations to scalar operations. This is needed as
2534810e1bSTres Popp // libm calls require scalars.
2634810e1bSTres Popp template <typename Op>
2734810e1bSTres Popp struct VecOpToScalarOp : public OpRewritePattern<Op> {
2834810e1bSTres Popp public:
2934810e1bSTres Popp   using OpRewritePattern<Op>::OpRewritePattern;
3034810e1bSTres Popp 
3134810e1bSTres Popp   LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
3234810e1bSTres Popp };
33a48adc56SBenjamin Kramer // Pattern to promote an op of a smaller floating point type to F32.
34a48adc56SBenjamin Kramer template <typename Op>
35a48adc56SBenjamin Kramer struct PromoteOpToF32 : public OpRewritePattern<Op> {
36a48adc56SBenjamin Kramer public:
37a48adc56SBenjamin Kramer   using OpRewritePattern<Op>::OpRewritePattern;
38a48adc56SBenjamin Kramer 
39a48adc56SBenjamin Kramer   LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
40a48adc56SBenjamin Kramer };
4134810e1bSTres Popp // Pattern to convert scalar math operations to calls to libm functions.
4234810e1bSTres Popp // Additionally the libm function signatures are declared.
4334810e1bSTres Popp template <typename Op>
4434810e1bSTres Popp struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
4534810e1bSTres Popp public:
4634810e1bSTres Popp   using OpRewritePattern<Op>::OpRewritePattern;
ScalarOpToLibmCall__anon05e0e9df0111::ScalarOpToLibmCall4734810e1bSTres Popp   ScalarOpToLibmCall<Op>(MLIRContext *context, StringRef floatFunc,
4834810e1bSTres Popp                          StringRef doubleFunc, PatternBenefit benefit)
4934810e1bSTres Popp       : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
5034810e1bSTres Popp         doubleFunc(doubleFunc){};
5134810e1bSTres Popp 
5234810e1bSTres Popp   LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
5334810e1bSTres Popp 
5434810e1bSTres Popp private:
5534810e1bSTres Popp   std::string floatFunc, doubleFunc;
5634810e1bSTres Popp };
5734810e1bSTres Popp } // namespace
5834810e1bSTres Popp 
5934810e1bSTres Popp template <typename Op>
6034810e1bSTres Popp LogicalResult
matchAndRewrite(Op op,PatternRewriter & rewriter) const6134810e1bSTres Popp VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
6234810e1bSTres Popp   auto opType = op.getType();
6334810e1bSTres Popp   auto loc = op.getLoc();
6434810e1bSTres Popp   auto vecType = opType.template dyn_cast<VectorType>();
6534810e1bSTres Popp 
6634810e1bSTres Popp   if (!vecType)
6734810e1bSTres Popp     return failure();
6834810e1bSTres Popp   if (!vecType.hasRank())
6934810e1bSTres Popp     return failure();
7034810e1bSTres Popp   auto shape = vecType.getShape();
71921d91f3SAdrian Kuegel   int64_t numElements = vecType.getNumElements();
7234810e1bSTres Popp 
73a54f4eaeSMogball   Value result = rewriter.create<arith::ConstantOp>(
7434810e1bSTres Popp       loc, DenseElementsAttr::get(
7534810e1bSTres Popp                vecType, FloatAttr::get(vecType.getElementType(), 0.0)));
76921d91f3SAdrian Kuegel   SmallVector<int64_t> ones(shape.size(), 1);
77921d91f3SAdrian Kuegel   SmallVector<int64_t> strides = computeStrides(shape, ones);
78921d91f3SAdrian Kuegel   for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) {
79921d91f3SAdrian Kuegel     SmallVector<int64_t> positions = delinearize(strides, linearIndex);
8034810e1bSTres Popp     SmallVector<Value> operands;
8134810e1bSTres Popp     for (auto input : op->getOperands())
8234810e1bSTres Popp       operands.push_back(
83921d91f3SAdrian Kuegel           rewriter.create<vector::ExtractOp>(loc, input, positions));
8434810e1bSTres Popp     Value scalarOp =
8534810e1bSTres Popp         rewriter.create<Op>(loc, vecType.getElementType(), operands);
86921d91f3SAdrian Kuegel     result =
87921d91f3SAdrian Kuegel         rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions);
8834810e1bSTres Popp   }
8934810e1bSTres Popp   rewriter.replaceOp(op, {result});
9034810e1bSTres Popp   return success();
9134810e1bSTres Popp }
9234810e1bSTres Popp 
9334810e1bSTres Popp template <typename Op>
9434810e1bSTres Popp LogicalResult
matchAndRewrite(Op op,PatternRewriter & rewriter) const95a48adc56SBenjamin Kramer PromoteOpToF32<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
96a48adc56SBenjamin Kramer   auto opType = op.getType();
97a48adc56SBenjamin Kramer   if (!opType.template isa<Float16Type, BFloat16Type>())
98a48adc56SBenjamin Kramer     return failure();
99a48adc56SBenjamin Kramer 
100a48adc56SBenjamin Kramer   auto loc = op.getLoc();
101a48adc56SBenjamin Kramer   auto f32 = rewriter.getF32Type();
102a48adc56SBenjamin Kramer   auto extendedOperands = llvm::to_vector(
103a48adc56SBenjamin Kramer       llvm::map_range(op->getOperands(), [&](Value operand) -> Value {
104a48adc56SBenjamin Kramer         return rewriter.create<arith::ExtFOp>(loc, f32, operand);
105a48adc56SBenjamin Kramer       }));
106a48adc56SBenjamin Kramer   auto newOp = rewriter.create<Op>(loc, f32, extendedOperands);
107a48adc56SBenjamin Kramer   rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, opType, newOp);
108a48adc56SBenjamin Kramer   return success();
109a48adc56SBenjamin Kramer }
110a48adc56SBenjamin Kramer 
111a48adc56SBenjamin Kramer template <typename Op>
112a48adc56SBenjamin Kramer LogicalResult
matchAndRewrite(Op op,PatternRewriter & rewriter) const11334810e1bSTres Popp ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
11434810e1bSTres Popp                                         PatternRewriter &rewriter) const {
1151ebf7ce9STres Popp   auto module = SymbolTable::getNearestSymbolTable(op);
11634810e1bSTres Popp   auto type = op.getType();
11734810e1bSTres Popp   if (!type.template isa<Float32Type, Float64Type>())
11834810e1bSTres Popp     return failure();
11934810e1bSTres Popp 
12034810e1bSTres Popp   auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc;
1211ebf7ce9STres Popp   auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
1221ebf7ce9STres Popp       SymbolTable::lookupSymbolIn(module, name));
12334810e1bSTres Popp   // Forward declare function if it hasn't already been
12434810e1bSTres Popp   if (!opFunc) {
12534810e1bSTres Popp     OpBuilder::InsertionGuard guard(rewriter);
1261ebf7ce9STres Popp     rewriter.setInsertionPointToStart(&module->getRegion(0).front());
12734810e1bSTres Popp     auto opFunctionTy = FunctionType::get(
12834810e1bSTres Popp         rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
12958ceae95SRiver Riddle     opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name,
13058ceae95SRiver Riddle                                            opFunctionTy);
13134810e1bSTres Popp     opFunc.setPrivate();
13234810e1bSTres Popp   }
1337ceffae1SRiver Riddle   assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name)));
13434810e1bSTres Popp 
13523aa5a74SRiver Riddle   rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(),
1361ebf7ce9STres Popp                                             op->getOperands());
13734810e1bSTres Popp 
13834810e1bSTres Popp   return success();
13934810e1bSTres Popp }
14034810e1bSTres Popp 
populateMathToLibmConversionPatterns(RewritePatternSet & patterns,PatternBenefit benefit)14134810e1bSTres Popp void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
14234810e1bSTres Popp                                                 PatternBenefit benefit) {
14334810e1bSTres Popp   patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>,
1440bae40efSlewuathe                VecOpToScalarOp<math::TanhOp>, VecOpToScalarOp<math::CosOp>,
145ce07b956Slewuathe                VecOpToScalarOp<math::SinOp>, VecOpToScalarOp<math::ErfOp>,
146*451e5e2bSSlava Zakharin                VecOpToScalarOp<math::RoundOp>, VecOpToScalarOp<math::AtanOp>,
147*451e5e2bSSlava Zakharin                VecOpToScalarOp<math::TanOp>>(patterns.getContext(), benefit);
148a48adc56SBenjamin Kramer   patterns.add<PromoteOpToF32<math::Atan2Op>, PromoteOpToF32<math::ExpM1Op>,
1490bae40efSlewuathe                PromoteOpToF32<math::TanhOp>, PromoteOpToF32<math::CosOp>,
150ce07b956Slewuathe                PromoteOpToF32<math::SinOp>, PromoteOpToF32<math::ErfOp>,
151*451e5e2bSSlava Zakharin                PromoteOpToF32<math::RoundOp>, PromoteOpToF32<math::AtanOp>,
152*451e5e2bSSlava Zakharin                PromoteOpToF32<math::TanOp>>(patterns.getContext(), benefit);
153b163ac33SSlava Zakharin   patterns.add<ScalarOpToLibmCall<math::AtanOp>>(patterns.getContext(), "atanf",
154b163ac33SSlava Zakharin                                                  "atan", benefit);
15534810e1bSTres Popp   patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(),
15634810e1bSTres Popp                                                   "atan2f", "atan2", benefit);
157f1b92218SBoian Petkantchin   patterns.add<ScalarOpToLibmCall<math::ErfOp>>(patterns.getContext(), "erff",
158f1b92218SBoian Petkantchin                                                 "erf", benefit);
15934810e1bSTres Popp   patterns.add<ScalarOpToLibmCall<math::ExpM1Op>>(patterns.getContext(),
16034810e1bSTres Popp                                                   "expm1f", "expm1", benefit);
161*451e5e2bSSlava Zakharin   patterns.add<ScalarOpToLibmCall<math::TanOp>>(patterns.getContext(), "tanf",
162*451e5e2bSSlava Zakharin                                                 "tan", benefit);
16334810e1bSTres Popp   patterns.add<ScalarOpToLibmCall<math::TanhOp>>(patterns.getContext(), "tanhf",
16434810e1bSTres Popp                                                  "tanh", benefit);
165a0fc94abSlorenzo chelini   patterns.add<ScalarOpToLibmCall<math::RoundOp>>(patterns.getContext(),
166a0fc94abSlorenzo chelini                                                   "roundf", "round", benefit);
1670bae40efSlewuathe   patterns.add<ScalarOpToLibmCall<math::CosOp>>(patterns.getContext(), "cosf",
1680bae40efSlewuathe                                                 "cos", benefit);
1690bae40efSlewuathe   patterns.add<ScalarOpToLibmCall<math::SinOp>>(patterns.getContext(), "sinf",
1700bae40efSlewuathe                                                 "sin", benefit);
17134810e1bSTres Popp }
17234810e1bSTres Popp 
17334810e1bSTres Popp namespace {
17434810e1bSTres Popp struct ConvertMathToLibmPass
17534810e1bSTres Popp     : public ConvertMathToLibmBase<ConvertMathToLibmPass> {
17634810e1bSTres Popp   void runOnOperation() override;
17734810e1bSTres Popp };
17834810e1bSTres Popp } // namespace
17934810e1bSTres Popp 
runOnOperation()18034810e1bSTres Popp void ConvertMathToLibmPass::runOnOperation() {
18134810e1bSTres Popp   auto module = getOperation();
18234810e1bSTres Popp 
18334810e1bSTres Popp   RewritePatternSet patterns(&getContext());
18434810e1bSTres Popp   populateMathToLibmConversionPatterns(patterns, /*benefit=*/1);
18534810e1bSTres Popp 
18634810e1bSTres Popp   ConversionTarget target(getContext());
187a54f4eaeSMogball   target.addLegalDialect<arith::ArithmeticDialect, BuiltinDialect,
18823aa5a74SRiver Riddle                          func::FuncDialect, vector::VectorDialect>();
18934810e1bSTres Popp   target.addIllegalDialect<math::MathDialect>();
19034810e1bSTres Popp   if (failed(applyPartialConversion(module, target, std::move(patterns))))
19134810e1bSTres Popp     signalPassFailure();
19234810e1bSTres Popp }
19334810e1bSTres Popp 
createConvertMathToLibmPass()19434810e1bSTres Popp std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertMathToLibmPass() {
19534810e1bSTres Popp   return std::make_unique<ConvertMathToLibmPass>();
19634810e1bSTres Popp }
197