1d94426d2SEugene Zhulenev //===- AlgebraicSimplification.cpp - Simplify algebraic expressions -------===// 2d94426d2SEugene Zhulenev // 3d94426d2SEugene Zhulenev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4d94426d2SEugene Zhulenev // See https://llvm.org/LICENSE.txt for license information. 5d94426d2SEugene Zhulenev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6d94426d2SEugene Zhulenev // 7d94426d2SEugene Zhulenev //===----------------------------------------------------------------------===// 8d94426d2SEugene Zhulenev // 9d94426d2SEugene Zhulenev // This file implements rewrites based on the basic rules of algebra 10d94426d2SEugene Zhulenev // (Commutativity, associativity, etc...) and strength reductions for math 11d94426d2SEugene Zhulenev // operations. 12d94426d2SEugene Zhulenev // 13d94426d2SEugene Zhulenev //===----------------------------------------------------------------------===// 14d94426d2SEugene Zhulenev 15*a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16d94426d2SEugene Zhulenev #include "mlir/Dialect/Math/IR/Math.h" 17d94426d2SEugene Zhulenev #include "mlir/Dialect/Math/Transforms/Passes.h" 18d94426d2SEugene Zhulenev #include "mlir/Dialect/Vector/VectorOps.h" 19d94426d2SEugene Zhulenev #include "mlir/IR/Builders.h" 20d94426d2SEugene Zhulenev #include "mlir/IR/Matchers.h" 21d94426d2SEugene Zhulenev #include "mlir/IR/TypeUtilities.h" 22d94426d2SEugene Zhulenev #include <climits> 23d94426d2SEugene Zhulenev 24d94426d2SEugene Zhulenev using namespace mlir; 25d94426d2SEugene Zhulenev 26d94426d2SEugene Zhulenev //----------------------------------------------------------------------------// 27d94426d2SEugene Zhulenev // PowFOp strength reduction. 28d94426d2SEugene Zhulenev //----------------------------------------------------------------------------// 29d94426d2SEugene Zhulenev 30d94426d2SEugene Zhulenev namespace { 31d94426d2SEugene Zhulenev struct PowFStrengthReduction : public OpRewritePattern<math::PowFOp> { 32d94426d2SEugene Zhulenev public: 33d94426d2SEugene Zhulenev using OpRewritePattern::OpRewritePattern; 34d94426d2SEugene Zhulenev 35d94426d2SEugene Zhulenev LogicalResult matchAndRewrite(math::PowFOp op, 36d94426d2SEugene Zhulenev PatternRewriter &rewriter) const final; 37d94426d2SEugene Zhulenev }; 38d94426d2SEugene Zhulenev } // namespace 39d94426d2SEugene Zhulenev 40d94426d2SEugene Zhulenev LogicalResult 41d94426d2SEugene Zhulenev PowFStrengthReduction::matchAndRewrite(math::PowFOp op, 42d94426d2SEugene Zhulenev PatternRewriter &rewriter) const { 43d94426d2SEugene Zhulenev Location loc = op.getLoc(); 44d94426d2SEugene Zhulenev Value x = op.lhs(); 45d94426d2SEugene Zhulenev 46d94426d2SEugene Zhulenev FloatAttr scalarExponent; 47d94426d2SEugene Zhulenev DenseFPElementsAttr vectorExponent; 48d94426d2SEugene Zhulenev 49d94426d2SEugene Zhulenev bool isScalar = matchPattern(op.rhs(), m_Constant(&scalarExponent)); 50d94426d2SEugene Zhulenev bool isVector = matchPattern(op.rhs(), m_Constant(&vectorExponent)); 51d94426d2SEugene Zhulenev 52d94426d2SEugene Zhulenev // Returns true if exponent is a constant equal to `value`. 53d94426d2SEugene Zhulenev auto isExponentValue = [&](double value) -> bool { 54d94426d2SEugene Zhulenev if (isScalar) 55d94426d2SEugene Zhulenev return scalarExponent.getValue().isExactlyValue(value); 56d94426d2SEugene Zhulenev 57d94426d2SEugene Zhulenev if (isVector && vectorExponent.isSplat()) 58d94426d2SEugene Zhulenev return vectorExponent.getSplatValue<FloatAttr>() 59d94426d2SEugene Zhulenev .getValue() 60d94426d2SEugene Zhulenev .isExactlyValue(value); 61d94426d2SEugene Zhulenev 62d94426d2SEugene Zhulenev return false; 63d94426d2SEugene Zhulenev }; 64d94426d2SEugene Zhulenev 65d94426d2SEugene Zhulenev // Maybe broadcasts scalar value into vector type compatible with `op`. 66d94426d2SEugene Zhulenev auto bcast = [&](Value value) -> Value { 67d94426d2SEugene Zhulenev if (auto vec = op.getType().dyn_cast<VectorType>()) 68d94426d2SEugene Zhulenev return rewriter.create<vector::BroadcastOp>(op.getLoc(), vec, value); 69d94426d2SEugene Zhulenev return value; 70d94426d2SEugene Zhulenev }; 71d94426d2SEugene Zhulenev 72d94426d2SEugene Zhulenev // Replace `pow(x, 1.0)` with `x`. 73d94426d2SEugene Zhulenev if (isExponentValue(1.0)) { 74d94426d2SEugene Zhulenev rewriter.replaceOp(op, x); 75d94426d2SEugene Zhulenev return success(); 76d94426d2SEugene Zhulenev } 77d94426d2SEugene Zhulenev 78d94426d2SEugene Zhulenev // Replace `pow(x, 2.0)` with `x * x`. 79d94426d2SEugene Zhulenev if (isExponentValue(2.0)) { 80*a54f4eaeSMogball rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, x})); 81d94426d2SEugene Zhulenev return success(); 82d94426d2SEugene Zhulenev } 83d94426d2SEugene Zhulenev 84391456f3Sbakhtiyar // Replace `pow(x, 3.0)` with `x * x * x`. 85d94426d2SEugene Zhulenev if (isExponentValue(3.0)) { 86*a54f4eaeSMogball Value square = 87*a54f4eaeSMogball rewriter.create<arith::MulFOp>(op.getLoc(), ValueRange({x, x})); 88*a54f4eaeSMogball rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, square})); 89d94426d2SEugene Zhulenev return success(); 90d94426d2SEugene Zhulenev } 91d94426d2SEugene Zhulenev 92d94426d2SEugene Zhulenev // Replace `pow(x, -1.0)` with `1.0 / x`. 93d94426d2SEugene Zhulenev if (isExponentValue(-1.0)) { 94*a54f4eaeSMogball Value one = rewriter.create<arith::ConstantOp>( 95d94426d2SEugene Zhulenev loc, rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0)); 96*a54f4eaeSMogball rewriter.replaceOpWithNewOp<arith::DivFOp>(op, ValueRange({bcast(one), x})); 97d94426d2SEugene Zhulenev return success(); 98d94426d2SEugene Zhulenev } 99d94426d2SEugene Zhulenev 100391456f3Sbakhtiyar // Replace `pow(x, 0.5)` with `sqrt(x)`. 101391456f3Sbakhtiyar if (isExponentValue(0.5)) { 102d94426d2SEugene Zhulenev rewriter.replaceOpWithNewOp<math::SqrtOp>(op, x); 103d94426d2SEugene Zhulenev return success(); 104d94426d2SEugene Zhulenev } 105d94426d2SEugene Zhulenev 106391456f3Sbakhtiyar // Replace `pow(x, -0.5)` with `rsqrt(x)`. 107391456f3Sbakhtiyar if (isExponentValue(-0.5)) { 108391456f3Sbakhtiyar rewriter.replaceOpWithNewOp<math::RsqrtOp>(op, x); 109391456f3Sbakhtiyar return success(); 110391456f3Sbakhtiyar } 111391456f3Sbakhtiyar 112d94426d2SEugene Zhulenev return failure(); 113d94426d2SEugene Zhulenev } 114d94426d2SEugene Zhulenev 115d94426d2SEugene Zhulenev //----------------------------------------------------------------------------// 116d94426d2SEugene Zhulenev 117d94426d2SEugene Zhulenev void mlir::populateMathAlgebraicSimplificationPatterns( 118d94426d2SEugene Zhulenev RewritePatternSet &patterns) { 119d94426d2SEugene Zhulenev patterns.add<PowFStrengthReduction>(patterns.getContext()); 120d94426d2SEugene Zhulenev } 121