1 //===-- MathToLibm.cpp - conversion from Math to libm calls ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/MathToLibm/MathToLibm.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 13 #include "mlir/Dialect/Math/IR/Math.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Dialect/Vector/VectorOps.h" 16 #include "mlir/IR/BuiltinDialect.h" 17 #include "mlir/IR/PatternMatch.h" 18 19 using namespace mlir; 20 21 namespace { 22 // Pattern to convert vector operations to scalar operations. This is needed as 23 // libm calls require scalars. 24 template <typename Op> 25 struct VecOpToScalarOp : public OpRewritePattern<Op> { 26 public: 27 using OpRewritePattern<Op>::OpRewritePattern; 28 29 LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; 30 }; 31 // Pattern to convert scalar math operations to calls to libm functions. 32 // Additionally the libm function signatures are declared. 33 template <typename Op> 34 struct ScalarOpToLibmCall : public OpRewritePattern<Op> { 35 public: 36 using OpRewritePattern<Op>::OpRewritePattern; 37 ScalarOpToLibmCall<Op>(MLIRContext *context, StringRef floatFunc, 38 StringRef doubleFunc, PatternBenefit benefit) 39 : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc), 40 doubleFunc(doubleFunc){}; 41 42 LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; 43 44 private: 45 std::string floatFunc, doubleFunc; 46 }; 47 } // namespace 48 49 template <typename Op> 50 LogicalResult 51 VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const { 52 auto opType = op.getType(); 53 auto loc = op.getLoc(); 54 auto vecType = opType.template dyn_cast<VectorType>(); 55 56 if (!vecType) 57 return failure(); 58 if (!vecType.hasRank()) 59 return failure(); 60 auto shape = vecType.getShape(); 61 // TODO: support multidimensional vectors 62 if (shape.size() != 1) 63 return failure(); 64 65 Value result = rewriter.create<arith::ConstantOp>( 66 loc, DenseElementsAttr::get( 67 vecType, FloatAttr::get(vecType.getElementType(), 0.0))); 68 for (auto i = 0; i < shape.front(); ++i) { 69 SmallVector<Value> operands; 70 for (auto input : op->getOperands()) 71 operands.push_back( 72 rewriter.create<vector::ExtractElementOp>(loc, input, i)); 73 Value scalarOp = 74 rewriter.create<Op>(loc, vecType.getElementType(), operands); 75 result = rewriter.create<vector::InsertElementOp>(loc, scalarOp, result, i); 76 } 77 rewriter.replaceOp(op, {result}); 78 return success(); 79 } 80 81 template <typename Op> 82 LogicalResult 83 ScalarOpToLibmCall<Op>::matchAndRewrite(Op op, 84 PatternRewriter &rewriter) const { 85 auto module = SymbolTable::getNearestSymbolTable(op); 86 auto type = op.getType(); 87 // TODO: Support Float16 by upcasting to Float32 88 if (!type.template isa<Float32Type, Float64Type>()) 89 return failure(); 90 91 auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc; 92 auto opFunc = dyn_cast_or_null<SymbolOpInterface>( 93 SymbolTable::lookupSymbolIn(module, name)); 94 // Forward declare function if it hasn't already been 95 if (!opFunc) { 96 OpBuilder::InsertionGuard guard(rewriter); 97 rewriter.setInsertionPointToStart(&module->getRegion(0).front()); 98 auto opFunctionTy = FunctionType::get( 99 rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); 100 opFunc = 101 rewriter.create<FuncOp>(rewriter.getUnknownLoc(), name, opFunctionTy); 102 opFunc.setPrivate(); 103 } 104 assert(SymbolTable::lookupSymbolIn(module, name) 105 ->template hasTrait<mlir::OpTrait::FunctionLike>()); 106 107 rewriter.replaceOpWithNewOp<CallOp>(op, name, op.getType(), 108 op->getOperands()); 109 110 return success(); 111 } 112 113 void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns, 114 PatternBenefit benefit) { 115 patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>, 116 VecOpToScalarOp<math::TanhOp>>(patterns.getContext(), benefit); 117 patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(), 118 "atan2f", "atan2", benefit); 119 patterns.add<ScalarOpToLibmCall<math::ExpM1Op>>(patterns.getContext(), 120 "expm1f", "expm1", benefit); 121 patterns.add<ScalarOpToLibmCall<math::TanhOp>>(patterns.getContext(), "tanhf", 122 "tanh", benefit); 123 } 124 125 namespace { 126 struct ConvertMathToLibmPass 127 : public ConvertMathToLibmBase<ConvertMathToLibmPass> { 128 void runOnOperation() override; 129 }; 130 } // namespace 131 132 void ConvertMathToLibmPass::runOnOperation() { 133 auto module = getOperation(); 134 135 RewritePatternSet patterns(&getContext()); 136 populateMathToLibmConversionPatterns(patterns, /*benefit=*/1); 137 138 ConversionTarget target(getContext()); 139 target.addLegalDialect<arith::ArithmeticDialect, BuiltinDialect, 140 StandardOpsDialect, vector::VectorDialect>(); 141 target.addIllegalDialect<math::MathDialect>(); 142 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 143 signalPassFailure(); 144 } 145 146 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertMathToLibmPass() { 147 return std::make_unique<ConvertMathToLibmPass>(); 148 } 149