1d94426d2SEugene Zhulenev //===- AlgebraicSimplification.cpp - Simplify algebraic expressions -------===//
2d94426d2SEugene Zhulenev //
3d94426d2SEugene Zhulenev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4d94426d2SEugene Zhulenev // See https://llvm.org/LICENSE.txt for license information.
5d94426d2SEugene Zhulenev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6d94426d2SEugene Zhulenev //
7d94426d2SEugene Zhulenev //===----------------------------------------------------------------------===//
8d94426d2SEugene Zhulenev //
9d94426d2SEugene Zhulenev // This file implements rewrites based on the basic rules of algebra
10d94426d2SEugene Zhulenev // (Commutativity, associativity, etc...) and strength reductions for math
11d94426d2SEugene Zhulenev // operations.
12d94426d2SEugene Zhulenev //
13d94426d2SEugene Zhulenev //===----------------------------------------------------------------------===//
14d94426d2SEugene Zhulenev
15a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16d94426d2SEugene Zhulenev #include "mlir/Dialect/Math/IR/Math.h"
17d94426d2SEugene Zhulenev #include "mlir/Dialect/Math/Transforms/Passes.h"
18*99ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
19d94426d2SEugene Zhulenev #include "mlir/IR/Builders.h"
20d94426d2SEugene Zhulenev #include "mlir/IR/Matchers.h"
21d94426d2SEugene Zhulenev #include "mlir/IR/TypeUtilities.h"
22d94426d2SEugene Zhulenev #include <climits>
23d94426d2SEugene Zhulenev
24d94426d2SEugene Zhulenev using namespace mlir;
25d94426d2SEugene Zhulenev
26d94426d2SEugene Zhulenev //----------------------------------------------------------------------------//
27d94426d2SEugene Zhulenev // PowFOp strength reduction.
28d94426d2SEugene Zhulenev //----------------------------------------------------------------------------//
29d94426d2SEugene Zhulenev
30d94426d2SEugene Zhulenev namespace {
31d94426d2SEugene Zhulenev struct PowFStrengthReduction : public OpRewritePattern<math::PowFOp> {
32d94426d2SEugene Zhulenev public:
33d94426d2SEugene Zhulenev using OpRewritePattern::OpRewritePattern;
34d94426d2SEugene Zhulenev
35d94426d2SEugene Zhulenev LogicalResult matchAndRewrite(math::PowFOp op,
36d94426d2SEugene Zhulenev PatternRewriter &rewriter) const final;
37d94426d2SEugene Zhulenev };
38d94426d2SEugene Zhulenev } // namespace
39d94426d2SEugene Zhulenev
40d94426d2SEugene Zhulenev LogicalResult
matchAndRewrite(math::PowFOp op,PatternRewriter & rewriter) const41d94426d2SEugene Zhulenev PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
42d94426d2SEugene Zhulenev PatternRewriter &rewriter) const {
43d94426d2SEugene Zhulenev Location loc = op.getLoc();
4462fea88bSJacques Pienaar Value x = op.getLhs();
45d94426d2SEugene Zhulenev
46d94426d2SEugene Zhulenev FloatAttr scalarExponent;
47d94426d2SEugene Zhulenev DenseFPElementsAttr vectorExponent;
48d94426d2SEugene Zhulenev
4962fea88bSJacques Pienaar bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent));
5062fea88bSJacques Pienaar bool isVector = matchPattern(op.getRhs(), m_Constant(&vectorExponent));
51d94426d2SEugene Zhulenev
52d94426d2SEugene Zhulenev // Returns true if exponent is a constant equal to `value`.
53d94426d2SEugene Zhulenev auto isExponentValue = [&](double value) -> bool {
54d94426d2SEugene Zhulenev if (isScalar)
55d94426d2SEugene Zhulenev return scalarExponent.getValue().isExactlyValue(value);
56d94426d2SEugene Zhulenev
57d94426d2SEugene Zhulenev if (isVector && vectorExponent.isSplat())
58d94426d2SEugene Zhulenev return vectorExponent.getSplatValue<FloatAttr>()
59d94426d2SEugene Zhulenev .getValue()
60d94426d2SEugene Zhulenev .isExactlyValue(value);
61d94426d2SEugene Zhulenev
62d94426d2SEugene Zhulenev return false;
63d94426d2SEugene Zhulenev };
64d94426d2SEugene Zhulenev
65d94426d2SEugene Zhulenev // Maybe broadcasts scalar value into vector type compatible with `op`.
66d94426d2SEugene Zhulenev auto bcast = [&](Value value) -> Value {
67d94426d2SEugene Zhulenev if (auto vec = op.getType().dyn_cast<VectorType>())
68d94426d2SEugene Zhulenev return rewriter.create<vector::BroadcastOp>(op.getLoc(), vec, value);
69d94426d2SEugene Zhulenev return value;
70d94426d2SEugene Zhulenev };
71d94426d2SEugene Zhulenev
72d94426d2SEugene Zhulenev // Replace `pow(x, 1.0)` with `x`.
73d94426d2SEugene Zhulenev if (isExponentValue(1.0)) {
74d94426d2SEugene Zhulenev rewriter.replaceOp(op, x);
75d94426d2SEugene Zhulenev return success();
76d94426d2SEugene Zhulenev }
77d94426d2SEugene Zhulenev
78d94426d2SEugene Zhulenev // Replace `pow(x, 2.0)` with `x * x`.
79d94426d2SEugene Zhulenev if (isExponentValue(2.0)) {
80a54f4eaeSMogball rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, x}));
81d94426d2SEugene Zhulenev return success();
82d94426d2SEugene Zhulenev }
83d94426d2SEugene Zhulenev
84391456f3Sbakhtiyar // Replace `pow(x, 3.0)` with `x * x * x`.
85d94426d2SEugene Zhulenev if (isExponentValue(3.0)) {
86a54f4eaeSMogball Value square =
87a54f4eaeSMogball rewriter.create<arith::MulFOp>(op.getLoc(), ValueRange({x, x}));
88a54f4eaeSMogball rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, square}));
89d94426d2SEugene Zhulenev return success();
90d94426d2SEugene Zhulenev }
91d94426d2SEugene Zhulenev
92d94426d2SEugene Zhulenev // Replace `pow(x, -1.0)` with `1.0 / x`.
93d94426d2SEugene Zhulenev if (isExponentValue(-1.0)) {
94a54f4eaeSMogball Value one = rewriter.create<arith::ConstantOp>(
95d94426d2SEugene Zhulenev loc, rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
96a54f4eaeSMogball rewriter.replaceOpWithNewOp<arith::DivFOp>(op, ValueRange({bcast(one), x}));
97d94426d2SEugene Zhulenev return success();
98d94426d2SEugene Zhulenev }
99d94426d2SEugene Zhulenev
100391456f3Sbakhtiyar // Replace `pow(x, 0.5)` with `sqrt(x)`.
101391456f3Sbakhtiyar if (isExponentValue(0.5)) {
102d94426d2SEugene Zhulenev rewriter.replaceOpWithNewOp<math::SqrtOp>(op, x);
103d94426d2SEugene Zhulenev return success();
104d94426d2SEugene Zhulenev }
105d94426d2SEugene Zhulenev
106391456f3Sbakhtiyar // Replace `pow(x, -0.5)` with `rsqrt(x)`.
107391456f3Sbakhtiyar if (isExponentValue(-0.5)) {
108391456f3Sbakhtiyar rewriter.replaceOpWithNewOp<math::RsqrtOp>(op, x);
109391456f3Sbakhtiyar return success();
110391456f3Sbakhtiyar }
111391456f3Sbakhtiyar
112d94426d2SEugene Zhulenev return failure();
113d94426d2SEugene Zhulenev }
114d94426d2SEugene Zhulenev
115d94426d2SEugene Zhulenev //----------------------------------------------------------------------------//
116d94426d2SEugene Zhulenev
populateMathAlgebraicSimplificationPatterns(RewritePatternSet & patterns)117d94426d2SEugene Zhulenev void mlir::populateMathAlgebraicSimplificationPatterns(
118d94426d2SEugene Zhulenev RewritePatternSet &patterns) {
119d94426d2SEugene Zhulenev patterns.add<PowFStrengthReduction>(patterns.getContext());
120d94426d2SEugene Zhulenev }
121