1 //===- AlgebraicSimplification.cpp - Simplify algebraic expressions -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewrites based on the basic rules of algebra
10 // (Commutativity, associativity, etc...) and strength reductions for math
11 // operations.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/Math/IR/Math.h"
17 #include "mlir/Dialect/Math/Transforms/Passes.h"
18 #include "mlir/Dialect/Vector/IR/VectorOps.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/TypeUtilities.h"
22 #include <climits>
23
24 using namespace mlir;
25
26 //----------------------------------------------------------------------------//
27 // PowFOp strength reduction.
28 //----------------------------------------------------------------------------//
29
30 namespace {
31 struct PowFStrengthReduction : public OpRewritePattern<math::PowFOp> {
32 public:
33 using OpRewritePattern::OpRewritePattern;
34
35 LogicalResult matchAndRewrite(math::PowFOp op,
36 PatternRewriter &rewriter) const final;
37 };
38 } // namespace
39
40 LogicalResult
matchAndRewrite(math::PowFOp op,PatternRewriter & rewriter) const41 PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
42 PatternRewriter &rewriter) const {
43 Location loc = op.getLoc();
44 Value x = op.getLhs();
45
46 FloatAttr scalarExponent;
47 DenseFPElementsAttr vectorExponent;
48
49 bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent));
50 bool isVector = matchPattern(op.getRhs(), m_Constant(&vectorExponent));
51
52 // Returns true if exponent is a constant equal to `value`.
53 auto isExponentValue = [&](double value) -> bool {
54 if (isScalar)
55 return scalarExponent.getValue().isExactlyValue(value);
56
57 if (isVector && vectorExponent.isSplat())
58 return vectorExponent.getSplatValue<FloatAttr>()
59 .getValue()
60 .isExactlyValue(value);
61
62 return false;
63 };
64
65 // Maybe broadcasts scalar value into vector type compatible with `op`.
66 auto bcast = [&](Value value) -> Value {
67 if (auto vec = op.getType().dyn_cast<VectorType>())
68 return rewriter.create<vector::BroadcastOp>(op.getLoc(), vec, value);
69 return value;
70 };
71
72 // Replace `pow(x, 1.0)` with `x`.
73 if (isExponentValue(1.0)) {
74 rewriter.replaceOp(op, x);
75 return success();
76 }
77
78 // Replace `pow(x, 2.0)` with `x * x`.
79 if (isExponentValue(2.0)) {
80 rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, x}));
81 return success();
82 }
83
84 // Replace `pow(x, 3.0)` with `x * x * x`.
85 if (isExponentValue(3.0)) {
86 Value square =
87 rewriter.create<arith::MulFOp>(op.getLoc(), ValueRange({x, x}));
88 rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, square}));
89 return success();
90 }
91
92 // Replace `pow(x, -1.0)` with `1.0 / x`.
93 if (isExponentValue(-1.0)) {
94 Value one = rewriter.create<arith::ConstantOp>(
95 loc, rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
96 rewriter.replaceOpWithNewOp<arith::DivFOp>(op, ValueRange({bcast(one), x}));
97 return success();
98 }
99
100 // Replace `pow(x, 0.5)` with `sqrt(x)`.
101 if (isExponentValue(0.5)) {
102 rewriter.replaceOpWithNewOp<math::SqrtOp>(op, x);
103 return success();
104 }
105
106 // Replace `pow(x, -0.5)` with `rsqrt(x)`.
107 if (isExponentValue(-0.5)) {
108 rewriter.replaceOpWithNewOp<math::RsqrtOp>(op, x);
109 return success();
110 }
111
112 return failure();
113 }
114
115 //----------------------------------------------------------------------------//
116
populateMathAlgebraicSimplificationPatterns(RewritePatternSet & patterns)117 void mlir::populateMathAlgebraicSimplificationPatterns(
118 RewritePatternSet &patterns) {
119 patterns.add<PowFStrengthReduction>(patterns.getContext());
120 }
121