1 //===-- MathToLibm.cpp - conversion from Math to libm calls ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/MathToLibm/MathToLibm.h"
10
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/Math/IR/Math.h"
15 #include "mlir/Dialect/Utils/IndexingUtils.h"
16 #include "mlir/Dialect/Vector/IR/VectorOps.h"
17 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
18 #include "mlir/IR/BuiltinDialect.h"
19 #include "mlir/IR/PatternMatch.h"
20
21 using namespace mlir;
22
23 namespace {
24 // Pattern to convert vector operations to scalar operations. This is needed as
25 // libm calls require scalars.
26 template <typename Op>
27 struct VecOpToScalarOp : public OpRewritePattern<Op> {
28 public:
29 using OpRewritePattern<Op>::OpRewritePattern;
30
31 LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
32 };
33 // Pattern to promote an op of a smaller floating point type to F32.
34 template <typename Op>
35 struct PromoteOpToF32 : public OpRewritePattern<Op> {
36 public:
37 using OpRewritePattern<Op>::OpRewritePattern;
38
39 LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
40 };
41 // Pattern to convert scalar math operations to calls to libm functions.
42 // Additionally the libm function signatures are declared.
43 template <typename Op>
44 struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
45 public:
46 using OpRewritePattern<Op>::OpRewritePattern;
ScalarOpToLibmCall__anon05e0e9df0111::ScalarOpToLibmCall47 ScalarOpToLibmCall<Op>(MLIRContext *context, StringRef floatFunc,
48 StringRef doubleFunc, PatternBenefit benefit)
49 : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
50 doubleFunc(doubleFunc){};
51
52 LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
53
54 private:
55 std::string floatFunc, doubleFunc;
56 };
57 } // namespace
58
59 template <typename Op>
60 LogicalResult
matchAndRewrite(Op op,PatternRewriter & rewriter) const61 VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
62 auto opType = op.getType();
63 auto loc = op.getLoc();
64 auto vecType = opType.template dyn_cast<VectorType>();
65
66 if (!vecType)
67 return failure();
68 if (!vecType.hasRank())
69 return failure();
70 auto shape = vecType.getShape();
71 int64_t numElements = vecType.getNumElements();
72
73 Value result = rewriter.create<arith::ConstantOp>(
74 loc, DenseElementsAttr::get(
75 vecType, FloatAttr::get(vecType.getElementType(), 0.0)));
76 SmallVector<int64_t> ones(shape.size(), 1);
77 SmallVector<int64_t> strides = computeStrides(shape, ones);
78 for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) {
79 SmallVector<int64_t> positions = delinearize(strides, linearIndex);
80 SmallVector<Value> operands;
81 for (auto input : op->getOperands())
82 operands.push_back(
83 rewriter.create<vector::ExtractOp>(loc, input, positions));
84 Value scalarOp =
85 rewriter.create<Op>(loc, vecType.getElementType(), operands);
86 result =
87 rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions);
88 }
89 rewriter.replaceOp(op, {result});
90 return success();
91 }
92
93 template <typename Op>
94 LogicalResult
matchAndRewrite(Op op,PatternRewriter & rewriter) const95 PromoteOpToF32<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
96 auto opType = op.getType();
97 if (!opType.template isa<Float16Type, BFloat16Type>())
98 return failure();
99
100 auto loc = op.getLoc();
101 auto f32 = rewriter.getF32Type();
102 auto extendedOperands = llvm::to_vector(
103 llvm::map_range(op->getOperands(), [&](Value operand) -> Value {
104 return rewriter.create<arith::ExtFOp>(loc, f32, operand);
105 }));
106 auto newOp = rewriter.create<Op>(loc, f32, extendedOperands);
107 rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, opType, newOp);
108 return success();
109 }
110
111 template <typename Op>
112 LogicalResult
matchAndRewrite(Op op,PatternRewriter & rewriter) const113 ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
114 PatternRewriter &rewriter) const {
115 auto module = SymbolTable::getNearestSymbolTable(op);
116 auto type = op.getType();
117 if (!type.template isa<Float32Type, Float64Type>())
118 return failure();
119
120 auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc;
121 auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
122 SymbolTable::lookupSymbolIn(module, name));
123 // Forward declare function if it hasn't already been
124 if (!opFunc) {
125 OpBuilder::InsertionGuard guard(rewriter);
126 rewriter.setInsertionPointToStart(&module->getRegion(0).front());
127 auto opFunctionTy = FunctionType::get(
128 rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
129 opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name,
130 opFunctionTy);
131 opFunc.setPrivate();
132 }
133 assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name)));
134
135 rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(),
136 op->getOperands());
137
138 return success();
139 }
140
populateMathToLibmConversionPatterns(RewritePatternSet & patterns,PatternBenefit benefit)141 void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
142 PatternBenefit benefit) {
143 patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>,
144 VecOpToScalarOp<math::TanhOp>, VecOpToScalarOp<math::CosOp>,
145 VecOpToScalarOp<math::SinOp>, VecOpToScalarOp<math::ErfOp>,
146 VecOpToScalarOp<math::RoundOp>, VecOpToScalarOp<math::AtanOp>,
147 VecOpToScalarOp<math::TanOp>>(patterns.getContext(), benefit);
148 patterns.add<PromoteOpToF32<math::Atan2Op>, PromoteOpToF32<math::ExpM1Op>,
149 PromoteOpToF32<math::TanhOp>, PromoteOpToF32<math::CosOp>,
150 PromoteOpToF32<math::SinOp>, PromoteOpToF32<math::ErfOp>,
151 PromoteOpToF32<math::RoundOp>, PromoteOpToF32<math::AtanOp>,
152 PromoteOpToF32<math::TanOp>>(patterns.getContext(), benefit);
153 patterns.add<ScalarOpToLibmCall<math::AtanOp>>(patterns.getContext(), "atanf",
154 "atan", benefit);
155 patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(),
156 "atan2f", "atan2", benefit);
157 patterns.add<ScalarOpToLibmCall<math::ErfOp>>(patterns.getContext(), "erff",
158 "erf", benefit);
159 patterns.add<ScalarOpToLibmCall<math::ExpM1Op>>(patterns.getContext(),
160 "expm1f", "expm1", benefit);
161 patterns.add<ScalarOpToLibmCall<math::TanOp>>(patterns.getContext(), "tanf",
162 "tan", benefit);
163 patterns.add<ScalarOpToLibmCall<math::TanhOp>>(patterns.getContext(), "tanhf",
164 "tanh", benefit);
165 patterns.add<ScalarOpToLibmCall<math::RoundOp>>(patterns.getContext(),
166 "roundf", "round", benefit);
167 patterns.add<ScalarOpToLibmCall<math::CosOp>>(patterns.getContext(), "cosf",
168 "cos", benefit);
169 patterns.add<ScalarOpToLibmCall<math::SinOp>>(patterns.getContext(), "sinf",
170 "sin", benefit);
171 }
172
173 namespace {
174 struct ConvertMathToLibmPass
175 : public ConvertMathToLibmBase<ConvertMathToLibmPass> {
176 void runOnOperation() override;
177 };
178 } // namespace
179
runOnOperation()180 void ConvertMathToLibmPass::runOnOperation() {
181 auto module = getOperation();
182
183 RewritePatternSet patterns(&getContext());
184 populateMathToLibmConversionPatterns(patterns, /*benefit=*/1);
185
186 ConversionTarget target(getContext());
187 target.addLegalDialect<arith::ArithmeticDialect, BuiltinDialect,
188 func::FuncDialect, vector::VectorDialect>();
189 target.addIllegalDialect<math::MathDialect>();
190 if (failed(applyPartialConversion(module, target, std::move(patterns))))
191 signalPassFailure();
192 }
193
createConvertMathToLibmPass()194 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertMathToLibmPass() {
195 return std::make_unique<ConvertMathToLibmPass>();
196 }
197