1 //===- PolynomialApproximation.cpp - Approximate math operations ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements expansion of math operations to fast approximations 10 // that do not rely on any of the library functions. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include <climits> 15 #include <cstddef> 16 17 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 18 #include "mlir/Dialect/Math/IR/Math.h" 19 #include "mlir/Dialect/Math/Transforms/Approximation.h" 20 #include "mlir/Dialect/Math/Transforms/Passes.h" 21 #include "mlir/Dialect/Vector/VectorOps.h" 22 #include "mlir/Dialect/Vector/VectorUtils.h" 23 #include "mlir/Dialect/X86Vector/X86VectorDialect.h" 24 #include "mlir/IR/Builders.h" 25 #include "mlir/IR/ImplicitLocOpBuilder.h" 26 #include "mlir/IR/TypeUtilities.h" 27 #include "mlir/Transforms/Bufferize.h" 28 #include "mlir/Transforms/DialectConversion.h" 29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 30 #include "llvm/ADT/ArrayRef.h" 31 32 using namespace mlir; 33 using namespace mlir::math; 34 using namespace mlir::vector; 35 36 // Returns vector shape if the type is a vector. Returns an empty shape if it is 37 // not a vector. 38 static ArrayRef<int64_t> vectorShape(Type type) { 39 auto vectorType = type.dyn_cast<VectorType>(); 40 return vectorType ? vectorType.getShape() : ArrayRef<int64_t>(); 41 } 42 43 static ArrayRef<int64_t> vectorShape(Value value) { 44 return vectorShape(value.getType()); 45 } 46 47 //----------------------------------------------------------------------------// 48 // Broadcast scalar types and values into vector types and values. 49 //----------------------------------------------------------------------------// 50 51 // Broadcasts scalar type into vector type (iff shape is non-scalar). 52 static Type broadcast(Type type, ArrayRef<int64_t> shape) { 53 assert(!type.isa<VectorType>() && "must be scalar type"); 54 return !shape.empty() ? VectorType::get(shape, type) : type; 55 } 56 57 // Broadcasts scalar value into vector (iff shape is non-scalar). 58 static Value broadcast(ImplicitLocOpBuilder &builder, Value value, 59 ArrayRef<int64_t> shape) { 60 assert(!value.getType().isa<VectorType>() && "must be scalar value"); 61 auto type = broadcast(value.getType(), shape); 62 return !shape.empty() ? builder.create<BroadcastOp>(type, value) : value; 63 } 64 65 //----------------------------------------------------------------------------// 66 // Helper function to handle n-D vectors with 1-D operations. 67 //----------------------------------------------------------------------------// 68 69 // Expands and unrolls n-D vector operands into multiple fixed size 1-D vectors 70 // and calls the compute function with 1-D vector operands. Stitches back all 71 // results into the original n-D vector result. 72 // 73 // Examples: vectorWidth = 8 74 // - vector<4x8xf32> unrolled 4 times 75 // - vector<16xf32> expanded to vector<2x8xf32> and unrolled 2 times 76 // - vector<4x16xf32> expanded to vector<4x2x8xf32> and unrolled 4*2 times 77 // 78 // Some math approximations rely on ISA-specific operations that only accept 79 // fixed size 1-D vectors (e.g. AVX expects vectors of width 8). 80 // 81 // It is the caller's responsibility to verify that the inner dimension is 82 // divisible by the vectorWidth, and that all operands have the same vector 83 // shape. 84 static Value 85 handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, 86 ValueRange operands, int64_t vectorWidth, 87 std::function<Value(ValueRange)> compute) { 88 assert(!operands.empty() && "operands must be not empty"); 89 assert(vectorWidth > 0 && "vector width must be larger than 0"); 90 91 VectorType inputType = operands[0].getType().cast<VectorType>(); 92 ArrayRef<int64_t> inputShape = inputType.getShape(); 93 94 // If input shape matches target vector width, we can just call the 95 // user-provided compute function with the operands. 96 if (inputShape == llvm::makeArrayRef(vectorWidth)) 97 return compute(operands); 98 99 // Check if the inner dimension has to be expanded, or we can directly iterate 100 // over the outer dimensions of the vector. 101 int64_t innerDim = inputShape.back(); 102 int64_t expansionDim = innerDim / vectorWidth; 103 assert((innerDim % vectorWidth == 0) && "invalid inner dimension size"); 104 105 // Maybe expand operands to the higher rank vector shape that we'll use to 106 // iterate over and extract one dimensional vectors. 107 SmallVector<int64_t> expandedShape(inputShape.begin(), inputShape.end()); 108 SmallVector<Value> expandedOperands(operands); 109 110 if (expansionDim > 1) { 111 // Expand shape from [..., innerDim] to [..., expansionDim, vectorWidth]. 112 expandedShape.insert(expandedShape.end() - 1, expansionDim); 113 expandedShape.back() = vectorWidth; 114 115 for (unsigned i = 0; i < operands.size(); ++i) { 116 auto operand = operands[i]; 117 auto eltType = operand.getType().cast<VectorType>().getElementType(); 118 auto expandedType = VectorType::get(expandedShape, eltType); 119 expandedOperands[i] = 120 builder.create<vector::ShapeCastOp>(expandedType, operand); 121 } 122 } 123 124 // Iterate over all outer dimensions of the compute shape vector type. 125 auto iterationDims = ArrayRef<int64_t>(expandedShape).drop_back(); 126 int64_t maxLinearIndex = computeMaxLinearIndex(iterationDims); 127 128 SmallVector<int64_t> ones(iterationDims.size(), 1); 129 auto strides = computeStrides(iterationDims, ones); 130 131 // Compute results for each one dimensional vector. 132 SmallVector<Value> results(maxLinearIndex); 133 134 for (int64_t i = 0; i < maxLinearIndex; ++i) { 135 auto offsets = delinearize(strides, i); 136 137 SmallVector<Value> extracted(expandedOperands.size()); 138 for (auto tuple : llvm::enumerate(expandedOperands)) 139 extracted[tuple.index()] = 140 builder.create<vector::ExtractOp>(tuple.value(), offsets); 141 142 results[i] = compute(extracted); 143 } 144 145 // Stitch results together into one large vector. 146 Type resultEltType = results[0].getType().cast<VectorType>().getElementType(); 147 Type resultExpandedType = VectorType::get(expandedShape, resultEltType); 148 Value result = builder.create<ConstantOp>( 149 resultExpandedType, builder.getZeroAttr(resultExpandedType)); 150 151 for (int64_t i = 0; i < maxLinearIndex; ++i) 152 result = builder.create<vector::InsertOp>(results[i], result, 153 delinearize(strides, i)); 154 155 // Reshape back to the original vector shape. 156 return builder.create<vector::ShapeCastOp>( 157 VectorType::get(inputShape, resultEltType), result); 158 } 159 160 //----------------------------------------------------------------------------// 161 // Helper functions to create constants. 162 //----------------------------------------------------------------------------// 163 164 static Value f32Cst(ImplicitLocOpBuilder &builder, float value) { 165 return builder.create<arith::ConstantOp>(builder.getF32FloatAttr(value)); 166 } 167 168 static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value) { 169 return builder.create<arith::ConstantOp>(builder.getI32IntegerAttr(value)); 170 } 171 172 static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits) { 173 Value i32Value = i32Cst(builder, static_cast<int32_t>(bits)); 174 return builder.create<arith::BitcastOp>(builder.getF32Type(), i32Value); 175 } 176 177 //----------------------------------------------------------------------------// 178 // Helper functions to build math functions approximations. 179 //----------------------------------------------------------------------------// 180 181 static Value min(ImplicitLocOpBuilder &builder, Value a, Value b) { 182 return builder.create<SelectOp>( 183 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, a, b), a, b); 184 } 185 186 static Value max(ImplicitLocOpBuilder &builder, Value a, Value b) { 187 return builder.create<SelectOp>( 188 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, a, b), a, b); 189 } 190 191 static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, 192 Value upperBound) { 193 return max(builder, min(builder, value, upperBound), lowerBound); 194 } 195 196 // Decomposes given floating point value `arg` into a normalized fraction and 197 // an integral power of two (see std::frexp). Returned values have float type. 198 static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg, 199 bool is_positive = false) { 200 assert(getElementTypeOrSelf(arg).isF32() && "arg must be f32 type"); 201 ArrayRef<int64_t> shape = vectorShape(arg); 202 203 auto bcast = [&](Value value) -> Value { 204 return broadcast(builder, value, shape); 205 }; 206 207 auto i32 = builder.getIntegerType(32); 208 auto i32Vec = broadcast(i32, shape); 209 auto f32Vec = broadcast(builder.getF32Type(), shape); 210 211 Value cst126f = f32Cst(builder, 126.0f); 212 Value cstHalf = f32Cst(builder, 0.5f); 213 Value cstInvMantMask = f32FromBits(builder, ~0x7f800000u); 214 215 // Bitcast to i32 for bitwise operations. 216 Value i32Half = builder.create<arith::BitcastOp>(i32, cstHalf); 217 Value i32InvMantMask = builder.create<arith::BitcastOp>(i32, cstInvMantMask); 218 Value i32Arg = builder.create<arith::BitcastOp>(i32Vec, arg); 219 220 // Compute normalized fraction. 221 Value tmp0 = builder.create<arith::AndIOp>(i32Arg, bcast(i32InvMantMask)); 222 Value tmp1 = builder.create<arith::OrIOp>(tmp0, bcast(i32Half)); 223 Value normalizedFraction = builder.create<arith::BitcastOp>(f32Vec, tmp1); 224 225 // Compute exponent. 226 Value arg0 = is_positive ? arg : builder.create<math::AbsOp>(arg); 227 Value biasedExponentBits = builder.create<arith::ShRUIOp>( 228 builder.create<arith::BitcastOp>(i32Vec, arg0), 229 bcast(i32Cst(builder, 23))); 230 Value biasedExponent = 231 builder.create<arith::SIToFPOp>(f32Vec, biasedExponentBits); 232 Value exponent = 233 builder.create<arith::SubFOp>(biasedExponent, bcast(cst126f)); 234 235 return {normalizedFraction, exponent}; 236 } 237 238 // Computes exp2 for an i32 argument. 239 static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) { 240 assert(getElementTypeOrSelf(arg).isInteger(32) && "arg must be i32 type"); 241 ArrayRef<int64_t> shape = vectorShape(arg); 242 243 auto bcast = [&](Value value) -> Value { 244 return broadcast(builder, value, shape); 245 }; 246 247 auto f32Vec = broadcast(builder.getF32Type(), shape); 248 // The exponent of f32 located at 23-bit. 249 auto exponetBitLocation = bcast(i32Cst(builder, 23)); 250 // Set the exponent bias to zero. 251 auto bias = bcast(i32Cst(builder, 127)); 252 253 Value biasedArg = builder.create<arith::AddIOp>(arg, bias); 254 Value exp2ValueInt = 255 builder.create<arith::ShLIOp>(biasedArg, exponetBitLocation); 256 Value exp2ValueF32 = builder.create<arith::BitcastOp>(f32Vec, exp2ValueInt); 257 258 return exp2ValueF32; 259 } 260 261 namespace { 262 Value makePolynomialCalculation(ImplicitLocOpBuilder &builder, 263 llvm::ArrayRef<Value> coeffs, Value x) { 264 assert(getElementTypeOrSelf(x).isF32() && "x must be f32 type"); 265 ArrayRef<int64_t> shape = vectorShape(x); 266 267 if (coeffs.empty()) 268 return broadcast(builder, f32Cst(builder, 0.0f), shape); 269 270 if (coeffs.size() == 1) 271 return coeffs[0]; 272 273 Value res = builder.create<math::FmaOp>(x, coeffs[coeffs.size() - 1], 274 coeffs[coeffs.size() - 2]); 275 for (auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) { 276 res = builder.create<math::FmaOp>(x, res, coeffs[i]); 277 } 278 return res; 279 } 280 } // namespace 281 282 //----------------------------------------------------------------------------// 283 // TanhOp approximation. 284 //----------------------------------------------------------------------------// 285 286 namespace { 287 struct TanhApproximation : public OpRewritePattern<math::TanhOp> { 288 public: 289 using OpRewritePattern::OpRewritePattern; 290 291 LogicalResult matchAndRewrite(math::TanhOp op, 292 PatternRewriter &rewriter) const final; 293 }; 294 } // namespace 295 296 LogicalResult 297 TanhApproximation::matchAndRewrite(math::TanhOp op, 298 PatternRewriter &rewriter) const { 299 if (!getElementTypeOrSelf(op.operand()).isF32()) 300 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 301 302 ArrayRef<int64_t> shape = vectorShape(op.operand()); 303 304 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 305 auto bcast = [&](Value value) -> Value { 306 return broadcast(builder, value, shape); 307 }; 308 309 // Clamp operand into [plusClamp, minusClamp] range. 310 Value minusClamp = bcast(f32Cst(builder, -7.99881172180175781f)); 311 Value plusClamp = bcast(f32Cst(builder, 7.99881172180175781f)); 312 Value x = clamp(builder, op.operand(), minusClamp, plusClamp); 313 314 // Mask for tiny values that are approximated with `operand`. 315 Value tiny = bcast(f32Cst(builder, 0.0004f)); 316 Value tinyMask = builder.create<arith::CmpFOp>( 317 arith::CmpFPredicate::OLT, builder.create<math::AbsOp>(op.operand()), 318 tiny); 319 320 // The monomial coefficients of the numerator polynomial (odd). 321 Value alpha1 = bcast(f32Cst(builder, 4.89352455891786e-03f)); 322 Value alpha3 = bcast(f32Cst(builder, 6.37261928875436e-04f)); 323 Value alpha5 = bcast(f32Cst(builder, 1.48572235717979e-05f)); 324 Value alpha7 = bcast(f32Cst(builder, 5.12229709037114e-08f)); 325 Value alpha9 = bcast(f32Cst(builder, -8.60467152213735e-11f)); 326 Value alpha11 = bcast(f32Cst(builder, 2.00018790482477e-13f)); 327 Value alpha13 = bcast(f32Cst(builder, -2.76076847742355e-16f)); 328 329 // The monomial coefficients of the denominator polynomial (even). 330 Value beta0 = bcast(f32Cst(builder, 4.89352518554385e-03f)); 331 Value beta2 = bcast(f32Cst(builder, 2.26843463243900e-03f)); 332 Value beta4 = bcast(f32Cst(builder, 1.18534705686654e-04f)); 333 Value beta6 = bcast(f32Cst(builder, 1.19825839466702e-06f)); 334 335 // Since the polynomials are odd/even, we need x^2. 336 Value x2 = builder.create<arith::MulFOp>(x, x); 337 338 // Evaluate the numerator polynomial p. 339 Value p = builder.create<math::FmaOp>(x2, alpha13, alpha11); 340 p = builder.create<math::FmaOp>(x2, p, alpha9); 341 p = builder.create<math::FmaOp>(x2, p, alpha7); 342 p = builder.create<math::FmaOp>(x2, p, alpha5); 343 p = builder.create<math::FmaOp>(x2, p, alpha3); 344 p = builder.create<math::FmaOp>(x2, p, alpha1); 345 p = builder.create<arith::MulFOp>(x, p); 346 347 // Evaluate the denominator polynomial q. 348 Value q = builder.create<math::FmaOp>(x2, beta6, beta4); 349 q = builder.create<math::FmaOp>(x2, q, beta2); 350 q = builder.create<math::FmaOp>(x2, q, beta0); 351 352 // Divide the numerator by the denominator. 353 Value res = builder.create<SelectOp>(tinyMask, x, 354 builder.create<arith::DivFOp>(p, q)); 355 356 rewriter.replaceOp(op, res); 357 358 return success(); 359 } 360 361 #define LN2_VALUE \ 362 0.693147180559945309417232121458176568075500134360255254120680009493393621L 363 #define LOG2E_VALUE \ 364 1.442695040888963407359924681001892137426645954152985934135449406931109219L 365 366 //----------------------------------------------------------------------------// 367 // LogOp and Log2Op approximation. 368 //----------------------------------------------------------------------------// 369 370 namespace { 371 template <typename Op> 372 struct LogApproximationBase : public OpRewritePattern<Op> { 373 using OpRewritePattern<Op>::OpRewritePattern; 374 375 /// Base 2 if 'base2' is set; natural logarithm (base e) otherwise. 376 LogicalResult logMatchAndRewrite(Op op, PatternRewriter &rewriter, 377 bool base2) const; 378 }; 379 } // namespace 380 381 // This approximation comes from Julien Pommier's SSE math library. 382 // Link: http://gruntthepeon.free.fr/ssemath 383 template <typename Op> 384 LogicalResult 385 LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter, 386 bool base2) const { 387 if (!getElementTypeOrSelf(op.operand()).isF32()) 388 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 389 390 ArrayRef<int64_t> shape = vectorShape(op.operand()); 391 392 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 393 auto bcast = [&](Value value) -> Value { 394 return broadcast(builder, value, shape); 395 }; 396 397 Value cstZero = bcast(f32Cst(builder, 0.0f)); 398 Value cstOne = bcast(f32Cst(builder, 1.0f)); 399 Value cstNegHalf = bcast(f32Cst(builder, -0.5f)); 400 401 // The smallest non denormalized float number. 402 Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u)); 403 Value cstMinusInf = bcast(f32FromBits(builder, 0xff800000u)); 404 Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u)); 405 Value cstNan = bcast(f32FromBits(builder, 0x7fc00000)); 406 407 // Polynomial coefficients. 408 Value cstCephesSQRTHF = bcast(f32Cst(builder, 0.707106781186547524f)); 409 Value cstCephesLogP0 = bcast(f32Cst(builder, 7.0376836292E-2f)); 410 Value cstCephesLogP1 = bcast(f32Cst(builder, -1.1514610310E-1f)); 411 Value cstCephesLogP2 = bcast(f32Cst(builder, 1.1676998740E-1f)); 412 Value cstCephesLogP3 = bcast(f32Cst(builder, -1.2420140846E-1f)); 413 Value cstCephesLogP4 = bcast(f32Cst(builder, +1.4249322787E-1f)); 414 Value cstCephesLogP5 = bcast(f32Cst(builder, -1.6668057665E-1f)); 415 Value cstCephesLogP6 = bcast(f32Cst(builder, +2.0000714765E-1f)); 416 Value cstCephesLogP7 = bcast(f32Cst(builder, -2.4999993993E-1f)); 417 Value cstCephesLogP8 = bcast(f32Cst(builder, +3.3333331174E-1f)); 418 419 Value x = op.operand(); 420 421 // Truncate input values to the minimum positive normal. 422 x = max(builder, x, cstMinNormPos); 423 424 // Extract significant in the range [0.5,1) and exponent. 425 std::pair<Value, Value> pair = frexp(builder, x, /*is_positive=*/true); 426 x = pair.first; 427 Value e = pair.second; 428 429 // Shift the inputs from the range [0.5,1) to [sqrt(1/2), sqrt(2)) and shift 430 // by -1.0. The values are then centered around 0, which improves the 431 // stability of the polynomial evaluation: 432 // 433 // if( x < SQRTHF ) { 434 // e -= 1; 435 // x = x + x - 1.0; 436 // } else { x = x - 1.0; } 437 Value mask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, 438 cstCephesSQRTHF); 439 Value tmp = builder.create<SelectOp>(mask, x, cstZero); 440 441 x = builder.create<arith::SubFOp>(x, cstOne); 442 e = builder.create<arith::SubFOp>( 443 e, builder.create<SelectOp>(mask, cstOne, cstZero)); 444 x = builder.create<arith::AddFOp>(x, tmp); 445 446 Value x2 = builder.create<arith::MulFOp>(x, x); 447 Value x3 = builder.create<arith::MulFOp>(x2, x); 448 449 // Evaluate the polynomial approximant of degree 8 in three parts. 450 Value y0, y1, y2; 451 y0 = builder.create<math::FmaOp>(cstCephesLogP0, x, cstCephesLogP1); 452 y1 = builder.create<math::FmaOp>(cstCephesLogP3, x, cstCephesLogP4); 453 y2 = builder.create<math::FmaOp>(cstCephesLogP6, x, cstCephesLogP7); 454 y0 = builder.create<math::FmaOp>(y0, x, cstCephesLogP2); 455 y1 = builder.create<math::FmaOp>(y1, x, cstCephesLogP5); 456 y2 = builder.create<math::FmaOp>(y2, x, cstCephesLogP8); 457 y0 = builder.create<math::FmaOp>(y0, x3, y1); 458 y0 = builder.create<math::FmaOp>(y0, x3, y2); 459 y0 = builder.create<arith::MulFOp>(y0, x3); 460 461 y0 = builder.create<math::FmaOp>(cstNegHalf, x2, y0); 462 x = builder.create<arith::AddFOp>(x, y0); 463 464 if (base2) { 465 Value cstLog2e = bcast(f32Cst(builder, static_cast<float>(LOG2E_VALUE))); 466 x = builder.create<math::FmaOp>(x, cstLog2e, e); 467 } else { 468 Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE))); 469 x = builder.create<math::FmaOp>(e, cstLn2, x); 470 } 471 472 Value invalidMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, 473 op.operand(), cstZero); 474 Value zeroMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, 475 op.operand(), cstZero); 476 Value posInfMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, 477 op.operand(), cstPosInf); 478 479 // Filter out invalid values: 480 // • x == 0 -> -INF 481 // • x < 0 -> NAN 482 // • x == +INF -> +INF 483 Value aproximation = builder.create<SelectOp>( 484 zeroMask, cstMinusInf, 485 builder.create<SelectOp>( 486 invalidMask, cstNan, 487 builder.create<SelectOp>(posInfMask, cstPosInf, x))); 488 489 rewriter.replaceOp(op, aproximation); 490 491 return success(); 492 } 493 494 namespace { 495 struct LogApproximation : public LogApproximationBase<math::LogOp> { 496 using LogApproximationBase::LogApproximationBase; 497 498 LogicalResult matchAndRewrite(math::LogOp op, 499 PatternRewriter &rewriter) const final { 500 return logMatchAndRewrite(op, rewriter, /*base2=*/false); 501 } 502 }; 503 } // namespace 504 505 namespace { 506 struct Log2Approximation : public LogApproximationBase<math::Log2Op> { 507 using LogApproximationBase::LogApproximationBase; 508 509 LogicalResult matchAndRewrite(math::Log2Op op, 510 PatternRewriter &rewriter) const final { 511 return logMatchAndRewrite(op, rewriter, /*base2=*/true); 512 } 513 }; 514 } // namespace 515 516 //----------------------------------------------------------------------------// 517 // Log1p approximation. 518 //----------------------------------------------------------------------------// 519 520 namespace { 521 struct Log1pApproximation : public OpRewritePattern<math::Log1pOp> { 522 public: 523 using OpRewritePattern::OpRewritePattern; 524 525 LogicalResult matchAndRewrite(math::Log1pOp op, 526 PatternRewriter &rewriter) const final; 527 }; 528 } // namespace 529 530 // Approximate log(1+x). 531 LogicalResult 532 Log1pApproximation::matchAndRewrite(math::Log1pOp op, 533 PatternRewriter &rewriter) const { 534 if (!getElementTypeOrSelf(op.operand()).isF32()) 535 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 536 537 ArrayRef<int64_t> shape = vectorShape(op.operand()); 538 539 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 540 auto bcast = [&](Value value) -> Value { 541 return broadcast(builder, value, shape); 542 }; 543 544 // Approximate log(1+x) using the following, due to W. Kahan: 545 // u = x + 1.0; 546 // if (u == 1.0 || u == inf) return x; 547 // return x * log(u) / (u - 1.0); 548 // ^^^^^^^^^^^^^^^^^^^^^^ 549 // "logLarge" below. 550 Value cstOne = bcast(f32Cst(builder, 1.0f)); 551 Value x = op.operand(); 552 Value u = builder.create<arith::AddFOp>(x, cstOne); 553 Value uSmall = 554 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne); 555 Value logU = builder.create<math::LogOp>(u); 556 Value uInf = 557 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, logU); 558 Value logLarge = builder.create<arith::MulFOp>( 559 x, builder.create<arith::DivFOp>( 560 logU, builder.create<arith::SubFOp>(u, cstOne))); 561 Value approximation = builder.create<SelectOp>( 562 builder.create<arith::OrIOp>(uSmall, uInf), x, logLarge); 563 rewriter.replaceOp(op, approximation); 564 return success(); 565 } 566 567 //----------------------------------------------------------------------------// 568 // Erf approximation. 569 //----------------------------------------------------------------------------// 570 571 // Approximates erf(x) with 572 // a - P(x)/Q(x) 573 // where P and Q are polynomials of degree 4. 574 // Different coefficients are chosen based on the value of x. 575 // The approximation error is ~2.5e-07. 576 // Boost's minimax tool that utilizes the Remez method was used to find the 577 // coefficients. 578 LogicalResult 579 ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op, 580 PatternRewriter &rewriter) const { 581 if (!getElementTypeOrSelf(op.operand()).isF32()) 582 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 583 584 ArrayRef<int64_t> shape = vectorShape(op.operand()); 585 586 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 587 auto bcast = [&](Value value) -> Value { 588 return broadcast(builder, value, shape); 589 }; 590 591 const int intervalsCount = 3; 592 const int polyDegree = 4; 593 594 Value zero = bcast(f32Cst(builder, 0)); 595 Value one = bcast(f32Cst(builder, 1)); 596 Value pp[intervalsCount][polyDegree + 1]; 597 pp[0][0] = bcast(f32Cst(builder, +0.00000000000000000e+00f)); 598 pp[0][1] = bcast(f32Cst(builder, +1.12837916222975858e+00f)); 599 pp[0][2] = bcast(f32Cst(builder, -5.23018562988006470e-01f)); 600 pp[0][3] = bcast(f32Cst(builder, +2.09741709609267072e-01f)); 601 pp[0][4] = bcast(f32Cst(builder, +2.58146801602987875e-02f)); 602 pp[1][0] = bcast(f32Cst(builder, +0.00000000000000000e+00f)); 603 pp[1][1] = bcast(f32Cst(builder, +1.12750687816789140e+00f)); 604 pp[1][2] = bcast(f32Cst(builder, -3.64721408487825775e-01f)); 605 pp[1][3] = bcast(f32Cst(builder, +1.18407396425136952e-01f)); 606 pp[1][4] = bcast(f32Cst(builder, +3.70645533056476558e-02f)); 607 pp[2][0] = bcast(f32Cst(builder, -3.30093071049483172e-03f)); 608 pp[2][1] = bcast(f32Cst(builder, +3.51961938357697011e-03f)); 609 pp[2][2] = bcast(f32Cst(builder, -1.41373622814988039e-03f)); 610 pp[2][3] = bcast(f32Cst(builder, +2.53447094961941348e-04f)); 611 pp[2][4] = bcast(f32Cst(builder, -1.71048029455037401e-05f)); 612 613 Value qq[intervalsCount][polyDegree + 1]; 614 qq[0][0] = bcast(f32Cst(builder, +1.000000000000000000e+00f)); 615 qq[0][1] = bcast(f32Cst(builder, -4.635138185962547255e-01f)); 616 qq[0][2] = bcast(f32Cst(builder, +5.192301327279782447e-01f)); 617 qq[0][3] = bcast(f32Cst(builder, -1.318089722204810087e-01f)); 618 qq[0][4] = bcast(f32Cst(builder, +7.397964654672315005e-02f)); 619 qq[1][0] = bcast(f32Cst(builder, +1.00000000000000000e+00f)); 620 qq[1][1] = bcast(f32Cst(builder, -3.27607011824493086e-01f)); 621 qq[1][2] = bcast(f32Cst(builder, +4.48369090658821977e-01f)); 622 qq[1][3] = bcast(f32Cst(builder, -8.83462621207857930e-02f)); 623 qq[1][4] = bcast(f32Cst(builder, +5.72442770283176093e-02f)); 624 qq[2][0] = bcast(f32Cst(builder, +1.00000000000000000e+00f)); 625 qq[2][1] = bcast(f32Cst(builder, -2.06069165953913769e+00f)); 626 qq[2][2] = bcast(f32Cst(builder, +1.62705939945477759e+00f)); 627 qq[2][3] = bcast(f32Cst(builder, -5.83389859211130017e-01f)); 628 qq[2][4] = bcast(f32Cst(builder, +8.21908939856640930e-02f)); 629 630 Value offsets[intervalsCount]; 631 offsets[0] = bcast(f32Cst(builder, 0.0f)); 632 offsets[1] = bcast(f32Cst(builder, 0.0f)); 633 offsets[2] = bcast(f32Cst(builder, 1.0f)); 634 635 Value bounds[intervalsCount]; 636 bounds[0] = bcast(f32Cst(builder, 0.8f)); 637 bounds[1] = bcast(f32Cst(builder, 2.0f)); 638 bounds[2] = bcast(f32Cst(builder, 3.75f)); 639 640 Value isNegativeArg = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, 641 op.operand(), zero); 642 Value negArg = builder.create<arith::NegFOp>(op.operand()); 643 Value x = builder.create<SelectOp>(isNegativeArg, negArg, op.operand()); 644 645 Value offset = offsets[0]; 646 Value p[polyDegree + 1]; 647 Value q[polyDegree + 1]; 648 for (int i = 0; i <= polyDegree; ++i) { 649 p[i] = pp[0][i]; 650 q[i] = qq[0][i]; 651 } 652 653 // TODO: maybe use vector stacking to reduce the number of selects. 654 Value isLessThanBound[intervalsCount]; 655 for (int j = 0; j < intervalsCount - 1; ++j) { 656 isLessThanBound[j] = 657 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, bounds[j]); 658 for (int i = 0; i <= polyDegree; ++i) { 659 p[i] = builder.create<SelectOp>(isLessThanBound[j], p[i], pp[j + 1][i]); 660 q[i] = builder.create<SelectOp>(isLessThanBound[j], q[i], qq[j + 1][i]); 661 } 662 offset = 663 builder.create<SelectOp>(isLessThanBound[j], offset, offsets[j + 1]); 664 } 665 isLessThanBound[intervalsCount - 1] = builder.create<arith::CmpFOp>( 666 arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]); 667 668 Value pPoly = makePolynomialCalculation(builder, p, x); 669 Value qPoly = makePolynomialCalculation(builder, q, x); 670 Value rationalPoly = builder.create<arith::DivFOp>(pPoly, qPoly); 671 Value formula = builder.create<arith::AddFOp>(offset, rationalPoly); 672 formula = builder.create<SelectOp>(isLessThanBound[intervalsCount - 1], 673 formula, one); 674 675 // erf is odd function: erf(x) = -erf(-x). 676 Value negFormula = builder.create<arith::NegFOp>(formula); 677 Value res = builder.create<SelectOp>(isNegativeArg, negFormula, formula); 678 679 rewriter.replaceOp(op, res); 680 681 return success(); 682 } 683 684 //----------------------------------------------------------------------------// 685 // Exp approximation. 686 //----------------------------------------------------------------------------// 687 688 namespace { 689 690 struct ExpApproximation : public OpRewritePattern<math::ExpOp> { 691 public: 692 using OpRewritePattern::OpRewritePattern; 693 694 LogicalResult matchAndRewrite(math::ExpOp op, 695 PatternRewriter &rewriter) const final; 696 }; 697 } // namespace 698 699 // Approximate exp(x) using its reduced range exp(y) where y is in the range 700 // [0, ln(2)], let y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2), exp(x) 701 // = exp(y) * 2^k. exp(y). 702 LogicalResult 703 ExpApproximation::matchAndRewrite(math::ExpOp op, 704 PatternRewriter &rewriter) const { 705 if (!getElementTypeOrSelf(op.operand()).isF32()) 706 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 707 708 ArrayRef<int64_t> shape = vectorShape(op.operand()); 709 710 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 711 712 // TODO: Consider a common pattern rewriter with all methods below to 713 // write the approximations. 714 auto bcast = [&](Value value) -> Value { 715 return broadcast(builder, value, shape); 716 }; 717 auto fmla = [&](Value a, Value b, Value c) { 718 return builder.create<math::FmaOp>(a, b, c); 719 }; 720 auto mul = [&](Value a, Value b) -> Value { 721 return builder.create<arith::MulFOp>(a, b); 722 }; 723 auto sub = [&](Value a, Value b) -> Value { 724 return builder.create<arith::SubFOp>(a, b); 725 }; 726 auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); }; 727 728 Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE))); 729 Value cstLog2E = bcast(f32Cst(builder, static_cast<float>(LOG2E_VALUE))); 730 731 // Polynomial coefficients. 732 Value cstCephesExpP0 = bcast(f32Cst(builder, 1.0)); 733 Value cstCephesExpP1 = bcast(f32Cst(builder, 1.0)); 734 Value cstCephesExpP2 = bcast(f32Cst(builder, 0.49970514590562437052f)); 735 Value cstCephesExpP3 = bcast(f32Cst(builder, 0.16873890085469545053f)); 736 Value cstCephesExpP4 = bcast(f32Cst(builder, 0.03668965196652099192f)); 737 Value cstCephesExpP5 = bcast(f32Cst(builder, 0.01314350012789660196f)); 738 739 Value x = op.operand(); 740 741 // Reduced y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2) 742 Value xL2Inv = mul(x, cstLog2E); 743 Value kF32 = floor(xL2Inv); 744 Value kLn2 = mul(kF32, cstLn2); 745 Value y = sub(x, kLn2); 746 747 // Use Estrin's evaluation scheme with 3 independent parts: 748 // P(y)^y : (c0 + c1 y) + (c2 + c3 y) y^2 + (c4 + c5 y) y^4 749 Value y2 = mul(y, y); 750 Value y4 = mul(y2, y2); 751 752 Value q0 = fmla(cstCephesExpP1, y, cstCephesExpP0); 753 Value q1 = fmla(cstCephesExpP3, y, cstCephesExpP2); 754 Value q2 = fmla(cstCephesExpP5, y, cstCephesExpP4); 755 Value expY = fmla(q1, y2, q0); 756 expY = fmla(q2, y4, expY); 757 758 auto i32Vec = broadcast(builder.getI32Type(), shape); 759 760 // exp2(k) 761 Value k = builder.create<arith::FPToSIOp>(kF32, i32Vec); 762 Value exp2KValue = exp2I32(builder, k); 763 764 // exp(x) = exp(y) * exp2(k) 765 expY = mul(expY, exp2KValue); 766 767 // Handle overflow, inf and underflow of exp(x). exp(x) range is [0, inf], its 768 // partitioned as the following: 769 // exp(x) = 0, x <= -inf 770 // exp(x) = underflow (min_float), x <= -88 771 // exp(x) = inf (min_float), x >= 88 772 // Note: |k| = 127 is the value where the 8-bits exponent saturates. 773 Value zerof32Const = bcast(f32Cst(builder, 0)); 774 auto constPosInfinity = 775 bcast(f32Cst(builder, std::numeric_limits<float>::infinity())); 776 auto constNegIfinity = 777 bcast(f32Cst(builder, -std::numeric_limits<float>::infinity())); 778 auto underflow = bcast(f32Cst(builder, std::numeric_limits<float>::min())); 779 780 Value kMaxConst = bcast(i32Cst(builder, 127)); 781 Value kMaxNegConst = bcast(i32Cst(builder, -127)); 782 Value rightBound = 783 builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, k, kMaxConst); 784 Value leftBound = 785 builder.create<arith::CmpIOp>(arith::CmpIPredicate::sge, k, kMaxNegConst); 786 787 Value isNegInfinityX = builder.create<arith::CmpFOp>( 788 arith::CmpFPredicate::OEQ, x, constNegIfinity); 789 Value isPosInfinityX = builder.create<arith::CmpFOp>( 790 arith::CmpFPredicate::OEQ, x, constPosInfinity); 791 Value isPostiveX = 792 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zerof32Const); 793 Value isComputable = builder.create<arith::AndIOp>(rightBound, leftBound); 794 795 expY = builder.create<SelectOp>( 796 isNegInfinityX, zerof32Const, 797 builder.create<SelectOp>( 798 isPosInfinityX, constPosInfinity, 799 builder.create<SelectOp>(isComputable, expY, 800 builder.create<SelectOp>(isPostiveX, 801 constPosInfinity, 802 underflow)))); 803 804 rewriter.replaceOp(op, expY); 805 806 return success(); 807 } 808 809 //----------------------------------------------------------------------------// 810 // ExpM1 approximation. 811 //----------------------------------------------------------------------------// 812 813 namespace { 814 815 struct ExpM1Approximation : public OpRewritePattern<math::ExpM1Op> { 816 public: 817 using OpRewritePattern::OpRewritePattern; 818 819 LogicalResult matchAndRewrite(math::ExpM1Op op, 820 PatternRewriter &rewriter) const final; 821 }; 822 } // namespace 823 824 LogicalResult 825 ExpM1Approximation::matchAndRewrite(math::ExpM1Op op, 826 PatternRewriter &rewriter) const { 827 if (!getElementTypeOrSelf(op.operand()).isF32()) 828 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 829 830 ArrayRef<int64_t> shape = vectorShape(op.operand()); 831 832 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 833 auto bcast = [&](Value value) -> Value { 834 return broadcast(builder, value, shape); 835 }; 836 837 // expm1(x) = exp(x) - 1 = u - 1. 838 // We have to handle it carefully when x is near 0, i.e. u ~= 1, 839 // and when the input is ~= -inf, i.e. u - 1 ~= -1. 840 Value cstOne = bcast(f32Cst(builder, 1.0f)); 841 Value cstNegOne = bcast(f32Cst(builder, -1.0f)); 842 Value x = op.operand(); 843 Value u = builder.create<math::ExpOp>(x); 844 Value uEqOne = 845 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne); 846 Value uMinusOne = builder.create<arith::SubFOp>(u, cstOne); 847 Value uMinusOneEqNegOne = builder.create<arith::CmpFOp>( 848 arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne); 849 // logU = log(u) ~= x 850 Value logU = builder.create<math::LogOp>(u); 851 852 // Detect exp(x) = +inf; written this way to avoid having to form +inf. 853 Value isInf = 854 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, logU, u); 855 856 // (u - 1) * (x / ~x) 857 Value expm1 = builder.create<arith::MulFOp>( 858 uMinusOne, builder.create<arith::DivFOp>(x, logU)); 859 expm1 = builder.create<SelectOp>(isInf, u, expm1); 860 Value approximation = builder.create<SelectOp>( 861 uEqOne, x, builder.create<SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1)); 862 rewriter.replaceOp(op, approximation); 863 return success(); 864 } 865 866 //----------------------------------------------------------------------------// 867 // Sin and Cos approximation. 868 //----------------------------------------------------------------------------// 869 870 namespace { 871 872 template <bool isSine, typename OpTy> 873 struct SinAndCosApproximation : public OpRewritePattern<OpTy> { 874 public: 875 using OpRewritePattern<OpTy>::OpRewritePattern; 876 877 LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final; 878 }; 879 } // namespace 880 881 #define TWO_OVER_PI \ 882 0.6366197723675813430755350534900574481378385829618257949906693762L 883 #define PI_OVER_2 \ 884 1.5707963267948966192313216916397514420985846996875529104874722961L 885 886 // Approximates sin(x) or cos(x) by finding the best approximation polynomial in 887 // the reduced range [0, pi/2] for both sin(x) and cos(x). Then given y in the 888 // reduced range sin(x) will be computed as sin(y), -sin(y), cos(y) or -cos(y). 889 template <bool isSine, typename OpTy> 890 LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite( 891 OpTy op, PatternRewriter &rewriter) const { 892 static_assert( 893 llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value, 894 "SinAndCosApproximation pattern expects math::SinOp or math::CosOp"); 895 896 if (!getElementTypeOrSelf(op.operand()).isF32()) 897 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 898 899 ArrayRef<int64_t> shape = vectorShape(op.operand()); 900 901 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 902 auto bcast = [&](Value value) -> Value { 903 return broadcast(builder, value, shape); 904 }; 905 auto mul = [&](Value a, Value b) -> Value { 906 return builder.create<arith::MulFOp>(a, b); 907 }; 908 auto sub = [&](Value a, Value b) -> Value { 909 return builder.create<arith::SubFOp>(a, b); 910 }; 911 auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); }; 912 913 auto i32Vec = broadcast(builder.getI32Type(), shape); 914 auto fPToSingedInteger = [&](Value a) -> Value { 915 return builder.create<arith::FPToSIOp>(a, i32Vec); 916 }; 917 918 auto modulo4 = [&](Value a) -> Value { 919 return builder.create<arith::AndIOp>(a, bcast(i32Cst(builder, 3))); 920 }; 921 922 auto isEqualTo = [&](Value a, Value b) -> Value { 923 return builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, a, b); 924 }; 925 926 auto isGreaterThan = [&](Value a, Value b) -> Value { 927 return builder.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, a, b); 928 }; 929 930 auto select = [&](Value cond, Value t, Value f) -> Value { 931 return builder.create<SelectOp>(cond, t, f); 932 }; 933 934 auto fmla = [&](Value a, Value b, Value c) { 935 return builder.create<math::FmaOp>(a, b, c); 936 }; 937 938 auto bitwiseOr = [&](Value a, Value b) { 939 return builder.create<arith::OrIOp>(a, b); 940 }; 941 942 Value twoOverPi = bcast(f32Cst(builder, TWO_OVER_PI)); 943 Value piOverTwo = bcast(f32Cst(builder, PI_OVER_2)); 944 945 Value x = op.operand(); 946 947 Value k = floor(mul(x, twoOverPi)); 948 949 Value y = sub(x, mul(k, piOverTwo)); 950 951 Value cstOne = bcast(f32Cst(builder, 1.0)); 952 Value cstNegativeOne = bcast(f32Cst(builder, -1.0)); 953 954 Value cstSC2 = bcast(f32Cst(builder, -0.16666667163372039794921875f)); 955 Value cstSC4 = bcast(f32Cst(builder, 8.333347737789154052734375e-3f)); 956 Value cstSC6 = bcast(f32Cst(builder, -1.9842604524455964565277099609375e-4f)); 957 Value cstSC8 = 958 bcast(f32Cst(builder, 2.760012648650445044040679931640625e-6f)); 959 Value cstSC10 = 960 bcast(f32Cst(builder, -2.50293279435709337121807038784027099609375e-8f)); 961 962 Value cstCC2 = bcast(f32Cst(builder, -0.5f)); 963 Value cstCC4 = bcast(f32Cst(builder, 4.166664183139801025390625e-2f)); 964 Value cstCC6 = bcast(f32Cst(builder, -1.388833043165504932403564453125e-3f)); 965 Value cstCC8 = bcast(f32Cst(builder, 2.47562347794882953166961669921875e-5f)); 966 Value cstCC10 = 967 bcast(f32Cst(builder, -2.59630184018533327616751194000244140625e-7f)); 968 969 Value kMod4 = modulo4(fPToSingedInteger(k)); 970 971 Value kR0 = isEqualTo(kMod4, bcast(i32Cst(builder, 0))); 972 Value kR1 = isEqualTo(kMod4, bcast(i32Cst(builder, 1))); 973 Value kR2 = isEqualTo(kMod4, bcast(i32Cst(builder, 2))); 974 Value kR3 = isEqualTo(kMod4, bcast(i32Cst(builder, 3))); 975 976 Value sinuseCos = isSine ? bitwiseOr(kR1, kR3) : bitwiseOr(kR0, kR2); 977 Value negativeRange = isSine ? isGreaterThan(kMod4, bcast(i32Cst(builder, 1))) 978 : bitwiseOr(kR1, kR2); 979 980 Value y2 = mul(y, y); 981 982 Value base = select(sinuseCos, cstOne, y); 983 Value cstC2 = select(sinuseCos, cstCC2, cstSC2); 984 Value cstC4 = select(sinuseCos, cstCC4, cstSC4); 985 Value cstC6 = select(sinuseCos, cstCC6, cstSC6); 986 Value cstC8 = select(sinuseCos, cstCC8, cstSC8); 987 Value cstC10 = select(sinuseCos, cstCC10, cstSC10); 988 989 Value v1 = fmla(y2, cstC10, cstC8); 990 Value v2 = fmla(y2, v1, cstC6); 991 Value v3 = fmla(y2, v2, cstC4); 992 Value v4 = fmla(y2, v3, cstC2); 993 Value v5 = fmla(y2, v4, cstOne); 994 Value v6 = mul(base, v5); 995 996 Value approximation = select(negativeRange, mul(cstNegativeOne, v6), v6); 997 998 rewriter.replaceOp(op, approximation); 999 1000 return success(); 1001 } 1002 1003 //----------------------------------------------------------------------------// 1004 // Rsqrt approximation. 1005 //----------------------------------------------------------------------------// 1006 1007 namespace { 1008 struct RsqrtApproximation : public OpRewritePattern<math::RsqrtOp> { 1009 using OpRewritePattern::OpRewritePattern; 1010 1011 LogicalResult matchAndRewrite(math::RsqrtOp op, 1012 PatternRewriter &rewriter) const final; 1013 }; 1014 } // namespace 1015 1016 LogicalResult 1017 RsqrtApproximation::matchAndRewrite(math::RsqrtOp op, 1018 PatternRewriter &rewriter) const { 1019 if (!getElementTypeOrSelf(op.operand()).isF32()) 1020 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 1021 1022 ArrayRef<int64_t> shape = vectorShape(op.operand()); 1023 1024 // Only support already-vectorized rsqrt's. 1025 if (shape.empty() || shape.back() % 8 != 0) 1026 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 1027 1028 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 1029 auto bcast = [&](Value value) -> Value { 1030 return broadcast(builder, value, shape); 1031 }; 1032 1033 Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u)); 1034 Value cstOnePointFive = bcast(f32Cst(builder, 1.5f)); 1035 Value cstNegHalf = bcast(f32Cst(builder, -0.5f)); 1036 Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u)); 1037 1038 Value negHalf = builder.create<arith::MulFOp>(op.operand(), cstNegHalf); 1039 1040 // Select only the inverse sqrt of positive normals (denormals are 1041 // flushed to zero). 1042 Value ltMinMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, 1043 op.operand(), cstMinNormPos); 1044 Value infMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, 1045 op.operand(), cstPosInf); 1046 Value notNormalFiniteMask = builder.create<arith::OrIOp>(ltMinMask, infMask); 1047 1048 // Compute an approximate result. 1049 Value yApprox = handleMultidimensionalVectors( 1050 builder, op->getOperands(), 8, [&builder](ValueRange operands) -> Value { 1051 return builder.create<x86vector::RsqrtOp>(operands); 1052 }); 1053 1054 // Do a single step of Newton-Raphson iteration to improve the approximation. 1055 // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n). 1056 // It is essential to evaluate the inner term like this because forming 1057 // y_n^2 may over- or underflow. 1058 Value inner = builder.create<arith::MulFOp>(negHalf, yApprox); 1059 Value fma = builder.create<math::FmaOp>(yApprox, inner, cstOnePointFive); 1060 Value yNewton = builder.create<arith::MulFOp>(yApprox, fma); 1061 1062 // Select the result of the Newton-Raphson step for positive normal arguments. 1063 // For other arguments, choose the output of the intrinsic. This will 1064 // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(x) = +inf if 1065 // x is zero or a positive denormalized float (equivalent to flushing positive 1066 // denormalized inputs to zero). 1067 Value res = builder.create<SelectOp>(notNormalFiniteMask, yApprox, yNewton); 1068 rewriter.replaceOp(op, res); 1069 1070 return success(); 1071 } 1072 1073 //----------------------------------------------------------------------------// 1074 1075 void mlir::populateMathPolynomialApproximationPatterns( 1076 RewritePatternSet &patterns, 1077 const MathPolynomialApproximationOptions &options) { 1078 patterns.add<TanhApproximation, LogApproximation, Log2Approximation, 1079 Log1pApproximation, ErfPolynomialApproximation, ExpApproximation, 1080 ExpM1Approximation, SinAndCosApproximation<true, math::SinOp>, 1081 SinAndCosApproximation<false, math::CosOp>>( 1082 patterns.getContext()); 1083 if (options.enableAvx2) 1084 patterns.add<RsqrtApproximation>(patterns.getContext()); 1085 } 1086