134810e1bSTres Popp //===-- MathToLibm.cpp - conversion from Math to libm calls ---------------===// 234810e1bSTres Popp // 334810e1bSTres Popp // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 434810e1bSTres Popp // See https://llvm.org/LICENSE.txt for license information. 534810e1bSTres Popp // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 634810e1bSTres Popp // 734810e1bSTres Popp //===----------------------------------------------------------------------===// 834810e1bSTres Popp 934810e1bSTres Popp #include "mlir/Conversion/MathToLibm/MathToLibm.h" 1034810e1bSTres Popp 1134810e1bSTres Popp #include "../PassDetail.h" 12a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 1323aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h" 1434810e1bSTres Popp #include "mlir/Dialect/Math/IR/Math.h" 1599ef9eebSMatthias Springer #include "mlir/Dialect/Utils/IndexingUtils.h" 1699ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h" 1799ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 1834810e1bSTres Popp #include "mlir/IR/BuiltinDialect.h" 1934810e1bSTres Popp #include "mlir/IR/PatternMatch.h" 2034810e1bSTres Popp 2134810e1bSTres Popp using namespace mlir; 2234810e1bSTres Popp 2334810e1bSTres Popp namespace { 2434810e1bSTres Popp // Pattern to convert vector operations to scalar operations. This is needed as 2534810e1bSTres Popp // libm calls require scalars. 2634810e1bSTres Popp template <typename Op> 2734810e1bSTres Popp struct VecOpToScalarOp : public OpRewritePattern<Op> { 2834810e1bSTres Popp public: 2934810e1bSTres Popp using OpRewritePattern<Op>::OpRewritePattern; 3034810e1bSTres Popp 3134810e1bSTres Popp LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; 3234810e1bSTres Popp }; 33a48adc56SBenjamin Kramer // Pattern to promote an op of a smaller floating point type to F32. 34a48adc56SBenjamin Kramer template <typename Op> 35a48adc56SBenjamin Kramer struct PromoteOpToF32 : public OpRewritePattern<Op> { 36a48adc56SBenjamin Kramer public: 37a48adc56SBenjamin Kramer using OpRewritePattern<Op>::OpRewritePattern; 38a48adc56SBenjamin Kramer 39a48adc56SBenjamin Kramer LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; 40a48adc56SBenjamin Kramer }; 4134810e1bSTres Popp // Pattern to convert scalar math operations to calls to libm functions. 4234810e1bSTres Popp // Additionally the libm function signatures are declared. 4334810e1bSTres Popp template <typename Op> 4434810e1bSTres Popp struct ScalarOpToLibmCall : public OpRewritePattern<Op> { 4534810e1bSTres Popp public: 4634810e1bSTres Popp using OpRewritePattern<Op>::OpRewritePattern; 4734810e1bSTres Popp ScalarOpToLibmCall<Op>(MLIRContext *context, StringRef floatFunc, 4834810e1bSTres Popp StringRef doubleFunc, PatternBenefit benefit) 4934810e1bSTres Popp : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc), 5034810e1bSTres Popp doubleFunc(doubleFunc){}; 5134810e1bSTres Popp 5234810e1bSTres Popp LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; 5334810e1bSTres Popp 5434810e1bSTres Popp private: 5534810e1bSTres Popp std::string floatFunc, doubleFunc; 5634810e1bSTres Popp }; 5734810e1bSTres Popp } // namespace 5834810e1bSTres Popp 5934810e1bSTres Popp template <typename Op> 6034810e1bSTres Popp LogicalResult 6134810e1bSTres Popp VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const { 6234810e1bSTres Popp auto opType = op.getType(); 6334810e1bSTres Popp auto loc = op.getLoc(); 6434810e1bSTres Popp auto vecType = opType.template dyn_cast<VectorType>(); 6534810e1bSTres Popp 6634810e1bSTres Popp if (!vecType) 6734810e1bSTres Popp return failure(); 6834810e1bSTres Popp if (!vecType.hasRank()) 6934810e1bSTres Popp return failure(); 7034810e1bSTres Popp auto shape = vecType.getShape(); 71921d91f3SAdrian Kuegel int64_t numElements = vecType.getNumElements(); 7234810e1bSTres Popp 73a54f4eaeSMogball Value result = rewriter.create<arith::ConstantOp>( 7434810e1bSTres Popp loc, DenseElementsAttr::get( 7534810e1bSTres Popp vecType, FloatAttr::get(vecType.getElementType(), 0.0))); 76921d91f3SAdrian Kuegel SmallVector<int64_t> ones(shape.size(), 1); 77921d91f3SAdrian Kuegel SmallVector<int64_t> strides = computeStrides(shape, ones); 78921d91f3SAdrian Kuegel for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) { 79921d91f3SAdrian Kuegel SmallVector<int64_t> positions = delinearize(strides, linearIndex); 8034810e1bSTres Popp SmallVector<Value> operands; 8134810e1bSTres Popp for (auto input : op->getOperands()) 8234810e1bSTres Popp operands.push_back( 83921d91f3SAdrian Kuegel rewriter.create<vector::ExtractOp>(loc, input, positions)); 8434810e1bSTres Popp Value scalarOp = 8534810e1bSTres Popp rewriter.create<Op>(loc, vecType.getElementType(), operands); 86921d91f3SAdrian Kuegel result = 87921d91f3SAdrian Kuegel rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions); 8834810e1bSTres Popp } 8934810e1bSTres Popp rewriter.replaceOp(op, {result}); 9034810e1bSTres Popp return success(); 9134810e1bSTres Popp } 9234810e1bSTres Popp 9334810e1bSTres Popp template <typename Op> 9434810e1bSTres Popp LogicalResult 95a48adc56SBenjamin Kramer PromoteOpToF32<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const { 96a48adc56SBenjamin Kramer auto opType = op.getType(); 97a48adc56SBenjamin Kramer if (!opType.template isa<Float16Type, BFloat16Type>()) 98a48adc56SBenjamin Kramer return failure(); 99a48adc56SBenjamin Kramer 100a48adc56SBenjamin Kramer auto loc = op.getLoc(); 101a48adc56SBenjamin Kramer auto f32 = rewriter.getF32Type(); 102a48adc56SBenjamin Kramer auto extendedOperands = llvm::to_vector( 103a48adc56SBenjamin Kramer llvm::map_range(op->getOperands(), [&](Value operand) -> Value { 104a48adc56SBenjamin Kramer return rewriter.create<arith::ExtFOp>(loc, f32, operand); 105a48adc56SBenjamin Kramer })); 106a48adc56SBenjamin Kramer auto newOp = rewriter.create<Op>(loc, f32, extendedOperands); 107a48adc56SBenjamin Kramer rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, opType, newOp); 108a48adc56SBenjamin Kramer return success(); 109a48adc56SBenjamin Kramer } 110a48adc56SBenjamin Kramer 111a48adc56SBenjamin Kramer template <typename Op> 112a48adc56SBenjamin Kramer LogicalResult 11334810e1bSTres Popp ScalarOpToLibmCall<Op>::matchAndRewrite(Op op, 11434810e1bSTres Popp PatternRewriter &rewriter) const { 1151ebf7ce9STres Popp auto module = SymbolTable::getNearestSymbolTable(op); 11634810e1bSTres Popp auto type = op.getType(); 11734810e1bSTres Popp if (!type.template isa<Float32Type, Float64Type>()) 11834810e1bSTres Popp return failure(); 11934810e1bSTres Popp 12034810e1bSTres Popp auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc; 1211ebf7ce9STres Popp auto opFunc = dyn_cast_or_null<SymbolOpInterface>( 1221ebf7ce9STres Popp SymbolTable::lookupSymbolIn(module, name)); 12334810e1bSTres Popp // Forward declare function if it hasn't already been 12434810e1bSTres Popp if (!opFunc) { 12534810e1bSTres Popp OpBuilder::InsertionGuard guard(rewriter); 1261ebf7ce9STres Popp rewriter.setInsertionPointToStart(&module->getRegion(0).front()); 12734810e1bSTres Popp auto opFunctionTy = FunctionType::get( 12834810e1bSTres Popp rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); 12958ceae95SRiver Riddle opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name, 13058ceae95SRiver Riddle opFunctionTy); 13134810e1bSTres Popp opFunc.setPrivate(); 13234810e1bSTres Popp } 1337ceffae1SRiver Riddle assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name))); 13434810e1bSTres Popp 13523aa5a74SRiver Riddle rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(), 1361ebf7ce9STres Popp op->getOperands()); 13734810e1bSTres Popp 13834810e1bSTres Popp return success(); 13934810e1bSTres Popp } 14034810e1bSTres Popp 14134810e1bSTres Popp void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns, 14234810e1bSTres Popp PatternBenefit benefit) { 14334810e1bSTres Popp patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>, 144*0bae40efSlewuathe VecOpToScalarOp<math::TanhOp>, VecOpToScalarOp<math::CosOp>, 145*0bae40efSlewuathe VecOpToScalarOp<math::SinOp>>(patterns.getContext(), benefit); 146a48adc56SBenjamin Kramer patterns.add<PromoteOpToF32<math::Atan2Op>, PromoteOpToF32<math::ExpM1Op>, 147*0bae40efSlewuathe PromoteOpToF32<math::TanhOp>, PromoteOpToF32<math::CosOp>, 148*0bae40efSlewuathe PromoteOpToF32<math::SinOp>>(patterns.getContext(), benefit); 14934810e1bSTres Popp patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(), 15034810e1bSTres Popp "atan2f", "atan2", benefit); 151f1b92218SBoian Petkantchin patterns.add<ScalarOpToLibmCall<math::ErfOp>>(patterns.getContext(), "erff", 152f1b92218SBoian Petkantchin "erf", benefit); 15334810e1bSTres Popp patterns.add<ScalarOpToLibmCall<math::ExpM1Op>>(patterns.getContext(), 15434810e1bSTres Popp "expm1f", "expm1", benefit); 15534810e1bSTres Popp patterns.add<ScalarOpToLibmCall<math::TanhOp>>(patterns.getContext(), "tanhf", 15634810e1bSTres Popp "tanh", benefit); 157a0fc94abSlorenzo chelini patterns.add<ScalarOpToLibmCall<math::RoundOp>>(patterns.getContext(), 158a0fc94abSlorenzo chelini "roundf", "round", benefit); 159*0bae40efSlewuathe patterns.add<ScalarOpToLibmCall<math::CosOp>>(patterns.getContext(), "cosf", 160*0bae40efSlewuathe "cos", benefit); 161*0bae40efSlewuathe patterns.add<ScalarOpToLibmCall<math::SinOp>>(patterns.getContext(), "sinf", 162*0bae40efSlewuathe "sin", benefit); 16334810e1bSTres Popp } 16434810e1bSTres Popp 16534810e1bSTres Popp namespace { 16634810e1bSTres Popp struct ConvertMathToLibmPass 16734810e1bSTres Popp : public ConvertMathToLibmBase<ConvertMathToLibmPass> { 16834810e1bSTres Popp void runOnOperation() override; 16934810e1bSTres Popp }; 17034810e1bSTres Popp } // namespace 17134810e1bSTres Popp 17234810e1bSTres Popp void ConvertMathToLibmPass::runOnOperation() { 17334810e1bSTres Popp auto module = getOperation(); 17434810e1bSTres Popp 17534810e1bSTres Popp RewritePatternSet patterns(&getContext()); 17634810e1bSTres Popp populateMathToLibmConversionPatterns(patterns, /*benefit=*/1); 17734810e1bSTres Popp 17834810e1bSTres Popp ConversionTarget target(getContext()); 179a54f4eaeSMogball target.addLegalDialect<arith::ArithmeticDialect, BuiltinDialect, 18023aa5a74SRiver Riddle func::FuncDialect, vector::VectorDialect>(); 18134810e1bSTres Popp target.addIllegalDialect<math::MathDialect>(); 18234810e1bSTres Popp if (failed(applyPartialConversion(module, target, std::move(patterns)))) 18334810e1bSTres Popp signalPassFailure(); 18434810e1bSTres Popp } 18534810e1bSTres Popp 18634810e1bSTres Popp std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertMathToLibmPass() { 18734810e1bSTres Popp return std::make_unique<ConvertMathToLibmPass>(); 18834810e1bSTres Popp } 189