1 //===- PolynomialApproximation.cpp - Approximate math operations ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements expansion of math operations to fast approximations
10 // that do not rely on any of the library functions.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include <climits>
15 #include <cstddef>
16
17 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
18 #include "mlir/Dialect/Math/IR/Math.h"
19 #include "mlir/Dialect/Math/Transforms/Approximation.h"
20 #include "mlir/Dialect/Math/Transforms/Passes.h"
21 #include "mlir/Dialect/Utils/IndexingUtils.h"
22 #include "mlir/Dialect/Vector/IR/VectorOps.h"
23 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
24 #include "mlir/Dialect/X86Vector/X86VectorDialect.h"
25 #include "mlir/IR/Builders.h"
26 #include "mlir/IR/BuiltinTypes.h"
27 #include "mlir/IR/ImplicitLocOpBuilder.h"
28 #include "mlir/IR/OpDefinition.h"
29 #include "mlir/IR/PatternMatch.h"
30 #include "mlir/IR/TypeUtilities.h"
31 #include "mlir/Transforms/DialectConversion.h"
32 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
33 #include "llvm/ADT/ArrayRef.h"
34 #include "llvm/ADT/STLExtras.h"
35
36 using namespace mlir;
37 using namespace mlir::math;
38 using namespace mlir::vector;
39
40 // Returns vector shape if the type is a vector. Returns an empty shape if it is
41 // not a vector.
vectorShape(Type type)42 static ArrayRef<int64_t> vectorShape(Type type) {
43 auto vectorType = type.dyn_cast<VectorType>();
44 return vectorType ? vectorType.getShape() : ArrayRef<int64_t>();
45 }
46
vectorShape(Value value)47 static ArrayRef<int64_t> vectorShape(Value value) {
48 return vectorShape(value.getType());
49 }
50
51 //----------------------------------------------------------------------------//
52 // Broadcast scalar types and values into vector types and values.
53 //----------------------------------------------------------------------------//
54
55 // Broadcasts scalar type into vector type (iff shape is non-scalar).
broadcast(Type type,ArrayRef<int64_t> shape)56 static Type broadcast(Type type, ArrayRef<int64_t> shape) {
57 assert(!type.isa<VectorType>() && "must be scalar type");
58 return !shape.empty() ? VectorType::get(shape, type) : type;
59 }
60
61 // Broadcasts scalar value into vector (iff shape is non-scalar).
broadcast(ImplicitLocOpBuilder & builder,Value value,ArrayRef<int64_t> shape)62 static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
63 ArrayRef<int64_t> shape) {
64 assert(!value.getType().isa<VectorType>() && "must be scalar value");
65 auto type = broadcast(value.getType(), shape);
66 return !shape.empty() ? builder.create<BroadcastOp>(type, value) : value;
67 }
68
69 //----------------------------------------------------------------------------//
70 // Helper function to handle n-D vectors with 1-D operations.
71 //----------------------------------------------------------------------------//
72
73 // Expands and unrolls n-D vector operands into multiple fixed size 1-D vectors
74 // and calls the compute function with 1-D vector operands. Stitches back all
75 // results into the original n-D vector result.
76 //
77 // Examples: vectorWidth = 8
78 // - vector<4x8xf32> unrolled 4 times
79 // - vector<16xf32> expanded to vector<2x8xf32> and unrolled 2 times
80 // - vector<4x16xf32> expanded to vector<4x2x8xf32> and unrolled 4*2 times
81 //
82 // Some math approximations rely on ISA-specific operations that only accept
83 // fixed size 1-D vectors (e.g. AVX expects vectors of width 8).
84 //
85 // It is the caller's responsibility to verify that the inner dimension is
86 // divisible by the vectorWidth, and that all operands have the same vector
87 // shape.
88 static Value
handleMultidimensionalVectors(ImplicitLocOpBuilder & builder,ValueRange operands,int64_t vectorWidth,llvm::function_ref<Value (ValueRange)> compute)89 handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
90 ValueRange operands, int64_t vectorWidth,
91 llvm::function_ref<Value(ValueRange)> compute) {
92 assert(!operands.empty() && "operands must be not empty");
93 assert(vectorWidth > 0 && "vector width must be larger than 0");
94
95 VectorType inputType = operands[0].getType().cast<VectorType>();
96 ArrayRef<int64_t> inputShape = inputType.getShape();
97
98 // If input shape matches target vector width, we can just call the
99 // user-provided compute function with the operands.
100 if (inputShape == llvm::makeArrayRef(vectorWidth))
101 return compute(operands);
102
103 // Check if the inner dimension has to be expanded, or we can directly iterate
104 // over the outer dimensions of the vector.
105 int64_t innerDim = inputShape.back();
106 int64_t expansionDim = innerDim / vectorWidth;
107 assert((innerDim % vectorWidth == 0) && "invalid inner dimension size");
108
109 // Maybe expand operands to the higher rank vector shape that we'll use to
110 // iterate over and extract one dimensional vectors.
111 SmallVector<int64_t> expandedShape(inputShape.begin(), inputShape.end());
112 SmallVector<Value> expandedOperands(operands);
113
114 if (expansionDim > 1) {
115 // Expand shape from [..., innerDim] to [..., expansionDim, vectorWidth].
116 expandedShape.insert(expandedShape.end() - 1, expansionDim);
117 expandedShape.back() = vectorWidth;
118
119 for (unsigned i = 0; i < operands.size(); ++i) {
120 auto operand = operands[i];
121 auto eltType = operand.getType().cast<VectorType>().getElementType();
122 auto expandedType = VectorType::get(expandedShape, eltType);
123 expandedOperands[i] =
124 builder.create<vector::ShapeCastOp>(expandedType, operand);
125 }
126 }
127
128 // Iterate over all outer dimensions of the compute shape vector type.
129 auto iterationDims = ArrayRef<int64_t>(expandedShape).drop_back();
130 int64_t maxLinearIndex = computeMaxLinearIndex(iterationDims);
131
132 SmallVector<int64_t> ones(iterationDims.size(), 1);
133 auto strides = computeStrides(iterationDims, ones);
134
135 // Compute results for each one dimensional vector.
136 SmallVector<Value> results(maxLinearIndex);
137
138 for (int64_t i = 0; i < maxLinearIndex; ++i) {
139 auto offsets = delinearize(strides, i);
140
141 SmallVector<Value> extracted(expandedOperands.size());
142 for (const auto &tuple : llvm::enumerate(expandedOperands))
143 extracted[tuple.index()] =
144 builder.create<vector::ExtractOp>(tuple.value(), offsets);
145
146 results[i] = compute(extracted);
147 }
148
149 // Stitch results together into one large vector.
150 Type resultEltType = results[0].getType().cast<VectorType>().getElementType();
151 Type resultExpandedType = VectorType::get(expandedShape, resultEltType);
152 Value result = builder.create<arith::ConstantOp>(
153 resultExpandedType, builder.getZeroAttr(resultExpandedType));
154
155 for (int64_t i = 0; i < maxLinearIndex; ++i)
156 result = builder.create<vector::InsertOp>(results[i], result,
157 delinearize(strides, i));
158
159 // Reshape back to the original vector shape.
160 return builder.create<vector::ShapeCastOp>(
161 VectorType::get(inputShape, resultEltType), result);
162 }
163
164 //----------------------------------------------------------------------------//
165 // Helper functions to create constants.
166 //----------------------------------------------------------------------------//
167
f32Cst(ImplicitLocOpBuilder & builder,float value)168 static Value f32Cst(ImplicitLocOpBuilder &builder, float value) {
169 return builder.create<arith::ConstantOp>(builder.getF32FloatAttr(value));
170 }
171
i32Cst(ImplicitLocOpBuilder & builder,int32_t value)172 static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value) {
173 return builder.create<arith::ConstantOp>(builder.getI32IntegerAttr(value));
174 }
175
f32FromBits(ImplicitLocOpBuilder & builder,uint32_t bits)176 static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits) {
177 Value i32Value = i32Cst(builder, static_cast<int32_t>(bits));
178 return builder.create<arith::BitcastOp>(builder.getF32Type(), i32Value);
179 }
180
181 //----------------------------------------------------------------------------//
182 // Helper functions to build math functions approximations.
183 //----------------------------------------------------------------------------//
184
185 // Return the minimum of the two values or NaN if value is NaN
min(ImplicitLocOpBuilder & builder,Value value,Value bound)186 static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound) {
187 return builder.create<arith::SelectOp>(
188 builder.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, value, bound),
189 value, bound);
190 }
191
192 // Return the maximum of the two values or NaN if value is NaN
max(ImplicitLocOpBuilder & builder,Value value,Value bound)193 static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound) {
194 return builder.create<arith::SelectOp>(
195 builder.create<arith::CmpFOp>(arith::CmpFPredicate::UGT, value, bound),
196 value, bound);
197 }
198
199 // Return the clamped value or NaN if value is NaN
clamp(ImplicitLocOpBuilder & builder,Value value,Value lowerBound,Value upperBound)200 static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound,
201 Value upperBound) {
202 return max(builder, min(builder, value, upperBound), lowerBound);
203 }
204
205 // Decomposes given floating point value `arg` into a normalized fraction and
206 // an integral power of two (see std::frexp). Returned values have float type.
frexp(ImplicitLocOpBuilder & builder,Value arg,bool isPositive=false)207 static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
208 bool isPositive = false) {
209 assert(getElementTypeOrSelf(arg).isF32() && "arg must be f32 type");
210 ArrayRef<int64_t> shape = vectorShape(arg);
211
212 auto bcast = [&](Value value) -> Value {
213 return broadcast(builder, value, shape);
214 };
215
216 auto i32 = builder.getIntegerType(32);
217 auto i32Vec = broadcast(i32, shape);
218 auto f32Vec = broadcast(builder.getF32Type(), shape);
219
220 Value cst126f = f32Cst(builder, 126.0f);
221 Value cstHalf = f32Cst(builder, 0.5f);
222 Value cstInvMantMask = f32FromBits(builder, ~0x7f800000u);
223
224 // Bitcast to i32 for bitwise operations.
225 Value i32Half = builder.create<arith::BitcastOp>(i32, cstHalf);
226 Value i32InvMantMask = builder.create<arith::BitcastOp>(i32, cstInvMantMask);
227 Value i32Arg = builder.create<arith::BitcastOp>(i32Vec, arg);
228
229 // Compute normalized fraction.
230 Value tmp0 = builder.create<arith::AndIOp>(i32Arg, bcast(i32InvMantMask));
231 Value tmp1 = builder.create<arith::OrIOp>(tmp0, bcast(i32Half));
232 Value normalizedFraction = builder.create<arith::BitcastOp>(f32Vec, tmp1);
233
234 // Compute exponent.
235 Value arg0 = isPositive ? arg : builder.create<math::AbsOp>(arg);
236 Value biasedExponentBits = builder.create<arith::ShRUIOp>(
237 builder.create<arith::BitcastOp>(i32Vec, arg0),
238 bcast(i32Cst(builder, 23)));
239 Value biasedExponent =
240 builder.create<arith::SIToFPOp>(f32Vec, biasedExponentBits);
241 Value exponent =
242 builder.create<arith::SubFOp>(biasedExponent, bcast(cst126f));
243
244 return {normalizedFraction, exponent};
245 }
246
247 // Computes exp2 for an i32 argument.
exp2I32(ImplicitLocOpBuilder & builder,Value arg)248 static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
249 assert(getElementTypeOrSelf(arg).isInteger(32) && "arg must be i32 type");
250 ArrayRef<int64_t> shape = vectorShape(arg);
251
252 auto bcast = [&](Value value) -> Value {
253 return broadcast(builder, value, shape);
254 };
255
256 auto f32Vec = broadcast(builder.getF32Type(), shape);
257 // The exponent of f32 located at 23-bit.
258 auto exponetBitLocation = bcast(i32Cst(builder, 23));
259 // Set the exponent bias to zero.
260 auto bias = bcast(i32Cst(builder, 127));
261
262 Value biasedArg = builder.create<arith::AddIOp>(arg, bias);
263 Value exp2ValueInt =
264 builder.create<arith::ShLIOp>(biasedArg, exponetBitLocation);
265 Value exp2ValueF32 = builder.create<arith::BitcastOp>(f32Vec, exp2ValueInt);
266
267 return exp2ValueF32;
268 }
269
270 namespace {
makePolynomialCalculation(ImplicitLocOpBuilder & builder,llvm::ArrayRef<Value> coeffs,Value x)271 Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
272 llvm::ArrayRef<Value> coeffs, Value x) {
273 assert(getElementTypeOrSelf(x).isF32() && "x must be f32 type");
274 ArrayRef<int64_t> shape = vectorShape(x);
275
276 if (coeffs.empty())
277 return broadcast(builder, f32Cst(builder, 0.0f), shape);
278
279 if (coeffs.size() == 1)
280 return coeffs[0];
281
282 Value res = builder.create<math::FmaOp>(x, coeffs[coeffs.size() - 1],
283 coeffs[coeffs.size() - 2]);
284 for (auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) {
285 res = builder.create<math::FmaOp>(x, res, coeffs[i]);
286 }
287 return res;
288 }
289 } // namespace
290
291 //----------------------------------------------------------------------------//
292 // Helper function/pattern to insert casts for reusing F32 bit expansion.
293 //----------------------------------------------------------------------------//
294
295 template <typename T>
insertCasts(Operation * op,PatternRewriter & rewriter)296 LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter) {
297 // Conservatively only allow where the operand and result types are exactly 1.
298 Type origType = op->getResultTypes().front();
299 for (Type t : llvm::drop_begin(op->getResultTypes()))
300 if (origType != t)
301 return rewriter.notifyMatchFailure(op, "required all types to match");
302 for (Type t : op->getOperandTypes())
303 if (origType != t)
304 return rewriter.notifyMatchFailure(op, "required all types to match");
305
306 // Skip if already F32 or larger than 32 bits.
307 if (getElementTypeOrSelf(origType).isF32() ||
308 getElementTypeOrSelf(origType).getIntOrFloatBitWidth() > 32)
309 return failure();
310
311 // Create F32 equivalent type.
312 Type newType;
313 if (auto shaped = origType.dyn_cast<ShapedType>()) {
314 newType = shaped.clone(rewriter.getF32Type());
315 } else if (origType.isa<FloatType>()) {
316 newType = rewriter.getF32Type();
317 } else {
318 return rewriter.notifyMatchFailure(op,
319 "unable to find F32 equivalent type");
320 }
321
322 Location loc = op->getLoc();
323 SmallVector<Value> operands;
324 for (auto operand : op->getOperands())
325 operands.push_back(rewriter.create<arith::ExtFOp>(loc, newType, operand));
326 auto result = rewriter.create<math::Atan2Op>(loc, newType, operands);
327 rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, origType, result);
328 return success();
329 }
330
331 namespace {
332 // Pattern to cast to F32 to reuse F32 expansion as fallback for single-result
333 // op.
334 // TODO: Consider revising to avoid adding multiple casts for a subgraph that is
335 // all in lower precision. Currently this is only fallback support and performs
336 // simplistic casting.
337 template <typename T>
338 struct ReuseF32Expansion : public OpRewritePattern<T> {
339 public:
340 using OpRewritePattern<T>::OpRewritePattern;
matchAndRewrite__anond27a14240411::ReuseF32Expansion341 LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter) const final {
342 static_assert(
343 T::template hasTrait<mlir::OpTrait::SameOperandsAndResultType>(),
344 "requires same operands and result types");
345 return insertCasts<T>(op, rewriter);
346 }
347 };
348 } // namespace
349
350 //----------------------------------------------------------------------------//
351 // AtanOp approximation.
352 //----------------------------------------------------------------------------//
353
354 namespace {
355 struct AtanApproximation : public OpRewritePattern<math::AtanOp> {
356 public:
357 using OpRewritePattern::OpRewritePattern;
358
359 LogicalResult matchAndRewrite(math::AtanOp op,
360 PatternRewriter &rewriter) const final;
361 };
362 } // namespace
363
364 LogicalResult
matchAndRewrite(math::AtanOp op,PatternRewriter & rewriter) const365 AtanApproximation::matchAndRewrite(math::AtanOp op,
366 PatternRewriter &rewriter) const {
367 auto operand = op.getOperand();
368 if (!getElementTypeOrSelf(operand).isF32())
369 return rewriter.notifyMatchFailure(op, "unsupported operand type");
370
371 ArrayRef<int64_t> shape = vectorShape(op.getOperand());
372
373 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
374 auto one = broadcast(builder, f32Cst(builder, 1.0f), shape);
375
376 // Remap the problem over [0.0, 1.0] by looking at the absolute value and the
377 // handling symmetry.
378 Value abs = builder.create<math::AbsOp>(operand);
379 Value reciprocal = builder.create<arith::DivFOp>(one, abs);
380 Value compare =
381 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, abs, reciprocal);
382 Value x = builder.create<arith::SelectOp>(compare, abs, reciprocal);
383
384 // Perform the Taylor series approximation for atan over the range
385 // [-1.0, 1.0].
386 auto n1 = broadcast(builder, f32Cst(builder, 0.14418283f), shape);
387 auto n2 = broadcast(builder, f32Cst(builder, -0.34999234f), shape);
388 auto n3 = broadcast(builder, f32Cst(builder, -0.01067831f), shape);
389 auto n4 = broadcast(builder, f32Cst(builder, 1.00209986f), shape);
390
391 Value p = builder.create<math::FmaOp>(x, n1, n2);
392 p = builder.create<math::FmaOp>(x, p, n3);
393 p = builder.create<math::FmaOp>(x, p, n4);
394 p = builder.create<arith::MulFOp>(x, p);
395
396 // Remap the solution for over [0.0, 1.0] to [0.0, inf]
397 auto halfPi = broadcast(builder, f32Cst(builder, 1.57079632679f), shape);
398 Value sub = builder.create<arith::SubFOp>(halfPi, p);
399 Value select = builder.create<arith::SelectOp>(compare, p, sub);
400
401 // Correct for signing of the input.
402 rewriter.replaceOpWithNewOp<math::CopySignOp>(op, select, operand);
403 return success();
404 }
405
406 //----------------------------------------------------------------------------//
407 // AtanOp approximation.
408 //----------------------------------------------------------------------------//
409
410 namespace {
411 struct Atan2Approximation : public OpRewritePattern<math::Atan2Op> {
412 public:
413 using OpRewritePattern::OpRewritePattern;
414
415 LogicalResult matchAndRewrite(math::Atan2Op op,
416 PatternRewriter &rewriter) const final;
417 };
418 } // namespace
419
420 LogicalResult
matchAndRewrite(math::Atan2Op op,PatternRewriter & rewriter) const421 Atan2Approximation::matchAndRewrite(math::Atan2Op op,
422 PatternRewriter &rewriter) const {
423 auto y = op.getOperand(0);
424 auto x = op.getOperand(1);
425 if (!getElementTypeOrSelf(x).isF32())
426 return rewriter.notifyMatchFailure(op, "unsupported operand type");
427
428 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
429 ArrayRef<int64_t> shape = vectorShape(op.getResult());
430
431 // Compute atan in the valid range.
432 auto div = builder.create<arith::DivFOp>(y, x);
433 auto atan = builder.create<math::AtanOp>(div);
434
435 // Determine what the atan would be for a 180 degree rotation.
436 auto zero = broadcast(builder, f32Cst(builder, 0.0f), shape);
437 auto pi = broadcast(builder, f32Cst(builder, 3.14159265359f), shape);
438 auto addPi = builder.create<arith::AddFOp>(atan, pi);
439 auto subPi = builder.create<arith::SubFOp>(atan, pi);
440 auto atanGt =
441 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, atan, zero);
442 auto flippedAtan = builder.create<arith::SelectOp>(atanGt, subPi, addPi);
443
444 // Determine whether to directly use atan or use the 180 degree flip
445 auto xGt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zero);
446 Value result = builder.create<arith::SelectOp>(xGt, atan, flippedAtan);
447
448 // Handle x = 0, y > 0
449 Value xZero =
450 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, x, zero);
451 Value yGt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, y, zero);
452 Value isHalfPi = builder.create<arith::AndIOp>(xZero, yGt);
453 auto halfPi = broadcast(builder, f32Cst(builder, 1.57079632679f), shape);
454 result = builder.create<arith::SelectOp>(isHalfPi, halfPi, result);
455
456 // Handle x = 0, y < 0
457 Value yLt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, y, zero);
458 Value isNegativeHalfPiPi = builder.create<arith::AndIOp>(xZero, yLt);
459 auto negativeHalfPiPi =
460 broadcast(builder, f32Cst(builder, -1.57079632679f), shape);
461 result = builder.create<arith::SelectOp>(isNegativeHalfPiPi, negativeHalfPiPi,
462 result);
463
464 // Handle x = 0, y = 0;
465 Value yZero =
466 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, y, zero);
467 Value isNan = builder.create<arith::AndIOp>(xZero, yZero);
468 Value cstNan = broadcast(builder, f32FromBits(builder, 0x7fc00000), shape);
469 result = builder.create<arith::SelectOp>(isNan, cstNan, result);
470
471 rewriter.replaceOp(op, result);
472 return success();
473 }
474
475 //----------------------------------------------------------------------------//
476 // TanhOp approximation.
477 //----------------------------------------------------------------------------//
478
479 namespace {
480 struct TanhApproximation : public OpRewritePattern<math::TanhOp> {
481 public:
482 using OpRewritePattern::OpRewritePattern;
483
484 LogicalResult matchAndRewrite(math::TanhOp op,
485 PatternRewriter &rewriter) const final;
486 };
487 } // namespace
488
489 LogicalResult
matchAndRewrite(math::TanhOp op,PatternRewriter & rewriter) const490 TanhApproximation::matchAndRewrite(math::TanhOp op,
491 PatternRewriter &rewriter) const {
492 if (!getElementTypeOrSelf(op.getOperand()).isF32())
493 return rewriter.notifyMatchFailure(op, "unsupported operand type");
494
495 ArrayRef<int64_t> shape = vectorShape(op.getOperand());
496
497 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
498 auto bcast = [&](Value value) -> Value {
499 return broadcast(builder, value, shape);
500 };
501
502 // Clamp operand into [plusClamp, minusClamp] range.
503 Value minusClamp = bcast(f32Cst(builder, -7.99881172180175781f));
504 Value plusClamp = bcast(f32Cst(builder, 7.99881172180175781f));
505 Value x = clamp(builder, op.getOperand(), minusClamp, plusClamp);
506
507 // Mask for tiny values that are approximated with `operand`.
508 Value tiny = bcast(f32Cst(builder, 0.0004f));
509 Value tinyMask = builder.create<arith::CmpFOp>(
510 arith::CmpFPredicate::OLT, builder.create<math::AbsOp>(op.getOperand()),
511 tiny);
512
513 // The monomial coefficients of the numerator polynomial (odd).
514 Value alpha1 = bcast(f32Cst(builder, 4.89352455891786e-03f));
515 Value alpha3 = bcast(f32Cst(builder, 6.37261928875436e-04f));
516 Value alpha5 = bcast(f32Cst(builder, 1.48572235717979e-05f));
517 Value alpha7 = bcast(f32Cst(builder, 5.12229709037114e-08f));
518 Value alpha9 = bcast(f32Cst(builder, -8.60467152213735e-11f));
519 Value alpha11 = bcast(f32Cst(builder, 2.00018790482477e-13f));
520 Value alpha13 = bcast(f32Cst(builder, -2.76076847742355e-16f));
521
522 // The monomial coefficients of the denominator polynomial (even).
523 Value beta0 = bcast(f32Cst(builder, 4.89352518554385e-03f));
524 Value beta2 = bcast(f32Cst(builder, 2.26843463243900e-03f));
525 Value beta4 = bcast(f32Cst(builder, 1.18534705686654e-04f));
526 Value beta6 = bcast(f32Cst(builder, 1.19825839466702e-06f));
527
528 // Since the polynomials are odd/even, we need x^2.
529 Value x2 = builder.create<arith::MulFOp>(x, x);
530
531 // Evaluate the numerator polynomial p.
532 Value p = builder.create<math::FmaOp>(x2, alpha13, alpha11);
533 p = builder.create<math::FmaOp>(x2, p, alpha9);
534 p = builder.create<math::FmaOp>(x2, p, alpha7);
535 p = builder.create<math::FmaOp>(x2, p, alpha5);
536 p = builder.create<math::FmaOp>(x2, p, alpha3);
537 p = builder.create<math::FmaOp>(x2, p, alpha1);
538 p = builder.create<arith::MulFOp>(x, p);
539
540 // Evaluate the denominator polynomial q.
541 Value q = builder.create<math::FmaOp>(x2, beta6, beta4);
542 q = builder.create<math::FmaOp>(x2, q, beta2);
543 q = builder.create<math::FmaOp>(x2, q, beta0);
544
545 // Divide the numerator by the denominator.
546 Value res = builder.create<arith::SelectOp>(
547 tinyMask, x, builder.create<arith::DivFOp>(p, q));
548
549 rewriter.replaceOp(op, res);
550
551 return success();
552 }
553
554 #define LN2_VALUE \
555 0.693147180559945309417232121458176568075500134360255254120680009493393621L
556 #define LOG2E_VALUE \
557 1.442695040888963407359924681001892137426645954152985934135449406931109219L
558
559 //----------------------------------------------------------------------------//
560 // LogOp and Log2Op approximation.
561 //----------------------------------------------------------------------------//
562
563 namespace {
564 template <typename Op>
565 struct LogApproximationBase : public OpRewritePattern<Op> {
566 using OpRewritePattern<Op>::OpRewritePattern;
567
568 /// Base 2 if 'base2' is set; natural logarithm (base e) otherwise.
569 LogicalResult logMatchAndRewrite(Op op, PatternRewriter &rewriter,
570 bool base2) const;
571 };
572 } // namespace
573
574 // This approximation comes from Julien Pommier's SSE math library.
575 // Link: http://gruntthepeon.free.fr/ssemath
576 template <typename Op>
577 LogicalResult
logMatchAndRewrite(Op op,PatternRewriter & rewriter,bool base2) const578 LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
579 bool base2) const {
580 if (!getElementTypeOrSelf(op.getOperand()).isF32())
581 return rewriter.notifyMatchFailure(op, "unsupported operand type");
582
583 ArrayRef<int64_t> shape = vectorShape(op.getOperand());
584
585 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
586 auto bcast = [&](Value value) -> Value {
587 return broadcast(builder, value, shape);
588 };
589
590 Value cstZero = bcast(f32Cst(builder, 0.0f));
591 Value cstOne = bcast(f32Cst(builder, 1.0f));
592 Value cstNegHalf = bcast(f32Cst(builder, -0.5f));
593
594 // The smallest non denormalized float number.
595 Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u));
596 Value cstMinusInf = bcast(f32FromBits(builder, 0xff800000u));
597 Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u));
598 Value cstNan = bcast(f32FromBits(builder, 0x7fc00000));
599
600 // Polynomial coefficients.
601 Value cstCephesSQRTHF = bcast(f32Cst(builder, 0.707106781186547524f));
602 Value cstCephesLogP0 = bcast(f32Cst(builder, 7.0376836292E-2f));
603 Value cstCephesLogP1 = bcast(f32Cst(builder, -1.1514610310E-1f));
604 Value cstCephesLogP2 = bcast(f32Cst(builder, 1.1676998740E-1f));
605 Value cstCephesLogP3 = bcast(f32Cst(builder, -1.2420140846E-1f));
606 Value cstCephesLogP4 = bcast(f32Cst(builder, +1.4249322787E-1f));
607 Value cstCephesLogP5 = bcast(f32Cst(builder, -1.6668057665E-1f));
608 Value cstCephesLogP6 = bcast(f32Cst(builder, +2.0000714765E-1f));
609 Value cstCephesLogP7 = bcast(f32Cst(builder, -2.4999993993E-1f));
610 Value cstCephesLogP8 = bcast(f32Cst(builder, +3.3333331174E-1f));
611
612 Value x = op.getOperand();
613
614 // Truncate input values to the minimum positive normal.
615 x = max(builder, x, cstMinNormPos);
616
617 // Extract significant in the range [0.5,1) and exponent.
618 std::pair<Value, Value> pair = frexp(builder, x, /*isPositive=*/true);
619 x = pair.first;
620 Value e = pair.second;
621
622 // Shift the inputs from the range [0.5,1) to [sqrt(1/2), sqrt(2)) and shift
623 // by -1.0. The values are then centered around 0, which improves the
624 // stability of the polynomial evaluation:
625 //
626 // if( x < SQRTHF ) {
627 // e -= 1;
628 // x = x + x - 1.0;
629 // } else { x = x - 1.0; }
630 Value mask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x,
631 cstCephesSQRTHF);
632 Value tmp = builder.create<arith::SelectOp>(mask, x, cstZero);
633
634 x = builder.create<arith::SubFOp>(x, cstOne);
635 e = builder.create<arith::SubFOp>(
636 e, builder.create<arith::SelectOp>(mask, cstOne, cstZero));
637 x = builder.create<arith::AddFOp>(x, tmp);
638
639 Value x2 = builder.create<arith::MulFOp>(x, x);
640 Value x3 = builder.create<arith::MulFOp>(x2, x);
641
642 // Evaluate the polynomial approximant of degree 8 in three parts.
643 Value y0, y1, y2;
644 y0 = builder.create<math::FmaOp>(cstCephesLogP0, x, cstCephesLogP1);
645 y1 = builder.create<math::FmaOp>(cstCephesLogP3, x, cstCephesLogP4);
646 y2 = builder.create<math::FmaOp>(cstCephesLogP6, x, cstCephesLogP7);
647 y0 = builder.create<math::FmaOp>(y0, x, cstCephesLogP2);
648 y1 = builder.create<math::FmaOp>(y1, x, cstCephesLogP5);
649 y2 = builder.create<math::FmaOp>(y2, x, cstCephesLogP8);
650 y0 = builder.create<math::FmaOp>(y0, x3, y1);
651 y0 = builder.create<math::FmaOp>(y0, x3, y2);
652 y0 = builder.create<arith::MulFOp>(y0, x3);
653
654 y0 = builder.create<math::FmaOp>(cstNegHalf, x2, y0);
655 x = builder.create<arith::AddFOp>(x, y0);
656
657 if (base2) {
658 Value cstLog2e = bcast(f32Cst(builder, static_cast<float>(LOG2E_VALUE)));
659 x = builder.create<math::FmaOp>(x, cstLog2e, e);
660 } else {
661 Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE)));
662 x = builder.create<math::FmaOp>(e, cstLn2, x);
663 }
664
665 Value invalidMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::ULT,
666 op.getOperand(), cstZero);
667 Value zeroMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
668 op.getOperand(), cstZero);
669 Value posInfMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
670 op.getOperand(), cstPosInf);
671
672 // Filter out invalid values:
673 // • x == 0 -> -INF
674 // • x < 0 -> NAN
675 // • x == +INF -> +INF
676 Value aproximation = builder.create<arith::SelectOp>(
677 zeroMask, cstMinusInf,
678 builder.create<arith::SelectOp>(
679 invalidMask, cstNan,
680 builder.create<arith::SelectOp>(posInfMask, cstPosInf, x)));
681
682 rewriter.replaceOp(op, aproximation);
683
684 return success();
685 }
686
687 namespace {
688 struct LogApproximation : public LogApproximationBase<math::LogOp> {
689 using LogApproximationBase::LogApproximationBase;
690
matchAndRewrite__anond27a14240b11::LogApproximation691 LogicalResult matchAndRewrite(math::LogOp op,
692 PatternRewriter &rewriter) const final {
693 return logMatchAndRewrite(op, rewriter, /*base2=*/false);
694 }
695 };
696 } // namespace
697
698 namespace {
699 struct Log2Approximation : public LogApproximationBase<math::Log2Op> {
700 using LogApproximationBase::LogApproximationBase;
701
matchAndRewrite__anond27a14240c11::Log2Approximation702 LogicalResult matchAndRewrite(math::Log2Op op,
703 PatternRewriter &rewriter) const final {
704 return logMatchAndRewrite(op, rewriter, /*base2=*/true);
705 }
706 };
707 } // namespace
708
709 //----------------------------------------------------------------------------//
710 // Log1p approximation.
711 //----------------------------------------------------------------------------//
712
713 namespace {
714 struct Log1pApproximation : public OpRewritePattern<math::Log1pOp> {
715 public:
716 using OpRewritePattern::OpRewritePattern;
717
718 LogicalResult matchAndRewrite(math::Log1pOp op,
719 PatternRewriter &rewriter) const final;
720 };
721 } // namespace
722
723 // Approximate log(1+x).
724 LogicalResult
matchAndRewrite(math::Log1pOp op,PatternRewriter & rewriter) const725 Log1pApproximation::matchAndRewrite(math::Log1pOp op,
726 PatternRewriter &rewriter) const {
727 if (!getElementTypeOrSelf(op.getOperand()).isF32())
728 return rewriter.notifyMatchFailure(op, "unsupported operand type");
729
730 ArrayRef<int64_t> shape = vectorShape(op.getOperand());
731
732 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
733 auto bcast = [&](Value value) -> Value {
734 return broadcast(builder, value, shape);
735 };
736
737 // Approximate log(1+x) using the following, due to W. Kahan:
738 // u = x + 1.0;
739 // if (u == 1.0 || u == inf) return x;
740 // return x * log(u) / (u - 1.0);
741 // ^^^^^^^^^^^^^^^^^^^^^^
742 // "logLarge" below.
743 Value cstOne = bcast(f32Cst(builder, 1.0f));
744 Value x = op.getOperand();
745 Value u = builder.create<arith::AddFOp>(x, cstOne);
746 Value uSmall =
747 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne);
748 Value logU = builder.create<math::LogOp>(u);
749 Value uInf =
750 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, logU);
751 Value logLarge = builder.create<arith::MulFOp>(
752 x, builder.create<arith::DivFOp>(
753 logU, builder.create<arith::SubFOp>(u, cstOne)));
754 Value approximation = builder.create<arith::SelectOp>(
755 builder.create<arith::OrIOp>(uSmall, uInf), x, logLarge);
756 rewriter.replaceOp(op, approximation);
757 return success();
758 }
759
760 //----------------------------------------------------------------------------//
761 // Erf approximation.
762 //----------------------------------------------------------------------------//
763
764 // Approximates erf(x) with
765 // a - P(x)/Q(x)
766 // where P and Q are polynomials of degree 4.
767 // Different coefficients are chosen based on the value of x.
768 // The approximation error is ~2.5e-07.
769 // Boost's minimax tool that utilizes the Remez method was used to find the
770 // coefficients.
771 LogicalResult
matchAndRewrite(math::ErfOp op,PatternRewriter & rewriter) const772 ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
773 PatternRewriter &rewriter) const {
774 if (!getElementTypeOrSelf(op.getOperand()).isF32())
775 return rewriter.notifyMatchFailure(op, "unsupported operand type");
776
777 ArrayRef<int64_t> shape = vectorShape(op.getOperand());
778
779 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
780 auto bcast = [&](Value value) -> Value {
781 return broadcast(builder, value, shape);
782 };
783
784 const int intervalsCount = 3;
785 const int polyDegree = 4;
786
787 Value zero = bcast(f32Cst(builder, 0));
788 Value one = bcast(f32Cst(builder, 1));
789 Value pp[intervalsCount][polyDegree + 1];
790 pp[0][0] = bcast(f32Cst(builder, +0.00000000000000000e+00f));
791 pp[0][1] = bcast(f32Cst(builder, +1.12837916222975858e+00f));
792 pp[0][2] = bcast(f32Cst(builder, -5.23018562988006470e-01f));
793 pp[0][3] = bcast(f32Cst(builder, +2.09741709609267072e-01f));
794 pp[0][4] = bcast(f32Cst(builder, +2.58146801602987875e-02f));
795 pp[1][0] = bcast(f32Cst(builder, +0.00000000000000000e+00f));
796 pp[1][1] = bcast(f32Cst(builder, +1.12750687816789140e+00f));
797 pp[1][2] = bcast(f32Cst(builder, -3.64721408487825775e-01f));
798 pp[1][3] = bcast(f32Cst(builder, +1.18407396425136952e-01f));
799 pp[1][4] = bcast(f32Cst(builder, +3.70645533056476558e-02f));
800 pp[2][0] = bcast(f32Cst(builder, -3.30093071049483172e-03f));
801 pp[2][1] = bcast(f32Cst(builder, +3.51961938357697011e-03f));
802 pp[2][2] = bcast(f32Cst(builder, -1.41373622814988039e-03f));
803 pp[2][3] = bcast(f32Cst(builder, +2.53447094961941348e-04f));
804 pp[2][4] = bcast(f32Cst(builder, -1.71048029455037401e-05f));
805
806 Value qq[intervalsCount][polyDegree + 1];
807 qq[0][0] = bcast(f32Cst(builder, +1.000000000000000000e+00f));
808 qq[0][1] = bcast(f32Cst(builder, -4.635138185962547255e-01f));
809 qq[0][2] = bcast(f32Cst(builder, +5.192301327279782447e-01f));
810 qq[0][3] = bcast(f32Cst(builder, -1.318089722204810087e-01f));
811 qq[0][4] = bcast(f32Cst(builder, +7.397964654672315005e-02f));
812 qq[1][0] = bcast(f32Cst(builder, +1.00000000000000000e+00f));
813 qq[1][1] = bcast(f32Cst(builder, -3.27607011824493086e-01f));
814 qq[1][2] = bcast(f32Cst(builder, +4.48369090658821977e-01f));
815 qq[1][3] = bcast(f32Cst(builder, -8.83462621207857930e-02f));
816 qq[1][4] = bcast(f32Cst(builder, +5.72442770283176093e-02f));
817 qq[2][0] = bcast(f32Cst(builder, +1.00000000000000000e+00f));
818 qq[2][1] = bcast(f32Cst(builder, -2.06069165953913769e+00f));
819 qq[2][2] = bcast(f32Cst(builder, +1.62705939945477759e+00f));
820 qq[2][3] = bcast(f32Cst(builder, -5.83389859211130017e-01f));
821 qq[2][4] = bcast(f32Cst(builder, +8.21908939856640930e-02f));
822
823 Value offsets[intervalsCount];
824 offsets[0] = bcast(f32Cst(builder, 0.0f));
825 offsets[1] = bcast(f32Cst(builder, 0.0f));
826 offsets[2] = bcast(f32Cst(builder, 1.0f));
827
828 Value bounds[intervalsCount];
829 bounds[0] = bcast(f32Cst(builder, 0.8f));
830 bounds[1] = bcast(f32Cst(builder, 2.0f));
831 bounds[2] = bcast(f32Cst(builder, 3.75f));
832
833 Value isNegativeArg = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT,
834 op.getOperand(), zero);
835 Value negArg = builder.create<arith::NegFOp>(op.getOperand());
836 Value x =
837 builder.create<arith::SelectOp>(isNegativeArg, negArg, op.getOperand());
838
839 Value offset = offsets[0];
840 Value p[polyDegree + 1];
841 Value q[polyDegree + 1];
842 for (int i = 0; i <= polyDegree; ++i) {
843 p[i] = pp[0][i];
844 q[i] = qq[0][i];
845 }
846
847 // TODO: maybe use vector stacking to reduce the number of selects.
848 Value isLessThanBound[intervalsCount];
849 for (int j = 0; j < intervalsCount - 1; ++j) {
850 isLessThanBound[j] =
851 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, bounds[j]);
852 for (int i = 0; i <= polyDegree; ++i) {
853 p[i] = builder.create<arith::SelectOp>(isLessThanBound[j], p[i],
854 pp[j + 1][i]);
855 q[i] = builder.create<arith::SelectOp>(isLessThanBound[j], q[i],
856 qq[j + 1][i]);
857 }
858 offset = builder.create<arith::SelectOp>(isLessThanBound[j], offset,
859 offsets[j + 1]);
860 }
861 isLessThanBound[intervalsCount - 1] = builder.create<arith::CmpFOp>(
862 arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]);
863
864 Value pPoly = makePolynomialCalculation(builder, p, x);
865 Value qPoly = makePolynomialCalculation(builder, q, x);
866 Value rationalPoly = builder.create<arith::DivFOp>(pPoly, qPoly);
867 Value formula = builder.create<arith::AddFOp>(offset, rationalPoly);
868 formula = builder.create<arith::SelectOp>(isLessThanBound[intervalsCount - 1],
869 formula, one);
870
871 // erf is odd function: erf(x) = -erf(-x).
872 Value negFormula = builder.create<arith::NegFOp>(formula);
873 Value res =
874 builder.create<arith::SelectOp>(isNegativeArg, negFormula, formula);
875
876 rewriter.replaceOp(op, res);
877
878 return success();
879 }
880
881 //----------------------------------------------------------------------------//
882 // Exp approximation.
883 //----------------------------------------------------------------------------//
884
885 namespace {
886
887 struct ExpApproximation : public OpRewritePattern<math::ExpOp> {
888 public:
889 using OpRewritePattern::OpRewritePattern;
890
891 LogicalResult matchAndRewrite(math::ExpOp op,
892 PatternRewriter &rewriter) const final;
893 };
894 } // namespace
895
896 // Approximate exp(x) using its reduced range exp(y) where y is in the range
897 // [0, ln(2)], let y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2), exp(x)
898 // = exp(y) * 2^k. exp(y).
899 LogicalResult
matchAndRewrite(math::ExpOp op,PatternRewriter & rewriter) const900 ExpApproximation::matchAndRewrite(math::ExpOp op,
901 PatternRewriter &rewriter) const {
902 if (!getElementTypeOrSelf(op.getOperand()).isF32())
903 return rewriter.notifyMatchFailure(op, "unsupported operand type");
904
905 ArrayRef<int64_t> shape = vectorShape(op.getOperand());
906
907 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
908
909 // TODO: Consider a common pattern rewriter with all methods below to
910 // write the approximations.
911 auto bcast = [&](Value value) -> Value {
912 return broadcast(builder, value, shape);
913 };
914 auto fmla = [&](Value a, Value b, Value c) {
915 return builder.create<math::FmaOp>(a, b, c);
916 };
917 auto mul = [&](Value a, Value b) -> Value {
918 return builder.create<arith::MulFOp>(a, b);
919 };
920 auto sub = [&](Value a, Value b) -> Value {
921 return builder.create<arith::SubFOp>(a, b);
922 };
923 auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); };
924
925 Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE)));
926 Value cstLog2E = bcast(f32Cst(builder, static_cast<float>(LOG2E_VALUE)));
927
928 // Polynomial coefficients.
929 Value cstCephesExpP0 = bcast(f32Cst(builder, 1.0));
930 Value cstCephesExpP1 = bcast(f32Cst(builder, 1.0));
931 Value cstCephesExpP2 = bcast(f32Cst(builder, 0.49970514590562437052f));
932 Value cstCephesExpP3 = bcast(f32Cst(builder, 0.16873890085469545053f));
933 Value cstCephesExpP4 = bcast(f32Cst(builder, 0.03668965196652099192f));
934 Value cstCephesExpP5 = bcast(f32Cst(builder, 0.01314350012789660196f));
935
936 Value x = op.getOperand();
937
938 Value isNan = builder.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, x, x);
939
940 // Reduced y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2)
941 Value xL2Inv = mul(x, cstLog2E);
942 Value kF32 = floor(xL2Inv);
943 Value kLn2 = mul(kF32, cstLn2);
944 Value y = sub(x, kLn2);
945
946 // Use Estrin's evaluation scheme with 3 independent parts:
947 // P(y)^y : (c0 + c1 y) + (c2 + c3 y) y^2 + (c4 + c5 y) y^4
948 Value y2 = mul(y, y);
949 Value y4 = mul(y2, y2);
950
951 Value q0 = fmla(cstCephesExpP1, y, cstCephesExpP0);
952 Value q1 = fmla(cstCephesExpP3, y, cstCephesExpP2);
953 Value q2 = fmla(cstCephesExpP5, y, cstCephesExpP4);
954 Value expY = fmla(q1, y2, q0);
955 expY = fmla(q2, y4, expY);
956
957 auto i32Vec = broadcast(builder.getI32Type(), shape);
958
959 // exp2(k)
960 Value k = builder.create<arith::FPToSIOp>(i32Vec, kF32);
961 Value exp2KValue = exp2I32(builder, k);
962
963 // exp(x) = exp(y) * exp2(k)
964 expY = mul(expY, exp2KValue);
965
966 // Handle overflow, inf and underflow of exp(x). exp(x) range is [0, inf], its
967 // partitioned as the following:
968 // exp(x) = 0, x <= -inf
969 // exp(x) = underflow (min_float), x <= -88
970 // exp(x) = inf (min_float), x >= 88
971 // Note: |k| = 127 is the value where the 8-bits exponent saturates.
972 Value zerof32Const = bcast(f32Cst(builder, 0));
973 auto constPosInfinity =
974 bcast(f32Cst(builder, std::numeric_limits<float>::infinity()));
975 auto constNegIfinity =
976 bcast(f32Cst(builder, -std::numeric_limits<float>::infinity()));
977 auto underflow = bcast(f32Cst(builder, std::numeric_limits<float>::min()));
978
979 Value kMaxConst = bcast(i32Cst(builder, 127));
980 Value kMaxNegConst = bcast(i32Cst(builder, -127));
981 Value rightBound =
982 builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, k, kMaxConst);
983 Value leftBound =
984 builder.create<arith::CmpIOp>(arith::CmpIPredicate::sge, k, kMaxNegConst);
985
986 Value isNegInfinityX = builder.create<arith::CmpFOp>(
987 arith::CmpFPredicate::OEQ, x, constNegIfinity);
988 Value isPosInfinityX = builder.create<arith::CmpFOp>(
989 arith::CmpFPredicate::OEQ, x, constPosInfinity);
990 Value isPostiveX =
991 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zerof32Const);
992 Value isComputable = builder.create<arith::AndIOp>(rightBound, leftBound);
993
994 expY = builder.create<arith::SelectOp>(
995 isNan, x,
996 builder.create<arith::SelectOp>(
997 isNegInfinityX, zerof32Const,
998 builder.create<arith::SelectOp>(
999 isPosInfinityX, constPosInfinity,
1000 builder.create<arith::SelectOp>(
1001 isComputable, expY,
1002 builder.create<arith::SelectOp>(isPostiveX, constPosInfinity,
1003 underflow)))));
1004
1005 rewriter.replaceOp(op, expY);
1006
1007 return success();
1008 }
1009
1010 //----------------------------------------------------------------------------//
1011 // ExpM1 approximation.
1012 //----------------------------------------------------------------------------//
1013
1014 namespace {
1015
1016 struct ExpM1Approximation : public OpRewritePattern<math::ExpM1Op> {
1017 public:
1018 using OpRewritePattern::OpRewritePattern;
1019
1020 LogicalResult matchAndRewrite(math::ExpM1Op op,
1021 PatternRewriter &rewriter) const final;
1022 };
1023 } // namespace
1024
1025 LogicalResult
matchAndRewrite(math::ExpM1Op op,PatternRewriter & rewriter) const1026 ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
1027 PatternRewriter &rewriter) const {
1028 if (!getElementTypeOrSelf(op.getOperand()).isF32())
1029 return rewriter.notifyMatchFailure(op, "unsupported operand type");
1030
1031 ArrayRef<int64_t> shape = vectorShape(op.getOperand());
1032
1033 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1034 auto bcast = [&](Value value) -> Value {
1035 return broadcast(builder, value, shape);
1036 };
1037
1038 // expm1(x) = exp(x) - 1 = u - 1.
1039 // We have to handle it carefully when x is near 0, i.e. u ~= 1,
1040 // and when the input is ~= -inf, i.e. u - 1 ~= -1.
1041 Value cstOne = bcast(f32Cst(builder, 1.0f));
1042 Value cstNegOne = bcast(f32Cst(builder, -1.0f));
1043 Value x = op.getOperand();
1044 Value u = builder.create<math::ExpOp>(x);
1045 Value uEqOneOrNaN =
1046 builder.create<arith::CmpFOp>(arith::CmpFPredicate::UEQ, u, cstOne);
1047 Value uMinusOne = builder.create<arith::SubFOp>(u, cstOne);
1048 Value uMinusOneEqNegOne = builder.create<arith::CmpFOp>(
1049 arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne);
1050 // logU = log(u) ~= x
1051 Value logU = builder.create<math::LogOp>(u);
1052
1053 // Detect exp(x) = +inf; written this way to avoid having to form +inf.
1054 Value isInf =
1055 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, logU, u);
1056
1057 // (u - 1) * (x / ~x)
1058 Value expm1 = builder.create<arith::MulFOp>(
1059 uMinusOne, builder.create<arith::DivFOp>(x, logU));
1060 expm1 = builder.create<arith::SelectOp>(isInf, u, expm1);
1061 Value approximation = builder.create<arith::SelectOp>(
1062 uEqOneOrNaN, x,
1063 builder.create<arith::SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1));
1064 rewriter.replaceOp(op, approximation);
1065 return success();
1066 }
1067
1068 //----------------------------------------------------------------------------//
1069 // Sin and Cos approximation.
1070 //----------------------------------------------------------------------------//
1071
1072 namespace {
1073
1074 template <bool isSine, typename OpTy>
1075 struct SinAndCosApproximation : public OpRewritePattern<OpTy> {
1076 public:
1077 using OpRewritePattern<OpTy>::OpRewritePattern;
1078
1079 LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final;
1080 };
1081 } // namespace
1082
1083 #define TWO_OVER_PI \
1084 0.6366197723675813430755350534900574481378385829618257949906693762L
1085 #define PI_OVER_2 \
1086 1.5707963267948966192313216916397514420985846996875529104874722961L
1087
1088 // Approximates sin(x) or cos(x) by finding the best approximation polynomial in
1089 // the reduced range [0, pi/2] for both sin(x) and cos(x). Then given y in the
1090 // reduced range sin(x) will be computed as sin(y), -sin(y), cos(y) or -cos(y).
1091 template <bool isSine, typename OpTy>
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const1092 LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
1093 OpTy op, PatternRewriter &rewriter) const {
1094 static_assert(
1095 llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value,
1096 "SinAndCosApproximation pattern expects math::SinOp or math::CosOp");
1097
1098 if (!getElementTypeOrSelf(op.getOperand()).isF32())
1099 return rewriter.notifyMatchFailure(op, "unsupported operand type");
1100
1101 ArrayRef<int64_t> shape = vectorShape(op.getOperand());
1102
1103 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1104 auto bcast = [&](Value value) -> Value {
1105 return broadcast(builder, value, shape);
1106 };
1107 auto mul = [&](Value a, Value b) -> Value {
1108 return builder.create<arith::MulFOp>(a, b);
1109 };
1110 auto sub = [&](Value a, Value b) -> Value {
1111 return builder.create<arith::SubFOp>(a, b);
1112 };
1113 auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); };
1114
1115 auto i32Vec = broadcast(builder.getI32Type(), shape);
1116 auto fPToSingedInteger = [&](Value a) -> Value {
1117 return builder.create<arith::FPToSIOp>(i32Vec, a);
1118 };
1119
1120 auto modulo4 = [&](Value a) -> Value {
1121 return builder.create<arith::AndIOp>(a, bcast(i32Cst(builder, 3)));
1122 };
1123
1124 auto isEqualTo = [&](Value a, Value b) -> Value {
1125 return builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, a, b);
1126 };
1127
1128 auto isGreaterThan = [&](Value a, Value b) -> Value {
1129 return builder.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, a, b);
1130 };
1131
1132 auto select = [&](Value cond, Value t, Value f) -> Value {
1133 return builder.create<arith::SelectOp>(cond, t, f);
1134 };
1135
1136 auto fmla = [&](Value a, Value b, Value c) {
1137 return builder.create<math::FmaOp>(a, b, c);
1138 };
1139
1140 auto bitwiseOr = [&](Value a, Value b) {
1141 return builder.create<arith::OrIOp>(a, b);
1142 };
1143
1144 Value twoOverPi = bcast(f32Cst(builder, (float)TWO_OVER_PI));
1145 Value piOverTwo = bcast(f32Cst(builder, (float)PI_OVER_2));
1146
1147 Value x = op.getOperand();
1148
1149 Value k = floor(mul(x, twoOverPi));
1150
1151 Value y = sub(x, mul(k, piOverTwo));
1152
1153 Value cstOne = bcast(f32Cst(builder, 1.0));
1154 Value cstNegativeOne = bcast(f32Cst(builder, -1.0));
1155
1156 Value cstSC2 = bcast(f32Cst(builder, -0.16666667163372039794921875f));
1157 Value cstSC4 = bcast(f32Cst(builder, 8.333347737789154052734375e-3f));
1158 Value cstSC6 = bcast(f32Cst(builder, -1.9842604524455964565277099609375e-4f));
1159 Value cstSC8 =
1160 bcast(f32Cst(builder, 2.760012648650445044040679931640625e-6f));
1161 Value cstSC10 =
1162 bcast(f32Cst(builder, -2.50293279435709337121807038784027099609375e-8f));
1163
1164 Value cstCC2 = bcast(f32Cst(builder, -0.5f));
1165 Value cstCC4 = bcast(f32Cst(builder, 4.166664183139801025390625e-2f));
1166 Value cstCC6 = bcast(f32Cst(builder, -1.388833043165504932403564453125e-3f));
1167 Value cstCC8 = bcast(f32Cst(builder, 2.47562347794882953166961669921875e-5f));
1168 Value cstCC10 =
1169 bcast(f32Cst(builder, -2.59630184018533327616751194000244140625e-7f));
1170
1171 Value kMod4 = modulo4(fPToSingedInteger(k));
1172
1173 Value kR0 = isEqualTo(kMod4, bcast(i32Cst(builder, 0)));
1174 Value kR1 = isEqualTo(kMod4, bcast(i32Cst(builder, 1)));
1175 Value kR2 = isEqualTo(kMod4, bcast(i32Cst(builder, 2)));
1176 Value kR3 = isEqualTo(kMod4, bcast(i32Cst(builder, 3)));
1177
1178 Value sinuseCos = isSine ? bitwiseOr(kR1, kR3) : bitwiseOr(kR0, kR2);
1179 Value negativeRange = isSine ? isGreaterThan(kMod4, bcast(i32Cst(builder, 1)))
1180 : bitwiseOr(kR1, kR2);
1181
1182 Value y2 = mul(y, y);
1183
1184 Value base = select(sinuseCos, cstOne, y);
1185 Value cstC2 = select(sinuseCos, cstCC2, cstSC2);
1186 Value cstC4 = select(sinuseCos, cstCC4, cstSC4);
1187 Value cstC6 = select(sinuseCos, cstCC6, cstSC6);
1188 Value cstC8 = select(sinuseCos, cstCC8, cstSC8);
1189 Value cstC10 = select(sinuseCos, cstCC10, cstSC10);
1190
1191 Value v1 = fmla(y2, cstC10, cstC8);
1192 Value v2 = fmla(y2, v1, cstC6);
1193 Value v3 = fmla(y2, v2, cstC4);
1194 Value v4 = fmla(y2, v3, cstC2);
1195 Value v5 = fmla(y2, v4, cstOne);
1196 Value v6 = mul(base, v5);
1197
1198 Value approximation = select(negativeRange, mul(cstNegativeOne, v6), v6);
1199
1200 rewriter.replaceOp(op, approximation);
1201
1202 return success();
1203 }
1204
1205 //----------------------------------------------------------------------------//
1206 // Rsqrt approximation.
1207 //----------------------------------------------------------------------------//
1208
1209 namespace {
1210 struct RsqrtApproximation : public OpRewritePattern<math::RsqrtOp> {
1211 using OpRewritePattern::OpRewritePattern;
1212
1213 LogicalResult matchAndRewrite(math::RsqrtOp op,
1214 PatternRewriter &rewriter) const final;
1215 };
1216 } // namespace
1217
1218 LogicalResult
matchAndRewrite(math::RsqrtOp op,PatternRewriter & rewriter) const1219 RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
1220 PatternRewriter &rewriter) const {
1221 if (!getElementTypeOrSelf(op.getOperand()).isF32())
1222 return rewriter.notifyMatchFailure(op, "unsupported operand type");
1223
1224 ArrayRef<int64_t> shape = vectorShape(op.getOperand());
1225
1226 // Only support already-vectorized rsqrt's.
1227 if (shape.empty() || shape.back() % 8 != 0)
1228 return rewriter.notifyMatchFailure(op, "unsupported operand type");
1229
1230 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1231 auto bcast = [&](Value value) -> Value {
1232 return broadcast(builder, value, shape);
1233 };
1234
1235 Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u));
1236 Value cstOnePointFive = bcast(f32Cst(builder, 1.5f));
1237 Value cstNegHalf = bcast(f32Cst(builder, -0.5f));
1238 Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u));
1239
1240 Value negHalf = builder.create<arith::MulFOp>(op.getOperand(), cstNegHalf);
1241
1242 // Select only the inverse sqrt of positive normals (denormals are
1243 // flushed to zero).
1244 Value ltMinMask = builder.create<arith::CmpFOp>(
1245 arith::CmpFPredicate::OLT, op.getOperand(), cstMinNormPos);
1246 Value infMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
1247 op.getOperand(), cstPosInf);
1248 Value notNormalFiniteMask = builder.create<arith::OrIOp>(ltMinMask, infMask);
1249
1250 // Compute an approximate result.
1251 Value yApprox = handleMultidimensionalVectors(
1252 builder, op->getOperands(), 8, [&builder](ValueRange operands) -> Value {
1253 return builder.create<x86vector::RsqrtOp>(operands);
1254 });
1255
1256 // Do a single step of Newton-Raphson iteration to improve the approximation.
1257 // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n).
1258 // It is essential to evaluate the inner term like this because forming
1259 // y_n^2 may over- or underflow.
1260 Value inner = builder.create<arith::MulFOp>(negHalf, yApprox);
1261 Value fma = builder.create<math::FmaOp>(yApprox, inner, cstOnePointFive);
1262 Value yNewton = builder.create<arith::MulFOp>(yApprox, fma);
1263
1264 // Select the result of the Newton-Raphson step for positive normal arguments.
1265 // For other arguments, choose the output of the intrinsic. This will
1266 // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(x) = +inf if
1267 // x is zero or a positive denormalized float (equivalent to flushing positive
1268 // denormalized inputs to zero).
1269 Value res =
1270 builder.create<arith::SelectOp>(notNormalFiniteMask, yApprox, yNewton);
1271 rewriter.replaceOp(op, res);
1272
1273 return success();
1274 }
1275
1276 //----------------------------------------------------------------------------//
1277
populateMathPolynomialApproximationPatterns(RewritePatternSet & patterns,const MathPolynomialApproximationOptions & options)1278 void mlir::populateMathPolynomialApproximationPatterns(
1279 RewritePatternSet &patterns,
1280 const MathPolynomialApproximationOptions &options) {
1281 patterns.add<AtanApproximation, Atan2Approximation, TanhApproximation,
1282 LogApproximation, Log2Approximation, Log1pApproximation,
1283 ErfPolynomialApproximation, ExpApproximation, ExpM1Approximation,
1284 ReuseF32Expansion<math::Atan2Op>,
1285 SinAndCosApproximation<true, math::SinOp>,
1286 SinAndCosApproximation<false, math::CosOp>>(
1287 patterns.getContext());
1288 if (options.enableAvx2)
1289 patterns.add<RsqrtApproximation>(patterns.getContext());
1290 }
1291