1 //===- AlgebraicSimplification.cpp - Simplify algebraic expressions -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements rewrites based on the basic rules of algebra 10 // (Commutativity, associativity, etc...) and strength reductions for math 11 // operations. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/Dialect/Math/IR/Math.h" 16 #include "mlir/Dialect/Math/Transforms/Passes.h" 17 #include "mlir/Dialect/Vector/VectorOps.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/Matchers.h" 20 #include "mlir/IR/TypeUtilities.h" 21 #include <climits> 22 23 using namespace mlir; 24 25 //----------------------------------------------------------------------------// 26 // PowFOp strength reduction. 27 //----------------------------------------------------------------------------// 28 29 namespace { 30 struct PowFStrengthReduction : public OpRewritePattern<math::PowFOp> { 31 public: 32 using OpRewritePattern::OpRewritePattern; 33 34 LogicalResult matchAndRewrite(math::PowFOp op, 35 PatternRewriter &rewriter) const final; 36 }; 37 } // namespace 38 39 LogicalResult 40 PowFStrengthReduction::matchAndRewrite(math::PowFOp op, 41 PatternRewriter &rewriter) const { 42 Location loc = op.getLoc(); 43 Value x = op.lhs(); 44 45 FloatAttr scalarExponent; 46 DenseFPElementsAttr vectorExponent; 47 48 bool isScalar = matchPattern(op.rhs(), m_Constant(&scalarExponent)); 49 bool isVector = matchPattern(op.rhs(), m_Constant(&vectorExponent)); 50 51 // Returns true if exponent is a constant equal to `value`. 52 auto isExponentValue = [&](double value) -> bool { 53 if (isScalar) 54 return scalarExponent.getValue().isExactlyValue(value); 55 56 if (isVector && vectorExponent.isSplat()) 57 return vectorExponent.getSplatValue<FloatAttr>() 58 .getValue() 59 .isExactlyValue(value); 60 61 return false; 62 }; 63 64 // Maybe broadcasts scalar value into vector type compatible with `op`. 65 auto bcast = [&](Value value) -> Value { 66 if (auto vec = op.getType().dyn_cast<VectorType>()) 67 return rewriter.create<vector::BroadcastOp>(op.getLoc(), vec, value); 68 return value; 69 }; 70 71 // Replace `pow(x, 1.0)` with `x`. 72 if (isExponentValue(1.0)) { 73 rewriter.replaceOp(op, x); 74 return success(); 75 } 76 77 // Replace `pow(x, 2.0)` with `x * x`. 78 if (isExponentValue(2.0)) { 79 rewriter.replaceOpWithNewOp<MulFOp>(op, ValueRange({x, x})); 80 return success(); 81 } 82 83 // Replace `pow(x, 3.0)` with `x * x * x`. 84 if (isExponentValue(3.0)) { 85 Value square = rewriter.create<MulFOp>(op.getLoc(), ValueRange({x, x})); 86 rewriter.replaceOpWithNewOp<MulFOp>(op, ValueRange({x, square})); 87 return success(); 88 } 89 90 // Replace `pow(x, -1.0)` with `1.0 / x`. 91 if (isExponentValue(-1.0)) { 92 Value one = rewriter.create<ConstantOp>( 93 loc, rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0)); 94 rewriter.replaceOpWithNewOp<DivFOp>(op, ValueRange({bcast(one), x})); 95 return success(); 96 } 97 98 // Replace `pow(x, 0.5)` with `sqrt(x)`. 99 if (isExponentValue(0.5)) { 100 rewriter.replaceOpWithNewOp<math::SqrtOp>(op, x); 101 return success(); 102 } 103 104 // Replace `pow(x, -0.5)` with `rsqrt(x)`. 105 if (isExponentValue(-0.5)) { 106 rewriter.replaceOpWithNewOp<math::RsqrtOp>(op, x); 107 return success(); 108 } 109 110 return failure(); 111 } 112 113 //----------------------------------------------------------------------------// 114 115 void mlir::populateMathAlgebraicSimplificationPatterns( 116 RewritePatternSet &patterns) { 117 patterns.add<PowFStrengthReduction>(patterns.getContext()); 118 } 119