1 //===- PolynomialApproximation.cpp - Approximate math operations ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements expansion of math operations to fast approximations
10 // that do not rely on any of the library functions.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Math/IR/Math.h"
15 #include "mlir/Dialect/Math/Transforms/Passes.h"
16 #include "mlir/Dialect/Vector/VectorOps.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/Transforms/DialectConversion.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 
21 using namespace mlir;
22 using namespace mlir::vector;
23 
24 static bool isValidFloatType(Type type) {
25   if (auto vectorType = type.dyn_cast<VectorType>())
26     return vectorType.getElementType().isa<FloatType>();
27   return type.isa<FloatType>();
28 }
29 
30 //----------------------------------------------------------------------------//
31 // A PatternRewriter wrapper that provides concise API for building expansions
32 // for operations on float scalars or vectors.
33 //----------------------------------------------------------------------------//
34 
35 namespace {
36 class FloatApproximationBuilder {
37 public:
38   FloatApproximationBuilder(Location loc, Type type, PatternRewriter &rewriter);
39 
40   Value constant(double value) const;
41 
42   Value abs(Value a) const;
43   Value min(Value a, Value b) const;
44   Value max(Value a, Value b) const;
45   Value mul(Value a, Value b) const;
46   Value div(Value a, Value b) const;
47 
48   // Fused multiple-add operation: a * b + c.
49   Value madd(Value a, Value b, Value c) const;
50 
51   // Compares values `a` and `b` with the given `predicate`.
52   Value cmp(CmpFPredicate predicate, Value a, Value b) const;
53 
54   // Selects values from `a` or `b` based on the `predicate`.
55   Value select(Value predicate, Value a, Value b) const;
56 
57 private:
58   Location loc;
59   PatternRewriter &rewriter;
60   VectorType vectorType; // can be null for scalar type
61   FloatType elementType;
62 };
63 } // namespace
64 
65 FloatApproximationBuilder::FloatApproximationBuilder(Location loc, Type type,
66                                                      PatternRewriter &rewriter)
67     : loc(loc), rewriter(rewriter) {
68   vectorType = type.dyn_cast<VectorType>();
69 
70   if (vectorType)
71     elementType = vectorType.getElementType().cast<FloatType>();
72   else
73     elementType = type.cast<FloatType>();
74 }
75 
76 Value FloatApproximationBuilder::constant(double value) const {
77   auto attr = rewriter.getFloatAttr(elementType, value);
78   Value scalar = rewriter.create<ConstantOp>(loc, attr);
79 
80   if (vectorType)
81     return rewriter.create<BroadcastOp>(loc, vectorType, scalar);
82   return scalar;
83 }
84 
85 Value FloatApproximationBuilder::abs(Value a) const {
86   return rewriter.create<AbsFOp>(loc, a);
87 }
88 
89 Value FloatApproximationBuilder::min(Value a, Value b) const {
90   return select(cmp(CmpFPredicate::OLT, a, b), a, b);
91 }
92 Value FloatApproximationBuilder::max(Value a, Value b) const {
93   return select(cmp(CmpFPredicate::OGT, a, b), a, b);
94 }
95 Value FloatApproximationBuilder::mul(Value a, Value b) const {
96   return rewriter.create<MulFOp>(loc, a, b);
97 }
98 
99 Value FloatApproximationBuilder::div(Value a, Value b) const {
100   return rewriter.create<DivFOp>(loc, a, b);
101 }
102 
103 Value FloatApproximationBuilder::madd(Value a, Value b, Value c) const {
104   return rewriter.create<FmaFOp>(loc, a, b, c);
105 }
106 
107 Value FloatApproximationBuilder::cmp(CmpFPredicate predicate, Value a,
108                                      Value b) const {
109   return rewriter.create<CmpFOp>(loc, predicate, a, b);
110 }
111 
112 Value FloatApproximationBuilder::select(Value predicate, Value a,
113                                         Value b) const {
114   return rewriter.create<SelectOp>(loc, predicate, a, b);
115 }
116 
117 //----------------------------------------------------------------------------//
118 // TanhOp approximation.
119 //----------------------------------------------------------------------------//
120 
121 namespace {
122 struct TanhApproximation : public OpRewritePattern<math::TanhOp> {
123 public:
124   using OpRewritePattern::OpRewritePattern;
125 
126   LogicalResult matchAndRewrite(math::TanhOp op,
127                                 PatternRewriter &rewriter) const final;
128 };
129 } // namespace
130 
131 LogicalResult
132 TanhApproximation::matchAndRewrite(math::TanhOp op,
133                                    PatternRewriter &rewriter) const {
134   if (!isValidFloatType(op.operand().getType()))
135     return rewriter.notifyMatchFailure(op, "unsupported operand type");
136 
137   Value operand = op.operand();
138   FloatApproximationBuilder builder(op->getLoc(), operand.getType(), rewriter);
139 
140   // Clamp operand into [plusClamp, minusClamp] range.
141   Value plusClamp = builder.constant(7.90531110763549805);
142   Value minusClamp = builder.constant(-7.9053111076354980);
143   Value x = builder.max(builder.min(operand, plusClamp), minusClamp);
144 
145   // Mask for tiny values that are approximated with `operand`.
146   Value tiny = builder.constant(0.0004f);
147   Value tinyMask = builder.cmp(CmpFPredicate::OLT, builder.abs(operand), tiny);
148 
149   // The monomial coefficients of the numerator polynomial (odd).
150   Value alpha1 = builder.constant(4.89352455891786e-03);
151   Value alpha3 = builder.constant(6.37261928875436e-04);
152   Value alpha5 = builder.constant(1.48572235717979e-05);
153   Value alpha7 = builder.constant(5.12229709037114e-08);
154   Value alpha9 = builder.constant(-8.60467152213735e-11);
155   Value alpha11 = builder.constant(2.00018790482477e-13);
156   Value alpha13 = builder.constant(-2.76076847742355e-16);
157 
158   // The monomial coefficients of the denominator polynomial (even).
159   Value beta0 = builder.constant(4.89352518554385e-03);
160   Value beta2 = builder.constant(2.26843463243900e-03);
161   Value beta4 = builder.constant(1.18534705686654e-04);
162   Value beta6 = builder.constant(1.19825839466702e-06);
163 
164   // Since the polynomials are odd/even, we need x^2.
165   Value x2 = builder.mul(x, x);
166 
167   // Evaluate the numerator polynomial p.
168   Value p = builder.madd(x2, alpha13, alpha11);
169   p = builder.madd(x2, p, alpha9);
170   p = builder.madd(x2, p, alpha7);
171   p = builder.madd(x2, p, alpha5);
172   p = builder.madd(x2, p, alpha3);
173   p = builder.madd(x2, p, alpha1);
174   p = builder.mul(x, p);
175 
176   // Evaluate the denominator polynomial q.
177   Value q = builder.madd(x2, beta6, beta4);
178   q = builder.madd(x2, q, beta2);
179   q = builder.madd(x2, q, beta0);
180 
181   // Divide the numerator by the denominator.
182   Value res = builder.select(tinyMask, x, builder.div(p, q));
183 
184   rewriter.replaceOp(op, res);
185 
186   return success();
187 }
188 
189 //----------------------------------------------------------------------------//
190 
191 void mlir::populateMathPolynomialApproximationPatterns(
192     OwningRewritePatternList &patterns, MLIRContext *ctx) {
193   patterns.insert<TanhApproximation>(ctx);
194 }
195