134810e1bSTres Popp //===-- MathToLibm.cpp - conversion from Math to libm calls ---------------===// 234810e1bSTres Popp // 334810e1bSTres Popp // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 434810e1bSTres Popp // See https://llvm.org/LICENSE.txt for license information. 534810e1bSTres Popp // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 634810e1bSTres Popp // 734810e1bSTres Popp //===----------------------------------------------------------------------===// 834810e1bSTres Popp 934810e1bSTres Popp #include "mlir/Conversion/MathToLibm/MathToLibm.h" 1034810e1bSTres Popp 1134810e1bSTres Popp #include "../PassDetail.h" 12a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 1334810e1bSTres Popp #include "mlir/Dialect/Math/IR/Math.h" 1434810e1bSTres Popp #include "mlir/Dialect/StandardOps/IR/Ops.h" 1534810e1bSTres Popp #include "mlir/Dialect/Vector/VectorOps.h" 16921d91f3SAdrian Kuegel #include "mlir/Dialect/Vector/VectorUtils.h" 1734810e1bSTres Popp #include "mlir/IR/BuiltinDialect.h" 1834810e1bSTres Popp #include "mlir/IR/PatternMatch.h" 1934810e1bSTres Popp 2034810e1bSTres Popp using namespace mlir; 2134810e1bSTres Popp 2234810e1bSTres Popp namespace { 2334810e1bSTres Popp // Pattern to convert vector operations to scalar operations. This is needed as 2434810e1bSTres Popp // libm calls require scalars. 2534810e1bSTres Popp template <typename Op> 2634810e1bSTres Popp struct VecOpToScalarOp : public OpRewritePattern<Op> { 2734810e1bSTres Popp public: 2834810e1bSTres Popp using OpRewritePattern<Op>::OpRewritePattern; 2934810e1bSTres Popp 3034810e1bSTres Popp LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; 3134810e1bSTres Popp }; 3234810e1bSTres Popp // Pattern to convert scalar math operations to calls to libm functions. 3334810e1bSTres Popp // Additionally the libm function signatures are declared. 3434810e1bSTres Popp template <typename Op> 3534810e1bSTres Popp struct ScalarOpToLibmCall : public OpRewritePattern<Op> { 3634810e1bSTres Popp public: 3734810e1bSTres Popp using OpRewritePattern<Op>::OpRewritePattern; 3834810e1bSTres Popp ScalarOpToLibmCall<Op>(MLIRContext *context, StringRef floatFunc, 3934810e1bSTres Popp StringRef doubleFunc, PatternBenefit benefit) 4034810e1bSTres Popp : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc), 4134810e1bSTres Popp doubleFunc(doubleFunc){}; 4234810e1bSTres Popp 4334810e1bSTres Popp LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; 4434810e1bSTres Popp 4534810e1bSTres Popp private: 4634810e1bSTres Popp std::string floatFunc, doubleFunc; 4734810e1bSTres Popp }; 4834810e1bSTres Popp } // namespace 4934810e1bSTres Popp 5034810e1bSTres Popp template <typename Op> 5134810e1bSTres Popp LogicalResult 5234810e1bSTres Popp VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const { 5334810e1bSTres Popp auto opType = op.getType(); 5434810e1bSTres Popp auto loc = op.getLoc(); 5534810e1bSTres Popp auto vecType = opType.template dyn_cast<VectorType>(); 5634810e1bSTres Popp 5734810e1bSTres Popp if (!vecType) 5834810e1bSTres Popp return failure(); 5934810e1bSTres Popp if (!vecType.hasRank()) 6034810e1bSTres Popp return failure(); 6134810e1bSTres Popp auto shape = vecType.getShape(); 62921d91f3SAdrian Kuegel int64_t numElements = vecType.getNumElements(); 6334810e1bSTres Popp 64a54f4eaeSMogball Value result = rewriter.create<arith::ConstantOp>( 6534810e1bSTres Popp loc, DenseElementsAttr::get( 6634810e1bSTres Popp vecType, FloatAttr::get(vecType.getElementType(), 0.0))); 67921d91f3SAdrian Kuegel SmallVector<int64_t> ones(shape.size(), 1); 68921d91f3SAdrian Kuegel SmallVector<int64_t> strides = computeStrides(shape, ones); 69921d91f3SAdrian Kuegel for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) { 70921d91f3SAdrian Kuegel SmallVector<int64_t> positions = delinearize(strides, linearIndex); 7134810e1bSTres Popp SmallVector<Value> operands; 7234810e1bSTres Popp for (auto input : op->getOperands()) 7334810e1bSTres Popp operands.push_back( 74921d91f3SAdrian Kuegel rewriter.create<vector::ExtractOp>(loc, input, positions)); 7534810e1bSTres Popp Value scalarOp = 7634810e1bSTres Popp rewriter.create<Op>(loc, vecType.getElementType(), operands); 77921d91f3SAdrian Kuegel result = 78921d91f3SAdrian Kuegel rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions); 7934810e1bSTres Popp } 8034810e1bSTres Popp rewriter.replaceOp(op, {result}); 8134810e1bSTres Popp return success(); 8234810e1bSTres Popp } 8334810e1bSTres Popp 8434810e1bSTres Popp template <typename Op> 8534810e1bSTres Popp LogicalResult 8634810e1bSTres Popp ScalarOpToLibmCall<Op>::matchAndRewrite(Op op, 8734810e1bSTres Popp PatternRewriter &rewriter) const { 881ebf7ce9STres Popp auto module = SymbolTable::getNearestSymbolTable(op); 8934810e1bSTres Popp auto type = op.getType(); 9034810e1bSTres Popp // TODO: Support Float16 by upcasting to Float32 9134810e1bSTres Popp if (!type.template isa<Float32Type, Float64Type>()) 9234810e1bSTres Popp return failure(); 9334810e1bSTres Popp 9434810e1bSTres Popp auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc; 951ebf7ce9STres Popp auto opFunc = dyn_cast_or_null<SymbolOpInterface>( 961ebf7ce9STres Popp SymbolTable::lookupSymbolIn(module, name)); 9734810e1bSTres Popp // Forward declare function if it hasn't already been 9834810e1bSTres Popp if (!opFunc) { 9934810e1bSTres Popp OpBuilder::InsertionGuard guard(rewriter); 1001ebf7ce9STres Popp rewriter.setInsertionPointToStart(&module->getRegion(0).front()); 10134810e1bSTres Popp auto opFunctionTy = FunctionType::get( 10234810e1bSTres Popp rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); 10334810e1bSTres Popp opFunc = 10434810e1bSTres Popp rewriter.create<FuncOp>(rewriter.getUnknownLoc(), name, opFunctionTy); 10534810e1bSTres Popp opFunc.setPrivate(); 10634810e1bSTres Popp } 107*7ceffae1SRiver Riddle assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name))); 10834810e1bSTres Popp 1091ebf7ce9STres Popp rewriter.replaceOpWithNewOp<CallOp>(op, name, op.getType(), 1101ebf7ce9STres Popp op->getOperands()); 11134810e1bSTres Popp 11234810e1bSTres Popp return success(); 11334810e1bSTres Popp } 11434810e1bSTres Popp 11534810e1bSTres Popp void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns, 11634810e1bSTres Popp PatternBenefit benefit) { 11734810e1bSTres Popp patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>, 11834810e1bSTres Popp VecOpToScalarOp<math::TanhOp>>(patterns.getContext(), benefit); 11934810e1bSTres Popp patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(), 12034810e1bSTres Popp "atan2f", "atan2", benefit); 121f1b92218SBoian Petkantchin patterns.add<ScalarOpToLibmCall<math::ErfOp>>(patterns.getContext(), "erff", 122f1b92218SBoian Petkantchin "erf", benefit); 12334810e1bSTres Popp patterns.add<ScalarOpToLibmCall<math::ExpM1Op>>(patterns.getContext(), 12434810e1bSTres Popp "expm1f", "expm1", benefit); 12534810e1bSTres Popp patterns.add<ScalarOpToLibmCall<math::TanhOp>>(patterns.getContext(), "tanhf", 12634810e1bSTres Popp "tanh", benefit); 12734810e1bSTres Popp } 12834810e1bSTres Popp 12934810e1bSTres Popp namespace { 13034810e1bSTres Popp struct ConvertMathToLibmPass 13134810e1bSTres Popp : public ConvertMathToLibmBase<ConvertMathToLibmPass> { 13234810e1bSTres Popp void runOnOperation() override; 13334810e1bSTres Popp }; 13434810e1bSTres Popp } // namespace 13534810e1bSTres Popp 13634810e1bSTres Popp void ConvertMathToLibmPass::runOnOperation() { 13734810e1bSTres Popp auto module = getOperation(); 13834810e1bSTres Popp 13934810e1bSTres Popp RewritePatternSet patterns(&getContext()); 14034810e1bSTres Popp populateMathToLibmConversionPatterns(patterns, /*benefit=*/1); 14134810e1bSTres Popp 14234810e1bSTres Popp ConversionTarget target(getContext()); 143a54f4eaeSMogball target.addLegalDialect<arith::ArithmeticDialect, BuiltinDialect, 144a54f4eaeSMogball StandardOpsDialect, vector::VectorDialect>(); 14534810e1bSTres Popp target.addIllegalDialect<math::MathDialect>(); 14634810e1bSTres Popp if (failed(applyPartialConversion(module, target, std::move(patterns)))) 14734810e1bSTres Popp signalPassFailure(); 14834810e1bSTres Popp } 14934810e1bSTres Popp 15034810e1bSTres Popp std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertMathToLibmPass() { 15134810e1bSTres Popp return std::make_unique<ConvertMathToLibmPass>(); 15234810e1bSTres Popp } 153