1 //===- PolynomialApproximation.cpp - Approximate math operations ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements expansion of math operations to fast approximations
10 // that do not rely on any of the library functions.
11 //
12 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
15 #include "mlir/Dialect/Math/IR/Math.h"
16 #include "mlir/Dialect/Math/Transforms/Passes.h"
17 #include "mlir/Dialect/Vector/VectorOps.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/ImplicitLocOpBuilder.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22 #include <limits.h>
23 
24 using namespace mlir;
25 using namespace mlir::vector;
26 
27 using TypePredicate = llvm::function_ref<bool(Type)>;
28 
29 // Returns vector width if the element type is matching the predicate (scalars
30 // that do match the predicate have width equal to `1`).
31 static Optional<int> vectorWidth(Type type, TypePredicate pred) {
32   // If the type matches the predicate then its width is `1`.
33   if (pred(type))
34     return 1;
35 
36   // Otherwise check if the type is a vector type.
37   auto vectorType = type.dyn_cast<VectorType>();
38   if (vectorType && pred(vectorType.getElementType())) {
39     assert(vectorType.getRank() == 1 && "only 1d vectors are supported");
40     return vectorType.getDimSize(0);
41   }
42 
43   return llvm::None;
44 }
45 
46 // Returns vector width of the type. If the type is a scalar returns `1`.
47 static int vectorWidth(Type type) {
48   auto vectorType = type.dyn_cast<VectorType>();
49   return vectorType ? vectorType.getDimSize(0) : 1;
50 }
51 
52 // Returns vector element type. If the type is a scalar returns the argument.
53 LLVM_ATTRIBUTE_UNUSED static Type elementType(Type type) {
54   auto vectorType = type.dyn_cast<VectorType>();
55   return vectorType ? vectorType.getElementType() : type;
56 }
57 
58 LLVM_ATTRIBUTE_UNUSED static bool isF32(Type type) { return type.isF32(); }
59 
60 LLVM_ATTRIBUTE_UNUSED static bool isI32(Type type) {
61   return type.isInteger(32);
62 }
63 
64 //----------------------------------------------------------------------------//
65 // Broadcast scalar types and values into vector types and values.
66 //----------------------------------------------------------------------------//
67 
68 // Broadcasts scalar type into vector type (iff width is greater then 1).
69 static Type broadcast(Type type, int width) {
70   assert(!type.isa<VectorType>() && "must be scalar type");
71   return width > 1 ? VectorType::get({width}, type) : type;
72 }
73 
74 // Broadcasts scalar value into vector (iff width is greater then 1).
75 static Value broadcast(ImplicitLocOpBuilder &builder, Value value, int width) {
76   assert(!value.getType().isa<VectorType>() && "must be scalar value");
77   auto type = broadcast(value.getType(), width);
78   return width > 1 ? builder.create<BroadcastOp>(type, value) : value;
79 }
80 
81 //----------------------------------------------------------------------------//
82 // Helper functions to create constants.
83 //----------------------------------------------------------------------------//
84 
85 static Value f32Cst(ImplicitLocOpBuilder &builder, float value) {
86   return builder.create<ConstantOp>(builder.getF32Type(),
87                                     builder.getF32FloatAttr(value));
88 }
89 
90 static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value) {
91   return builder.create<ConstantOp>(builder.getI32Type(),
92                                     builder.getI32IntegerAttr(value));
93 }
94 
95 static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits) {
96   Value i32Value = i32Cst(builder, static_cast<int32_t>(bits));
97   return builder.create<LLVM::BitcastOp>(builder.getF32Type(), i32Value);
98 }
99 
100 //----------------------------------------------------------------------------//
101 // Helper functions to build math functions approximations.
102 //----------------------------------------------------------------------------//
103 
104 static Value min(ImplicitLocOpBuilder &builder, Value a, Value b) {
105   return builder.create<SelectOp>(
106       builder.create<CmpFOp>(CmpFPredicate::OLT, a, b), a, b);
107 }
108 
109 static Value max(ImplicitLocOpBuilder &builder, Value a, Value b) {
110   return builder.create<SelectOp>(
111       builder.create<CmpFOp>(CmpFPredicate::OGT, a, b), a, b);
112 }
113 
114 static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound,
115                    Value upperBound) {
116   return max(builder, min(builder, value, upperBound), lowerBound);
117 }
118 
119 // Decomposes given floating point value `arg` into a normalized fraction and
120 // an integral power of two (see std::frexp). Returned values have float type.
121 static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
122                                      bool is_positive = false) {
123   assert(isF32(elementType(arg.getType())) && "argument must be f32 type");
124 
125   int width = vectorWidth(arg.getType());
126 
127   auto bcast = [&](Value value) -> Value {
128     return broadcast(builder, value, width);
129   };
130 
131   auto i32 = builder.getIntegerType(32);
132   auto i32Vec = broadcast(i32, width);
133   auto f32Vec = broadcast(builder.getF32Type(), width);
134 
135   Value cst126f = f32Cst(builder, 126.0f);
136   Value cstHalf = f32Cst(builder, 0.5f);
137   Value cstInvMantMask = f32FromBits(builder, ~0x7f800000u);
138 
139   // Bitcast to i32 for bitwise operations.
140   Value i32Half = builder.create<LLVM::BitcastOp>(i32, cstHalf);
141   Value i32InvMantMask = builder.create<LLVM::BitcastOp>(i32, cstInvMantMask);
142   Value i32Arg = builder.create<LLVM::BitcastOp>(i32Vec, arg);
143 
144   // Compute normalized fraction.
145   Value tmp0 = builder.create<LLVM::AndOp>(i32Arg, bcast(i32InvMantMask));
146   Value tmp1 = builder.create<LLVM::OrOp>(tmp0, bcast(i32Half));
147   Value normalizedFraction = builder.create<LLVM::BitcastOp>(f32Vec, tmp1);
148 
149   // Compute exponent.
150   Value arg0 = is_positive ? arg : builder.create<AbsFOp>(arg);
151   Value biasedExponentBits = builder.create<UnsignedShiftRightOp>(
152       builder.create<LLVM::BitcastOp>(i32Vec, arg0),
153       bcast(i32Cst(builder, 23)));
154   Value biasedExponent = builder.create<SIToFPOp>(f32Vec, biasedExponentBits);
155   Value exponent = builder.create<SubFOp>(biasedExponent, bcast(cst126f));
156 
157   return {normalizedFraction, exponent};
158 }
159 
160 // Computes exp2 for an i32 argument.
161 static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
162   assert(isI32(elementType(arg.getType())) && "argument must be i32 type");
163 
164   int width = vectorWidth(arg.getType());
165 
166   auto bcast = [&](Value value) -> Value {
167     return broadcast(builder, value, width);
168   };
169 
170   auto f32Vec = broadcast(builder.getF32Type(), width);
171   // The exponent of f32 located at 23-bit.
172   auto exponetBitLocation = bcast(i32Cst(builder, 23));
173   // Set the exponent bias to zero.
174   auto bias = bcast(i32Cst(builder, 127));
175 
176   Value biasedArg = builder.create<AddIOp>(arg, bias);
177   Value exp2ValueInt =
178       builder.create<ShiftLeftOp>(biasedArg, exponetBitLocation);
179   Value exp2ValueF32 = builder.create<LLVM::BitcastOp>(f32Vec, exp2ValueInt);
180 
181   return exp2ValueF32;
182 }
183 
184 //----------------------------------------------------------------------------//
185 // TanhOp approximation.
186 //----------------------------------------------------------------------------//
187 
188 namespace {
189 struct TanhApproximation : public OpRewritePattern<math::TanhOp> {
190 public:
191   using OpRewritePattern::OpRewritePattern;
192 
193   LogicalResult matchAndRewrite(math::TanhOp op,
194                                 PatternRewriter &rewriter) const final;
195 };
196 } // namespace
197 
198 LogicalResult
199 TanhApproximation::matchAndRewrite(math::TanhOp op,
200                                    PatternRewriter &rewriter) const {
201   auto width = vectorWidth(op.operand().getType(), isF32);
202   if (!width.hasValue())
203     return rewriter.notifyMatchFailure(op, "unsupported operand type");
204 
205   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
206   auto bcast = [&](Value value) -> Value {
207     return broadcast(builder, value, *width);
208   };
209 
210   // Clamp operand into [plusClamp, minusClamp] range.
211   Value minusClamp = bcast(f32Cst(builder, -7.9053111076354980f));
212   Value plusClamp = bcast(f32Cst(builder, 7.90531110763549805f));
213   Value x = clamp(builder, op.operand(), minusClamp, plusClamp);
214 
215   // Mask for tiny values that are approximated with `operand`.
216   Value tiny = bcast(f32Cst(builder, 0.0004f));
217   Value tinyMask = builder.create<CmpFOp>(
218       CmpFPredicate::OLT, builder.create<AbsFOp>(op.operand()), tiny);
219 
220   // The monomial coefficients of the numerator polynomial (odd).
221   Value alpha1 = bcast(f32Cst(builder, 4.89352455891786e-03f));
222   Value alpha3 = bcast(f32Cst(builder, 6.37261928875436e-04f));
223   Value alpha5 = bcast(f32Cst(builder, 1.48572235717979e-05f));
224   Value alpha7 = bcast(f32Cst(builder, 5.12229709037114e-08f));
225   Value alpha9 = bcast(f32Cst(builder, -8.60467152213735e-11f));
226   Value alpha11 = bcast(f32Cst(builder, 2.00018790482477e-13f));
227   Value alpha13 = bcast(f32Cst(builder, -2.76076847742355e-16f));
228 
229   // The monomial coefficients of the denominator polynomial (even).
230   Value beta0 = bcast(f32Cst(builder, 4.89352518554385e-03f));
231   Value beta2 = bcast(f32Cst(builder, 2.26843463243900e-03f));
232   Value beta4 = bcast(f32Cst(builder, 1.18534705686654e-04f));
233   Value beta6 = bcast(f32Cst(builder, 1.19825839466702e-06f));
234 
235   // Since the polynomials are odd/even, we need x^2.
236   Value x2 = builder.create<MulFOp>(x, x);
237 
238   // Evaluate the numerator polynomial p.
239   Value p = builder.create<FmaFOp>(x2, alpha13, alpha11);
240   p = builder.create<FmaFOp>(x2, p, alpha9);
241   p = builder.create<FmaFOp>(x2, p, alpha7);
242   p = builder.create<FmaFOp>(x2, p, alpha5);
243   p = builder.create<FmaFOp>(x2, p, alpha3);
244   p = builder.create<FmaFOp>(x2, p, alpha1);
245   p = builder.create<MulFOp>(x, p);
246 
247   // Evaluate the denominator polynomial q.
248   Value q = builder.create<FmaFOp>(x2, beta6, beta4);
249   q = builder.create<FmaFOp>(x2, q, beta2);
250   q = builder.create<FmaFOp>(x2, q, beta0);
251 
252   // Divide the numerator by the denominator.
253   Value res =
254       builder.create<SelectOp>(tinyMask, x, builder.create<DivFOp>(p, q));
255 
256   rewriter.replaceOp(op, res);
257 
258   return success();
259 }
260 
261 #define LN2_VALUE                                                              \
262   0.693147180559945309417232121458176568075500134360255254120680009493393621L
263 #define LOG2E_VALUE                                                            \
264   1.442695040888963407359924681001892137426645954152985934135449406931109219L
265 
266 //----------------------------------------------------------------------------//
267 // LogOp and Log2Op approximation.
268 //----------------------------------------------------------------------------//
269 
270 namespace {
271 template <typename Op>
272 struct LogApproximationBase : public OpRewritePattern<Op> {
273   using OpRewritePattern<Op>::OpRewritePattern;
274 
275   /// Base 2 if 'base2' is set; natural logarithm (base e) otherwise.
276   LogicalResult logMatchAndRewrite(Op op, PatternRewriter &rewriter,
277                                    bool base2) const;
278 };
279 } // namespace
280 
281 // This approximation comes from Julien Pommier's SSE math library.
282 // Link: http://gruntthepeon.free.fr/ssemath
283 template <typename Op>
284 LogicalResult
285 LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
286                                              bool base2) const {
287   auto width = vectorWidth(op.operand().getType(), isF32);
288   if (!width.hasValue())
289     return rewriter.notifyMatchFailure(op, "unsupported operand type");
290 
291   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
292   auto bcast = [&](Value value) -> Value {
293     return broadcast(builder, value, *width);
294   };
295 
296   Value cstZero = bcast(f32Cst(builder, 0.0f));
297   Value cstOne = bcast(f32Cst(builder, 1.0f));
298   Value cstNegHalf = bcast(f32Cst(builder, -0.5f));
299 
300   // The smallest non denormalized float number.
301   Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u));
302   Value cstMinusInf = bcast(f32FromBits(builder, 0xff800000u));
303   Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u));
304   Value cstNan = bcast(f32FromBits(builder, 0x7fc00000));
305 
306   // Polynomial coefficients.
307   Value cstCephesSQRTHF = bcast(f32Cst(builder, 0.707106781186547524f));
308   Value cstCephesLogP0 = bcast(f32Cst(builder, 7.0376836292E-2f));
309   Value cstCephesLogP1 = bcast(f32Cst(builder, -1.1514610310E-1f));
310   Value cstCephesLogP2 = bcast(f32Cst(builder, 1.1676998740E-1f));
311   Value cstCephesLogP3 = bcast(f32Cst(builder, -1.2420140846E-1f));
312   Value cstCephesLogP4 = bcast(f32Cst(builder, +1.4249322787E-1f));
313   Value cstCephesLogP5 = bcast(f32Cst(builder, -1.6668057665E-1f));
314   Value cstCephesLogP6 = bcast(f32Cst(builder, +2.0000714765E-1f));
315   Value cstCephesLogP7 = bcast(f32Cst(builder, -2.4999993993E-1f));
316   Value cstCephesLogP8 = bcast(f32Cst(builder, +3.3333331174E-1f));
317 
318   Value x = op.operand();
319 
320   // Truncate input values to the minimum positive normal.
321   x = max(builder, x, cstMinNormPos);
322 
323   // Extract significant in the range [0.5,1) and exponent.
324   std::pair<Value, Value> pair = frexp(builder, x, /*is_positive=*/true);
325   x = pair.first;
326   Value e = pair.second;
327 
328   // Shift the inputs from the range [0.5,1) to [sqrt(1/2), sqrt(2)) and shift
329   // by -1.0. The values are then centered around 0, which improves the
330   // stability of the polynomial evaluation:
331   //
332   //   if( x < SQRTHF ) {
333   //     e -= 1;
334   //     x = x + x - 1.0;
335   //   } else { x = x - 1.0; }
336   Value mask = builder.create<CmpFOp>(CmpFPredicate::OLT, x, cstCephesSQRTHF);
337   Value tmp = builder.create<SelectOp>(mask, x, cstZero);
338 
339   x = builder.create<SubFOp>(x, cstOne);
340   e = builder.create<SubFOp>(e,
341                              builder.create<SelectOp>(mask, cstOne, cstZero));
342   x = builder.create<AddFOp>(x, tmp);
343 
344   Value x2 = builder.create<MulFOp>(x, x);
345   Value x3 = builder.create<MulFOp>(x2, x);
346 
347   // Evaluate the polynomial approximant of degree 8 in three parts.
348   Value y0, y1, y2;
349   y0 = builder.create<FmaFOp>(cstCephesLogP0, x, cstCephesLogP1);
350   y1 = builder.create<FmaFOp>(cstCephesLogP3, x, cstCephesLogP4);
351   y2 = builder.create<FmaFOp>(cstCephesLogP6, x, cstCephesLogP7);
352   y0 = builder.create<FmaFOp>(y0, x, cstCephesLogP2);
353   y1 = builder.create<FmaFOp>(y1, x, cstCephesLogP5);
354   y2 = builder.create<FmaFOp>(y2, x, cstCephesLogP8);
355   y0 = builder.create<FmaFOp>(y0, x3, y1);
356   y0 = builder.create<FmaFOp>(y0, x3, y2);
357   y0 = builder.create<MulFOp>(y0, x3);
358 
359   y0 = builder.create<FmaFOp>(cstNegHalf, x2, y0);
360   x = builder.create<AddFOp>(x, y0);
361 
362   if (base2) {
363     Value cstLog2e = bcast(f32Cst(builder, static_cast<float>(LOG2E_VALUE)));
364     x = builder.create<FmaFOp>(x, cstLog2e, e);
365   } else {
366     Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE)));
367     x = builder.create<FmaFOp>(e, cstLn2, x);
368   }
369 
370   Value invalidMask =
371       builder.create<CmpFOp>(CmpFPredicate::ULT, op.operand(), cstZero);
372   Value zeroMask =
373       builder.create<CmpFOp>(CmpFPredicate::OEQ, op.operand(), cstZero);
374   Value posInfMask =
375       builder.create<CmpFOp>(CmpFPredicate::OEQ, op.operand(), cstPosInf);
376 
377   // Filter out invalid values:
378   //  • x == 0     -> -INF
379   //  • x < 0      ->  NAN
380   //  • x == +INF  -> +INF
381   Value aproximation = builder.create<SelectOp>(
382       zeroMask, cstMinusInf,
383       builder.create<SelectOp>(
384           invalidMask, cstNan,
385           builder.create<SelectOp>(posInfMask, cstPosInf, x)));
386 
387   rewriter.replaceOp(op, aproximation);
388 
389   return success();
390 }
391 
392 namespace {
393 struct LogApproximation : public LogApproximationBase<math::LogOp> {
394   using LogApproximationBase::LogApproximationBase;
395 
396   LogicalResult matchAndRewrite(math::LogOp op,
397                                 PatternRewriter &rewriter) const final {
398     return logMatchAndRewrite(op, rewriter, /*base2=*/false);
399   }
400 };
401 } // namespace
402 
403 namespace {
404 struct Log2Approximation : public LogApproximationBase<math::Log2Op> {
405   using LogApproximationBase::LogApproximationBase;
406 
407   LogicalResult matchAndRewrite(math::Log2Op op,
408                                 PatternRewriter &rewriter) const final {
409     return logMatchAndRewrite(op, rewriter, /*base2=*/true);
410   }
411 };
412 } // namespace
413 
414 //----------------------------------------------------------------------------//
415 // Exp approximation.
416 //----------------------------------------------------------------------------//
417 
418 namespace {
419 
420 struct ExpApproximation : public OpRewritePattern<math::ExpOp> {
421 public:
422   using OpRewritePattern::OpRewritePattern;
423 
424   LogicalResult matchAndRewrite(math::ExpOp op,
425                                 PatternRewriter &rewriter) const final;
426 };
427 } // namespace
428 
429 // Approximate exp(x) using its reduced range exp(y) where y is in the range
430 // [0, ln(2)], let y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2), exp(x)
431 // = exp(y) * 2^k. exp(y).
432 LogicalResult
433 ExpApproximation::matchAndRewrite(math::ExpOp op,
434                                   PatternRewriter &rewriter) const {
435   auto width = vectorWidth(op.operand().getType(), isF32);
436   if (!width.hasValue())
437     return rewriter.notifyMatchFailure(op, "unsupported operand type");
438   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
439 
440   // TODO: Consider a common pattern rewriter with all methods below to
441   // write the approximations.
442   auto bcast = [&](Value value) -> Value {
443     return broadcast(builder, value, *width);
444   };
445   auto fmla = [&](Value a, Value b, Value c) {
446     return builder.create<FmaFOp>(a, b, c);
447   };
448   auto mul = [&](Value a, Value b) -> Value {
449     return builder.create<MulFOp>(a, b);
450   };
451   auto sub = [&](Value a, Value b) -> Value {
452     return builder.create<SubFOp>(a, b);
453   };
454   auto floor = [&](Value a) { return builder.create<FloorFOp>(a); };
455 
456   Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE)));
457   Value cstLog2E = bcast(f32Cst(builder, static_cast<float>(LOG2E_VALUE)));
458 
459   // Polynomial coefficients.
460   Value cstCephesExpP0 = bcast(f32Cst(builder, 1.0));
461   Value cstCephesExpP1 = bcast(f32Cst(builder, 1.0));
462   Value cstCephesExpP2 = bcast(f32Cst(builder, 0.49970514590562437052f));
463   Value cstCephesExpP3 = bcast(f32Cst(builder, 0.16873890085469545053f));
464   Value cstCephesExpP4 = bcast(f32Cst(builder, 0.03668965196652099192f));
465   Value cstCephesExpP5 = bcast(f32Cst(builder, 0.01314350012789660196f));
466 
467   Value x = op.operand();
468 
469   // Reduced y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2)
470   Value xL2Inv = mul(x, cstLog2E);
471   Value kF32 = floor(xL2Inv);
472   Value kLn2 = mul(kF32, cstLn2);
473   Value y = sub(x, kLn2);
474 
475   // Use Estrin's evaluation scheme with 3 independent parts:
476   // P(y)^y : (c0 + c1 y) + (c2 + c3 y) y^2 + (c4 + c5 y) y^4
477   Value y2 = mul(y, y);
478   Value y4 = mul(y2, y2);
479 
480   Value q0 = fmla(cstCephesExpP1, y, cstCephesExpP0);
481   Value q1 = fmla(cstCephesExpP3, y, cstCephesExpP2);
482   Value q2 = fmla(cstCephesExpP5, y, cstCephesExpP4);
483   Value expY = fmla(q1, y2, q0);
484   expY = fmla(q2, y4, expY);
485 
486   auto i32Vec = broadcast(builder.getI32Type(), *width);
487 
488   // exp2(k)
489   Value k = builder.create<FPToSIOp>(kF32, i32Vec);
490   Value exp2KValue = exp2I32(builder, k);
491 
492   // exp(x) = exp(y) * exp2(k)
493   expY = mul(expY, exp2KValue);
494 
495   // Handle overflow, inf and underflow of exp(x). exp(x) range is [0, inf], its
496   // partitioned as the following:
497   // exp(x) = 0, x <= -inf
498   // exp(x) = underflow (min_float), x <= -88
499   // exp(x) = inf (min_float), x >= 88
500   // Note: |k| = 127 is the value where the 8-bits exponent saturates.
501   Value zerof32Const = bcast(f32Cst(builder, 0));
502   auto constPosInfinity =
503       bcast(f32Cst(builder, std::numeric_limits<float>::infinity()));
504   auto constNegIfinity =
505       bcast(f32Cst(builder, -std::numeric_limits<float>::infinity()));
506   auto underflow = bcast(f32Cst(builder, std::numeric_limits<float>::min()));
507 
508   Value kMaxConst = bcast(i32Cst(builder, 127));
509   Value kMaxNegConst = bcast(i32Cst(builder, -127));
510   Value rightBound = builder.create<CmpIOp>(CmpIPredicate::sle, k, kMaxConst);
511   Value leftBound = builder.create<CmpIOp>(CmpIPredicate::sge, k, kMaxNegConst);
512 
513   Value isNegInfinityX =
514       builder.create<CmpFOp>(CmpFPredicate::OEQ, x, constNegIfinity);
515   Value isPostiveX =
516       builder.create<CmpFOp>(CmpFPredicate::OGT, x, zerof32Const);
517   Value isComputable = builder.create<AndOp>(rightBound, leftBound);
518 
519   expY = builder.create<SelectOp>(
520       isComputable, expY,
521       builder.create<SelectOp>(
522           isPostiveX, constPosInfinity,
523           builder.create<SelectOp>(isNegInfinityX, zerof32Const, underflow)));
524 
525   rewriter.replaceOp(op, expY);
526 
527   return success();
528 }
529 
530 //----------------------------------------------------------------------------//
531 
532 void mlir::populateMathPolynomialApproximationPatterns(
533     OwningRewritePatternList &patterns, MLIRContext *ctx) {
534   patterns.insert<TanhApproximation, LogApproximation, Log2Approximation,
535                   ExpApproximation>(ctx);
536 }
537