134810e1bSTres Popp //===-- MathToLibm.cpp - conversion from Math to libm calls ---------------===//
234810e1bSTres Popp //
334810e1bSTres Popp // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
434810e1bSTres Popp // See https://llvm.org/LICENSE.txt for license information.
534810e1bSTres Popp // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
634810e1bSTres Popp //
734810e1bSTres Popp //===----------------------------------------------------------------------===//
834810e1bSTres Popp 
934810e1bSTres Popp #include "mlir/Conversion/MathToLibm/MathToLibm.h"
1034810e1bSTres Popp 
1134810e1bSTres Popp #include "../PassDetail.h"
12*a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1334810e1bSTres Popp #include "mlir/Dialect/Math/IR/Math.h"
1434810e1bSTres Popp #include "mlir/Dialect/StandardOps/IR/Ops.h"
1534810e1bSTres Popp #include "mlir/Dialect/Vector/VectorOps.h"
1634810e1bSTres Popp #include "mlir/IR/BuiltinDialect.h"
1734810e1bSTres Popp #include "mlir/IR/PatternMatch.h"
1834810e1bSTres Popp 
1934810e1bSTres Popp using namespace mlir;
2034810e1bSTres Popp 
2134810e1bSTres Popp namespace {
2234810e1bSTres Popp // Pattern to convert vector operations to scalar operations. This is needed as
2334810e1bSTres Popp // libm calls require scalars.
2434810e1bSTres Popp template <typename Op>
2534810e1bSTres Popp struct VecOpToScalarOp : public OpRewritePattern<Op> {
2634810e1bSTres Popp public:
2734810e1bSTres Popp   using OpRewritePattern<Op>::OpRewritePattern;
2834810e1bSTres Popp 
2934810e1bSTres Popp   LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
3034810e1bSTres Popp };
3134810e1bSTres Popp // Pattern to convert scalar math operations to calls to libm functions.
3234810e1bSTres Popp // Additionally the libm function signatures are declared.
3334810e1bSTres Popp template <typename Op>
3434810e1bSTres Popp struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
3534810e1bSTres Popp public:
3634810e1bSTres Popp   using OpRewritePattern<Op>::OpRewritePattern;
3734810e1bSTres Popp   ScalarOpToLibmCall<Op>(MLIRContext *context, StringRef floatFunc,
3834810e1bSTres Popp                          StringRef doubleFunc, PatternBenefit benefit)
3934810e1bSTres Popp       : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
4034810e1bSTres Popp         doubleFunc(doubleFunc){};
4134810e1bSTres Popp 
4234810e1bSTres Popp   LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
4334810e1bSTres Popp 
4434810e1bSTres Popp private:
4534810e1bSTres Popp   std::string floatFunc, doubleFunc;
4634810e1bSTres Popp };
4734810e1bSTres Popp } // namespace
4834810e1bSTres Popp 
4934810e1bSTres Popp template <typename Op>
5034810e1bSTres Popp LogicalResult
5134810e1bSTres Popp VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
5234810e1bSTres Popp   auto opType = op.getType();
5334810e1bSTres Popp   auto loc = op.getLoc();
5434810e1bSTres Popp   auto vecType = opType.template dyn_cast<VectorType>();
5534810e1bSTres Popp 
5634810e1bSTres Popp   if (!vecType)
5734810e1bSTres Popp     return failure();
5834810e1bSTres Popp   if (!vecType.hasRank())
5934810e1bSTres Popp     return failure();
6034810e1bSTres Popp   auto shape = vecType.getShape();
6134810e1bSTres Popp   // TODO: support multidimensional vectors
6234810e1bSTres Popp   if (shape.size() != 1)
6334810e1bSTres Popp     return failure();
6434810e1bSTres Popp 
65*a54f4eaeSMogball   Value result = rewriter.create<arith::ConstantOp>(
6634810e1bSTres Popp       loc, DenseElementsAttr::get(
6734810e1bSTres Popp                vecType, FloatAttr::get(vecType.getElementType(), 0.0)));
6834810e1bSTres Popp   for (auto i = 0; i < shape.front(); ++i) {
6934810e1bSTres Popp     SmallVector<Value> operands;
7034810e1bSTres Popp     for (auto input : op->getOperands())
7134810e1bSTres Popp       operands.push_back(
7234810e1bSTres Popp           rewriter.create<vector::ExtractElementOp>(loc, input, i));
7334810e1bSTres Popp     Value scalarOp =
7434810e1bSTres Popp         rewriter.create<Op>(loc, vecType.getElementType(), operands);
7534810e1bSTres Popp     result = rewriter.create<vector::InsertElementOp>(loc, scalarOp, result, i);
7634810e1bSTres Popp   }
7734810e1bSTres Popp   rewriter.replaceOp(op, {result});
7834810e1bSTres Popp   return success();
7934810e1bSTres Popp }
8034810e1bSTres Popp 
8134810e1bSTres Popp template <typename Op>
8234810e1bSTres Popp LogicalResult
8334810e1bSTres Popp ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
8434810e1bSTres Popp                                         PatternRewriter &rewriter) const {
851ebf7ce9STres Popp   auto module = SymbolTable::getNearestSymbolTable(op);
8634810e1bSTres Popp   auto type = op.getType();
8734810e1bSTres Popp   // TODO: Support Float16 by upcasting to Float32
8834810e1bSTres Popp   if (!type.template isa<Float32Type, Float64Type>())
8934810e1bSTres Popp     return failure();
9034810e1bSTres Popp 
9134810e1bSTres Popp   auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc;
921ebf7ce9STres Popp   auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
931ebf7ce9STres Popp       SymbolTable::lookupSymbolIn(module, name));
9434810e1bSTres Popp   // Forward declare function if it hasn't already been
9534810e1bSTres Popp   if (!opFunc) {
9634810e1bSTres Popp     OpBuilder::InsertionGuard guard(rewriter);
971ebf7ce9STres Popp     rewriter.setInsertionPointToStart(&module->getRegion(0).front());
9834810e1bSTres Popp     auto opFunctionTy = FunctionType::get(
9934810e1bSTres Popp         rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
10034810e1bSTres Popp     opFunc =
10134810e1bSTres Popp         rewriter.create<FuncOp>(rewriter.getUnknownLoc(), name, opFunctionTy);
10234810e1bSTres Popp     opFunc.setPrivate();
10334810e1bSTres Popp   }
1041ebf7ce9STres Popp   assert(SymbolTable::lookupSymbolIn(module, name)
1051ebf7ce9STres Popp              ->template hasTrait<mlir::OpTrait::FunctionLike>());
10634810e1bSTres Popp 
1071ebf7ce9STres Popp   rewriter.replaceOpWithNewOp<CallOp>(op, name, op.getType(),
1081ebf7ce9STres Popp                                       op->getOperands());
10934810e1bSTres Popp 
11034810e1bSTres Popp   return success();
11134810e1bSTres Popp }
11234810e1bSTres Popp 
11334810e1bSTres Popp void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
11434810e1bSTres Popp                                                 PatternBenefit benefit) {
11534810e1bSTres Popp   patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>,
11634810e1bSTres Popp                VecOpToScalarOp<math::TanhOp>>(patterns.getContext(), benefit);
11734810e1bSTres Popp   patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(),
11834810e1bSTres Popp                                                   "atan2f", "atan2", benefit);
11934810e1bSTres Popp   patterns.add<ScalarOpToLibmCall<math::ExpM1Op>>(patterns.getContext(),
12034810e1bSTres Popp                                                   "expm1f", "expm1", benefit);
12134810e1bSTres Popp   patterns.add<ScalarOpToLibmCall<math::TanhOp>>(patterns.getContext(), "tanhf",
12234810e1bSTres Popp                                                  "tanh", benefit);
12334810e1bSTres Popp }
12434810e1bSTres Popp 
12534810e1bSTres Popp namespace {
12634810e1bSTres Popp struct ConvertMathToLibmPass
12734810e1bSTres Popp     : public ConvertMathToLibmBase<ConvertMathToLibmPass> {
12834810e1bSTres Popp   void runOnOperation() override;
12934810e1bSTres Popp };
13034810e1bSTres Popp } // namespace
13134810e1bSTres Popp 
13234810e1bSTres Popp void ConvertMathToLibmPass::runOnOperation() {
13334810e1bSTres Popp   auto module = getOperation();
13434810e1bSTres Popp 
13534810e1bSTres Popp   RewritePatternSet patterns(&getContext());
13634810e1bSTres Popp   populateMathToLibmConversionPatterns(patterns, /*benefit=*/1);
13734810e1bSTres Popp 
13834810e1bSTres Popp   ConversionTarget target(getContext());
139*a54f4eaeSMogball   target.addLegalDialect<arith::ArithmeticDialect, BuiltinDialect,
140*a54f4eaeSMogball                          StandardOpsDialect, vector::VectorDialect>();
14134810e1bSTres Popp   target.addIllegalDialect<math::MathDialect>();
14234810e1bSTres Popp   if (failed(applyPartialConversion(module, target, std::move(patterns))))
14334810e1bSTres Popp     signalPassFailure();
14434810e1bSTres Popp }
14534810e1bSTres Popp 
14634810e1bSTres Popp std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertMathToLibmPass() {
14734810e1bSTres Popp   return std::make_unique<ConvertMathToLibmPass>();
14834810e1bSTres Popp }
149