1 //===- PolynomialApproximation.cpp - Approximate math operations ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements expansion of math operations to fast approximations 10 // that do not rely on any of the library functions. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Math/IR/Math.h" 15 #include "mlir/Dialect/Math/Transforms/Passes.h" 16 #include "mlir/Dialect/Vector/VectorOps.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/Transforms/DialectConversion.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 21 using namespace mlir; 22 using namespace mlir::vector; 23 24 static bool isValidFloatType(Type type) { 25 if (auto vectorType = type.dyn_cast<VectorType>()) 26 return vectorType.getElementType().isa<FloatType>(); 27 return type.isa<FloatType>(); 28 } 29 30 //----------------------------------------------------------------------------// 31 // A PatternRewriter wrapper that provides concise API for building expansions 32 // for operations on float scalars or vectors. 33 //----------------------------------------------------------------------------// 34 35 namespace { 36 class FloatApproximationBuilder { 37 public: 38 FloatApproximationBuilder(Location loc, Type type, PatternRewriter &rewriter); 39 40 Value constant(double value) const; 41 42 Value abs(Value a) const; 43 Value min(Value a, Value b) const; 44 Value max(Value a, Value b) const; 45 Value mul(Value a, Value b) const; 46 Value div(Value a, Value b) const; 47 48 // Fused multiple-add operation: a * b + c. 49 Value madd(Value a, Value b, Value c) const; 50 51 // Compares values `a` and `b` with the given `predicate`. 52 Value cmp(CmpFPredicate predicate, Value a, Value b) const; 53 54 // Selects values from `a` or `b` based on the `predicate`. 55 Value select(Value predicate, Value a, Value b) const; 56 57 private: 58 Location loc; 59 PatternRewriter &rewriter; 60 VectorType vectorType; // can be null for scalar type 61 FloatType elementType; 62 }; 63 } // namespace 64 65 FloatApproximationBuilder::FloatApproximationBuilder(Location loc, Type type, 66 PatternRewriter &rewriter) 67 : loc(loc), rewriter(rewriter) { 68 vectorType = type.dyn_cast<VectorType>(); 69 70 if (vectorType) 71 elementType = vectorType.getElementType().cast<FloatType>(); 72 else 73 elementType = type.cast<FloatType>(); 74 } 75 76 Value FloatApproximationBuilder::constant(double value) const { 77 auto attr = rewriter.getFloatAttr(elementType, value); 78 Value scalar = rewriter.create<ConstantOp>(loc, attr); 79 80 if (vectorType) 81 return rewriter.create<BroadcastOp>(loc, vectorType, scalar); 82 return scalar; 83 } 84 85 Value FloatApproximationBuilder::abs(Value a) const { 86 return rewriter.create<AbsFOp>(loc, a); 87 } 88 89 Value FloatApproximationBuilder::min(Value a, Value b) const { 90 return select(cmp(CmpFPredicate::OLT, a, b), a, b); 91 } 92 Value FloatApproximationBuilder::max(Value a, Value b) const { 93 return select(cmp(CmpFPredicate::OGT, a, b), a, b); 94 } 95 Value FloatApproximationBuilder::mul(Value a, Value b) const { 96 return rewriter.create<MulFOp>(loc, a, b); 97 } 98 99 Value FloatApproximationBuilder::div(Value a, Value b) const { 100 return rewriter.create<DivFOp>(loc, a, b); 101 } 102 103 Value FloatApproximationBuilder::madd(Value a, Value b, Value c) const { 104 return rewriter.create<FmaFOp>(loc, a, b, c); 105 } 106 107 Value FloatApproximationBuilder::cmp(CmpFPredicate predicate, Value a, 108 Value b) const { 109 return rewriter.create<CmpFOp>(loc, predicate, a, b); 110 } 111 112 Value FloatApproximationBuilder::select(Value predicate, Value a, 113 Value b) const { 114 return rewriter.create<SelectOp>(loc, predicate, a, b); 115 } 116 117 //----------------------------------------------------------------------------// 118 // TanhOp approximation. 119 //----------------------------------------------------------------------------// 120 121 namespace { 122 struct TanhApproximation : public OpRewritePattern<math::TanhOp> { 123 public: 124 using OpRewritePattern::OpRewritePattern; 125 126 LogicalResult matchAndRewrite(math::TanhOp op, 127 PatternRewriter &rewriter) const final; 128 }; 129 } // namespace 130 131 LogicalResult 132 TanhApproximation::matchAndRewrite(math::TanhOp op, 133 PatternRewriter &rewriter) const { 134 if (!isValidFloatType(op.operand().getType())) 135 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 136 137 Value operand = op.operand(); 138 FloatApproximationBuilder builder(op->getLoc(), operand.getType(), rewriter); 139 140 // Clamp operand into [plusClamp, minusClamp] range. 141 Value plusClamp = builder.constant(7.90531110763549805); 142 Value minusClamp = builder.constant(-7.9053111076354980); 143 Value x = builder.max(builder.min(operand, plusClamp), minusClamp); 144 145 // Mask for tiny values that are approximated with `operand`. 146 Value tiny = builder.constant(0.0004f); 147 Value tinyMask = builder.cmp(CmpFPredicate::OLT, builder.abs(operand), tiny); 148 149 // The monomial coefficients of the numerator polynomial (odd). 150 Value alpha1 = builder.constant(4.89352455891786e-03); 151 Value alpha3 = builder.constant(6.37261928875436e-04); 152 Value alpha5 = builder.constant(1.48572235717979e-05); 153 Value alpha7 = builder.constant(5.12229709037114e-08); 154 Value alpha9 = builder.constant(-8.60467152213735e-11); 155 Value alpha11 = builder.constant(2.00018790482477e-13); 156 Value alpha13 = builder.constant(-2.76076847742355e-16); 157 158 // The monomial coefficients of the denominator polynomial (even). 159 Value beta0 = builder.constant(4.89352518554385e-03); 160 Value beta2 = builder.constant(2.26843463243900e-03); 161 Value beta4 = builder.constant(1.18534705686654e-04); 162 Value beta6 = builder.constant(1.19825839466702e-06); 163 164 // Since the polynomials are odd/even, we need x^2. 165 Value x2 = builder.mul(x, x); 166 167 // Evaluate the numerator polynomial p. 168 Value p = builder.madd(x2, alpha13, alpha11); 169 p = builder.madd(x2, p, alpha9); 170 p = builder.madd(x2, p, alpha7); 171 p = builder.madd(x2, p, alpha5); 172 p = builder.madd(x2, p, alpha3); 173 p = builder.madd(x2, p, alpha1); 174 p = builder.mul(x, p); 175 176 // Evaluate the denominator polynomial q. 177 Value q = builder.madd(x2, beta6, beta4); 178 q = builder.madd(x2, q, beta2); 179 q = builder.madd(x2, q, beta0); 180 181 // Divide the numerator by the denominator. 182 Value res = builder.select(tinyMask, x, builder.div(p, q)); 183 184 rewriter.replaceOp(op, res); 185 186 return success(); 187 } 188 189 //----------------------------------------------------------------------------// 190 191 void mlir::populateMathPolynomialApproximationPatterns( 192 OwningRewritePatternList &patterns, MLIRContext *ctx) { 193 patterns.insert<TanhApproximation>(ctx); 194 } 195