1 //===- AlgebraicSimplification.cpp - Simplify algebraic expressions -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements rewrites based on the basic rules of algebra 10 // (Commutativity, associativity, etc...) and strength reductions for math 11 // operations. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/Math/IR/Math.h" 17 #include "mlir/Dialect/Math/Transforms/Passes.h" 18 #include "mlir/Dialect/Vector/IR/VectorOps.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/Matchers.h" 21 #include "mlir/IR/TypeUtilities.h" 22 #include <climits> 23 24 using namespace mlir; 25 26 //----------------------------------------------------------------------------// 27 // PowFOp strength reduction. 28 //----------------------------------------------------------------------------// 29 30 namespace { 31 struct PowFStrengthReduction : public OpRewritePattern<math::PowFOp> { 32 public: 33 using OpRewritePattern::OpRewritePattern; 34 35 LogicalResult matchAndRewrite(math::PowFOp op, 36 PatternRewriter &rewriter) const final; 37 }; 38 } // namespace 39 40 LogicalResult 41 PowFStrengthReduction::matchAndRewrite(math::PowFOp op, 42 PatternRewriter &rewriter) const { 43 Location loc = op.getLoc(); 44 Value x = op.getLhs(); 45 46 FloatAttr scalarExponent; 47 DenseFPElementsAttr vectorExponent; 48 49 bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent)); 50 bool isVector = matchPattern(op.getRhs(), m_Constant(&vectorExponent)); 51 52 // Returns true if exponent is a constant equal to `value`. 53 auto isExponentValue = [&](double value) -> bool { 54 if (isScalar) 55 return scalarExponent.getValue().isExactlyValue(value); 56 57 if (isVector && vectorExponent.isSplat()) 58 return vectorExponent.getSplatValue<FloatAttr>() 59 .getValue() 60 .isExactlyValue(value); 61 62 return false; 63 }; 64 65 // Maybe broadcasts scalar value into vector type compatible with `op`. 66 auto bcast = [&](Value value) -> Value { 67 if (auto vec = op.getType().dyn_cast<VectorType>()) 68 return rewriter.create<vector::BroadcastOp>(op.getLoc(), vec, value); 69 return value; 70 }; 71 72 // Replace `pow(x, 1.0)` with `x`. 73 if (isExponentValue(1.0)) { 74 rewriter.replaceOp(op, x); 75 return success(); 76 } 77 78 // Replace `pow(x, 2.0)` with `x * x`. 79 if (isExponentValue(2.0)) { 80 rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, x})); 81 return success(); 82 } 83 84 // Replace `pow(x, 3.0)` with `x * x * x`. 85 if (isExponentValue(3.0)) { 86 Value square = 87 rewriter.create<arith::MulFOp>(op.getLoc(), ValueRange({x, x})); 88 rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, square})); 89 return success(); 90 } 91 92 // Replace `pow(x, -1.0)` with `1.0 / x`. 93 if (isExponentValue(-1.0)) { 94 Value one = rewriter.create<arith::ConstantOp>( 95 loc, rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0)); 96 rewriter.replaceOpWithNewOp<arith::DivFOp>(op, ValueRange({bcast(one), x})); 97 return success(); 98 } 99 100 // Replace `pow(x, 0.5)` with `sqrt(x)`. 101 if (isExponentValue(0.5)) { 102 rewriter.replaceOpWithNewOp<math::SqrtOp>(op, x); 103 return success(); 104 } 105 106 // Replace `pow(x, -0.5)` with `rsqrt(x)`. 107 if (isExponentValue(-0.5)) { 108 rewriter.replaceOpWithNewOp<math::RsqrtOp>(op, x); 109 return success(); 110 } 111 112 return failure(); 113 } 114 115 //----------------------------------------------------------------------------// 116 117 void mlir::populateMathAlgebraicSimplificationPatterns( 118 RewritePatternSet &patterns) { 119 patterns.add<PowFStrengthReduction>(patterns.getContext()); 120 } 121