1 //===- AlgebraicSimplification.cpp - Simplify algebraic expressions -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewrites based on the basic rules of algebra
10 // (Commutativity, associativity, etc...) and strength reductions for math
11 // operations.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Dialect/Math/IR/Math.h"
16 #include "mlir/Dialect/Math/Transforms/Passes.h"
17 #include "mlir/Dialect/Vector/VectorOps.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/Matchers.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include <climits>
22 
23 using namespace mlir;
24 
25 //----------------------------------------------------------------------------//
26 // PowFOp strength reduction.
27 //----------------------------------------------------------------------------//
28 
29 namespace {
30 struct PowFStrengthReduction : public OpRewritePattern<math::PowFOp> {
31 public:
32   using OpRewritePattern::OpRewritePattern;
33 
34   LogicalResult matchAndRewrite(math::PowFOp op,
35                                 PatternRewriter &rewriter) const final;
36 };
37 } // namespace
38 
39 LogicalResult
40 PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
41                                        PatternRewriter &rewriter) const {
42   Location loc = op.getLoc();
43   Value x = op.lhs();
44 
45   FloatAttr scalarExponent;
46   DenseFPElementsAttr vectorExponent;
47 
48   bool isScalar = matchPattern(op.rhs(), m_Constant(&scalarExponent));
49   bool isVector = matchPattern(op.rhs(), m_Constant(&vectorExponent));
50 
51   // Returns true if exponent is a constant equal to `value`.
52   auto isExponentValue = [&](double value) -> bool {
53     if (isScalar)
54       return scalarExponent.getValue().isExactlyValue(value);
55 
56     if (isVector && vectorExponent.isSplat())
57       return vectorExponent.getSplatValue<FloatAttr>()
58           .getValue()
59           .isExactlyValue(value);
60 
61     return false;
62   };
63 
64   // Maybe broadcasts scalar value into vector type compatible with `op`.
65   auto bcast = [&](Value value) -> Value {
66     if (auto vec = op.getType().dyn_cast<VectorType>())
67       return rewriter.create<vector::BroadcastOp>(op.getLoc(), vec, value);
68     return value;
69   };
70 
71   // Replace `pow(x, 1.0)` with `x`.
72   if (isExponentValue(1.0)) {
73     rewriter.replaceOp(op, x);
74     return success();
75   }
76 
77   // Replace `pow(x, 2.0)` with `x * x`.
78   if (isExponentValue(2.0)) {
79     rewriter.replaceOpWithNewOp<MulFOp>(op, ValueRange({x, x}));
80     return success();
81   }
82 
83   // Replace `pow(x, 3.0)` with `x * x * x`.
84   if (isExponentValue(3.0)) {
85     Value square = rewriter.create<MulFOp>(op.getLoc(), ValueRange({x, x}));
86     rewriter.replaceOpWithNewOp<MulFOp>(op, ValueRange({x, square}));
87     return success();
88   }
89 
90   // Replace `pow(x, -1.0)` with `1.0 / x`.
91   if (isExponentValue(-1.0)) {
92     Value one = rewriter.create<ConstantOp>(
93         loc, rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
94     rewriter.replaceOpWithNewOp<DivFOp>(op, ValueRange({bcast(one), x}));
95     return success();
96   }
97 
98   // Replace `pow(x, 0.5)` with `sqrt(x)`.
99   if (isExponentValue(0.5)) {
100     rewriter.replaceOpWithNewOp<math::SqrtOp>(op, x);
101     return success();
102   }
103 
104   // Replace `pow(x, -0.5)` with `rsqrt(x)`.
105   if (isExponentValue(-0.5)) {
106     rewriter.replaceOpWithNewOp<math::RsqrtOp>(op, x);
107     return success();
108   }
109 
110   return failure();
111 }
112 
113 //----------------------------------------------------------------------------//
114 
115 void mlir::populateMathAlgebraicSimplificationPatterns(
116     RewritePatternSet &patterns) {
117   patterns.add<PowFStrengthReduction>(patterns.getContext());
118 }
119