1 //===- PolynomialApproximation.cpp - Approximate math operations ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements expansion of math operations to fast approximations
10 // that do not rely on any of the library functions.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <climits>
15 #include <cstddef>
16 
17 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
18 #include "mlir/Dialect/Math/IR/Math.h"
19 #include "mlir/Dialect/Math/Transforms/Approximation.h"
20 #include "mlir/Dialect/Math/Transforms/Passes.h"
21 #include "mlir/Dialect/Utils/IndexingUtils.h"
22 #include "mlir/Dialect/Vector/IR/VectorOps.h"
23 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
24 #include "mlir/Dialect/X86Vector/X86VectorDialect.h"
25 #include "mlir/IR/Builders.h"
26 #include "mlir/IR/BuiltinTypes.h"
27 #include "mlir/IR/ImplicitLocOpBuilder.h"
28 #include "mlir/IR/OpDefinition.h"
29 #include "mlir/IR/PatternMatch.h"
30 #include "mlir/IR/TypeUtilities.h"
31 #include "mlir/Transforms/DialectConversion.h"
32 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
33 #include "llvm/ADT/ArrayRef.h"
34 #include "llvm/ADT/STLExtras.h"
35 
36 using namespace mlir;
37 using namespace mlir::math;
38 using namespace mlir::vector;
39 
40 // Returns vector shape if the type is a vector. Returns an empty shape if it is
41 // not a vector.
vectorShape(Type type)42 static ArrayRef<int64_t> vectorShape(Type type) {
43   auto vectorType = type.dyn_cast<VectorType>();
44   return vectorType ? vectorType.getShape() : ArrayRef<int64_t>();
45 }
46 
vectorShape(Value value)47 static ArrayRef<int64_t> vectorShape(Value value) {
48   return vectorShape(value.getType());
49 }
50 
51 //----------------------------------------------------------------------------//
52 // Broadcast scalar types and values into vector types and values.
53 //----------------------------------------------------------------------------//
54 
55 // Broadcasts scalar type into vector type (iff shape is non-scalar).
broadcast(Type type,ArrayRef<int64_t> shape)56 static Type broadcast(Type type, ArrayRef<int64_t> shape) {
57   assert(!type.isa<VectorType>() && "must be scalar type");
58   return !shape.empty() ? VectorType::get(shape, type) : type;
59 }
60 
61 // Broadcasts scalar value into vector (iff shape is non-scalar).
broadcast(ImplicitLocOpBuilder & builder,Value value,ArrayRef<int64_t> shape)62 static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
63                        ArrayRef<int64_t> shape) {
64   assert(!value.getType().isa<VectorType>() && "must be scalar value");
65   auto type = broadcast(value.getType(), shape);
66   return !shape.empty() ? builder.create<BroadcastOp>(type, value) : value;
67 }
68 
69 //----------------------------------------------------------------------------//
70 // Helper function to handle n-D vectors with 1-D operations.
71 //----------------------------------------------------------------------------//
72 
73 // Expands and unrolls n-D vector operands into multiple fixed size 1-D vectors
74 // and calls the compute function with 1-D vector operands. Stitches back all
75 // results into the original n-D vector result.
76 //
77 // Examples: vectorWidth = 8
78 //   - vector<4x8xf32> unrolled 4 times
79 //   - vector<16xf32> expanded to vector<2x8xf32> and unrolled 2 times
80 //   - vector<4x16xf32> expanded to vector<4x2x8xf32> and unrolled 4*2 times
81 //
82 // Some math approximations rely on ISA-specific operations that only accept
83 // fixed size 1-D vectors (e.g. AVX expects vectors of width 8).
84 //
85 // It is the caller's responsibility to verify that the inner dimension is
86 // divisible by the vectorWidth, and that all operands have the same vector
87 // shape.
88 static Value
handleMultidimensionalVectors(ImplicitLocOpBuilder & builder,ValueRange operands,int64_t vectorWidth,llvm::function_ref<Value (ValueRange)> compute)89 handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
90                               ValueRange operands, int64_t vectorWidth,
91                               llvm::function_ref<Value(ValueRange)> compute) {
92   assert(!operands.empty() && "operands must be not empty");
93   assert(vectorWidth > 0 && "vector width must be larger than 0");
94 
95   VectorType inputType = operands[0].getType().cast<VectorType>();
96   ArrayRef<int64_t> inputShape = inputType.getShape();
97 
98   // If input shape matches target vector width, we can just call the
99   // user-provided compute function with the operands.
100   if (inputShape == llvm::makeArrayRef(vectorWidth))
101     return compute(operands);
102 
103   // Check if the inner dimension has to be expanded, or we can directly iterate
104   // over the outer dimensions of the vector.
105   int64_t innerDim = inputShape.back();
106   int64_t expansionDim = innerDim / vectorWidth;
107   assert((innerDim % vectorWidth == 0) && "invalid inner dimension size");
108 
109   // Maybe expand operands to the higher rank vector shape that we'll use to
110   // iterate over and extract one dimensional vectors.
111   SmallVector<int64_t> expandedShape(inputShape.begin(), inputShape.end());
112   SmallVector<Value> expandedOperands(operands);
113 
114   if (expansionDim > 1) {
115     // Expand shape from [..., innerDim] to [..., expansionDim, vectorWidth].
116     expandedShape.insert(expandedShape.end() - 1, expansionDim);
117     expandedShape.back() = vectorWidth;
118 
119     for (unsigned i = 0; i < operands.size(); ++i) {
120       auto operand = operands[i];
121       auto eltType = operand.getType().cast<VectorType>().getElementType();
122       auto expandedType = VectorType::get(expandedShape, eltType);
123       expandedOperands[i] =
124           builder.create<vector::ShapeCastOp>(expandedType, operand);
125     }
126   }
127 
128   // Iterate over all outer dimensions of the compute shape vector type.
129   auto iterationDims = ArrayRef<int64_t>(expandedShape).drop_back();
130   int64_t maxLinearIndex = computeMaxLinearIndex(iterationDims);
131 
132   SmallVector<int64_t> ones(iterationDims.size(), 1);
133   auto strides = computeStrides(iterationDims, ones);
134 
135   // Compute results for each one dimensional vector.
136   SmallVector<Value> results(maxLinearIndex);
137 
138   for (int64_t i = 0; i < maxLinearIndex; ++i) {
139     auto offsets = delinearize(strides, i);
140 
141     SmallVector<Value> extracted(expandedOperands.size());
142     for (const auto &tuple : llvm::enumerate(expandedOperands))
143       extracted[tuple.index()] =
144           builder.create<vector::ExtractOp>(tuple.value(), offsets);
145 
146     results[i] = compute(extracted);
147   }
148 
149   // Stitch results together into one large vector.
150   Type resultEltType = results[0].getType().cast<VectorType>().getElementType();
151   Type resultExpandedType = VectorType::get(expandedShape, resultEltType);
152   Value result = builder.create<arith::ConstantOp>(
153       resultExpandedType, builder.getZeroAttr(resultExpandedType));
154 
155   for (int64_t i = 0; i < maxLinearIndex; ++i)
156     result = builder.create<vector::InsertOp>(results[i], result,
157                                               delinearize(strides, i));
158 
159   // Reshape back to the original vector shape.
160   return builder.create<vector::ShapeCastOp>(
161       VectorType::get(inputShape, resultEltType), result);
162 }
163 
164 //----------------------------------------------------------------------------//
165 // Helper functions to create constants.
166 //----------------------------------------------------------------------------//
167 
f32Cst(ImplicitLocOpBuilder & builder,float value)168 static Value f32Cst(ImplicitLocOpBuilder &builder, float value) {
169   return builder.create<arith::ConstantOp>(builder.getF32FloatAttr(value));
170 }
171 
i32Cst(ImplicitLocOpBuilder & builder,int32_t value)172 static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value) {
173   return builder.create<arith::ConstantOp>(builder.getI32IntegerAttr(value));
174 }
175 
f32FromBits(ImplicitLocOpBuilder & builder,uint32_t bits)176 static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits) {
177   Value i32Value = i32Cst(builder, static_cast<int32_t>(bits));
178   return builder.create<arith::BitcastOp>(builder.getF32Type(), i32Value);
179 }
180 
181 //----------------------------------------------------------------------------//
182 // Helper functions to build math functions approximations.
183 //----------------------------------------------------------------------------//
184 
185 // Return the minimum of the two values or NaN if value is NaN
min(ImplicitLocOpBuilder & builder,Value value,Value bound)186 static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound) {
187   return builder.create<arith::SelectOp>(
188       builder.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, value, bound),
189       value, bound);
190 }
191 
192 // Return the maximum of the two values or NaN if value is NaN
max(ImplicitLocOpBuilder & builder,Value value,Value bound)193 static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound) {
194   return builder.create<arith::SelectOp>(
195       builder.create<arith::CmpFOp>(arith::CmpFPredicate::UGT, value, bound),
196       value, bound);
197 }
198 
199 // Return the clamped value or NaN if value is NaN
clamp(ImplicitLocOpBuilder & builder,Value value,Value lowerBound,Value upperBound)200 static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound,
201                    Value upperBound) {
202   return max(builder, min(builder, value, upperBound), lowerBound);
203 }
204 
205 // Decomposes given floating point value `arg` into a normalized fraction and
206 // an integral power of two (see std::frexp). Returned values have float type.
frexp(ImplicitLocOpBuilder & builder,Value arg,bool isPositive=false)207 static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
208                                      bool isPositive = false) {
209   assert(getElementTypeOrSelf(arg).isF32() && "arg must be f32 type");
210   ArrayRef<int64_t> shape = vectorShape(arg);
211 
212   auto bcast = [&](Value value) -> Value {
213     return broadcast(builder, value, shape);
214   };
215 
216   auto i32 = builder.getIntegerType(32);
217   auto i32Vec = broadcast(i32, shape);
218   auto f32Vec = broadcast(builder.getF32Type(), shape);
219 
220   Value cst126f = f32Cst(builder, 126.0f);
221   Value cstHalf = f32Cst(builder, 0.5f);
222   Value cstInvMantMask = f32FromBits(builder, ~0x7f800000u);
223 
224   // Bitcast to i32 for bitwise operations.
225   Value i32Half = builder.create<arith::BitcastOp>(i32, cstHalf);
226   Value i32InvMantMask = builder.create<arith::BitcastOp>(i32, cstInvMantMask);
227   Value i32Arg = builder.create<arith::BitcastOp>(i32Vec, arg);
228 
229   // Compute normalized fraction.
230   Value tmp0 = builder.create<arith::AndIOp>(i32Arg, bcast(i32InvMantMask));
231   Value tmp1 = builder.create<arith::OrIOp>(tmp0, bcast(i32Half));
232   Value normalizedFraction = builder.create<arith::BitcastOp>(f32Vec, tmp1);
233 
234   // Compute exponent.
235   Value arg0 = isPositive ? arg : builder.create<math::AbsOp>(arg);
236   Value biasedExponentBits = builder.create<arith::ShRUIOp>(
237       builder.create<arith::BitcastOp>(i32Vec, arg0),
238       bcast(i32Cst(builder, 23)));
239   Value biasedExponent =
240       builder.create<arith::SIToFPOp>(f32Vec, biasedExponentBits);
241   Value exponent =
242       builder.create<arith::SubFOp>(biasedExponent, bcast(cst126f));
243 
244   return {normalizedFraction, exponent};
245 }
246 
247 // Computes exp2 for an i32 argument.
exp2I32(ImplicitLocOpBuilder & builder,Value arg)248 static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
249   assert(getElementTypeOrSelf(arg).isInteger(32) && "arg must be i32 type");
250   ArrayRef<int64_t> shape = vectorShape(arg);
251 
252   auto bcast = [&](Value value) -> Value {
253     return broadcast(builder, value, shape);
254   };
255 
256   auto f32Vec = broadcast(builder.getF32Type(), shape);
257   // The exponent of f32 located at 23-bit.
258   auto exponetBitLocation = bcast(i32Cst(builder, 23));
259   // Set the exponent bias to zero.
260   auto bias = bcast(i32Cst(builder, 127));
261 
262   Value biasedArg = builder.create<arith::AddIOp>(arg, bias);
263   Value exp2ValueInt =
264       builder.create<arith::ShLIOp>(biasedArg, exponetBitLocation);
265   Value exp2ValueF32 = builder.create<arith::BitcastOp>(f32Vec, exp2ValueInt);
266 
267   return exp2ValueF32;
268 }
269 
270 namespace {
makePolynomialCalculation(ImplicitLocOpBuilder & builder,llvm::ArrayRef<Value> coeffs,Value x)271 Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
272                                 llvm::ArrayRef<Value> coeffs, Value x) {
273   assert(getElementTypeOrSelf(x).isF32() && "x must be f32 type");
274   ArrayRef<int64_t> shape = vectorShape(x);
275 
276   if (coeffs.empty())
277     return broadcast(builder, f32Cst(builder, 0.0f), shape);
278 
279   if (coeffs.size() == 1)
280     return coeffs[0];
281 
282   Value res = builder.create<math::FmaOp>(x, coeffs[coeffs.size() - 1],
283                                           coeffs[coeffs.size() - 2]);
284   for (auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) {
285     res = builder.create<math::FmaOp>(x, res, coeffs[i]);
286   }
287   return res;
288 }
289 } // namespace
290 
291 //----------------------------------------------------------------------------//
292 // Helper function/pattern to insert casts for reusing F32 bit expansion.
293 //----------------------------------------------------------------------------//
294 
295 template <typename T>
insertCasts(Operation * op,PatternRewriter & rewriter)296 LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter) {
297   // Conservatively only allow where the operand and result types are exactly 1.
298   Type origType = op->getResultTypes().front();
299   for (Type t : llvm::drop_begin(op->getResultTypes()))
300     if (origType != t)
301       return rewriter.notifyMatchFailure(op, "required all types to match");
302   for (Type t : op->getOperandTypes())
303     if (origType != t)
304       return rewriter.notifyMatchFailure(op, "required all types to match");
305 
306   // Skip if already F32  or larger than 32 bits.
307   if (getElementTypeOrSelf(origType).isF32() ||
308       getElementTypeOrSelf(origType).getIntOrFloatBitWidth() > 32)
309     return failure();
310 
311   // Create F32 equivalent type.
312   Type newType;
313   if (auto shaped = origType.dyn_cast<ShapedType>()) {
314     newType = shaped.clone(rewriter.getF32Type());
315   } else if (origType.isa<FloatType>()) {
316     newType = rewriter.getF32Type();
317   } else {
318     return rewriter.notifyMatchFailure(op,
319                                        "unable to find F32 equivalent type");
320   }
321 
322   Location loc = op->getLoc();
323   SmallVector<Value> operands;
324   for (auto operand : op->getOperands())
325     operands.push_back(rewriter.create<arith::ExtFOp>(loc, newType, operand));
326   auto result = rewriter.create<math::Atan2Op>(loc, newType, operands);
327   rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, origType, result);
328   return success();
329 }
330 
331 namespace {
332 // Pattern to cast to F32 to reuse F32 expansion as fallback for single-result
333 // op.
334 // TODO: Consider revising to avoid adding multiple casts for a subgraph that is
335 // all in lower precision. Currently this is only fallback support and performs
336 // simplistic casting.
337 template <typename T>
338 struct ReuseF32Expansion : public OpRewritePattern<T> {
339 public:
340   using OpRewritePattern<T>::OpRewritePattern;
matchAndRewrite__anond27a14240411::ReuseF32Expansion341   LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter) const final {
342     static_assert(
343         T::template hasTrait<mlir::OpTrait::SameOperandsAndResultType>(),
344         "requires same operands and result types");
345     return insertCasts<T>(op, rewriter);
346   }
347 };
348 } // namespace
349 
350 //----------------------------------------------------------------------------//
351 // AtanOp approximation.
352 //----------------------------------------------------------------------------//
353 
354 namespace {
355 struct AtanApproximation : public OpRewritePattern<math::AtanOp> {
356 public:
357   using OpRewritePattern::OpRewritePattern;
358 
359   LogicalResult matchAndRewrite(math::AtanOp op,
360                                 PatternRewriter &rewriter) const final;
361 };
362 } // namespace
363 
364 LogicalResult
matchAndRewrite(math::AtanOp op,PatternRewriter & rewriter) const365 AtanApproximation::matchAndRewrite(math::AtanOp op,
366                                    PatternRewriter &rewriter) const {
367   auto operand = op.getOperand();
368   if (!getElementTypeOrSelf(operand).isF32())
369     return rewriter.notifyMatchFailure(op, "unsupported operand type");
370 
371   ArrayRef<int64_t> shape = vectorShape(op.getOperand());
372 
373   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
374   auto one = broadcast(builder, f32Cst(builder, 1.0f), shape);
375 
376   // Remap the problem over [0.0, 1.0] by looking at the absolute value and the
377   // handling symmetry.
378   Value abs = builder.create<math::AbsOp>(operand);
379   Value reciprocal = builder.create<arith::DivFOp>(one, abs);
380   Value compare =
381       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, abs, reciprocal);
382   Value x = builder.create<arith::SelectOp>(compare, abs, reciprocal);
383 
384   // Perform the Taylor series approximation for atan over the range
385   // [-1.0, 1.0].
386   auto n1 = broadcast(builder, f32Cst(builder, 0.14418283f), shape);
387   auto n2 = broadcast(builder, f32Cst(builder, -0.34999234f), shape);
388   auto n3 = broadcast(builder, f32Cst(builder, -0.01067831f), shape);
389   auto n4 = broadcast(builder, f32Cst(builder, 1.00209986f), shape);
390 
391   Value p = builder.create<math::FmaOp>(x, n1, n2);
392   p = builder.create<math::FmaOp>(x, p, n3);
393   p = builder.create<math::FmaOp>(x, p, n4);
394   p = builder.create<arith::MulFOp>(x, p);
395 
396   // Remap the solution for over [0.0, 1.0] to [0.0, inf]
397   auto halfPi = broadcast(builder, f32Cst(builder, 1.57079632679f), shape);
398   Value sub = builder.create<arith::SubFOp>(halfPi, p);
399   Value select = builder.create<arith::SelectOp>(compare, p, sub);
400 
401   // Correct for signing of the input.
402   rewriter.replaceOpWithNewOp<math::CopySignOp>(op, select, operand);
403   return success();
404 }
405 
406 //----------------------------------------------------------------------------//
407 // AtanOp approximation.
408 //----------------------------------------------------------------------------//
409 
410 namespace {
411 struct Atan2Approximation : public OpRewritePattern<math::Atan2Op> {
412 public:
413   using OpRewritePattern::OpRewritePattern;
414 
415   LogicalResult matchAndRewrite(math::Atan2Op op,
416                                 PatternRewriter &rewriter) const final;
417 };
418 } // namespace
419 
420 LogicalResult
matchAndRewrite(math::Atan2Op op,PatternRewriter & rewriter) const421 Atan2Approximation::matchAndRewrite(math::Atan2Op op,
422                                     PatternRewriter &rewriter) const {
423   auto y = op.getOperand(0);
424   auto x = op.getOperand(1);
425   if (!getElementTypeOrSelf(x).isF32())
426     return rewriter.notifyMatchFailure(op, "unsupported operand type");
427 
428   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
429   ArrayRef<int64_t> shape = vectorShape(op.getResult());
430 
431   // Compute atan in the valid range.
432   auto div = builder.create<arith::DivFOp>(y, x);
433   auto atan = builder.create<math::AtanOp>(div);
434 
435   // Determine what the atan would be for a 180 degree rotation.
436   auto zero = broadcast(builder, f32Cst(builder, 0.0f), shape);
437   auto pi = broadcast(builder, f32Cst(builder, 3.14159265359f), shape);
438   auto addPi = builder.create<arith::AddFOp>(atan, pi);
439   auto subPi = builder.create<arith::SubFOp>(atan, pi);
440   auto atanGt =
441       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, atan, zero);
442   auto flippedAtan = builder.create<arith::SelectOp>(atanGt, subPi, addPi);
443 
444   // Determine whether to directly use atan or use the 180 degree flip
445   auto xGt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zero);
446   Value result = builder.create<arith::SelectOp>(xGt, atan, flippedAtan);
447 
448   // Handle x = 0, y > 0
449   Value xZero =
450       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, x, zero);
451   Value yGt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, y, zero);
452   Value isHalfPi = builder.create<arith::AndIOp>(xZero, yGt);
453   auto halfPi = broadcast(builder, f32Cst(builder, 1.57079632679f), shape);
454   result = builder.create<arith::SelectOp>(isHalfPi, halfPi, result);
455 
456   // Handle x = 0, y < 0
457   Value yLt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, y, zero);
458   Value isNegativeHalfPiPi = builder.create<arith::AndIOp>(xZero, yLt);
459   auto negativeHalfPiPi =
460       broadcast(builder, f32Cst(builder, -1.57079632679f), shape);
461   result = builder.create<arith::SelectOp>(isNegativeHalfPiPi, negativeHalfPiPi,
462                                            result);
463 
464   // Handle x = 0, y = 0;
465   Value yZero =
466       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, y, zero);
467   Value isNan = builder.create<arith::AndIOp>(xZero, yZero);
468   Value cstNan = broadcast(builder, f32FromBits(builder, 0x7fc00000), shape);
469   result = builder.create<arith::SelectOp>(isNan, cstNan, result);
470 
471   rewriter.replaceOp(op, result);
472   return success();
473 }
474 
475 //----------------------------------------------------------------------------//
476 // TanhOp approximation.
477 //----------------------------------------------------------------------------//
478 
479 namespace {
480 struct TanhApproximation : public OpRewritePattern<math::TanhOp> {
481 public:
482   using OpRewritePattern::OpRewritePattern;
483 
484   LogicalResult matchAndRewrite(math::TanhOp op,
485                                 PatternRewriter &rewriter) const final;
486 };
487 } // namespace
488 
489 LogicalResult
matchAndRewrite(math::TanhOp op,PatternRewriter & rewriter) const490 TanhApproximation::matchAndRewrite(math::TanhOp op,
491                                    PatternRewriter &rewriter) const {
492   if (!getElementTypeOrSelf(op.getOperand()).isF32())
493     return rewriter.notifyMatchFailure(op, "unsupported operand type");
494 
495   ArrayRef<int64_t> shape = vectorShape(op.getOperand());
496 
497   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
498   auto bcast = [&](Value value) -> Value {
499     return broadcast(builder, value, shape);
500   };
501 
502   // Clamp operand into [plusClamp, minusClamp] range.
503   Value minusClamp = bcast(f32Cst(builder, -7.99881172180175781f));
504   Value plusClamp = bcast(f32Cst(builder, 7.99881172180175781f));
505   Value x = clamp(builder, op.getOperand(), minusClamp, plusClamp);
506 
507   // Mask for tiny values that are approximated with `operand`.
508   Value tiny = bcast(f32Cst(builder, 0.0004f));
509   Value tinyMask = builder.create<arith::CmpFOp>(
510       arith::CmpFPredicate::OLT, builder.create<math::AbsOp>(op.getOperand()),
511       tiny);
512 
513   // The monomial coefficients of the numerator polynomial (odd).
514   Value alpha1 = bcast(f32Cst(builder, 4.89352455891786e-03f));
515   Value alpha3 = bcast(f32Cst(builder, 6.37261928875436e-04f));
516   Value alpha5 = bcast(f32Cst(builder, 1.48572235717979e-05f));
517   Value alpha7 = bcast(f32Cst(builder, 5.12229709037114e-08f));
518   Value alpha9 = bcast(f32Cst(builder, -8.60467152213735e-11f));
519   Value alpha11 = bcast(f32Cst(builder, 2.00018790482477e-13f));
520   Value alpha13 = bcast(f32Cst(builder, -2.76076847742355e-16f));
521 
522   // The monomial coefficients of the denominator polynomial (even).
523   Value beta0 = bcast(f32Cst(builder, 4.89352518554385e-03f));
524   Value beta2 = bcast(f32Cst(builder, 2.26843463243900e-03f));
525   Value beta4 = bcast(f32Cst(builder, 1.18534705686654e-04f));
526   Value beta6 = bcast(f32Cst(builder, 1.19825839466702e-06f));
527 
528   // Since the polynomials are odd/even, we need x^2.
529   Value x2 = builder.create<arith::MulFOp>(x, x);
530 
531   // Evaluate the numerator polynomial p.
532   Value p = builder.create<math::FmaOp>(x2, alpha13, alpha11);
533   p = builder.create<math::FmaOp>(x2, p, alpha9);
534   p = builder.create<math::FmaOp>(x2, p, alpha7);
535   p = builder.create<math::FmaOp>(x2, p, alpha5);
536   p = builder.create<math::FmaOp>(x2, p, alpha3);
537   p = builder.create<math::FmaOp>(x2, p, alpha1);
538   p = builder.create<arith::MulFOp>(x, p);
539 
540   // Evaluate the denominator polynomial q.
541   Value q = builder.create<math::FmaOp>(x2, beta6, beta4);
542   q = builder.create<math::FmaOp>(x2, q, beta2);
543   q = builder.create<math::FmaOp>(x2, q, beta0);
544 
545   // Divide the numerator by the denominator.
546   Value res = builder.create<arith::SelectOp>(
547       tinyMask, x, builder.create<arith::DivFOp>(p, q));
548 
549   rewriter.replaceOp(op, res);
550 
551   return success();
552 }
553 
554 #define LN2_VALUE                                                              \
555   0.693147180559945309417232121458176568075500134360255254120680009493393621L
556 #define LOG2E_VALUE                                                            \
557   1.442695040888963407359924681001892137426645954152985934135449406931109219L
558 
559 //----------------------------------------------------------------------------//
560 // LogOp and Log2Op approximation.
561 //----------------------------------------------------------------------------//
562 
563 namespace {
564 template <typename Op>
565 struct LogApproximationBase : public OpRewritePattern<Op> {
566   using OpRewritePattern<Op>::OpRewritePattern;
567 
568   /// Base 2 if 'base2' is set; natural logarithm (base e) otherwise.
569   LogicalResult logMatchAndRewrite(Op op, PatternRewriter &rewriter,
570                                    bool base2) const;
571 };
572 } // namespace
573 
574 // This approximation comes from Julien Pommier's SSE math library.
575 // Link: http://gruntthepeon.free.fr/ssemath
576 template <typename Op>
577 LogicalResult
logMatchAndRewrite(Op op,PatternRewriter & rewriter,bool base2) const578 LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
579                                              bool base2) const {
580   if (!getElementTypeOrSelf(op.getOperand()).isF32())
581     return rewriter.notifyMatchFailure(op, "unsupported operand type");
582 
583   ArrayRef<int64_t> shape = vectorShape(op.getOperand());
584 
585   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
586   auto bcast = [&](Value value) -> Value {
587     return broadcast(builder, value, shape);
588   };
589 
590   Value cstZero = bcast(f32Cst(builder, 0.0f));
591   Value cstOne = bcast(f32Cst(builder, 1.0f));
592   Value cstNegHalf = bcast(f32Cst(builder, -0.5f));
593 
594   // The smallest non denormalized float number.
595   Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u));
596   Value cstMinusInf = bcast(f32FromBits(builder, 0xff800000u));
597   Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u));
598   Value cstNan = bcast(f32FromBits(builder, 0x7fc00000));
599 
600   // Polynomial coefficients.
601   Value cstCephesSQRTHF = bcast(f32Cst(builder, 0.707106781186547524f));
602   Value cstCephesLogP0 = bcast(f32Cst(builder, 7.0376836292E-2f));
603   Value cstCephesLogP1 = bcast(f32Cst(builder, -1.1514610310E-1f));
604   Value cstCephesLogP2 = bcast(f32Cst(builder, 1.1676998740E-1f));
605   Value cstCephesLogP3 = bcast(f32Cst(builder, -1.2420140846E-1f));
606   Value cstCephesLogP4 = bcast(f32Cst(builder, +1.4249322787E-1f));
607   Value cstCephesLogP5 = bcast(f32Cst(builder, -1.6668057665E-1f));
608   Value cstCephesLogP6 = bcast(f32Cst(builder, +2.0000714765E-1f));
609   Value cstCephesLogP7 = bcast(f32Cst(builder, -2.4999993993E-1f));
610   Value cstCephesLogP8 = bcast(f32Cst(builder, +3.3333331174E-1f));
611 
612   Value x = op.getOperand();
613 
614   // Truncate input values to the minimum positive normal.
615   x = max(builder, x, cstMinNormPos);
616 
617   // Extract significant in the range [0.5,1) and exponent.
618   std::pair<Value, Value> pair = frexp(builder, x, /*isPositive=*/true);
619   x = pair.first;
620   Value e = pair.second;
621 
622   // Shift the inputs from the range [0.5,1) to [sqrt(1/2), sqrt(2)) and shift
623   // by -1.0. The values are then centered around 0, which improves the
624   // stability of the polynomial evaluation:
625   //
626   //   if( x < SQRTHF ) {
627   //     e -= 1;
628   //     x = x + x - 1.0;
629   //   } else { x = x - 1.0; }
630   Value mask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x,
631                                              cstCephesSQRTHF);
632   Value tmp = builder.create<arith::SelectOp>(mask, x, cstZero);
633 
634   x = builder.create<arith::SubFOp>(x, cstOne);
635   e = builder.create<arith::SubFOp>(
636       e, builder.create<arith::SelectOp>(mask, cstOne, cstZero));
637   x = builder.create<arith::AddFOp>(x, tmp);
638 
639   Value x2 = builder.create<arith::MulFOp>(x, x);
640   Value x3 = builder.create<arith::MulFOp>(x2, x);
641 
642   // Evaluate the polynomial approximant of degree 8 in three parts.
643   Value y0, y1, y2;
644   y0 = builder.create<math::FmaOp>(cstCephesLogP0, x, cstCephesLogP1);
645   y1 = builder.create<math::FmaOp>(cstCephesLogP3, x, cstCephesLogP4);
646   y2 = builder.create<math::FmaOp>(cstCephesLogP6, x, cstCephesLogP7);
647   y0 = builder.create<math::FmaOp>(y0, x, cstCephesLogP2);
648   y1 = builder.create<math::FmaOp>(y1, x, cstCephesLogP5);
649   y2 = builder.create<math::FmaOp>(y2, x, cstCephesLogP8);
650   y0 = builder.create<math::FmaOp>(y0, x3, y1);
651   y0 = builder.create<math::FmaOp>(y0, x3, y2);
652   y0 = builder.create<arith::MulFOp>(y0, x3);
653 
654   y0 = builder.create<math::FmaOp>(cstNegHalf, x2, y0);
655   x = builder.create<arith::AddFOp>(x, y0);
656 
657   if (base2) {
658     Value cstLog2e = bcast(f32Cst(builder, static_cast<float>(LOG2E_VALUE)));
659     x = builder.create<math::FmaOp>(x, cstLog2e, e);
660   } else {
661     Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE)));
662     x = builder.create<math::FmaOp>(e, cstLn2, x);
663   }
664 
665   Value invalidMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::ULT,
666                                                     op.getOperand(), cstZero);
667   Value zeroMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
668                                                  op.getOperand(), cstZero);
669   Value posInfMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
670                                                    op.getOperand(), cstPosInf);
671 
672   // Filter out invalid values:
673   //  • x == 0     -> -INF
674   //  • x < 0      ->  NAN
675   //  • x == +INF  -> +INF
676   Value aproximation = builder.create<arith::SelectOp>(
677       zeroMask, cstMinusInf,
678       builder.create<arith::SelectOp>(
679           invalidMask, cstNan,
680           builder.create<arith::SelectOp>(posInfMask, cstPosInf, x)));
681 
682   rewriter.replaceOp(op, aproximation);
683 
684   return success();
685 }
686 
687 namespace {
688 struct LogApproximation : public LogApproximationBase<math::LogOp> {
689   using LogApproximationBase::LogApproximationBase;
690 
matchAndRewrite__anond27a14240b11::LogApproximation691   LogicalResult matchAndRewrite(math::LogOp op,
692                                 PatternRewriter &rewriter) const final {
693     return logMatchAndRewrite(op, rewriter, /*base2=*/false);
694   }
695 };
696 } // namespace
697 
698 namespace {
699 struct Log2Approximation : public LogApproximationBase<math::Log2Op> {
700   using LogApproximationBase::LogApproximationBase;
701 
matchAndRewrite__anond27a14240c11::Log2Approximation702   LogicalResult matchAndRewrite(math::Log2Op op,
703                                 PatternRewriter &rewriter) const final {
704     return logMatchAndRewrite(op, rewriter, /*base2=*/true);
705   }
706 };
707 } // namespace
708 
709 //----------------------------------------------------------------------------//
710 // Log1p approximation.
711 //----------------------------------------------------------------------------//
712 
713 namespace {
714 struct Log1pApproximation : public OpRewritePattern<math::Log1pOp> {
715 public:
716   using OpRewritePattern::OpRewritePattern;
717 
718   LogicalResult matchAndRewrite(math::Log1pOp op,
719                                 PatternRewriter &rewriter) const final;
720 };
721 } // namespace
722 
723 // Approximate log(1+x).
724 LogicalResult
matchAndRewrite(math::Log1pOp op,PatternRewriter & rewriter) const725 Log1pApproximation::matchAndRewrite(math::Log1pOp op,
726                                     PatternRewriter &rewriter) const {
727   if (!getElementTypeOrSelf(op.getOperand()).isF32())
728     return rewriter.notifyMatchFailure(op, "unsupported operand type");
729 
730   ArrayRef<int64_t> shape = vectorShape(op.getOperand());
731 
732   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
733   auto bcast = [&](Value value) -> Value {
734     return broadcast(builder, value, shape);
735   };
736 
737   // Approximate log(1+x) using the following, due to W. Kahan:
738   //   u = x + 1.0;
739   //   if (u == 1.0 || u == inf) return x;
740   //   return x * log(u) / (u - 1.0);
741   //          ^^^^^^^^^^^^^^^^^^^^^^
742   //             "logLarge" below.
743   Value cstOne = bcast(f32Cst(builder, 1.0f));
744   Value x = op.getOperand();
745   Value u = builder.create<arith::AddFOp>(x, cstOne);
746   Value uSmall =
747       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne);
748   Value logU = builder.create<math::LogOp>(u);
749   Value uInf =
750       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, logU);
751   Value logLarge = builder.create<arith::MulFOp>(
752       x, builder.create<arith::DivFOp>(
753              logU, builder.create<arith::SubFOp>(u, cstOne)));
754   Value approximation = builder.create<arith::SelectOp>(
755       builder.create<arith::OrIOp>(uSmall, uInf), x, logLarge);
756   rewriter.replaceOp(op, approximation);
757   return success();
758 }
759 
760 //----------------------------------------------------------------------------//
761 // Erf approximation.
762 //----------------------------------------------------------------------------//
763 
764 // Approximates erf(x) with
765 // a - P(x)/Q(x)
766 // where P and Q are polynomials of degree 4.
767 // Different coefficients are chosen based on the value of x.
768 // The approximation error is ~2.5e-07.
769 // Boost's minimax tool that utilizes the Remez method was used to find the
770 // coefficients.
771 LogicalResult
matchAndRewrite(math::ErfOp op,PatternRewriter & rewriter) const772 ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
773                                             PatternRewriter &rewriter) const {
774   if (!getElementTypeOrSelf(op.getOperand()).isF32())
775     return rewriter.notifyMatchFailure(op, "unsupported operand type");
776 
777   ArrayRef<int64_t> shape = vectorShape(op.getOperand());
778 
779   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
780   auto bcast = [&](Value value) -> Value {
781     return broadcast(builder, value, shape);
782   };
783 
784   const int intervalsCount = 3;
785   const int polyDegree = 4;
786 
787   Value zero = bcast(f32Cst(builder, 0));
788   Value one = bcast(f32Cst(builder, 1));
789   Value pp[intervalsCount][polyDegree + 1];
790   pp[0][0] = bcast(f32Cst(builder, +0.00000000000000000e+00f));
791   pp[0][1] = bcast(f32Cst(builder, +1.12837916222975858e+00f));
792   pp[0][2] = bcast(f32Cst(builder, -5.23018562988006470e-01f));
793   pp[0][3] = bcast(f32Cst(builder, +2.09741709609267072e-01f));
794   pp[0][4] = bcast(f32Cst(builder, +2.58146801602987875e-02f));
795   pp[1][0] = bcast(f32Cst(builder, +0.00000000000000000e+00f));
796   pp[1][1] = bcast(f32Cst(builder, +1.12750687816789140e+00f));
797   pp[1][2] = bcast(f32Cst(builder, -3.64721408487825775e-01f));
798   pp[1][3] = bcast(f32Cst(builder, +1.18407396425136952e-01f));
799   pp[1][4] = bcast(f32Cst(builder, +3.70645533056476558e-02f));
800   pp[2][0] = bcast(f32Cst(builder, -3.30093071049483172e-03f));
801   pp[2][1] = bcast(f32Cst(builder, +3.51961938357697011e-03f));
802   pp[2][2] = bcast(f32Cst(builder, -1.41373622814988039e-03f));
803   pp[2][3] = bcast(f32Cst(builder, +2.53447094961941348e-04f));
804   pp[2][4] = bcast(f32Cst(builder, -1.71048029455037401e-05f));
805 
806   Value qq[intervalsCount][polyDegree + 1];
807   qq[0][0] = bcast(f32Cst(builder, +1.000000000000000000e+00f));
808   qq[0][1] = bcast(f32Cst(builder, -4.635138185962547255e-01f));
809   qq[0][2] = bcast(f32Cst(builder, +5.192301327279782447e-01f));
810   qq[0][3] = bcast(f32Cst(builder, -1.318089722204810087e-01f));
811   qq[0][4] = bcast(f32Cst(builder, +7.397964654672315005e-02f));
812   qq[1][0] = bcast(f32Cst(builder, +1.00000000000000000e+00f));
813   qq[1][1] = bcast(f32Cst(builder, -3.27607011824493086e-01f));
814   qq[1][2] = bcast(f32Cst(builder, +4.48369090658821977e-01f));
815   qq[1][3] = bcast(f32Cst(builder, -8.83462621207857930e-02f));
816   qq[1][4] = bcast(f32Cst(builder, +5.72442770283176093e-02f));
817   qq[2][0] = bcast(f32Cst(builder, +1.00000000000000000e+00f));
818   qq[2][1] = bcast(f32Cst(builder, -2.06069165953913769e+00f));
819   qq[2][2] = bcast(f32Cst(builder, +1.62705939945477759e+00f));
820   qq[2][3] = bcast(f32Cst(builder, -5.83389859211130017e-01f));
821   qq[2][4] = bcast(f32Cst(builder, +8.21908939856640930e-02f));
822 
823   Value offsets[intervalsCount];
824   offsets[0] = bcast(f32Cst(builder, 0.0f));
825   offsets[1] = bcast(f32Cst(builder, 0.0f));
826   offsets[2] = bcast(f32Cst(builder, 1.0f));
827 
828   Value bounds[intervalsCount];
829   bounds[0] = bcast(f32Cst(builder, 0.8f));
830   bounds[1] = bcast(f32Cst(builder, 2.0f));
831   bounds[2] = bcast(f32Cst(builder, 3.75f));
832 
833   Value isNegativeArg = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT,
834                                                       op.getOperand(), zero);
835   Value negArg = builder.create<arith::NegFOp>(op.getOperand());
836   Value x =
837       builder.create<arith::SelectOp>(isNegativeArg, negArg, op.getOperand());
838 
839   Value offset = offsets[0];
840   Value p[polyDegree + 1];
841   Value q[polyDegree + 1];
842   for (int i = 0; i <= polyDegree; ++i) {
843     p[i] = pp[0][i];
844     q[i] = qq[0][i];
845   }
846 
847   // TODO: maybe use vector stacking to reduce the number of selects.
848   Value isLessThanBound[intervalsCount];
849   for (int j = 0; j < intervalsCount - 1; ++j) {
850     isLessThanBound[j] =
851         builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, bounds[j]);
852     for (int i = 0; i <= polyDegree; ++i) {
853       p[i] = builder.create<arith::SelectOp>(isLessThanBound[j], p[i],
854                                              pp[j + 1][i]);
855       q[i] = builder.create<arith::SelectOp>(isLessThanBound[j], q[i],
856                                              qq[j + 1][i]);
857     }
858     offset = builder.create<arith::SelectOp>(isLessThanBound[j], offset,
859                                              offsets[j + 1]);
860   }
861   isLessThanBound[intervalsCount - 1] = builder.create<arith::CmpFOp>(
862       arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]);
863 
864   Value pPoly = makePolynomialCalculation(builder, p, x);
865   Value qPoly = makePolynomialCalculation(builder, q, x);
866   Value rationalPoly = builder.create<arith::DivFOp>(pPoly, qPoly);
867   Value formula = builder.create<arith::AddFOp>(offset, rationalPoly);
868   formula = builder.create<arith::SelectOp>(isLessThanBound[intervalsCount - 1],
869                                             formula, one);
870 
871   // erf is odd function: erf(x) = -erf(-x).
872   Value negFormula = builder.create<arith::NegFOp>(formula);
873   Value res =
874       builder.create<arith::SelectOp>(isNegativeArg, negFormula, formula);
875 
876   rewriter.replaceOp(op, res);
877 
878   return success();
879 }
880 
881 //----------------------------------------------------------------------------//
882 // Exp approximation.
883 //----------------------------------------------------------------------------//
884 
885 namespace {
886 
887 struct ExpApproximation : public OpRewritePattern<math::ExpOp> {
888 public:
889   using OpRewritePattern::OpRewritePattern;
890 
891   LogicalResult matchAndRewrite(math::ExpOp op,
892                                 PatternRewriter &rewriter) const final;
893 };
894 } // namespace
895 
896 // Approximate exp(x) using its reduced range exp(y) where y is in the range
897 // [0, ln(2)], let y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2), exp(x)
898 // = exp(y) * 2^k. exp(y).
899 LogicalResult
matchAndRewrite(math::ExpOp op,PatternRewriter & rewriter) const900 ExpApproximation::matchAndRewrite(math::ExpOp op,
901                                   PatternRewriter &rewriter) const {
902   if (!getElementTypeOrSelf(op.getOperand()).isF32())
903     return rewriter.notifyMatchFailure(op, "unsupported operand type");
904 
905   ArrayRef<int64_t> shape = vectorShape(op.getOperand());
906 
907   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
908 
909   // TODO: Consider a common pattern rewriter with all methods below to
910   // write the approximations.
911   auto bcast = [&](Value value) -> Value {
912     return broadcast(builder, value, shape);
913   };
914   auto fmla = [&](Value a, Value b, Value c) {
915     return builder.create<math::FmaOp>(a, b, c);
916   };
917   auto mul = [&](Value a, Value b) -> Value {
918     return builder.create<arith::MulFOp>(a, b);
919   };
920   auto sub = [&](Value a, Value b) -> Value {
921     return builder.create<arith::SubFOp>(a, b);
922   };
923   auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); };
924 
925   Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE)));
926   Value cstLog2E = bcast(f32Cst(builder, static_cast<float>(LOG2E_VALUE)));
927 
928   // Polynomial coefficients.
929   Value cstCephesExpP0 = bcast(f32Cst(builder, 1.0));
930   Value cstCephesExpP1 = bcast(f32Cst(builder, 1.0));
931   Value cstCephesExpP2 = bcast(f32Cst(builder, 0.49970514590562437052f));
932   Value cstCephesExpP3 = bcast(f32Cst(builder, 0.16873890085469545053f));
933   Value cstCephesExpP4 = bcast(f32Cst(builder, 0.03668965196652099192f));
934   Value cstCephesExpP5 = bcast(f32Cst(builder, 0.01314350012789660196f));
935 
936   Value x = op.getOperand();
937 
938   Value isNan = builder.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, x, x);
939 
940   // Reduced y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2)
941   Value xL2Inv = mul(x, cstLog2E);
942   Value kF32 = floor(xL2Inv);
943   Value kLn2 = mul(kF32, cstLn2);
944   Value y = sub(x, kLn2);
945 
946   // Use Estrin's evaluation scheme with 3 independent parts:
947   // P(y)^y : (c0 + c1 y) + (c2 + c3 y) y^2 + (c4 + c5 y) y^4
948   Value y2 = mul(y, y);
949   Value y4 = mul(y2, y2);
950 
951   Value q0 = fmla(cstCephesExpP1, y, cstCephesExpP0);
952   Value q1 = fmla(cstCephesExpP3, y, cstCephesExpP2);
953   Value q2 = fmla(cstCephesExpP5, y, cstCephesExpP4);
954   Value expY = fmla(q1, y2, q0);
955   expY = fmla(q2, y4, expY);
956 
957   auto i32Vec = broadcast(builder.getI32Type(), shape);
958 
959   // exp2(k)
960   Value k = builder.create<arith::FPToSIOp>(i32Vec, kF32);
961   Value exp2KValue = exp2I32(builder, k);
962 
963   // exp(x) = exp(y) * exp2(k)
964   expY = mul(expY, exp2KValue);
965 
966   // Handle overflow, inf and underflow of exp(x). exp(x) range is [0, inf], its
967   // partitioned as the following:
968   // exp(x) = 0, x <= -inf
969   // exp(x) = underflow (min_float), x <= -88
970   // exp(x) = inf (min_float), x >= 88
971   // Note: |k| = 127 is the value where the 8-bits exponent saturates.
972   Value zerof32Const = bcast(f32Cst(builder, 0));
973   auto constPosInfinity =
974       bcast(f32Cst(builder, std::numeric_limits<float>::infinity()));
975   auto constNegIfinity =
976       bcast(f32Cst(builder, -std::numeric_limits<float>::infinity()));
977   auto underflow = bcast(f32Cst(builder, std::numeric_limits<float>::min()));
978 
979   Value kMaxConst = bcast(i32Cst(builder, 127));
980   Value kMaxNegConst = bcast(i32Cst(builder, -127));
981   Value rightBound =
982       builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, k, kMaxConst);
983   Value leftBound =
984       builder.create<arith::CmpIOp>(arith::CmpIPredicate::sge, k, kMaxNegConst);
985 
986   Value isNegInfinityX = builder.create<arith::CmpFOp>(
987       arith::CmpFPredicate::OEQ, x, constNegIfinity);
988   Value isPosInfinityX = builder.create<arith::CmpFOp>(
989       arith::CmpFPredicate::OEQ, x, constPosInfinity);
990   Value isPostiveX =
991       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zerof32Const);
992   Value isComputable = builder.create<arith::AndIOp>(rightBound, leftBound);
993 
994   expY = builder.create<arith::SelectOp>(
995       isNan, x,
996       builder.create<arith::SelectOp>(
997           isNegInfinityX, zerof32Const,
998           builder.create<arith::SelectOp>(
999               isPosInfinityX, constPosInfinity,
1000               builder.create<arith::SelectOp>(
1001                   isComputable, expY,
1002                   builder.create<arith::SelectOp>(isPostiveX, constPosInfinity,
1003                                                   underflow)))));
1004 
1005   rewriter.replaceOp(op, expY);
1006 
1007   return success();
1008 }
1009 
1010 //----------------------------------------------------------------------------//
1011 // ExpM1 approximation.
1012 //----------------------------------------------------------------------------//
1013 
1014 namespace {
1015 
1016 struct ExpM1Approximation : public OpRewritePattern<math::ExpM1Op> {
1017 public:
1018   using OpRewritePattern::OpRewritePattern;
1019 
1020   LogicalResult matchAndRewrite(math::ExpM1Op op,
1021                                 PatternRewriter &rewriter) const final;
1022 };
1023 } // namespace
1024 
1025 LogicalResult
matchAndRewrite(math::ExpM1Op op,PatternRewriter & rewriter) const1026 ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
1027                                     PatternRewriter &rewriter) const {
1028   if (!getElementTypeOrSelf(op.getOperand()).isF32())
1029     return rewriter.notifyMatchFailure(op, "unsupported operand type");
1030 
1031   ArrayRef<int64_t> shape = vectorShape(op.getOperand());
1032 
1033   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1034   auto bcast = [&](Value value) -> Value {
1035     return broadcast(builder, value, shape);
1036   };
1037 
1038   // expm1(x) = exp(x) - 1 = u - 1.
1039   // We have to handle it carefully when x is near 0, i.e. u ~= 1,
1040   // and when the input is ~= -inf, i.e. u - 1 ~= -1.
1041   Value cstOne = bcast(f32Cst(builder, 1.0f));
1042   Value cstNegOne = bcast(f32Cst(builder, -1.0f));
1043   Value x = op.getOperand();
1044   Value u = builder.create<math::ExpOp>(x);
1045   Value uEqOneOrNaN =
1046       builder.create<arith::CmpFOp>(arith::CmpFPredicate::UEQ, u, cstOne);
1047   Value uMinusOne = builder.create<arith::SubFOp>(u, cstOne);
1048   Value uMinusOneEqNegOne = builder.create<arith::CmpFOp>(
1049       arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne);
1050   // logU = log(u) ~= x
1051   Value logU = builder.create<math::LogOp>(u);
1052 
1053   // Detect exp(x) = +inf; written this way to avoid having to form +inf.
1054   Value isInf =
1055       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, logU, u);
1056 
1057   // (u - 1) * (x / ~x)
1058   Value expm1 = builder.create<arith::MulFOp>(
1059       uMinusOne, builder.create<arith::DivFOp>(x, logU));
1060   expm1 = builder.create<arith::SelectOp>(isInf, u, expm1);
1061   Value approximation = builder.create<arith::SelectOp>(
1062       uEqOneOrNaN, x,
1063       builder.create<arith::SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1));
1064   rewriter.replaceOp(op, approximation);
1065   return success();
1066 }
1067 
1068 //----------------------------------------------------------------------------//
1069 // Sin and Cos approximation.
1070 //----------------------------------------------------------------------------//
1071 
1072 namespace {
1073 
1074 template <bool isSine, typename OpTy>
1075 struct SinAndCosApproximation : public OpRewritePattern<OpTy> {
1076 public:
1077   using OpRewritePattern<OpTy>::OpRewritePattern;
1078 
1079   LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final;
1080 };
1081 } // namespace
1082 
1083 #define TWO_OVER_PI                                                            \
1084   0.6366197723675813430755350534900574481378385829618257949906693762L
1085 #define PI_OVER_2                                                              \
1086   1.5707963267948966192313216916397514420985846996875529104874722961L
1087 
1088 // Approximates sin(x) or cos(x) by finding the best approximation polynomial in
1089 // the reduced range [0, pi/2] for both sin(x) and cos(x). Then given y in the
1090 // reduced range sin(x) will be computed as sin(y), -sin(y), cos(y) or -cos(y).
1091 template <bool isSine, typename OpTy>
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const1092 LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
1093     OpTy op, PatternRewriter &rewriter) const {
1094   static_assert(
1095       llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value,
1096       "SinAndCosApproximation pattern expects math::SinOp or math::CosOp");
1097 
1098   if (!getElementTypeOrSelf(op.getOperand()).isF32())
1099     return rewriter.notifyMatchFailure(op, "unsupported operand type");
1100 
1101   ArrayRef<int64_t> shape = vectorShape(op.getOperand());
1102 
1103   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1104   auto bcast = [&](Value value) -> Value {
1105     return broadcast(builder, value, shape);
1106   };
1107   auto mul = [&](Value a, Value b) -> Value {
1108     return builder.create<arith::MulFOp>(a, b);
1109   };
1110   auto sub = [&](Value a, Value b) -> Value {
1111     return builder.create<arith::SubFOp>(a, b);
1112   };
1113   auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); };
1114 
1115   auto i32Vec = broadcast(builder.getI32Type(), shape);
1116   auto fPToSingedInteger = [&](Value a) -> Value {
1117     return builder.create<arith::FPToSIOp>(i32Vec, a);
1118   };
1119 
1120   auto modulo4 = [&](Value a) -> Value {
1121     return builder.create<arith::AndIOp>(a, bcast(i32Cst(builder, 3)));
1122   };
1123 
1124   auto isEqualTo = [&](Value a, Value b) -> Value {
1125     return builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, a, b);
1126   };
1127 
1128   auto isGreaterThan = [&](Value a, Value b) -> Value {
1129     return builder.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, a, b);
1130   };
1131 
1132   auto select = [&](Value cond, Value t, Value f) -> Value {
1133     return builder.create<arith::SelectOp>(cond, t, f);
1134   };
1135 
1136   auto fmla = [&](Value a, Value b, Value c) {
1137     return builder.create<math::FmaOp>(a, b, c);
1138   };
1139 
1140   auto bitwiseOr = [&](Value a, Value b) {
1141     return builder.create<arith::OrIOp>(a, b);
1142   };
1143 
1144   Value twoOverPi = bcast(f32Cst(builder, (float)TWO_OVER_PI));
1145   Value piOverTwo = bcast(f32Cst(builder, (float)PI_OVER_2));
1146 
1147   Value x = op.getOperand();
1148 
1149   Value k = floor(mul(x, twoOverPi));
1150 
1151   Value y = sub(x, mul(k, piOverTwo));
1152 
1153   Value cstOne = bcast(f32Cst(builder, 1.0));
1154   Value cstNegativeOne = bcast(f32Cst(builder, -1.0));
1155 
1156   Value cstSC2 = bcast(f32Cst(builder, -0.16666667163372039794921875f));
1157   Value cstSC4 = bcast(f32Cst(builder, 8.333347737789154052734375e-3f));
1158   Value cstSC6 = bcast(f32Cst(builder, -1.9842604524455964565277099609375e-4f));
1159   Value cstSC8 =
1160       bcast(f32Cst(builder, 2.760012648650445044040679931640625e-6f));
1161   Value cstSC10 =
1162       bcast(f32Cst(builder, -2.50293279435709337121807038784027099609375e-8f));
1163 
1164   Value cstCC2 = bcast(f32Cst(builder, -0.5f));
1165   Value cstCC4 = bcast(f32Cst(builder, 4.166664183139801025390625e-2f));
1166   Value cstCC6 = bcast(f32Cst(builder, -1.388833043165504932403564453125e-3f));
1167   Value cstCC8 = bcast(f32Cst(builder, 2.47562347794882953166961669921875e-5f));
1168   Value cstCC10 =
1169       bcast(f32Cst(builder, -2.59630184018533327616751194000244140625e-7f));
1170 
1171   Value kMod4 = modulo4(fPToSingedInteger(k));
1172 
1173   Value kR0 = isEqualTo(kMod4, bcast(i32Cst(builder, 0)));
1174   Value kR1 = isEqualTo(kMod4, bcast(i32Cst(builder, 1)));
1175   Value kR2 = isEqualTo(kMod4, bcast(i32Cst(builder, 2)));
1176   Value kR3 = isEqualTo(kMod4, bcast(i32Cst(builder, 3)));
1177 
1178   Value sinuseCos = isSine ? bitwiseOr(kR1, kR3) : bitwiseOr(kR0, kR2);
1179   Value negativeRange = isSine ? isGreaterThan(kMod4, bcast(i32Cst(builder, 1)))
1180                                : bitwiseOr(kR1, kR2);
1181 
1182   Value y2 = mul(y, y);
1183 
1184   Value base = select(sinuseCos, cstOne, y);
1185   Value cstC2 = select(sinuseCos, cstCC2, cstSC2);
1186   Value cstC4 = select(sinuseCos, cstCC4, cstSC4);
1187   Value cstC6 = select(sinuseCos, cstCC6, cstSC6);
1188   Value cstC8 = select(sinuseCos, cstCC8, cstSC8);
1189   Value cstC10 = select(sinuseCos, cstCC10, cstSC10);
1190 
1191   Value v1 = fmla(y2, cstC10, cstC8);
1192   Value v2 = fmla(y2, v1, cstC6);
1193   Value v3 = fmla(y2, v2, cstC4);
1194   Value v4 = fmla(y2, v3, cstC2);
1195   Value v5 = fmla(y2, v4, cstOne);
1196   Value v6 = mul(base, v5);
1197 
1198   Value approximation = select(negativeRange, mul(cstNegativeOne, v6), v6);
1199 
1200   rewriter.replaceOp(op, approximation);
1201 
1202   return success();
1203 }
1204 
1205 //----------------------------------------------------------------------------//
1206 // Rsqrt approximation.
1207 //----------------------------------------------------------------------------//
1208 
1209 namespace {
1210 struct RsqrtApproximation : public OpRewritePattern<math::RsqrtOp> {
1211   using OpRewritePattern::OpRewritePattern;
1212 
1213   LogicalResult matchAndRewrite(math::RsqrtOp op,
1214                                 PatternRewriter &rewriter) const final;
1215 };
1216 } // namespace
1217 
1218 LogicalResult
matchAndRewrite(math::RsqrtOp op,PatternRewriter & rewriter) const1219 RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
1220                                     PatternRewriter &rewriter) const {
1221   if (!getElementTypeOrSelf(op.getOperand()).isF32())
1222     return rewriter.notifyMatchFailure(op, "unsupported operand type");
1223 
1224   ArrayRef<int64_t> shape = vectorShape(op.getOperand());
1225 
1226   // Only support already-vectorized rsqrt's.
1227   if (shape.empty() || shape.back() % 8 != 0)
1228     return rewriter.notifyMatchFailure(op, "unsupported operand type");
1229 
1230   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1231   auto bcast = [&](Value value) -> Value {
1232     return broadcast(builder, value, shape);
1233   };
1234 
1235   Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u));
1236   Value cstOnePointFive = bcast(f32Cst(builder, 1.5f));
1237   Value cstNegHalf = bcast(f32Cst(builder, -0.5f));
1238   Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u));
1239 
1240   Value negHalf = builder.create<arith::MulFOp>(op.getOperand(), cstNegHalf);
1241 
1242   // Select only the inverse sqrt of positive normals (denormals are
1243   // flushed to zero).
1244   Value ltMinMask = builder.create<arith::CmpFOp>(
1245       arith::CmpFPredicate::OLT, op.getOperand(), cstMinNormPos);
1246   Value infMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
1247                                                 op.getOperand(), cstPosInf);
1248   Value notNormalFiniteMask = builder.create<arith::OrIOp>(ltMinMask, infMask);
1249 
1250   // Compute an approximate result.
1251   Value yApprox = handleMultidimensionalVectors(
1252       builder, op->getOperands(), 8, [&builder](ValueRange operands) -> Value {
1253         return builder.create<x86vector::RsqrtOp>(operands);
1254       });
1255 
1256   // Do a single step of Newton-Raphson iteration to improve the approximation.
1257   // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n).
1258   // It is essential to evaluate the inner term like this because forming
1259   // y_n^2 may over- or underflow.
1260   Value inner = builder.create<arith::MulFOp>(negHalf, yApprox);
1261   Value fma = builder.create<math::FmaOp>(yApprox, inner, cstOnePointFive);
1262   Value yNewton = builder.create<arith::MulFOp>(yApprox, fma);
1263 
1264   // Select the result of the Newton-Raphson step for positive normal arguments.
1265   // For other arguments, choose the output of the intrinsic. This will
1266   // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(x) = +inf if
1267   // x is zero or a positive denormalized float (equivalent to flushing positive
1268   // denormalized inputs to zero).
1269   Value res =
1270       builder.create<arith::SelectOp>(notNormalFiniteMask, yApprox, yNewton);
1271   rewriter.replaceOp(op, res);
1272 
1273   return success();
1274 }
1275 
1276 //----------------------------------------------------------------------------//
1277 
populateMathPolynomialApproximationPatterns(RewritePatternSet & patterns,const MathPolynomialApproximationOptions & options)1278 void mlir::populateMathPolynomialApproximationPatterns(
1279     RewritePatternSet &patterns,
1280     const MathPolynomialApproximationOptions &options) {
1281   patterns.add<AtanApproximation, Atan2Approximation, TanhApproximation,
1282                LogApproximation, Log2Approximation, Log1pApproximation,
1283                ErfPolynomialApproximation, ExpApproximation, ExpM1Approximation,
1284                ReuseF32Expansion<math::Atan2Op>,
1285                SinAndCosApproximation<true, math::SinOp>,
1286                SinAndCosApproximation<false, math::CosOp>>(
1287       patterns.getContext());
1288   if (options.enableAvx2)
1289     patterns.add<RsqrtApproximation>(patterns.getContext());
1290 }
1291