//===-- MathToLibm.cpp - conversion from Math to libm calls ---------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/MathToLibm/MathToLibm.h" #include "../PassDetail.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/PatternMatch.h" using namespace mlir; namespace { // Pattern to convert vector operations to scalar operations. This is needed as // libm calls require scalars. template struct VecOpToScalarOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; }; // Pattern to promote an op of a smaller floating point type to F32. template struct PromoteOpToF32 : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; }; // Pattern to convert scalar math operations to calls to libm functions. // Additionally the libm function signatures are declared. template struct ScalarOpToLibmCall : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; ScalarOpToLibmCall(MLIRContext *context, StringRef floatFunc, StringRef doubleFunc, PatternBenefit benefit) : OpRewritePattern(context, benefit), floatFunc(floatFunc), doubleFunc(doubleFunc){}; LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; private: std::string floatFunc, doubleFunc; }; } // namespace template LogicalResult VecOpToScalarOp::matchAndRewrite(Op op, PatternRewriter &rewriter) const { auto opType = op.getType(); auto loc = op.getLoc(); auto vecType = opType.template dyn_cast(); if (!vecType) return failure(); if (!vecType.hasRank()) return failure(); auto shape = vecType.getShape(); int64_t numElements = vecType.getNumElements(); Value result = rewriter.create( loc, DenseElementsAttr::get( vecType, FloatAttr::get(vecType.getElementType(), 0.0))); SmallVector ones(shape.size(), 1); SmallVector strides = computeStrides(shape, ones); for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) { SmallVector positions = delinearize(strides, linearIndex); SmallVector operands; for (auto input : op->getOperands()) operands.push_back( rewriter.create(loc, input, positions)); Value scalarOp = rewriter.create(loc, vecType.getElementType(), operands); result = rewriter.create(loc, scalarOp, result, positions); } rewriter.replaceOp(op, {result}); return success(); } template LogicalResult PromoteOpToF32::matchAndRewrite(Op op, PatternRewriter &rewriter) const { auto opType = op.getType(); if (!opType.template isa()) return failure(); auto loc = op.getLoc(); auto f32 = rewriter.getF32Type(); auto extendedOperands = llvm::to_vector( llvm::map_range(op->getOperands(), [&](Value operand) -> Value { return rewriter.create(loc, f32, operand); })); auto newOp = rewriter.create(loc, f32, extendedOperands); rewriter.replaceOpWithNewOp(op, opType, newOp); return success(); } template LogicalResult ScalarOpToLibmCall::matchAndRewrite(Op op, PatternRewriter &rewriter) const { auto module = SymbolTable::getNearestSymbolTable(op); auto type = op.getType(); if (!type.template isa()) return failure(); auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc; auto opFunc = dyn_cast_or_null( SymbolTable::lookupSymbolIn(module, name)); // Forward declare function if it hasn't already been if (!opFunc) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&module->getRegion(0).front()); auto opFunctionTy = FunctionType::get( rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); opFunc = rewriter.create(rewriter.getUnknownLoc(), name, opFunctionTy); opFunc.setPrivate(); } assert(isa(SymbolTable::lookupSymbolIn(module, name))); rewriter.replaceOpWithNewOp(op, name, op.getType(), op->getOperands()); return success(); } void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add, VecOpToScalarOp, VecOpToScalarOp, VecOpToScalarOp, VecOpToScalarOp, VecOpToScalarOp, VecOpToScalarOp, VecOpToScalarOp, VecOpToScalarOp>(patterns.getContext(), benefit); patterns.add, PromoteOpToF32, PromoteOpToF32, PromoteOpToF32, PromoteOpToF32, PromoteOpToF32, PromoteOpToF32, PromoteOpToF32, PromoteOpToF32>(patterns.getContext(), benefit); patterns.add>(patterns.getContext(), "atanf", "atan", benefit); patterns.add>(patterns.getContext(), "atan2f", "atan2", benefit); patterns.add>(patterns.getContext(), "erff", "erf", benefit); patterns.add>(patterns.getContext(), "expm1f", "expm1", benefit); patterns.add>(patterns.getContext(), "tanf", "tan", benefit); patterns.add>(patterns.getContext(), "tanhf", "tanh", benefit); patterns.add>(patterns.getContext(), "roundf", "round", benefit); patterns.add>(patterns.getContext(), "cosf", "cos", benefit); patterns.add>(patterns.getContext(), "sinf", "sin", benefit); } namespace { struct ConvertMathToLibmPass : public ConvertMathToLibmBase { void runOnOperation() override; }; } // namespace void ConvertMathToLibmPass::runOnOperation() { auto module = getOperation(); RewritePatternSet patterns(&getContext()); populateMathToLibmConversionPatterns(patterns, /*benefit=*/1); ConversionTarget target(getContext()); target.addLegalDialect(); target.addIllegalDialect(); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } std::unique_ptr> mlir::createConvertMathToLibmPass() { return std::make_unique(); }