1d94426d2SEugene Zhulenev //===- AlgebraicSimplification.cpp - Simplify algebraic expressions -------===//
2d94426d2SEugene Zhulenev //
3d94426d2SEugene Zhulenev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4d94426d2SEugene Zhulenev // See https://llvm.org/LICENSE.txt for license information.
5d94426d2SEugene Zhulenev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6d94426d2SEugene Zhulenev //
7d94426d2SEugene Zhulenev //===----------------------------------------------------------------------===//
8d94426d2SEugene Zhulenev //
9d94426d2SEugene Zhulenev // This file implements rewrites based on the basic rules of algebra
10d94426d2SEugene Zhulenev // (Commutativity, associativity, etc...) and strength reductions for math
11d94426d2SEugene Zhulenev // operations.
12d94426d2SEugene Zhulenev //
13d94426d2SEugene Zhulenev //===----------------------------------------------------------------------===//
14d94426d2SEugene Zhulenev 
15a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16d94426d2SEugene Zhulenev #include "mlir/Dialect/Math/IR/Math.h"
17d94426d2SEugene Zhulenev #include "mlir/Dialect/Math/Transforms/Passes.h"
18*99ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
19d94426d2SEugene Zhulenev #include "mlir/IR/Builders.h"
20d94426d2SEugene Zhulenev #include "mlir/IR/Matchers.h"
21d94426d2SEugene Zhulenev #include "mlir/IR/TypeUtilities.h"
22d94426d2SEugene Zhulenev #include <climits>
23d94426d2SEugene Zhulenev 
24d94426d2SEugene Zhulenev using namespace mlir;
25d94426d2SEugene Zhulenev 
26d94426d2SEugene Zhulenev //----------------------------------------------------------------------------//
27d94426d2SEugene Zhulenev // PowFOp strength reduction.
28d94426d2SEugene Zhulenev //----------------------------------------------------------------------------//
29d94426d2SEugene Zhulenev 
30d94426d2SEugene Zhulenev namespace {
31d94426d2SEugene Zhulenev struct PowFStrengthReduction : public OpRewritePattern<math::PowFOp> {
32d94426d2SEugene Zhulenev public:
33d94426d2SEugene Zhulenev   using OpRewritePattern::OpRewritePattern;
34d94426d2SEugene Zhulenev 
35d94426d2SEugene Zhulenev   LogicalResult matchAndRewrite(math::PowFOp op,
36d94426d2SEugene Zhulenev                                 PatternRewriter &rewriter) const final;
37d94426d2SEugene Zhulenev };
38d94426d2SEugene Zhulenev } // namespace
39d94426d2SEugene Zhulenev 
40d94426d2SEugene Zhulenev LogicalResult
matchAndRewrite(math::PowFOp op,PatternRewriter & rewriter) const41d94426d2SEugene Zhulenev PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
42d94426d2SEugene Zhulenev                                        PatternRewriter &rewriter) const {
43d94426d2SEugene Zhulenev   Location loc = op.getLoc();
4462fea88bSJacques Pienaar   Value x = op.getLhs();
45d94426d2SEugene Zhulenev 
46d94426d2SEugene Zhulenev   FloatAttr scalarExponent;
47d94426d2SEugene Zhulenev   DenseFPElementsAttr vectorExponent;
48d94426d2SEugene Zhulenev 
4962fea88bSJacques Pienaar   bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent));
5062fea88bSJacques Pienaar   bool isVector = matchPattern(op.getRhs(), m_Constant(&vectorExponent));
51d94426d2SEugene Zhulenev 
52d94426d2SEugene Zhulenev   // Returns true if exponent is a constant equal to `value`.
53d94426d2SEugene Zhulenev   auto isExponentValue = [&](double value) -> bool {
54d94426d2SEugene Zhulenev     if (isScalar)
55d94426d2SEugene Zhulenev       return scalarExponent.getValue().isExactlyValue(value);
56d94426d2SEugene Zhulenev 
57d94426d2SEugene Zhulenev     if (isVector && vectorExponent.isSplat())
58d94426d2SEugene Zhulenev       return vectorExponent.getSplatValue<FloatAttr>()
59d94426d2SEugene Zhulenev           .getValue()
60d94426d2SEugene Zhulenev           .isExactlyValue(value);
61d94426d2SEugene Zhulenev 
62d94426d2SEugene Zhulenev     return false;
63d94426d2SEugene Zhulenev   };
64d94426d2SEugene Zhulenev 
65d94426d2SEugene Zhulenev   // Maybe broadcasts scalar value into vector type compatible with `op`.
66d94426d2SEugene Zhulenev   auto bcast = [&](Value value) -> Value {
67d94426d2SEugene Zhulenev     if (auto vec = op.getType().dyn_cast<VectorType>())
68d94426d2SEugene Zhulenev       return rewriter.create<vector::BroadcastOp>(op.getLoc(), vec, value);
69d94426d2SEugene Zhulenev     return value;
70d94426d2SEugene Zhulenev   };
71d94426d2SEugene Zhulenev 
72d94426d2SEugene Zhulenev   // Replace `pow(x, 1.0)` with `x`.
73d94426d2SEugene Zhulenev   if (isExponentValue(1.0)) {
74d94426d2SEugene Zhulenev     rewriter.replaceOp(op, x);
75d94426d2SEugene Zhulenev     return success();
76d94426d2SEugene Zhulenev   }
77d94426d2SEugene Zhulenev 
78d94426d2SEugene Zhulenev   // Replace `pow(x, 2.0)` with `x * x`.
79d94426d2SEugene Zhulenev   if (isExponentValue(2.0)) {
80a54f4eaeSMogball     rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, x}));
81d94426d2SEugene Zhulenev     return success();
82d94426d2SEugene Zhulenev   }
83d94426d2SEugene Zhulenev 
84391456f3Sbakhtiyar   // Replace `pow(x, 3.0)` with `x * x * x`.
85d94426d2SEugene Zhulenev   if (isExponentValue(3.0)) {
86a54f4eaeSMogball     Value square =
87a54f4eaeSMogball         rewriter.create<arith::MulFOp>(op.getLoc(), ValueRange({x, x}));
88a54f4eaeSMogball     rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, square}));
89d94426d2SEugene Zhulenev     return success();
90d94426d2SEugene Zhulenev   }
91d94426d2SEugene Zhulenev 
92d94426d2SEugene Zhulenev   // Replace `pow(x, -1.0)` with `1.0 / x`.
93d94426d2SEugene Zhulenev   if (isExponentValue(-1.0)) {
94a54f4eaeSMogball     Value one = rewriter.create<arith::ConstantOp>(
95d94426d2SEugene Zhulenev         loc, rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
96a54f4eaeSMogball     rewriter.replaceOpWithNewOp<arith::DivFOp>(op, ValueRange({bcast(one), x}));
97d94426d2SEugene Zhulenev     return success();
98d94426d2SEugene Zhulenev   }
99d94426d2SEugene Zhulenev 
100391456f3Sbakhtiyar   // Replace `pow(x, 0.5)` with `sqrt(x)`.
101391456f3Sbakhtiyar   if (isExponentValue(0.5)) {
102d94426d2SEugene Zhulenev     rewriter.replaceOpWithNewOp<math::SqrtOp>(op, x);
103d94426d2SEugene Zhulenev     return success();
104d94426d2SEugene Zhulenev   }
105d94426d2SEugene Zhulenev 
106391456f3Sbakhtiyar   // Replace `pow(x, -0.5)` with `rsqrt(x)`.
107391456f3Sbakhtiyar   if (isExponentValue(-0.5)) {
108391456f3Sbakhtiyar     rewriter.replaceOpWithNewOp<math::RsqrtOp>(op, x);
109391456f3Sbakhtiyar     return success();
110391456f3Sbakhtiyar   }
111391456f3Sbakhtiyar 
112d94426d2SEugene Zhulenev   return failure();
113d94426d2SEugene Zhulenev }
114d94426d2SEugene Zhulenev 
115d94426d2SEugene Zhulenev //----------------------------------------------------------------------------//
116d94426d2SEugene Zhulenev 
populateMathAlgebraicSimplificationPatterns(RewritePatternSet & patterns)117d94426d2SEugene Zhulenev void mlir::populateMathAlgebraicSimplificationPatterns(
118d94426d2SEugene Zhulenev     RewritePatternSet &patterns) {
119d94426d2SEugene Zhulenev   patterns.add<PowFStrengthReduction>(patterns.getContext());
120d94426d2SEugene Zhulenev }
121