1 //===- AlgebraicSimplification.cpp - Simplify algebraic expressions -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewrites based on the basic rules of algebra
10 // (Commutativity, associativity, etc...) and strength reductions for math
11 // operations.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/Math/IR/Math.h"
17 #include "mlir/Dialect/Math/Transforms/Passes.h"
18 #include "mlir/Dialect/Vector/IR/VectorOps.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/TypeUtilities.h"
22 #include <climits>
23 
24 using namespace mlir;
25 
26 //----------------------------------------------------------------------------//
27 // PowFOp strength reduction.
28 //----------------------------------------------------------------------------//
29 
30 namespace {
31 struct PowFStrengthReduction : public OpRewritePattern<math::PowFOp> {
32 public:
33   using OpRewritePattern::OpRewritePattern;
34 
35   LogicalResult matchAndRewrite(math::PowFOp op,
36                                 PatternRewriter &rewriter) const final;
37 };
38 } // namespace
39 
40 LogicalResult
41 PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
42                                        PatternRewriter &rewriter) const {
43   Location loc = op.getLoc();
44   Value x = op.getLhs();
45 
46   FloatAttr scalarExponent;
47   DenseFPElementsAttr vectorExponent;
48 
49   bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent));
50   bool isVector = matchPattern(op.getRhs(), m_Constant(&vectorExponent));
51 
52   // Returns true if exponent is a constant equal to `value`.
53   auto isExponentValue = [&](double value) -> bool {
54     if (isScalar)
55       return scalarExponent.getValue().isExactlyValue(value);
56 
57     if (isVector && vectorExponent.isSplat())
58       return vectorExponent.getSplatValue<FloatAttr>()
59           .getValue()
60           .isExactlyValue(value);
61 
62     return false;
63   };
64 
65   // Maybe broadcasts scalar value into vector type compatible with `op`.
66   auto bcast = [&](Value value) -> Value {
67     if (auto vec = op.getType().dyn_cast<VectorType>())
68       return rewriter.create<vector::BroadcastOp>(op.getLoc(), vec, value);
69     return value;
70   };
71 
72   // Replace `pow(x, 1.0)` with `x`.
73   if (isExponentValue(1.0)) {
74     rewriter.replaceOp(op, x);
75     return success();
76   }
77 
78   // Replace `pow(x, 2.0)` with `x * x`.
79   if (isExponentValue(2.0)) {
80     rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, x}));
81     return success();
82   }
83 
84   // Replace `pow(x, 3.0)` with `x * x * x`.
85   if (isExponentValue(3.0)) {
86     Value square =
87         rewriter.create<arith::MulFOp>(op.getLoc(), ValueRange({x, x}));
88     rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, square}));
89     return success();
90   }
91 
92   // Replace `pow(x, -1.0)` with `1.0 / x`.
93   if (isExponentValue(-1.0)) {
94     Value one = rewriter.create<arith::ConstantOp>(
95         loc, rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
96     rewriter.replaceOpWithNewOp<arith::DivFOp>(op, ValueRange({bcast(one), x}));
97     return success();
98   }
99 
100   // Replace `pow(x, 0.5)` with `sqrt(x)`.
101   if (isExponentValue(0.5)) {
102     rewriter.replaceOpWithNewOp<math::SqrtOp>(op, x);
103     return success();
104   }
105 
106   // Replace `pow(x, -0.5)` with `rsqrt(x)`.
107   if (isExponentValue(-0.5)) {
108     rewriter.replaceOpWithNewOp<math::RsqrtOp>(op, x);
109     return success();
110   }
111 
112   return failure();
113 }
114 
115 //----------------------------------------------------------------------------//
116 
117 void mlir::populateMathAlgebraicSimplificationPatterns(
118     RewritePatternSet &patterns) {
119   patterns.add<PowFStrengthReduction>(patterns.getContext());
120 }
121