1 //===- PolynomialApproximation.cpp - Approximate math operations ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements expansion of math operations to fast approximations
10 // that do not rely on any of the library functions.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/Math/IR/Math.h"
16 #include "mlir/Dialect/Math/Transforms/Approximation.h"
17 #include "mlir/Dialect/Math/Transforms/Passes.h"
18 #include "mlir/Dialect/Vector/VectorOps.h"
19 #include "mlir/Dialect/X86Vector/X86VectorDialect.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/ImplicitLocOpBuilder.h"
22 #include "mlir/Transforms/Bufferize.h"
23 #include "mlir/Transforms/DialectConversion.h"
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25 #include "llvm/ADT/ArrayRef.h"
26 #include <climits>
27 #include <cstddef>
28 
29 using namespace mlir;
30 using namespace mlir::math;
31 using namespace mlir::vector;
32 
33 using TypePredicate = llvm::function_ref<bool(Type)>;
34 
35 // Returns vector shape if the element type is matching the predicate (scalars
36 // that do match the predicate have shape equal to `{1}`).
37 static Optional<SmallVector<int64_t, 2>> vectorShape(Type type,
38                                                      TypePredicate pred) {
39   // If the type matches the predicate then its shape is `{1}`.
40   if (pred(type))
41     return SmallVector<int64_t, 2>{1};
42 
43   // Otherwise check if the type is a vector type.
44   auto vectorType = type.dyn_cast<VectorType>();
45   if (vectorType && pred(vectorType.getElementType())) {
46     return llvm::to_vector<2>(vectorType.getShape());
47   }
48 
49   return llvm::None;
50 }
51 
52 // Returns vector shape of the type. If the type is a scalar returns `1`.
53 static SmallVector<int64_t, 2> vectorShape(Type type) {
54   auto vectorType = type.dyn_cast<VectorType>();
55   return vectorType ? llvm::to_vector<2>(vectorType.getShape())
56                     : SmallVector<int64_t, 2>{1};
57 }
58 
59 // Returns vector element type. If the type is a scalar returns the argument.
60 LLVM_ATTRIBUTE_UNUSED static Type elementType(Type type) {
61   auto vectorType = type.dyn_cast<VectorType>();
62   return vectorType ? vectorType.getElementType() : type;
63 }
64 
65 LLVM_ATTRIBUTE_UNUSED static bool isF32(Type type) { return type.isF32(); }
66 
67 LLVM_ATTRIBUTE_UNUSED static bool isI32(Type type) {
68   return type.isInteger(32);
69 }
70 
71 //----------------------------------------------------------------------------//
72 // Broadcast scalar types and values into vector types and values.
73 //----------------------------------------------------------------------------//
74 
75 // Returns true if shape != {1}.
76 static bool isNonScalarShape(ArrayRef<int64_t> shape) {
77   return shape.size() > 1 || shape[0] > 1;
78 }
79 
80 // Broadcasts scalar type into vector type (iff shape is non-scalar).
81 static Type broadcast(Type type, ArrayRef<int64_t> shape) {
82   assert(!type.isa<VectorType>() && "must be scalar type");
83   return isNonScalarShape(shape) ? VectorType::get(shape, type) : type;
84 }
85 
86 // Broadcasts scalar value into vector (iff shape is non-scalar).
87 static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
88                        ArrayRef<int64_t> shape) {
89   assert(!value.getType().isa<VectorType>() && "must be scalar value");
90   auto type = broadcast(value.getType(), shape);
91   return isNonScalarShape(shape) ? builder.create<BroadcastOp>(type, value)
92                                  : value;
93 }
94 
95 //----------------------------------------------------------------------------//
96 // Helper functions to create constants.
97 //----------------------------------------------------------------------------//
98 
99 static Value f32Cst(ImplicitLocOpBuilder &builder, float value) {
100   return builder.create<arith::ConstantOp>(builder.getF32FloatAttr(value));
101 }
102 
103 static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value) {
104   return builder.create<arith::ConstantOp>(builder.getI32IntegerAttr(value));
105 }
106 
107 static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits) {
108   Value i32Value = i32Cst(builder, static_cast<int32_t>(bits));
109   return builder.create<arith::BitcastOp>(builder.getF32Type(), i32Value);
110 }
111 
112 //----------------------------------------------------------------------------//
113 // Helper functions to build math functions approximations.
114 //----------------------------------------------------------------------------//
115 
116 static Value min(ImplicitLocOpBuilder &builder, Value a, Value b) {
117   return builder.create<SelectOp>(
118       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, a, b), a, b);
119 }
120 
121 static Value max(ImplicitLocOpBuilder &builder, Value a, Value b) {
122   return builder.create<SelectOp>(
123       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, a, b), a, b);
124 }
125 
126 static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound,
127                    Value upperBound) {
128   return max(builder, min(builder, value, upperBound), lowerBound);
129 }
130 
131 // Decomposes given floating point value `arg` into a normalized fraction and
132 // an integral power of two (see std::frexp). Returned values have float type.
133 static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
134                                      bool is_positive = false) {
135   assert(isF32(elementType(arg.getType())) && "argument must be f32 type");
136 
137   auto shape = vectorShape(arg.getType());
138 
139   auto bcast = [&](Value value) -> Value {
140     return broadcast(builder, value, shape);
141   };
142 
143   auto i32 = builder.getIntegerType(32);
144   auto i32Vec = broadcast(i32, shape);
145   auto f32Vec = broadcast(builder.getF32Type(), shape);
146 
147   Value cst126f = f32Cst(builder, 126.0f);
148   Value cstHalf = f32Cst(builder, 0.5f);
149   Value cstInvMantMask = f32FromBits(builder, ~0x7f800000u);
150 
151   // Bitcast to i32 for bitwise operations.
152   Value i32Half = builder.create<arith::BitcastOp>(i32, cstHalf);
153   Value i32InvMantMask = builder.create<arith::BitcastOp>(i32, cstInvMantMask);
154   Value i32Arg = builder.create<arith::BitcastOp>(i32Vec, arg);
155 
156   // Compute normalized fraction.
157   Value tmp0 = builder.create<arith::AndIOp>(i32Arg, bcast(i32InvMantMask));
158   Value tmp1 = builder.create<arith::OrIOp>(tmp0, bcast(i32Half));
159   Value normalizedFraction = builder.create<arith::BitcastOp>(f32Vec, tmp1);
160 
161   // Compute exponent.
162   Value arg0 = is_positive ? arg : builder.create<math::AbsOp>(arg);
163   Value biasedExponentBits = builder.create<arith::ShRUIOp>(
164       builder.create<arith::BitcastOp>(i32Vec, arg0),
165       bcast(i32Cst(builder, 23)));
166   Value biasedExponent =
167       builder.create<arith::SIToFPOp>(f32Vec, biasedExponentBits);
168   Value exponent =
169       builder.create<arith::SubFOp>(biasedExponent, bcast(cst126f));
170 
171   return {normalizedFraction, exponent};
172 }
173 
174 // Computes exp2 for an i32 argument.
175 static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
176   assert(isI32(elementType(arg.getType())) && "argument must be i32 type");
177 
178   auto shape = vectorShape(arg.getType());
179 
180   auto bcast = [&](Value value) -> Value {
181     return broadcast(builder, value, shape);
182   };
183 
184   auto f32Vec = broadcast(builder.getF32Type(), shape);
185   // The exponent of f32 located at 23-bit.
186   auto exponetBitLocation = bcast(i32Cst(builder, 23));
187   // Set the exponent bias to zero.
188   auto bias = bcast(i32Cst(builder, 127));
189 
190   Value biasedArg = builder.create<arith::AddIOp>(arg, bias);
191   Value exp2ValueInt =
192       builder.create<arith::ShLIOp>(biasedArg, exponetBitLocation);
193   Value exp2ValueF32 = builder.create<arith::BitcastOp>(f32Vec, exp2ValueInt);
194 
195   return exp2ValueF32;
196 }
197 
198 namespace {
199 Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
200                                 llvm::ArrayRef<Value> coeffs, Value x) {
201   auto shape = vectorShape(x.getType(), isF32);
202   if (coeffs.size() == 0) {
203     return broadcast(builder, f32Cst(builder, 0.0f), *shape);
204   } else if (coeffs.size() == 1) {
205     return coeffs[0];
206   }
207   Value res = builder.create<math::FmaOp>(x, coeffs[coeffs.size() - 1],
208                                           coeffs[coeffs.size() - 2]);
209   for (auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) {
210     res = builder.create<math::FmaOp>(x, res, coeffs[i]);
211   }
212   return res;
213 }
214 } // namespace
215 
216 //----------------------------------------------------------------------------//
217 // TanhOp approximation.
218 //----------------------------------------------------------------------------//
219 
220 namespace {
221 struct TanhApproximation : public OpRewritePattern<math::TanhOp> {
222 public:
223   using OpRewritePattern::OpRewritePattern;
224 
225   LogicalResult matchAndRewrite(math::TanhOp op,
226                                 PatternRewriter &rewriter) const final;
227 };
228 } // namespace
229 
230 LogicalResult
231 TanhApproximation::matchAndRewrite(math::TanhOp op,
232                                    PatternRewriter &rewriter) const {
233   auto shape = vectorShape(op.operand().getType(), isF32);
234   if (!shape.hasValue())
235     return rewriter.notifyMatchFailure(op, "unsupported operand type");
236 
237   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
238   auto bcast = [&](Value value) -> Value {
239     return broadcast(builder, value, *shape);
240   };
241 
242   // Clamp operand into [plusClamp, minusClamp] range.
243   Value minusClamp = bcast(f32Cst(builder, -7.99881172180175781f));
244   Value plusClamp = bcast(f32Cst(builder, 7.99881172180175781f));
245   Value x = clamp(builder, op.operand(), minusClamp, plusClamp);
246 
247   // Mask for tiny values that are approximated with `operand`.
248   Value tiny = bcast(f32Cst(builder, 0.0004f));
249   Value tinyMask = builder.create<arith::CmpFOp>(
250       arith::CmpFPredicate::OLT, builder.create<math::AbsOp>(op.operand()),
251       tiny);
252 
253   // The monomial coefficients of the numerator polynomial (odd).
254   Value alpha1 = bcast(f32Cst(builder, 4.89352455891786e-03f));
255   Value alpha3 = bcast(f32Cst(builder, 6.37261928875436e-04f));
256   Value alpha5 = bcast(f32Cst(builder, 1.48572235717979e-05f));
257   Value alpha7 = bcast(f32Cst(builder, 5.12229709037114e-08f));
258   Value alpha9 = bcast(f32Cst(builder, -8.60467152213735e-11f));
259   Value alpha11 = bcast(f32Cst(builder, 2.00018790482477e-13f));
260   Value alpha13 = bcast(f32Cst(builder, -2.76076847742355e-16f));
261 
262   // The monomial coefficients of the denominator polynomial (even).
263   Value beta0 = bcast(f32Cst(builder, 4.89352518554385e-03f));
264   Value beta2 = bcast(f32Cst(builder, 2.26843463243900e-03f));
265   Value beta4 = bcast(f32Cst(builder, 1.18534705686654e-04f));
266   Value beta6 = bcast(f32Cst(builder, 1.19825839466702e-06f));
267 
268   // Since the polynomials are odd/even, we need x^2.
269   Value x2 = builder.create<arith::MulFOp>(x, x);
270 
271   // Evaluate the numerator polynomial p.
272   Value p = builder.create<math::FmaOp>(x2, alpha13, alpha11);
273   p = builder.create<math::FmaOp>(x2, p, alpha9);
274   p = builder.create<math::FmaOp>(x2, p, alpha7);
275   p = builder.create<math::FmaOp>(x2, p, alpha5);
276   p = builder.create<math::FmaOp>(x2, p, alpha3);
277   p = builder.create<math::FmaOp>(x2, p, alpha1);
278   p = builder.create<arith::MulFOp>(x, p);
279 
280   // Evaluate the denominator polynomial q.
281   Value q = builder.create<math::FmaOp>(x2, beta6, beta4);
282   q = builder.create<math::FmaOp>(x2, q, beta2);
283   q = builder.create<math::FmaOp>(x2, q, beta0);
284 
285   // Divide the numerator by the denominator.
286   Value res = builder.create<SelectOp>(tinyMask, x,
287                                        builder.create<arith::DivFOp>(p, q));
288 
289   rewriter.replaceOp(op, res);
290 
291   return success();
292 }
293 
294 #define LN2_VALUE                                                              \
295   0.693147180559945309417232121458176568075500134360255254120680009493393621L
296 #define LOG2E_VALUE                                                            \
297   1.442695040888963407359924681001892137426645954152985934135449406931109219L
298 
299 //----------------------------------------------------------------------------//
300 // LogOp and Log2Op approximation.
301 //----------------------------------------------------------------------------//
302 
303 namespace {
304 template <typename Op>
305 struct LogApproximationBase : public OpRewritePattern<Op> {
306   using OpRewritePattern<Op>::OpRewritePattern;
307 
308   /// Base 2 if 'base2' is set; natural logarithm (base e) otherwise.
309   LogicalResult logMatchAndRewrite(Op op, PatternRewriter &rewriter,
310                                    bool base2) const;
311 };
312 } // namespace
313 
314 // This approximation comes from Julien Pommier's SSE math library.
315 // Link: http://gruntthepeon.free.fr/ssemath
316 template <typename Op>
317 LogicalResult
318 LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
319                                              bool base2) const {
320   auto shape = vectorShape(op.operand().getType(), isF32);
321   if (!shape.hasValue())
322     return rewriter.notifyMatchFailure(op, "unsupported operand type");
323 
324   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
325   auto bcast = [&](Value value) -> Value {
326     return broadcast(builder, value, *shape);
327   };
328 
329   Value cstZero = bcast(f32Cst(builder, 0.0f));
330   Value cstOne = bcast(f32Cst(builder, 1.0f));
331   Value cstNegHalf = bcast(f32Cst(builder, -0.5f));
332 
333   // The smallest non denormalized float number.
334   Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u));
335   Value cstMinusInf = bcast(f32FromBits(builder, 0xff800000u));
336   Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u));
337   Value cstNan = bcast(f32FromBits(builder, 0x7fc00000));
338 
339   // Polynomial coefficients.
340   Value cstCephesSQRTHF = bcast(f32Cst(builder, 0.707106781186547524f));
341   Value cstCephesLogP0 = bcast(f32Cst(builder, 7.0376836292E-2f));
342   Value cstCephesLogP1 = bcast(f32Cst(builder, -1.1514610310E-1f));
343   Value cstCephesLogP2 = bcast(f32Cst(builder, 1.1676998740E-1f));
344   Value cstCephesLogP3 = bcast(f32Cst(builder, -1.2420140846E-1f));
345   Value cstCephesLogP4 = bcast(f32Cst(builder, +1.4249322787E-1f));
346   Value cstCephesLogP5 = bcast(f32Cst(builder, -1.6668057665E-1f));
347   Value cstCephesLogP6 = bcast(f32Cst(builder, +2.0000714765E-1f));
348   Value cstCephesLogP7 = bcast(f32Cst(builder, -2.4999993993E-1f));
349   Value cstCephesLogP8 = bcast(f32Cst(builder, +3.3333331174E-1f));
350 
351   Value x = op.operand();
352 
353   // Truncate input values to the minimum positive normal.
354   x = max(builder, x, cstMinNormPos);
355 
356   // Extract significant in the range [0.5,1) and exponent.
357   std::pair<Value, Value> pair = frexp(builder, x, /*is_positive=*/true);
358   x = pair.first;
359   Value e = pair.second;
360 
361   // Shift the inputs from the range [0.5,1) to [sqrt(1/2), sqrt(2)) and shift
362   // by -1.0. The values are then centered around 0, which improves the
363   // stability of the polynomial evaluation:
364   //
365   //   if( x < SQRTHF ) {
366   //     e -= 1;
367   //     x = x + x - 1.0;
368   //   } else { x = x - 1.0; }
369   Value mask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x,
370                                              cstCephesSQRTHF);
371   Value tmp = builder.create<SelectOp>(mask, x, cstZero);
372 
373   x = builder.create<arith::SubFOp>(x, cstOne);
374   e = builder.create<arith::SubFOp>(
375       e, builder.create<SelectOp>(mask, cstOne, cstZero));
376   x = builder.create<arith::AddFOp>(x, tmp);
377 
378   Value x2 = builder.create<arith::MulFOp>(x, x);
379   Value x3 = builder.create<arith::MulFOp>(x2, x);
380 
381   // Evaluate the polynomial approximant of degree 8 in three parts.
382   Value y0, y1, y2;
383   y0 = builder.create<math::FmaOp>(cstCephesLogP0, x, cstCephesLogP1);
384   y1 = builder.create<math::FmaOp>(cstCephesLogP3, x, cstCephesLogP4);
385   y2 = builder.create<math::FmaOp>(cstCephesLogP6, x, cstCephesLogP7);
386   y0 = builder.create<math::FmaOp>(y0, x, cstCephesLogP2);
387   y1 = builder.create<math::FmaOp>(y1, x, cstCephesLogP5);
388   y2 = builder.create<math::FmaOp>(y2, x, cstCephesLogP8);
389   y0 = builder.create<math::FmaOp>(y0, x3, y1);
390   y0 = builder.create<math::FmaOp>(y0, x3, y2);
391   y0 = builder.create<arith::MulFOp>(y0, x3);
392 
393   y0 = builder.create<math::FmaOp>(cstNegHalf, x2, y0);
394   x = builder.create<arith::AddFOp>(x, y0);
395 
396   if (base2) {
397     Value cstLog2e = bcast(f32Cst(builder, static_cast<float>(LOG2E_VALUE)));
398     x = builder.create<math::FmaOp>(x, cstLog2e, e);
399   } else {
400     Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE)));
401     x = builder.create<math::FmaOp>(e, cstLn2, x);
402   }
403 
404   Value invalidMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::ULT,
405                                                     op.operand(), cstZero);
406   Value zeroMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
407                                                  op.operand(), cstZero);
408   Value posInfMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
409                                                    op.operand(), cstPosInf);
410 
411   // Filter out invalid values:
412   //  • x == 0     -> -INF
413   //  • x < 0      ->  NAN
414   //  • x == +INF  -> +INF
415   Value aproximation = builder.create<SelectOp>(
416       zeroMask, cstMinusInf,
417       builder.create<SelectOp>(
418           invalidMask, cstNan,
419           builder.create<SelectOp>(posInfMask, cstPosInf, x)));
420 
421   rewriter.replaceOp(op, aproximation);
422 
423   return success();
424 }
425 
426 namespace {
427 struct LogApproximation : public LogApproximationBase<math::LogOp> {
428   using LogApproximationBase::LogApproximationBase;
429 
430   LogicalResult matchAndRewrite(math::LogOp op,
431                                 PatternRewriter &rewriter) const final {
432     return logMatchAndRewrite(op, rewriter, /*base2=*/false);
433   }
434 };
435 } // namespace
436 
437 namespace {
438 struct Log2Approximation : public LogApproximationBase<math::Log2Op> {
439   using LogApproximationBase::LogApproximationBase;
440 
441   LogicalResult matchAndRewrite(math::Log2Op op,
442                                 PatternRewriter &rewriter) const final {
443     return logMatchAndRewrite(op, rewriter, /*base2=*/true);
444   }
445 };
446 } // namespace
447 
448 //----------------------------------------------------------------------------//
449 // Log1p approximation.
450 //----------------------------------------------------------------------------//
451 
452 namespace {
453 struct Log1pApproximation : public OpRewritePattern<math::Log1pOp> {
454 public:
455   using OpRewritePattern::OpRewritePattern;
456 
457   LogicalResult matchAndRewrite(math::Log1pOp op,
458                                 PatternRewriter &rewriter) const final;
459 };
460 } // namespace
461 
462 // Approximate log(1+x).
463 LogicalResult
464 Log1pApproximation::matchAndRewrite(math::Log1pOp op,
465                                     PatternRewriter &rewriter) const {
466   auto shape = vectorShape(op.operand().getType(), isF32);
467   if (!shape.hasValue())
468     return rewriter.notifyMatchFailure(op, "unsupported operand type");
469 
470   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
471   auto bcast = [&](Value value) -> Value {
472     return broadcast(builder, value, *shape);
473   };
474 
475   // Approximate log(1+x) using the following, due to W. Kahan:
476   //   u = x + 1.0;
477   //   if (u == 1.0 || u == inf) return x;
478   //   return x * log(u) / (u - 1.0);
479   //          ^^^^^^^^^^^^^^^^^^^^^^
480   //             "logLarge" below.
481   Value cstOne = bcast(f32Cst(builder, 1.0f));
482   Value x = op.operand();
483   Value u = builder.create<arith::AddFOp>(x, cstOne);
484   Value uSmall =
485       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne);
486   Value logU = builder.create<math::LogOp>(u);
487   Value uInf =
488       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, logU);
489   Value logLarge = builder.create<arith::MulFOp>(
490       x, builder.create<arith::DivFOp>(
491              logU, builder.create<arith::SubFOp>(u, cstOne)));
492   Value approximation = builder.create<SelectOp>(
493       builder.create<arith::OrIOp>(uSmall, uInf), x, logLarge);
494   rewriter.replaceOp(op, approximation);
495   return success();
496 }
497 
498 //----------------------------------------------------------------------------//
499 // Erf approximation.
500 //----------------------------------------------------------------------------//
501 
502 // Approximates erf(x) with
503 // a - P(x)/Q(x)
504 // where P and Q are polynomials of degree 4.
505 // Different coefficients are chosen based on the value of x.
506 // The approximation error is ~2.5e-07.
507 // Boost's minimax tool that utilizes the Remez method was used to find the
508 // coefficients.
509 LogicalResult
510 ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
511                                             PatternRewriter &rewriter) const {
512   auto shape = vectorShape(op.operand().getType(), isF32);
513   if (!shape.hasValue())
514     return rewriter.notifyMatchFailure(op, "unsupported operand type");
515 
516   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
517   auto bcast = [&](Value value) -> Value {
518     return broadcast(builder, value, *shape);
519   };
520 
521   const int intervalsCount = 3;
522   const int polyDegree = 4;
523 
524   Value zero = bcast(f32Cst(builder, 0));
525   Value one = bcast(f32Cst(builder, 1));
526   Value pp[intervalsCount][polyDegree + 1];
527   pp[0][0] = bcast(f32Cst(builder, +0.00000000000000000e+00f));
528   pp[0][1] = bcast(f32Cst(builder, +1.12837916222975858e+00f));
529   pp[0][2] = bcast(f32Cst(builder, -5.23018562988006470e-01f));
530   pp[0][3] = bcast(f32Cst(builder, +2.09741709609267072e-01f));
531   pp[0][4] = bcast(f32Cst(builder, +2.58146801602987875e-02f));
532   pp[1][0] = bcast(f32Cst(builder, +0.00000000000000000e+00f));
533   pp[1][1] = bcast(f32Cst(builder, +1.12750687816789140e+00f));
534   pp[1][2] = bcast(f32Cst(builder, -3.64721408487825775e-01f));
535   pp[1][3] = bcast(f32Cst(builder, +1.18407396425136952e-01f));
536   pp[1][4] = bcast(f32Cst(builder, +3.70645533056476558e-02f));
537   pp[2][0] = bcast(f32Cst(builder, -3.30093071049483172e-03f));
538   pp[2][1] = bcast(f32Cst(builder, +3.51961938357697011e-03f));
539   pp[2][2] = bcast(f32Cst(builder, -1.41373622814988039e-03f));
540   pp[2][3] = bcast(f32Cst(builder, +2.53447094961941348e-04f));
541   pp[2][4] = bcast(f32Cst(builder, -1.71048029455037401e-05f));
542 
543   Value qq[intervalsCount][polyDegree + 1];
544   qq[0][0] = bcast(f32Cst(builder, +1.000000000000000000e+00f));
545   qq[0][1] = bcast(f32Cst(builder, -4.635138185962547255e-01f));
546   qq[0][2] = bcast(f32Cst(builder, +5.192301327279782447e-01f));
547   qq[0][3] = bcast(f32Cst(builder, -1.318089722204810087e-01f));
548   qq[0][4] = bcast(f32Cst(builder, +7.397964654672315005e-02f));
549   qq[1][0] = bcast(f32Cst(builder, +1.00000000000000000e+00f));
550   qq[1][1] = bcast(f32Cst(builder, -3.27607011824493086e-01f));
551   qq[1][2] = bcast(f32Cst(builder, +4.48369090658821977e-01f));
552   qq[1][3] = bcast(f32Cst(builder, -8.83462621207857930e-02f));
553   qq[1][4] = bcast(f32Cst(builder, +5.72442770283176093e-02f));
554   qq[2][0] = bcast(f32Cst(builder, +1.00000000000000000e+00f));
555   qq[2][1] = bcast(f32Cst(builder, -2.06069165953913769e+00f));
556   qq[2][2] = bcast(f32Cst(builder, +1.62705939945477759e+00f));
557   qq[2][3] = bcast(f32Cst(builder, -5.83389859211130017e-01f));
558   qq[2][4] = bcast(f32Cst(builder, +8.21908939856640930e-02f));
559 
560   Value offsets[intervalsCount];
561   offsets[0] = bcast(f32Cst(builder, 0.0f));
562   offsets[1] = bcast(f32Cst(builder, 0.0f));
563   offsets[2] = bcast(f32Cst(builder, 1.0f));
564 
565   Value bounds[intervalsCount];
566   bounds[0] = bcast(f32Cst(builder, 0.8f));
567   bounds[1] = bcast(f32Cst(builder, 2.0f));
568   bounds[2] = bcast(f32Cst(builder, 3.75f));
569 
570   Value isNegativeArg = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT,
571                                                       op.operand(), zero);
572   Value negArg = builder.create<arith::NegFOp>(op.operand());
573   Value x = builder.create<SelectOp>(isNegativeArg, negArg, op.operand());
574 
575   Value offset = offsets[0];
576   Value p[polyDegree + 1];
577   Value q[polyDegree + 1];
578   for (int i = 0; i <= polyDegree; ++i) {
579     p[i] = pp[0][i];
580     q[i] = qq[0][i];
581   }
582 
583   // TODO: maybe use vector stacking to reduce the number of selects.
584   Value isLessThanBound[intervalsCount];
585   for (int j = 0; j < intervalsCount - 1; ++j) {
586     isLessThanBound[j] =
587         builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, bounds[j]);
588     for (int i = 0; i <= polyDegree; ++i) {
589       p[i] = builder.create<SelectOp>(isLessThanBound[j], p[i], pp[j + 1][i]);
590       q[i] = builder.create<SelectOp>(isLessThanBound[j], q[i], qq[j + 1][i]);
591     }
592     offset =
593         builder.create<SelectOp>(isLessThanBound[j], offset, offsets[j + 1]);
594   }
595   isLessThanBound[intervalsCount - 1] = builder.create<arith::CmpFOp>(
596       arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]);
597 
598   Value pPoly = makePolynomialCalculation(builder, p, x);
599   Value qPoly = makePolynomialCalculation(builder, q, x);
600   Value rationalPoly = builder.create<arith::DivFOp>(pPoly, qPoly);
601   Value formula = builder.create<arith::AddFOp>(offset, rationalPoly);
602   formula = builder.create<SelectOp>(isLessThanBound[intervalsCount - 1],
603                                      formula, one);
604 
605   // erf is odd function: erf(x) = -erf(-x).
606   Value negFormula = builder.create<arith::NegFOp>(formula);
607   Value res = builder.create<SelectOp>(isNegativeArg, negFormula, formula);
608 
609   rewriter.replaceOp(op, res);
610 
611   return success();
612 }
613 
614 //----------------------------------------------------------------------------//
615 // Exp approximation.
616 //----------------------------------------------------------------------------//
617 
618 namespace {
619 
620 struct ExpApproximation : public OpRewritePattern<math::ExpOp> {
621 public:
622   using OpRewritePattern::OpRewritePattern;
623 
624   LogicalResult matchAndRewrite(math::ExpOp op,
625                                 PatternRewriter &rewriter) const final;
626 };
627 } // namespace
628 
629 // Approximate exp(x) using its reduced range exp(y) where y is in the range
630 // [0, ln(2)], let y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2), exp(x)
631 // = exp(y) * 2^k. exp(y).
632 LogicalResult
633 ExpApproximation::matchAndRewrite(math::ExpOp op,
634                                   PatternRewriter &rewriter) const {
635   auto shape = vectorShape(op.operand().getType(), isF32);
636   if (!shape.hasValue())
637     return rewriter.notifyMatchFailure(op, "unsupported operand type");
638   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
639 
640   // TODO: Consider a common pattern rewriter with all methods below to
641   // write the approximations.
642   auto bcast = [&](Value value) -> Value {
643     return broadcast(builder, value, *shape);
644   };
645   auto fmla = [&](Value a, Value b, Value c) {
646     return builder.create<math::FmaOp>(a, b, c);
647   };
648   auto mul = [&](Value a, Value b) -> Value {
649     return builder.create<arith::MulFOp>(a, b);
650   };
651   auto sub = [&](Value a, Value b) -> Value {
652     return builder.create<arith::SubFOp>(a, b);
653   };
654   auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); };
655 
656   Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE)));
657   Value cstLog2E = bcast(f32Cst(builder, static_cast<float>(LOG2E_VALUE)));
658 
659   // Polynomial coefficients.
660   Value cstCephesExpP0 = bcast(f32Cst(builder, 1.0));
661   Value cstCephesExpP1 = bcast(f32Cst(builder, 1.0));
662   Value cstCephesExpP2 = bcast(f32Cst(builder, 0.49970514590562437052f));
663   Value cstCephesExpP3 = bcast(f32Cst(builder, 0.16873890085469545053f));
664   Value cstCephesExpP4 = bcast(f32Cst(builder, 0.03668965196652099192f));
665   Value cstCephesExpP5 = bcast(f32Cst(builder, 0.01314350012789660196f));
666 
667   Value x = op.operand();
668 
669   // Reduced y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2)
670   Value xL2Inv = mul(x, cstLog2E);
671   Value kF32 = floor(xL2Inv);
672   Value kLn2 = mul(kF32, cstLn2);
673   Value y = sub(x, kLn2);
674 
675   // Use Estrin's evaluation scheme with 3 independent parts:
676   // P(y)^y : (c0 + c1 y) + (c2 + c3 y) y^2 + (c4 + c5 y) y^4
677   Value y2 = mul(y, y);
678   Value y4 = mul(y2, y2);
679 
680   Value q0 = fmla(cstCephesExpP1, y, cstCephesExpP0);
681   Value q1 = fmla(cstCephesExpP3, y, cstCephesExpP2);
682   Value q2 = fmla(cstCephesExpP5, y, cstCephesExpP4);
683   Value expY = fmla(q1, y2, q0);
684   expY = fmla(q2, y4, expY);
685 
686   auto i32Vec = broadcast(builder.getI32Type(), *shape);
687 
688   // exp2(k)
689   Value k = builder.create<arith::FPToSIOp>(kF32, i32Vec);
690   Value exp2KValue = exp2I32(builder, k);
691 
692   // exp(x) = exp(y) * exp2(k)
693   expY = mul(expY, exp2KValue);
694 
695   // Handle overflow, inf and underflow of exp(x). exp(x) range is [0, inf], its
696   // partitioned as the following:
697   // exp(x) = 0, x <= -inf
698   // exp(x) = underflow (min_float), x <= -88
699   // exp(x) = inf (min_float), x >= 88
700   // Note: |k| = 127 is the value where the 8-bits exponent saturates.
701   Value zerof32Const = bcast(f32Cst(builder, 0));
702   auto constPosInfinity =
703       bcast(f32Cst(builder, std::numeric_limits<float>::infinity()));
704   auto constNegIfinity =
705       bcast(f32Cst(builder, -std::numeric_limits<float>::infinity()));
706   auto underflow = bcast(f32Cst(builder, std::numeric_limits<float>::min()));
707 
708   Value kMaxConst = bcast(i32Cst(builder, 127));
709   Value kMaxNegConst = bcast(i32Cst(builder, -127));
710   Value rightBound =
711       builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, k, kMaxConst);
712   Value leftBound =
713       builder.create<arith::CmpIOp>(arith::CmpIPredicate::sge, k, kMaxNegConst);
714 
715   Value isNegInfinityX = builder.create<arith::CmpFOp>(
716       arith::CmpFPredicate::OEQ, x, constNegIfinity);
717   Value isPosInfinityX = builder.create<arith::CmpFOp>(
718       arith::CmpFPredicate::OEQ, x, constPosInfinity);
719   Value isPostiveX =
720       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zerof32Const);
721   Value isComputable = builder.create<arith::AndIOp>(rightBound, leftBound);
722 
723   expY = builder.create<SelectOp>(
724       isNegInfinityX, zerof32Const,
725       builder.create<SelectOp>(
726           isPosInfinityX, constPosInfinity,
727           builder.create<SelectOp>(isComputable, expY,
728                                    builder.create<SelectOp>(isPostiveX,
729                                                             constPosInfinity,
730                                                             underflow))));
731 
732   rewriter.replaceOp(op, expY);
733 
734   return success();
735 }
736 
737 //----------------------------------------------------------------------------//
738 // ExpM1 approximation.
739 //----------------------------------------------------------------------------//
740 
741 namespace {
742 
743 struct ExpM1Approximation : public OpRewritePattern<math::ExpM1Op> {
744 public:
745   using OpRewritePattern::OpRewritePattern;
746 
747   LogicalResult matchAndRewrite(math::ExpM1Op op,
748                                 PatternRewriter &rewriter) const final;
749 };
750 } // namespace
751 
752 LogicalResult
753 ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
754                                     PatternRewriter &rewriter) const {
755   auto shape = vectorShape(op.operand().getType(), isF32);
756   if (!shape.hasValue())
757     return rewriter.notifyMatchFailure(op, "unsupported operand type");
758 
759   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
760   auto bcast = [&](Value value) -> Value {
761     return broadcast(builder, value, *shape);
762   };
763 
764   // expm1(x) = exp(x) - 1 = u - 1.
765   // We have to handle it carefully when x is near 0, i.e. u ~= 1,
766   // and when the input is ~= -inf, i.e. u - 1 ~= -1.
767   Value cstOne = bcast(f32Cst(builder, 1.0f));
768   Value cstNegOne = bcast(f32Cst(builder, -1.0f));
769   Value x = op.operand();
770   Value u = builder.create<math::ExpOp>(x);
771   Value uEqOne =
772       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne);
773   Value uMinusOne = builder.create<arith::SubFOp>(u, cstOne);
774   Value uMinusOneEqNegOne = builder.create<arith::CmpFOp>(
775       arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne);
776   // logU = log(u) ~= x
777   Value logU = builder.create<math::LogOp>(u);
778 
779   // Detect exp(x) = +inf; written this way to avoid having to form +inf.
780   Value isInf =
781       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, logU, u);
782 
783   // (u - 1) * (x / ~x)
784   Value expm1 = builder.create<arith::MulFOp>(
785       uMinusOne, builder.create<arith::DivFOp>(x, logU));
786   expm1 = builder.create<SelectOp>(isInf, u, expm1);
787   Value approximation = builder.create<SelectOp>(
788       uEqOne, x, builder.create<SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1));
789   rewriter.replaceOp(op, approximation);
790   return success();
791 }
792 
793 //----------------------------------------------------------------------------//
794 // Sin and Cos approximation.
795 //----------------------------------------------------------------------------//
796 
797 namespace {
798 
799 template <bool isSine, typename OpTy>
800 struct SinAndCosApproximation : public OpRewritePattern<OpTy> {
801 public:
802   using OpRewritePattern<OpTy>::OpRewritePattern;
803 
804   LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final;
805 };
806 } // namespace
807 
808 #define TWO_OVER_PI                                                            \
809   0.6366197723675813430755350534900574481378385829618257949906693762L
810 #define PI_OVER_2                                                              \
811   1.5707963267948966192313216916397514420985846996875529104874722961L
812 
813 // Approximates sin(x) or cos(x) by finding the best approximation polynomial in
814 // the reduced range [0, pi/2] for both sin(x) and cos(x). Then given y in the
815 // reduced range sin(x) will be computed as sin(y), -sin(y), cos(y) or -cos(y).
816 template <bool isSine, typename OpTy>
817 LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
818     OpTy op, PatternRewriter &rewriter) const {
819   static_assert(
820       llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value,
821       "SinAndCosApproximation pattern expects math::SinOp or math::CosOp");
822   auto shape = vectorShape(op.operand().getType(), isF32);
823   if (!shape.hasValue())
824     return rewriter.notifyMatchFailure(op, "unsupported operand type");
825 
826   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
827   auto bcast = [&](Value value) -> Value {
828     return broadcast(builder, value, *shape);
829   };
830   auto mul = [&](Value a, Value b) -> Value {
831     return builder.create<arith::MulFOp>(a, b);
832   };
833   auto sub = [&](Value a, Value b) -> Value {
834     return builder.create<arith::SubFOp>(a, b);
835   };
836   auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); };
837 
838   auto i32Vec = broadcast(builder.getI32Type(), *shape);
839   auto fPToSingedInteger = [&](Value a) -> Value {
840     return builder.create<arith::FPToSIOp>(a, i32Vec);
841   };
842 
843   auto modulo4 = [&](Value a) -> Value {
844     return builder.create<arith::AndIOp>(a, bcast(i32Cst(builder, 3)));
845   };
846 
847   auto isEqualTo = [&](Value a, Value b) -> Value {
848     return builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, a, b);
849   };
850 
851   auto isGreaterThan = [&](Value a, Value b) -> Value {
852     return builder.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, a, b);
853   };
854 
855   auto select = [&](Value cond, Value t, Value f) -> Value {
856     return builder.create<SelectOp>(cond, t, f);
857   };
858 
859   auto fmla = [&](Value a, Value b, Value c) {
860     return builder.create<math::FmaOp>(a, b, c);
861   };
862 
863   auto bitwiseOr = [&](Value a, Value b) {
864     return builder.create<arith::OrIOp>(a, b);
865   };
866 
867   Value twoOverPi = bcast(f32Cst(builder, TWO_OVER_PI));
868   Value piOverTwo = bcast(f32Cst(builder, PI_OVER_2));
869 
870   Value x = op.operand();
871 
872   Value k = floor(mul(x, twoOverPi));
873 
874   Value y = sub(x, mul(k, piOverTwo));
875 
876   Value cstOne = bcast(f32Cst(builder, 1.0));
877   Value cstNegativeOne = bcast(f32Cst(builder, -1.0));
878 
879   Value cstSC2 = bcast(f32Cst(builder, -0.16666667163372039794921875f));
880   Value cstSC4 = bcast(f32Cst(builder, 8.333347737789154052734375e-3f));
881   Value cstSC6 = bcast(f32Cst(builder, -1.9842604524455964565277099609375e-4f));
882   Value cstSC8 =
883       bcast(f32Cst(builder, 2.760012648650445044040679931640625e-6f));
884   Value cstSC10 =
885       bcast(f32Cst(builder, -2.50293279435709337121807038784027099609375e-8f));
886 
887   Value cstCC2 = bcast(f32Cst(builder, -0.5f));
888   Value cstCC4 = bcast(f32Cst(builder, 4.166664183139801025390625e-2f));
889   Value cstCC6 = bcast(f32Cst(builder, -1.388833043165504932403564453125e-3f));
890   Value cstCC8 = bcast(f32Cst(builder, 2.47562347794882953166961669921875e-5f));
891   Value cstCC10 =
892       bcast(f32Cst(builder, -2.59630184018533327616751194000244140625e-7f));
893 
894   Value kMod4 = modulo4(fPToSingedInteger(k));
895 
896   Value kR0 = isEqualTo(kMod4, bcast(i32Cst(builder, 0)));
897   Value kR1 = isEqualTo(kMod4, bcast(i32Cst(builder, 1)));
898   Value kR2 = isEqualTo(kMod4, bcast(i32Cst(builder, 2)));
899   Value kR3 = isEqualTo(kMod4, bcast(i32Cst(builder, 3)));
900 
901   Value sinuseCos = isSine ? bitwiseOr(kR1, kR3) : bitwiseOr(kR0, kR2);
902   Value negativeRange = isSine ? isGreaterThan(kMod4, bcast(i32Cst(builder, 1)))
903                                : bitwiseOr(kR1, kR2);
904 
905   Value y2 = mul(y, y);
906 
907   Value base = select(sinuseCos, cstOne, y);
908   Value cstC2 = select(sinuseCos, cstCC2, cstSC2);
909   Value cstC4 = select(sinuseCos, cstCC4, cstSC4);
910   Value cstC6 = select(sinuseCos, cstCC6, cstSC6);
911   Value cstC8 = select(sinuseCos, cstCC8, cstSC8);
912   Value cstC10 = select(sinuseCos, cstCC10, cstSC10);
913 
914   Value v1 = fmla(y2, cstC10, cstC8);
915   Value v2 = fmla(y2, v1, cstC6);
916   Value v3 = fmla(y2, v2, cstC4);
917   Value v4 = fmla(y2, v3, cstC2);
918   Value v5 = fmla(y2, v4, cstOne);
919   Value v6 = mul(base, v5);
920 
921   Value approximation = select(negativeRange, mul(cstNegativeOne, v6), v6);
922 
923   rewriter.replaceOp(op, approximation);
924 
925   return success();
926 }
927 
928 //----------------------------------------------------------------------------//
929 // Rsqrt approximation.
930 //----------------------------------------------------------------------------//
931 
932 namespace {
933 struct RsqrtApproximation : public OpRewritePattern<math::RsqrtOp> {
934   using OpRewritePattern::OpRewritePattern;
935 
936   LogicalResult matchAndRewrite(math::RsqrtOp op,
937                                 PatternRewriter &rewriter) const final;
938 };
939 } // namespace
940 
941 LogicalResult
942 RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
943                                     PatternRewriter &rewriter) const {
944   auto shape = vectorShape(op.operand().getType(), isF32);
945   // Only support already-vectorized rsqrt's.
946   if (!shape.hasValue() || (*shape)[0] != 8)
947     return rewriter.notifyMatchFailure(op, "unsupported operand type");
948 
949   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
950   auto bcast = [&](Value value) -> Value {
951     return broadcast(builder, value, *shape);
952   };
953 
954   Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u));
955   Value cstOnePointFive = bcast(f32Cst(builder, 1.5f));
956   Value cstNegHalf = bcast(f32Cst(builder, -0.5f));
957   Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u));
958 
959   Value negHalf = builder.create<arith::MulFOp>(op.operand(), cstNegHalf);
960 
961   // Select only the inverse sqrt of positive normals (denormals are
962   // flushed to zero).
963   Value ltMinMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT,
964                                                   op.operand(), cstMinNormPos);
965   Value infMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
966                                                 op.operand(), cstPosInf);
967   Value notNormalFiniteMask = builder.create<arith::OrIOp>(ltMinMask, infMask);
968 
969   // Compute an approximate result.
970   Value yApprox = builder.create<x86vector::RsqrtOp>(op.operand());
971 
972   // Do a single step of Newton-Raphson iteration to improve the approximation.
973   // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n).
974   // It is essential to evaluate the inner term like this because forming
975   // y_n^2 may over- or underflow.
976   Value inner = builder.create<arith::MulFOp>(negHalf, yApprox);
977   Value fma = builder.create<math::FmaOp>(yApprox, inner, cstOnePointFive);
978   Value yNewton = builder.create<arith::MulFOp>(yApprox, fma);
979 
980   // Select the result of the Newton-Raphson step for positive normal arguments.
981   // For other arguments, choose the output of the intrinsic. This will
982   // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(x) = +inf if
983   // x is zero or a positive denormalized float (equivalent to flushing positive
984   // denormalized inputs to zero).
985   Value res = builder.create<SelectOp>(notNormalFiniteMask, yApprox, yNewton);
986   rewriter.replaceOp(op, res);
987 
988   return success();
989 }
990 
991 //----------------------------------------------------------------------------//
992 
993 void mlir::populateMathPolynomialApproximationPatterns(
994     RewritePatternSet &patterns,
995     const MathPolynomialApproximationOptions &options) {
996   patterns.add<TanhApproximation, LogApproximation, Log2Approximation,
997                Log1pApproximation, ErfPolynomialApproximation, ExpApproximation,
998                ExpM1Approximation, SinAndCosApproximation<true, math::SinOp>,
999                SinAndCosApproximation<false, math::CosOp>>(
1000       patterns.getContext());
1001   if (options.enableAvx2)
1002     patterns.add<RsqrtApproximation>(patterns.getContext());
1003 }
1004