1 //===-- MathToLibm.cpp - conversion from Math to libm calls ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/MathToLibm/MathToLibm.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/Math/IR/Math.h" 13 #include "mlir/Dialect/StandardOps/IR/Ops.h" 14 #include "mlir/Dialect/Vector/VectorOps.h" 15 #include "mlir/IR/BuiltinDialect.h" 16 #include "mlir/IR/PatternMatch.h" 17 18 using namespace mlir; 19 20 namespace { 21 // Pattern to convert vector operations to scalar operations. This is needed as 22 // libm calls require scalars. 23 template <typename Op> 24 struct VecOpToScalarOp : public OpRewritePattern<Op> { 25 public: 26 using OpRewritePattern<Op>::OpRewritePattern; 27 28 LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; 29 }; 30 // Pattern to convert scalar math operations to calls to libm functions. 31 // Additionally the libm function signatures are declared. 32 template <typename Op> 33 struct ScalarOpToLibmCall : public OpRewritePattern<Op> { 34 public: 35 using OpRewritePattern<Op>::OpRewritePattern; 36 ScalarOpToLibmCall<Op>(MLIRContext *context, StringRef floatFunc, 37 StringRef doubleFunc, PatternBenefit benefit) 38 : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc), 39 doubleFunc(doubleFunc){}; 40 41 LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; 42 43 private: 44 std::string floatFunc, doubleFunc; 45 }; 46 } // namespace 47 48 template <typename Op> 49 LogicalResult 50 VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const { 51 auto opType = op.getType(); 52 auto loc = op.getLoc(); 53 auto vecType = opType.template dyn_cast<VectorType>(); 54 55 if (!vecType) 56 return failure(); 57 if (!vecType.hasRank()) 58 return failure(); 59 auto shape = vecType.getShape(); 60 // TODO: support multidimensional vectors 61 if (shape.size() != 1) 62 return failure(); 63 64 Value result = rewriter.create<ConstantOp>( 65 loc, DenseElementsAttr::get( 66 vecType, FloatAttr::get(vecType.getElementType(), 0.0))); 67 for (auto i = 0; i < shape.front(); ++i) { 68 SmallVector<Value> operands; 69 for (auto input : op->getOperands()) 70 operands.push_back( 71 rewriter.create<vector::ExtractElementOp>(loc, input, i)); 72 Value scalarOp = 73 rewriter.create<Op>(loc, vecType.getElementType(), operands); 74 result = rewriter.create<vector::InsertElementOp>(loc, scalarOp, result, i); 75 } 76 rewriter.replaceOp(op, {result}); 77 return success(); 78 } 79 80 template <typename Op> 81 LogicalResult 82 ScalarOpToLibmCall<Op>::matchAndRewrite(Op op, 83 PatternRewriter &rewriter) const { 84 auto module = op->template getParentOfType<ModuleOp>(); 85 auto type = op.getType(); 86 // TODO: Support Float16 by upcasting to Float32 87 if (!type.template isa<Float32Type, Float64Type>()) 88 return failure(); 89 90 auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc; 91 auto opFunc = module.template lookupSymbol<FuncOp>(name); 92 // Forward declare function if it hasn't already been 93 if (!opFunc) { 94 OpBuilder::InsertionGuard guard(rewriter); 95 rewriter.setInsertionPointToStart(module.getBody()); 96 auto opFunctionTy = FunctionType::get( 97 rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); 98 opFunc = 99 rewriter.create<FuncOp>(rewriter.getUnknownLoc(), name, opFunctionTy); 100 opFunc.setPrivate(); 101 } 102 assert(opFunc.getType().template cast<FunctionType>().getResults() == 103 op->getResultTypes()); 104 assert(opFunc.getType().template cast<FunctionType>().getInputs() == 105 op->getOperandTypes()); 106 107 rewriter.replaceOpWithNewOp<CallOp>(op, opFunc, op->getOperands()); 108 109 return success(); 110 } 111 112 void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns, 113 PatternBenefit benefit) { 114 patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>, 115 VecOpToScalarOp<math::TanhOp>>(patterns.getContext(), benefit); 116 patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(), 117 "atan2f", "atan2", benefit); 118 patterns.add<ScalarOpToLibmCall<math::ExpM1Op>>(patterns.getContext(), 119 "expm1f", "expm1", benefit); 120 patterns.add<ScalarOpToLibmCall<math::TanhOp>>(patterns.getContext(), "tanhf", 121 "tanh", benefit); 122 } 123 124 namespace { 125 struct ConvertMathToLibmPass 126 : public ConvertMathToLibmBase<ConvertMathToLibmPass> { 127 void runOnOperation() override; 128 }; 129 } // namespace 130 131 void ConvertMathToLibmPass::runOnOperation() { 132 auto module = getOperation(); 133 134 RewritePatternSet patterns(&getContext()); 135 populateMathToLibmConversionPatterns(patterns, /*benefit=*/1); 136 137 ConversionTarget target(getContext()); 138 target.addLegalDialect<BuiltinDialect, StandardOpsDialect, 139 vector::VectorDialect>(); 140 target.addIllegalDialect<math::MathDialect>(); 141 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 142 signalPassFailure(); 143 } 144 145 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertMathToLibmPass() { 146 return std::make_unique<ConvertMathToLibmPass>(); 147 } 148