134810e1bSTres Popp //===-- MathToLibm.cpp - conversion from Math to libm calls ---------------===// 234810e1bSTres Popp // 334810e1bSTres Popp // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 434810e1bSTres Popp // See https://llvm.org/LICENSE.txt for license information. 534810e1bSTres Popp // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 634810e1bSTres Popp // 734810e1bSTres Popp //===----------------------------------------------------------------------===// 834810e1bSTres Popp 934810e1bSTres Popp #include "mlir/Conversion/MathToLibm/MathToLibm.h" 1034810e1bSTres Popp 1134810e1bSTres Popp #include "../PassDetail.h" 12*a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 1334810e1bSTres Popp #include "mlir/Dialect/Math/IR/Math.h" 1434810e1bSTres Popp #include "mlir/Dialect/StandardOps/IR/Ops.h" 1534810e1bSTres Popp #include "mlir/Dialect/Vector/VectorOps.h" 1634810e1bSTres Popp #include "mlir/IR/BuiltinDialect.h" 1734810e1bSTres Popp #include "mlir/IR/PatternMatch.h" 1834810e1bSTres Popp 1934810e1bSTres Popp using namespace mlir; 2034810e1bSTres Popp 2134810e1bSTres Popp namespace { 2234810e1bSTres Popp // Pattern to convert vector operations to scalar operations. This is needed as 2334810e1bSTres Popp // libm calls require scalars. 2434810e1bSTres Popp template <typename Op> 2534810e1bSTres Popp struct VecOpToScalarOp : public OpRewritePattern<Op> { 2634810e1bSTres Popp public: 2734810e1bSTres Popp using OpRewritePattern<Op>::OpRewritePattern; 2834810e1bSTres Popp 2934810e1bSTres Popp LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; 3034810e1bSTres Popp }; 3134810e1bSTres Popp // Pattern to convert scalar math operations to calls to libm functions. 3234810e1bSTres Popp // Additionally the libm function signatures are declared. 3334810e1bSTres Popp template <typename Op> 3434810e1bSTres Popp struct ScalarOpToLibmCall : public OpRewritePattern<Op> { 3534810e1bSTres Popp public: 3634810e1bSTres Popp using OpRewritePattern<Op>::OpRewritePattern; 3734810e1bSTres Popp ScalarOpToLibmCall<Op>(MLIRContext *context, StringRef floatFunc, 3834810e1bSTres Popp StringRef doubleFunc, PatternBenefit benefit) 3934810e1bSTres Popp : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc), 4034810e1bSTres Popp doubleFunc(doubleFunc){}; 4134810e1bSTres Popp 4234810e1bSTres Popp LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; 4334810e1bSTres Popp 4434810e1bSTres Popp private: 4534810e1bSTres Popp std::string floatFunc, doubleFunc; 4634810e1bSTres Popp }; 4734810e1bSTres Popp } // namespace 4834810e1bSTres Popp 4934810e1bSTres Popp template <typename Op> 5034810e1bSTres Popp LogicalResult 5134810e1bSTres Popp VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const { 5234810e1bSTres Popp auto opType = op.getType(); 5334810e1bSTres Popp auto loc = op.getLoc(); 5434810e1bSTres Popp auto vecType = opType.template dyn_cast<VectorType>(); 5534810e1bSTres Popp 5634810e1bSTres Popp if (!vecType) 5734810e1bSTres Popp return failure(); 5834810e1bSTres Popp if (!vecType.hasRank()) 5934810e1bSTres Popp return failure(); 6034810e1bSTres Popp auto shape = vecType.getShape(); 6134810e1bSTres Popp // TODO: support multidimensional vectors 6234810e1bSTres Popp if (shape.size() != 1) 6334810e1bSTres Popp return failure(); 6434810e1bSTres Popp 65*a54f4eaeSMogball Value result = rewriter.create<arith::ConstantOp>( 6634810e1bSTres Popp loc, DenseElementsAttr::get( 6734810e1bSTres Popp vecType, FloatAttr::get(vecType.getElementType(), 0.0))); 6834810e1bSTres Popp for (auto i = 0; i < shape.front(); ++i) { 6934810e1bSTres Popp SmallVector<Value> operands; 7034810e1bSTres Popp for (auto input : op->getOperands()) 7134810e1bSTres Popp operands.push_back( 7234810e1bSTres Popp rewriter.create<vector::ExtractElementOp>(loc, input, i)); 7334810e1bSTres Popp Value scalarOp = 7434810e1bSTres Popp rewriter.create<Op>(loc, vecType.getElementType(), operands); 7534810e1bSTres Popp result = rewriter.create<vector::InsertElementOp>(loc, scalarOp, result, i); 7634810e1bSTres Popp } 7734810e1bSTres Popp rewriter.replaceOp(op, {result}); 7834810e1bSTres Popp return success(); 7934810e1bSTres Popp } 8034810e1bSTres Popp 8134810e1bSTres Popp template <typename Op> 8234810e1bSTres Popp LogicalResult 8334810e1bSTres Popp ScalarOpToLibmCall<Op>::matchAndRewrite(Op op, 8434810e1bSTres Popp PatternRewriter &rewriter) const { 851ebf7ce9STres Popp auto module = SymbolTable::getNearestSymbolTable(op); 8634810e1bSTres Popp auto type = op.getType(); 8734810e1bSTres Popp // TODO: Support Float16 by upcasting to Float32 8834810e1bSTres Popp if (!type.template isa<Float32Type, Float64Type>()) 8934810e1bSTres Popp return failure(); 9034810e1bSTres Popp 9134810e1bSTres Popp auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc; 921ebf7ce9STres Popp auto opFunc = dyn_cast_or_null<SymbolOpInterface>( 931ebf7ce9STres Popp SymbolTable::lookupSymbolIn(module, name)); 9434810e1bSTres Popp // Forward declare function if it hasn't already been 9534810e1bSTres Popp if (!opFunc) { 9634810e1bSTres Popp OpBuilder::InsertionGuard guard(rewriter); 971ebf7ce9STres Popp rewriter.setInsertionPointToStart(&module->getRegion(0).front()); 9834810e1bSTres Popp auto opFunctionTy = FunctionType::get( 9934810e1bSTres Popp rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); 10034810e1bSTres Popp opFunc = 10134810e1bSTres Popp rewriter.create<FuncOp>(rewriter.getUnknownLoc(), name, opFunctionTy); 10234810e1bSTres Popp opFunc.setPrivate(); 10334810e1bSTres Popp } 1041ebf7ce9STres Popp assert(SymbolTable::lookupSymbolIn(module, name) 1051ebf7ce9STres Popp ->template hasTrait<mlir::OpTrait::FunctionLike>()); 10634810e1bSTres Popp 1071ebf7ce9STres Popp rewriter.replaceOpWithNewOp<CallOp>(op, name, op.getType(), 1081ebf7ce9STres Popp op->getOperands()); 10934810e1bSTres Popp 11034810e1bSTres Popp return success(); 11134810e1bSTres Popp } 11234810e1bSTres Popp 11334810e1bSTres Popp void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns, 11434810e1bSTres Popp PatternBenefit benefit) { 11534810e1bSTres Popp patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>, 11634810e1bSTres Popp VecOpToScalarOp<math::TanhOp>>(patterns.getContext(), benefit); 11734810e1bSTres Popp patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(), 11834810e1bSTres Popp "atan2f", "atan2", benefit); 11934810e1bSTres Popp patterns.add<ScalarOpToLibmCall<math::ExpM1Op>>(patterns.getContext(), 12034810e1bSTres Popp "expm1f", "expm1", benefit); 12134810e1bSTres Popp patterns.add<ScalarOpToLibmCall<math::TanhOp>>(patterns.getContext(), "tanhf", 12234810e1bSTres Popp "tanh", benefit); 12334810e1bSTres Popp } 12434810e1bSTres Popp 12534810e1bSTres Popp namespace { 12634810e1bSTres Popp struct ConvertMathToLibmPass 12734810e1bSTres Popp : public ConvertMathToLibmBase<ConvertMathToLibmPass> { 12834810e1bSTres Popp void runOnOperation() override; 12934810e1bSTres Popp }; 13034810e1bSTres Popp } // namespace 13134810e1bSTres Popp 13234810e1bSTres Popp void ConvertMathToLibmPass::runOnOperation() { 13334810e1bSTres Popp auto module = getOperation(); 13434810e1bSTres Popp 13534810e1bSTres Popp RewritePatternSet patterns(&getContext()); 13634810e1bSTres Popp populateMathToLibmConversionPatterns(patterns, /*benefit=*/1); 13734810e1bSTres Popp 13834810e1bSTres Popp ConversionTarget target(getContext()); 139*a54f4eaeSMogball target.addLegalDialect<arith::ArithmeticDialect, BuiltinDialect, 140*a54f4eaeSMogball StandardOpsDialect, vector::VectorDialect>(); 14134810e1bSTres Popp target.addIllegalDialect<math::MathDialect>(); 14234810e1bSTres Popp if (failed(applyPartialConversion(module, target, std::move(patterns)))) 14334810e1bSTres Popp signalPassFailure(); 14434810e1bSTres Popp } 14534810e1bSTres Popp 14634810e1bSTres Popp std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertMathToLibmPass() { 14734810e1bSTres Popp return std::make_unique<ConvertMathToLibmPass>(); 14834810e1bSTres Popp } 149