1 //===- PolynomialApproximation.cpp - Approximate math operations ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements expansion of math operations to fast approximations 10 // that do not rely on any of the library functions. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/Math/IR/Math.h" 16 #include "mlir/Dialect/Math/Transforms/Approximation.h" 17 #include "mlir/Dialect/Math/Transforms/Passes.h" 18 #include "mlir/Dialect/Vector/VectorOps.h" 19 #include "mlir/Dialect/X86Vector/X86VectorDialect.h" 20 #include "mlir/IR/Builders.h" 21 #include "mlir/IR/ImplicitLocOpBuilder.h" 22 #include "mlir/Transforms/Bufferize.h" 23 #include "mlir/Transforms/DialectConversion.h" 24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 25 #include "llvm/ADT/ArrayRef.h" 26 #include <climits> 27 #include <cstddef> 28 29 using namespace mlir; 30 using namespace mlir::math; 31 using namespace mlir::vector; 32 33 using TypePredicate = llvm::function_ref<bool(Type)>; 34 35 // Returns vector width if the element type is matching the predicate (scalars 36 // that do match the predicate have width equal to `1`). 37 static Optional<int> vectorWidth(Type type, TypePredicate pred) { 38 // If the type matches the predicate then its width is `1`. 39 if (pred(type)) 40 return 1; 41 42 // Otherwise check if the type is a vector type. 43 auto vectorType = type.dyn_cast<VectorType>(); 44 if (vectorType && pred(vectorType.getElementType())) { 45 assert(vectorType.getRank() == 1 && "only 1d vectors are supported"); 46 return vectorType.getDimSize(0); 47 } 48 49 return llvm::None; 50 } 51 52 // Returns vector width of the type. If the type is a scalar returns `1`. 53 static int vectorWidth(Type type) { 54 auto vectorType = type.dyn_cast<VectorType>(); 55 return vectorType ? vectorType.getDimSize(0) : 1; 56 } 57 58 // Returns vector element type. If the type is a scalar returns the argument. 59 LLVM_ATTRIBUTE_UNUSED static Type elementType(Type type) { 60 auto vectorType = type.dyn_cast<VectorType>(); 61 return vectorType ? vectorType.getElementType() : type; 62 } 63 64 LLVM_ATTRIBUTE_UNUSED static bool isF32(Type type) { return type.isF32(); } 65 66 LLVM_ATTRIBUTE_UNUSED static bool isI32(Type type) { 67 return type.isInteger(32); 68 } 69 70 //----------------------------------------------------------------------------// 71 // Broadcast scalar types and values into vector types and values. 72 //----------------------------------------------------------------------------// 73 74 // Broadcasts scalar type into vector type (iff width is greater then 1). 75 static Type broadcast(Type type, int width) { 76 assert(!type.isa<VectorType>() && "must be scalar type"); 77 return width > 1 ? VectorType::get({width}, type) : type; 78 } 79 80 // Broadcasts scalar value into vector (iff width is greater then 1). 81 static Value broadcast(ImplicitLocOpBuilder &builder, Value value, int width) { 82 assert(!value.getType().isa<VectorType>() && "must be scalar value"); 83 auto type = broadcast(value.getType(), width); 84 return width > 1 ? builder.create<BroadcastOp>(type, value) : value; 85 } 86 87 //----------------------------------------------------------------------------// 88 // Helper functions to create constants. 89 //----------------------------------------------------------------------------// 90 91 static Value f32Cst(ImplicitLocOpBuilder &builder, float value) { 92 return builder.create<arith::ConstantOp>(builder.getF32FloatAttr(value)); 93 } 94 95 static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value) { 96 return builder.create<arith::ConstantOp>(builder.getI32IntegerAttr(value)); 97 } 98 99 static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits) { 100 Value i32Value = i32Cst(builder, static_cast<int32_t>(bits)); 101 return builder.create<arith::BitcastOp>(builder.getF32Type(), i32Value); 102 } 103 104 //----------------------------------------------------------------------------// 105 // Helper functions to build math functions approximations. 106 //----------------------------------------------------------------------------// 107 108 static Value min(ImplicitLocOpBuilder &builder, Value a, Value b) { 109 return builder.create<SelectOp>( 110 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, a, b), a, b); 111 } 112 113 static Value max(ImplicitLocOpBuilder &builder, Value a, Value b) { 114 return builder.create<SelectOp>( 115 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, a, b), a, b); 116 } 117 118 static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, 119 Value upperBound) { 120 return max(builder, min(builder, value, upperBound), lowerBound); 121 } 122 123 // Decomposes given floating point value `arg` into a normalized fraction and 124 // an integral power of two (see std::frexp). Returned values have float type. 125 static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg, 126 bool is_positive = false) { 127 assert(isF32(elementType(arg.getType())) && "argument must be f32 type"); 128 129 int width = vectorWidth(arg.getType()); 130 131 auto bcast = [&](Value value) -> Value { 132 return broadcast(builder, value, width); 133 }; 134 135 auto i32 = builder.getIntegerType(32); 136 auto i32Vec = broadcast(i32, width); 137 auto f32Vec = broadcast(builder.getF32Type(), width); 138 139 Value cst126f = f32Cst(builder, 126.0f); 140 Value cstHalf = f32Cst(builder, 0.5f); 141 Value cstInvMantMask = f32FromBits(builder, ~0x7f800000u); 142 143 // Bitcast to i32 for bitwise operations. 144 Value i32Half = builder.create<arith::BitcastOp>(i32, cstHalf); 145 Value i32InvMantMask = builder.create<arith::BitcastOp>(i32, cstInvMantMask); 146 Value i32Arg = builder.create<arith::BitcastOp>(i32Vec, arg); 147 148 // Compute normalized fraction. 149 Value tmp0 = builder.create<arith::AndIOp>(i32Arg, bcast(i32InvMantMask)); 150 Value tmp1 = builder.create<arith::OrIOp>(tmp0, bcast(i32Half)); 151 Value normalizedFraction = builder.create<arith::BitcastOp>(f32Vec, tmp1); 152 153 // Compute exponent. 154 Value arg0 = is_positive ? arg : builder.create<math::AbsOp>(arg); 155 Value biasedExponentBits = builder.create<arith::ShRUIOp>( 156 builder.create<arith::BitcastOp>(i32Vec, arg0), 157 bcast(i32Cst(builder, 23))); 158 Value biasedExponent = 159 builder.create<arith::SIToFPOp>(f32Vec, biasedExponentBits); 160 Value exponent = 161 builder.create<arith::SubFOp>(biasedExponent, bcast(cst126f)); 162 163 return {normalizedFraction, exponent}; 164 } 165 166 // Computes exp2 for an i32 argument. 167 static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) { 168 assert(isI32(elementType(arg.getType())) && "argument must be i32 type"); 169 170 int width = vectorWidth(arg.getType()); 171 172 auto bcast = [&](Value value) -> Value { 173 return broadcast(builder, value, width); 174 }; 175 176 auto f32Vec = broadcast(builder.getF32Type(), width); 177 // The exponent of f32 located at 23-bit. 178 auto exponetBitLocation = bcast(i32Cst(builder, 23)); 179 // Set the exponent bias to zero. 180 auto bias = bcast(i32Cst(builder, 127)); 181 182 Value biasedArg = builder.create<arith::AddIOp>(arg, bias); 183 Value exp2ValueInt = 184 builder.create<arith::ShLIOp>(biasedArg, exponetBitLocation); 185 Value exp2ValueF32 = builder.create<arith::BitcastOp>(f32Vec, exp2ValueInt); 186 187 return exp2ValueF32; 188 } 189 190 namespace { 191 Value makePolynomialCalculation(ImplicitLocOpBuilder &builder, 192 llvm::ArrayRef<Value> coeffs, Value x) { 193 auto width = vectorWidth(x.getType(), isF32); 194 if (coeffs.size() == 0) { 195 return broadcast(builder, f32Cst(builder, 0.0f), *width); 196 } else if (coeffs.size() == 1) { 197 return coeffs[0]; 198 } 199 Value res = builder.create<math::FmaOp>(x, coeffs[coeffs.size() - 1], 200 coeffs[coeffs.size() - 2]); 201 for (auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) { 202 res = builder.create<math::FmaOp>(x, res, coeffs[i]); 203 } 204 return res; 205 } 206 } // namespace 207 208 //----------------------------------------------------------------------------// 209 // TanhOp approximation. 210 //----------------------------------------------------------------------------// 211 212 namespace { 213 struct TanhApproximation : public OpRewritePattern<math::TanhOp> { 214 public: 215 using OpRewritePattern::OpRewritePattern; 216 217 LogicalResult matchAndRewrite(math::TanhOp op, 218 PatternRewriter &rewriter) const final; 219 }; 220 } // namespace 221 222 LogicalResult 223 TanhApproximation::matchAndRewrite(math::TanhOp op, 224 PatternRewriter &rewriter) const { 225 auto width = vectorWidth(op.operand().getType(), isF32); 226 if (!width.hasValue()) 227 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 228 229 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 230 auto bcast = [&](Value value) -> Value { 231 return broadcast(builder, value, *width); 232 }; 233 234 // Clamp operand into [plusClamp, minusClamp] range. 235 Value minusClamp = bcast(f32Cst(builder, -7.99881172180175781f)); 236 Value plusClamp = bcast(f32Cst(builder, 7.99881172180175781f)); 237 Value x = clamp(builder, op.operand(), minusClamp, plusClamp); 238 239 // Mask for tiny values that are approximated with `operand`. 240 Value tiny = bcast(f32Cst(builder, 0.0004f)); 241 Value tinyMask = builder.create<arith::CmpFOp>( 242 arith::CmpFPredicate::OLT, builder.create<math::AbsOp>(op.operand()), 243 tiny); 244 245 // The monomial coefficients of the numerator polynomial (odd). 246 Value alpha1 = bcast(f32Cst(builder, 4.89352455891786e-03f)); 247 Value alpha3 = bcast(f32Cst(builder, 6.37261928875436e-04f)); 248 Value alpha5 = bcast(f32Cst(builder, 1.48572235717979e-05f)); 249 Value alpha7 = bcast(f32Cst(builder, 5.12229709037114e-08f)); 250 Value alpha9 = bcast(f32Cst(builder, -8.60467152213735e-11f)); 251 Value alpha11 = bcast(f32Cst(builder, 2.00018790482477e-13f)); 252 Value alpha13 = bcast(f32Cst(builder, -2.76076847742355e-16f)); 253 254 // The monomial coefficients of the denominator polynomial (even). 255 Value beta0 = bcast(f32Cst(builder, 4.89352518554385e-03f)); 256 Value beta2 = bcast(f32Cst(builder, 2.26843463243900e-03f)); 257 Value beta4 = bcast(f32Cst(builder, 1.18534705686654e-04f)); 258 Value beta6 = bcast(f32Cst(builder, 1.19825839466702e-06f)); 259 260 // Since the polynomials are odd/even, we need x^2. 261 Value x2 = builder.create<arith::MulFOp>(x, x); 262 263 // Evaluate the numerator polynomial p. 264 Value p = builder.create<math::FmaOp>(x2, alpha13, alpha11); 265 p = builder.create<math::FmaOp>(x2, p, alpha9); 266 p = builder.create<math::FmaOp>(x2, p, alpha7); 267 p = builder.create<math::FmaOp>(x2, p, alpha5); 268 p = builder.create<math::FmaOp>(x2, p, alpha3); 269 p = builder.create<math::FmaOp>(x2, p, alpha1); 270 p = builder.create<arith::MulFOp>(x, p); 271 272 // Evaluate the denominator polynomial q. 273 Value q = builder.create<math::FmaOp>(x2, beta6, beta4); 274 q = builder.create<math::FmaOp>(x2, q, beta2); 275 q = builder.create<math::FmaOp>(x2, q, beta0); 276 277 // Divide the numerator by the denominator. 278 Value res = builder.create<SelectOp>(tinyMask, x, 279 builder.create<arith::DivFOp>(p, q)); 280 281 rewriter.replaceOp(op, res); 282 283 return success(); 284 } 285 286 #define LN2_VALUE \ 287 0.693147180559945309417232121458176568075500134360255254120680009493393621L 288 #define LOG2E_VALUE \ 289 1.442695040888963407359924681001892137426645954152985934135449406931109219L 290 291 //----------------------------------------------------------------------------// 292 // LogOp and Log2Op approximation. 293 //----------------------------------------------------------------------------// 294 295 namespace { 296 template <typename Op> 297 struct LogApproximationBase : public OpRewritePattern<Op> { 298 using OpRewritePattern<Op>::OpRewritePattern; 299 300 /// Base 2 if 'base2' is set; natural logarithm (base e) otherwise. 301 LogicalResult logMatchAndRewrite(Op op, PatternRewriter &rewriter, 302 bool base2) const; 303 }; 304 } // namespace 305 306 // This approximation comes from Julien Pommier's SSE math library. 307 // Link: http://gruntthepeon.free.fr/ssemath 308 template <typename Op> 309 LogicalResult 310 LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter, 311 bool base2) const { 312 auto width = vectorWidth(op.operand().getType(), isF32); 313 if (!width.hasValue()) 314 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 315 316 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 317 auto bcast = [&](Value value) -> Value { 318 return broadcast(builder, value, *width); 319 }; 320 321 Value cstZero = bcast(f32Cst(builder, 0.0f)); 322 Value cstOne = bcast(f32Cst(builder, 1.0f)); 323 Value cstNegHalf = bcast(f32Cst(builder, -0.5f)); 324 325 // The smallest non denormalized float number. 326 Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u)); 327 Value cstMinusInf = bcast(f32FromBits(builder, 0xff800000u)); 328 Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u)); 329 Value cstNan = bcast(f32FromBits(builder, 0x7fc00000)); 330 331 // Polynomial coefficients. 332 Value cstCephesSQRTHF = bcast(f32Cst(builder, 0.707106781186547524f)); 333 Value cstCephesLogP0 = bcast(f32Cst(builder, 7.0376836292E-2f)); 334 Value cstCephesLogP1 = bcast(f32Cst(builder, -1.1514610310E-1f)); 335 Value cstCephesLogP2 = bcast(f32Cst(builder, 1.1676998740E-1f)); 336 Value cstCephesLogP3 = bcast(f32Cst(builder, -1.2420140846E-1f)); 337 Value cstCephesLogP4 = bcast(f32Cst(builder, +1.4249322787E-1f)); 338 Value cstCephesLogP5 = bcast(f32Cst(builder, -1.6668057665E-1f)); 339 Value cstCephesLogP6 = bcast(f32Cst(builder, +2.0000714765E-1f)); 340 Value cstCephesLogP7 = bcast(f32Cst(builder, -2.4999993993E-1f)); 341 Value cstCephesLogP8 = bcast(f32Cst(builder, +3.3333331174E-1f)); 342 343 Value x = op.operand(); 344 345 // Truncate input values to the minimum positive normal. 346 x = max(builder, x, cstMinNormPos); 347 348 // Extract significant in the range [0.5,1) and exponent. 349 std::pair<Value, Value> pair = frexp(builder, x, /*is_positive=*/true); 350 x = pair.first; 351 Value e = pair.second; 352 353 // Shift the inputs from the range [0.5,1) to [sqrt(1/2), sqrt(2)) and shift 354 // by -1.0. The values are then centered around 0, which improves the 355 // stability of the polynomial evaluation: 356 // 357 // if( x < SQRTHF ) { 358 // e -= 1; 359 // x = x + x - 1.0; 360 // } else { x = x - 1.0; } 361 Value mask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, 362 cstCephesSQRTHF); 363 Value tmp = builder.create<SelectOp>(mask, x, cstZero); 364 365 x = builder.create<arith::SubFOp>(x, cstOne); 366 e = builder.create<arith::SubFOp>( 367 e, builder.create<SelectOp>(mask, cstOne, cstZero)); 368 x = builder.create<arith::AddFOp>(x, tmp); 369 370 Value x2 = builder.create<arith::MulFOp>(x, x); 371 Value x3 = builder.create<arith::MulFOp>(x2, x); 372 373 // Evaluate the polynomial approximant of degree 8 in three parts. 374 Value y0, y1, y2; 375 y0 = builder.create<math::FmaOp>(cstCephesLogP0, x, cstCephesLogP1); 376 y1 = builder.create<math::FmaOp>(cstCephesLogP3, x, cstCephesLogP4); 377 y2 = builder.create<math::FmaOp>(cstCephesLogP6, x, cstCephesLogP7); 378 y0 = builder.create<math::FmaOp>(y0, x, cstCephesLogP2); 379 y1 = builder.create<math::FmaOp>(y1, x, cstCephesLogP5); 380 y2 = builder.create<math::FmaOp>(y2, x, cstCephesLogP8); 381 y0 = builder.create<math::FmaOp>(y0, x3, y1); 382 y0 = builder.create<math::FmaOp>(y0, x3, y2); 383 y0 = builder.create<arith::MulFOp>(y0, x3); 384 385 y0 = builder.create<math::FmaOp>(cstNegHalf, x2, y0); 386 x = builder.create<arith::AddFOp>(x, y0); 387 388 if (base2) { 389 Value cstLog2e = bcast(f32Cst(builder, static_cast<float>(LOG2E_VALUE))); 390 x = builder.create<math::FmaOp>(x, cstLog2e, e); 391 } else { 392 Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE))); 393 x = builder.create<math::FmaOp>(e, cstLn2, x); 394 } 395 396 Value invalidMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, 397 op.operand(), cstZero); 398 Value zeroMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, 399 op.operand(), cstZero); 400 Value posInfMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, 401 op.operand(), cstPosInf); 402 403 // Filter out invalid values: 404 // • x == 0 -> -INF 405 // • x < 0 -> NAN 406 // • x == +INF -> +INF 407 Value aproximation = builder.create<SelectOp>( 408 zeroMask, cstMinusInf, 409 builder.create<SelectOp>( 410 invalidMask, cstNan, 411 builder.create<SelectOp>(posInfMask, cstPosInf, x))); 412 413 rewriter.replaceOp(op, aproximation); 414 415 return success(); 416 } 417 418 namespace { 419 struct LogApproximation : public LogApproximationBase<math::LogOp> { 420 using LogApproximationBase::LogApproximationBase; 421 422 LogicalResult matchAndRewrite(math::LogOp op, 423 PatternRewriter &rewriter) const final { 424 return logMatchAndRewrite(op, rewriter, /*base2=*/false); 425 } 426 }; 427 } // namespace 428 429 namespace { 430 struct Log2Approximation : public LogApproximationBase<math::Log2Op> { 431 using LogApproximationBase::LogApproximationBase; 432 433 LogicalResult matchAndRewrite(math::Log2Op op, 434 PatternRewriter &rewriter) const final { 435 return logMatchAndRewrite(op, rewriter, /*base2=*/true); 436 } 437 }; 438 } // namespace 439 440 //----------------------------------------------------------------------------// 441 // Log1p approximation. 442 //----------------------------------------------------------------------------// 443 444 namespace { 445 struct Log1pApproximation : public OpRewritePattern<math::Log1pOp> { 446 public: 447 using OpRewritePattern::OpRewritePattern; 448 449 LogicalResult matchAndRewrite(math::Log1pOp op, 450 PatternRewriter &rewriter) const final; 451 }; 452 } // namespace 453 454 // Approximate log(1+x). 455 LogicalResult 456 Log1pApproximation::matchAndRewrite(math::Log1pOp op, 457 PatternRewriter &rewriter) const { 458 auto width = vectorWidth(op.operand().getType(), isF32); 459 if (!width.hasValue()) 460 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 461 462 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 463 auto bcast = [&](Value value) -> Value { 464 return broadcast(builder, value, *width); 465 }; 466 467 // Approximate log(1+x) using the following, due to W. Kahan: 468 // u = x + 1.0; 469 // if (u == 1.0 || u == inf) return x; 470 // return x * log(u) / (u - 1.0); 471 // ^^^^^^^^^^^^^^^^^^^^^^ 472 // "logLarge" below. 473 Value cstOne = bcast(f32Cst(builder, 1.0f)); 474 Value x = op.operand(); 475 Value u = builder.create<arith::AddFOp>(x, cstOne); 476 Value uSmall = 477 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne); 478 Value logU = builder.create<math::LogOp>(u); 479 Value uInf = 480 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, logU); 481 Value logLarge = builder.create<arith::MulFOp>( 482 x, builder.create<arith::DivFOp>( 483 logU, builder.create<arith::SubFOp>(u, cstOne))); 484 Value approximation = builder.create<SelectOp>( 485 builder.create<arith::OrIOp>(uSmall, uInf), x, logLarge); 486 rewriter.replaceOp(op, approximation); 487 return success(); 488 } 489 490 //----------------------------------------------------------------------------// 491 // Erf approximation. 492 //----------------------------------------------------------------------------// 493 494 // Approximates erf(x) with 495 // a - P(x)/Q(x) 496 // where P and Q are polynomials of degree 4. 497 // Different coefficients are chosen based on the value of x. 498 // The approximation error is ~2.5e-07. 499 // Boost's minimax tool that utilizes the Remez method was used to find the 500 // coefficients. 501 LogicalResult 502 ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op, 503 PatternRewriter &rewriter) const { 504 auto width = vectorWidth(op.operand().getType(), isF32); 505 if (!width.hasValue()) 506 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 507 508 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 509 auto bcast = [&](Value value) -> Value { 510 return broadcast(builder, value, *width); 511 }; 512 513 const int intervalsCount = 3; 514 const int polyDegree = 4; 515 516 Value zero = bcast(f32Cst(builder, 0)); 517 Value one = bcast(f32Cst(builder, 1)); 518 Value pp[intervalsCount][polyDegree + 1]; 519 pp[0][0] = bcast(f32Cst(builder, +0.00000000000000000e+00)); 520 pp[0][1] = bcast(f32Cst(builder, +1.12837916222975858e+00)); 521 pp[0][2] = bcast(f32Cst(builder, -5.23018562988006470e-01)); 522 pp[0][3] = bcast(f32Cst(builder, +2.09741709609267072e-01)); 523 pp[0][4] = bcast(f32Cst(builder, +2.58146801602987875e-02)); 524 pp[1][0] = bcast(f32Cst(builder, +0.00000000000000000e+00)); 525 pp[1][1] = bcast(f32Cst(builder, +1.12750687816789140e+00)); 526 pp[1][2] = bcast(f32Cst(builder, -3.64721408487825775e-01)); 527 pp[1][3] = bcast(f32Cst(builder, +1.18407396425136952e-01)); 528 pp[1][4] = bcast(f32Cst(builder, +3.70645533056476558e-02)); 529 pp[2][0] = bcast(f32Cst(builder, -3.30093071049483172e-03)); 530 pp[2][1] = bcast(f32Cst(builder, +3.51961938357697011e-03)); 531 pp[2][2] = bcast(f32Cst(builder, -1.41373622814988039e-03)); 532 pp[2][3] = bcast(f32Cst(builder, +2.53447094961941348e-04)); 533 pp[2][4] = bcast(f32Cst(builder, -1.71048029455037401e-05)); 534 535 Value qq[intervalsCount][polyDegree + 1]; 536 qq[0][0] = bcast(f32Cst(builder, +1.000000000000000000e+00)); 537 qq[0][1] = bcast(f32Cst(builder, -4.635138185962547255e-01)); 538 qq[0][2] = bcast(f32Cst(builder, +5.192301327279782447e-01)); 539 qq[0][3] = bcast(f32Cst(builder, -1.318089722204810087e-01)); 540 qq[0][4] = bcast(f32Cst(builder, +7.397964654672315005e-02)); 541 qq[1][0] = bcast(f32Cst(builder, +1.00000000000000000e+00)); 542 qq[1][1] = bcast(f32Cst(builder, -3.27607011824493086e-01)); 543 qq[1][2] = bcast(f32Cst(builder, +4.48369090658821977e-01)); 544 qq[1][3] = bcast(f32Cst(builder, -8.83462621207857930e-02)); 545 qq[1][4] = bcast(f32Cst(builder, +5.72442770283176093e-02)); 546 qq[2][0] = bcast(f32Cst(builder, +1.00000000000000000e+00)); 547 qq[2][1] = bcast(f32Cst(builder, -2.06069165953913769e+00)); 548 qq[2][2] = bcast(f32Cst(builder, +1.62705939945477759e+00)); 549 qq[2][3] = bcast(f32Cst(builder, -5.83389859211130017e-01)); 550 qq[2][4] = bcast(f32Cst(builder, +8.21908939856640930e-02)); 551 552 Value offsets[intervalsCount]; 553 offsets[0] = bcast(f32Cst(builder, 0)); 554 offsets[1] = bcast(f32Cst(builder, 0)); 555 offsets[2] = bcast(f32Cst(builder, 1)); 556 557 Value bounds[intervalsCount]; 558 bounds[0] = bcast(f32Cst(builder, 0.8)); 559 bounds[1] = bcast(f32Cst(builder, 2)); 560 bounds[2] = bcast(f32Cst(builder, 3.75)); 561 562 Value isNegativeArg = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, 563 op.operand(), zero); 564 Value negArg = builder.create<arith::NegFOp>(op.operand()); 565 Value x = builder.create<SelectOp>(isNegativeArg, negArg, op.operand()); 566 567 Value offset = offsets[0]; 568 Value p[polyDegree + 1]; 569 Value q[polyDegree + 1]; 570 for (int i = 0; i <= polyDegree; ++i) { 571 p[i] = pp[0][i]; 572 q[i] = qq[0][i]; 573 } 574 575 // TODO: maybe use vector stacking to reduce the number of selects. 576 Value isLessThanBound[intervalsCount]; 577 for (int j = 0; j < intervalsCount - 1; ++j) { 578 isLessThanBound[j] = 579 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, bounds[j]); 580 for (int i = 0; i <= polyDegree; ++i) { 581 p[i] = builder.create<SelectOp>(isLessThanBound[j], p[i], pp[j + 1][i]); 582 q[i] = builder.create<SelectOp>(isLessThanBound[j], q[i], qq[j + 1][i]); 583 } 584 offset = 585 builder.create<SelectOp>(isLessThanBound[j], offset, offsets[j + 1]); 586 } 587 isLessThanBound[intervalsCount - 1] = builder.create<arith::CmpFOp>( 588 arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]); 589 590 Value pPoly = makePolynomialCalculation(builder, p, x); 591 Value qPoly = makePolynomialCalculation(builder, q, x); 592 Value rationalPoly = builder.create<arith::DivFOp>(pPoly, qPoly); 593 Value formula = builder.create<arith::AddFOp>(offset, rationalPoly); 594 formula = builder.create<SelectOp>(isLessThanBound[intervalsCount - 1], 595 formula, one); 596 597 // erf is odd function: erf(x) = -erf(-x). 598 Value negFormula = builder.create<arith::NegFOp>(formula); 599 Value res = builder.create<SelectOp>(isNegativeArg, negFormula, formula); 600 601 rewriter.replaceOp(op, res); 602 603 return success(); 604 } 605 606 //----------------------------------------------------------------------------// 607 // Exp approximation. 608 //----------------------------------------------------------------------------// 609 610 namespace { 611 612 struct ExpApproximation : public OpRewritePattern<math::ExpOp> { 613 public: 614 using OpRewritePattern::OpRewritePattern; 615 616 LogicalResult matchAndRewrite(math::ExpOp op, 617 PatternRewriter &rewriter) const final; 618 }; 619 } // namespace 620 621 // Approximate exp(x) using its reduced range exp(y) where y is in the range 622 // [0, ln(2)], let y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2), exp(x) 623 // = exp(y) * 2^k. exp(y). 624 LogicalResult 625 ExpApproximation::matchAndRewrite(math::ExpOp op, 626 PatternRewriter &rewriter) const { 627 auto width = vectorWidth(op.operand().getType(), isF32); 628 if (!width.hasValue()) 629 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 630 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 631 632 // TODO: Consider a common pattern rewriter with all methods below to 633 // write the approximations. 634 auto bcast = [&](Value value) -> Value { 635 return broadcast(builder, value, *width); 636 }; 637 auto fmla = [&](Value a, Value b, Value c) { 638 return builder.create<math::FmaOp>(a, b, c); 639 }; 640 auto mul = [&](Value a, Value b) -> Value { 641 return builder.create<arith::MulFOp>(a, b); 642 }; 643 auto sub = [&](Value a, Value b) -> Value { 644 return builder.create<arith::SubFOp>(a, b); 645 }; 646 auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); }; 647 648 Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE))); 649 Value cstLog2E = bcast(f32Cst(builder, static_cast<float>(LOG2E_VALUE))); 650 651 // Polynomial coefficients. 652 Value cstCephesExpP0 = bcast(f32Cst(builder, 1.0)); 653 Value cstCephesExpP1 = bcast(f32Cst(builder, 1.0)); 654 Value cstCephesExpP2 = bcast(f32Cst(builder, 0.49970514590562437052f)); 655 Value cstCephesExpP3 = bcast(f32Cst(builder, 0.16873890085469545053f)); 656 Value cstCephesExpP4 = bcast(f32Cst(builder, 0.03668965196652099192f)); 657 Value cstCephesExpP5 = bcast(f32Cst(builder, 0.01314350012789660196f)); 658 659 Value x = op.operand(); 660 661 // Reduced y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2) 662 Value xL2Inv = mul(x, cstLog2E); 663 Value kF32 = floor(xL2Inv); 664 Value kLn2 = mul(kF32, cstLn2); 665 Value y = sub(x, kLn2); 666 667 // Use Estrin's evaluation scheme with 3 independent parts: 668 // P(y)^y : (c0 + c1 y) + (c2 + c3 y) y^2 + (c4 + c5 y) y^4 669 Value y2 = mul(y, y); 670 Value y4 = mul(y2, y2); 671 672 Value q0 = fmla(cstCephesExpP1, y, cstCephesExpP0); 673 Value q1 = fmla(cstCephesExpP3, y, cstCephesExpP2); 674 Value q2 = fmla(cstCephesExpP5, y, cstCephesExpP4); 675 Value expY = fmla(q1, y2, q0); 676 expY = fmla(q2, y4, expY); 677 678 auto i32Vec = broadcast(builder.getI32Type(), *width); 679 680 // exp2(k) 681 Value k = builder.create<arith::FPToSIOp>(kF32, i32Vec); 682 Value exp2KValue = exp2I32(builder, k); 683 684 // exp(x) = exp(y) * exp2(k) 685 expY = mul(expY, exp2KValue); 686 687 // Handle overflow, inf and underflow of exp(x). exp(x) range is [0, inf], its 688 // partitioned as the following: 689 // exp(x) = 0, x <= -inf 690 // exp(x) = underflow (min_float), x <= -88 691 // exp(x) = inf (min_float), x >= 88 692 // Note: |k| = 127 is the value where the 8-bits exponent saturates. 693 Value zerof32Const = bcast(f32Cst(builder, 0)); 694 auto constPosInfinity = 695 bcast(f32Cst(builder, std::numeric_limits<float>::infinity())); 696 auto constNegIfinity = 697 bcast(f32Cst(builder, -std::numeric_limits<float>::infinity())); 698 auto underflow = bcast(f32Cst(builder, std::numeric_limits<float>::min())); 699 700 Value kMaxConst = bcast(i32Cst(builder, 127)); 701 Value kMaxNegConst = bcast(i32Cst(builder, -127)); 702 Value rightBound = 703 builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, k, kMaxConst); 704 Value leftBound = 705 builder.create<arith::CmpIOp>(arith::CmpIPredicate::sge, k, kMaxNegConst); 706 707 Value isNegInfinityX = builder.create<arith::CmpFOp>( 708 arith::CmpFPredicate::OEQ, x, constNegIfinity); 709 Value isPosInfinityX = builder.create<arith::CmpFOp>( 710 arith::CmpFPredicate::OEQ, x, constPosInfinity); 711 Value isPostiveX = 712 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zerof32Const); 713 Value isComputable = builder.create<arith::AndIOp>(rightBound, leftBound); 714 715 expY = builder.create<SelectOp>( 716 isNegInfinityX, zerof32Const, 717 builder.create<SelectOp>( 718 isPosInfinityX, constPosInfinity, 719 builder.create<SelectOp>(isComputable, expY, 720 builder.create<SelectOp>(isPostiveX, 721 constPosInfinity, 722 underflow)))); 723 724 rewriter.replaceOp(op, expY); 725 726 return success(); 727 } 728 729 //----------------------------------------------------------------------------// 730 // ExpM1 approximation. 731 //----------------------------------------------------------------------------// 732 733 namespace { 734 735 struct ExpM1Approximation : public OpRewritePattern<math::ExpM1Op> { 736 public: 737 using OpRewritePattern::OpRewritePattern; 738 739 LogicalResult matchAndRewrite(math::ExpM1Op op, 740 PatternRewriter &rewriter) const final; 741 }; 742 } // namespace 743 744 LogicalResult 745 ExpM1Approximation::matchAndRewrite(math::ExpM1Op op, 746 PatternRewriter &rewriter) const { 747 auto width = vectorWidth(op.operand().getType(), isF32); 748 if (!width.hasValue()) 749 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 750 751 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 752 auto bcast = [&](Value value) -> Value { 753 return broadcast(builder, value, *width); 754 }; 755 756 // expm1(x) = exp(x) - 1 = u - 1. 757 // We have to handle it carefully when x is near 0, i.e. u ~= 1, 758 // and when the input is ~= -inf, i.e. u - 1 ~= -1. 759 Value cstOne = bcast(f32Cst(builder, 1.0f)); 760 Value cstNegOne = bcast(f32Cst(builder, -1.0f)); 761 Value x = op.operand(); 762 Value u = builder.create<math::ExpOp>(x); 763 Value uEqOne = 764 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne); 765 Value uMinusOne = builder.create<arith::SubFOp>(u, cstOne); 766 Value uMinusOneEqNegOne = builder.create<arith::CmpFOp>( 767 arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne); 768 // logU = log(u) ~= x 769 Value logU = builder.create<math::LogOp>(u); 770 771 // Detect exp(x) = +inf; written this way to avoid having to form +inf. 772 Value isInf = 773 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, logU, u); 774 775 // (u - 1) * (x / ~x) 776 Value expm1 = builder.create<arith::MulFOp>( 777 uMinusOne, builder.create<arith::DivFOp>(x, logU)); 778 expm1 = builder.create<SelectOp>(isInf, u, expm1); 779 Value approximation = builder.create<SelectOp>( 780 uEqOne, x, builder.create<SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1)); 781 rewriter.replaceOp(op, approximation); 782 return success(); 783 } 784 785 //----------------------------------------------------------------------------// 786 // Sin and Cos approximation. 787 //----------------------------------------------------------------------------// 788 789 namespace { 790 791 template <bool isSine, typename OpTy> 792 struct SinAndCosApproximation : public OpRewritePattern<OpTy> { 793 public: 794 using OpRewritePattern<OpTy>::OpRewritePattern; 795 796 LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final; 797 }; 798 } // namespace 799 800 #define TWO_OVER_PI \ 801 0.6366197723675813430755350534900574481378385829618257949906693762L 802 #define PI_OVER_2 \ 803 1.5707963267948966192313216916397514420985846996875529104874722961L 804 805 // Approximates sin(x) or cos(x) by finding the best approximation polynomial in 806 // the reduced range [0, pi/2] for both sin(x) and cos(x). Then given y in the 807 // reduced range sin(x) will be computed as sin(y), -sin(y), cos(y) or -cos(y). 808 template <bool isSine, typename OpTy> 809 LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite( 810 OpTy op, PatternRewriter &rewriter) const { 811 static_assert( 812 llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value, 813 "SinAndCosApproximation pattern expects math::SinOp or math::CosOp"); 814 auto width = vectorWidth(op.operand().getType(), isF32); 815 if (!width.hasValue()) 816 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 817 818 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 819 auto bcast = [&](Value value) -> Value { 820 return broadcast(builder, value, *width); 821 }; 822 auto mul = [&](Value a, Value b) -> Value { 823 return builder.create<arith::MulFOp>(a, b); 824 }; 825 auto sub = [&](Value a, Value b) -> Value { 826 return builder.create<arith::SubFOp>(a, b); 827 }; 828 auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); }; 829 830 auto i32Vec = broadcast(builder.getI32Type(), *width); 831 auto fPToSingedInteger = [&](Value a) -> Value { 832 return builder.create<arith::FPToSIOp>(a, i32Vec); 833 }; 834 835 auto modulo4 = [&](Value a) -> Value { 836 return builder.create<arith::AndIOp>(a, bcast(i32Cst(builder, 3))); 837 }; 838 839 auto isEqualTo = [&](Value a, Value b) -> Value { 840 return builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, a, b); 841 }; 842 843 auto isGreaterThan = [&](Value a, Value b) -> Value { 844 return builder.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, a, b); 845 }; 846 847 auto select = [&](Value cond, Value t, Value f) -> Value { 848 return builder.create<SelectOp>(cond, t, f); 849 }; 850 851 auto fmla = [&](Value a, Value b, Value c) { 852 return builder.create<math::FmaOp>(a, b, c); 853 }; 854 855 auto bitwiseOr = [&](Value a, Value b) { 856 return builder.create<arith::OrIOp>(a, b); 857 }; 858 859 Value twoOverPi = bcast(f32Cst(builder, TWO_OVER_PI)); 860 Value piOverTwo = bcast(f32Cst(builder, PI_OVER_2)); 861 862 Value x = op.operand(); 863 864 Value k = floor(mul(x, twoOverPi)); 865 866 Value y = sub(x, mul(k, piOverTwo)); 867 868 Value cstOne = bcast(f32Cst(builder, 1.0)); 869 Value cstNegativeOne = bcast(f32Cst(builder, -1.0)); 870 871 Value cstSC2 = bcast(f32Cst(builder, -0.16666667163372039794921875f)); 872 Value cstSC4 = bcast(f32Cst(builder, 8.333347737789154052734375e-3f)); 873 Value cstSC6 = bcast(f32Cst(builder, -1.9842604524455964565277099609375e-4f)); 874 Value cstSC8 = 875 bcast(f32Cst(builder, 2.760012648650445044040679931640625e-6f)); 876 Value cstSC10 = 877 bcast(f32Cst(builder, -2.50293279435709337121807038784027099609375e-8f)); 878 879 Value cstCC2 = bcast(f32Cst(builder, -0.5f)); 880 Value cstCC4 = bcast(f32Cst(builder, 4.166664183139801025390625e-2f)); 881 Value cstCC6 = bcast(f32Cst(builder, -1.388833043165504932403564453125e-3f)); 882 Value cstCC8 = bcast(f32Cst(builder, 2.47562347794882953166961669921875e-5f)); 883 Value cstCC10 = 884 bcast(f32Cst(builder, -2.59630184018533327616751194000244140625e-7f)); 885 886 Value kMod4 = modulo4(fPToSingedInteger(k)); 887 888 Value kR0 = isEqualTo(kMod4, bcast(i32Cst(builder, 0))); 889 Value kR1 = isEqualTo(kMod4, bcast(i32Cst(builder, 1))); 890 Value kR2 = isEqualTo(kMod4, bcast(i32Cst(builder, 2))); 891 Value kR3 = isEqualTo(kMod4, bcast(i32Cst(builder, 3))); 892 893 Value sinuseCos = isSine ? bitwiseOr(kR1, kR3) : bitwiseOr(kR0, kR2); 894 Value negativeRange = isSine ? isGreaterThan(kMod4, bcast(i32Cst(builder, 1))) 895 : bitwiseOr(kR1, kR2); 896 897 Value y2 = mul(y, y); 898 899 Value base = select(sinuseCos, cstOne, y); 900 Value cstC2 = select(sinuseCos, cstCC2, cstSC2); 901 Value cstC4 = select(sinuseCos, cstCC4, cstSC4); 902 Value cstC6 = select(sinuseCos, cstCC6, cstSC6); 903 Value cstC8 = select(sinuseCos, cstCC8, cstSC8); 904 Value cstC10 = select(sinuseCos, cstCC10, cstSC10); 905 906 Value v1 = fmla(y2, cstC10, cstC8); 907 Value v2 = fmla(y2, v1, cstC6); 908 Value v3 = fmla(y2, v2, cstC4); 909 Value v4 = fmla(y2, v3, cstC2); 910 Value v5 = fmla(y2, v4, cstOne); 911 Value v6 = mul(base, v5); 912 913 Value approximation = select(negativeRange, mul(cstNegativeOne, v6), v6); 914 915 rewriter.replaceOp(op, approximation); 916 917 return success(); 918 } 919 920 //----------------------------------------------------------------------------// 921 // Rsqrt approximation. 922 //----------------------------------------------------------------------------// 923 924 namespace { 925 struct RsqrtApproximation : public OpRewritePattern<math::RsqrtOp> { 926 using OpRewritePattern::OpRewritePattern; 927 928 LogicalResult matchAndRewrite(math::RsqrtOp op, 929 PatternRewriter &rewriter) const final; 930 }; 931 } // namespace 932 933 LogicalResult 934 RsqrtApproximation::matchAndRewrite(math::RsqrtOp op, 935 PatternRewriter &rewriter) const { 936 auto width = vectorWidth(op.operand().getType(), isF32); 937 // Only support already-vectorized rsqrt's. 938 if (!width.hasValue() || *width != 8) 939 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 940 941 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 942 auto bcast = [&](Value value) -> Value { 943 return broadcast(builder, value, *width); 944 }; 945 946 Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u)); 947 Value cstOnePointFive = bcast(f32Cst(builder, 1.5f)); 948 Value cstNegHalf = bcast(f32Cst(builder, -0.5f)); 949 Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u)); 950 951 Value negHalf = builder.create<arith::MulFOp>(op.operand(), cstNegHalf); 952 953 // Select only the inverse sqrt of positive normals (denormals are 954 // flushed to zero). 955 Value ltMinMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, 956 op.operand(), cstMinNormPos); 957 Value infMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, 958 op.operand(), cstPosInf); 959 Value notNormalFiniteMask = builder.create<arith::OrIOp>(ltMinMask, infMask); 960 961 // Compute an approximate result. 962 Value yApprox = builder.create<x86vector::RsqrtOp>(op.operand()); 963 964 // Do a single step of Newton-Raphson iteration to improve the approximation. 965 // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n). 966 // It is essential to evaluate the inner term like this because forming 967 // y_n^2 may over- or underflow. 968 Value inner = builder.create<arith::MulFOp>(negHalf, yApprox); 969 Value fma = builder.create<math::FmaOp>(yApprox, inner, cstOnePointFive); 970 Value yNewton = builder.create<arith::MulFOp>(yApprox, fma); 971 972 // Select the result of the Newton-Raphson step for positive normal arguments. 973 // For other arguments, choose the output of the intrinsic. This will 974 // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(x) = +inf if 975 // x is zero or a positive denormalized float (equivalent to flushing positive 976 // denormalized inputs to zero). 977 Value res = builder.create<SelectOp>(notNormalFiniteMask, yApprox, yNewton); 978 rewriter.replaceOp(op, res); 979 980 return success(); 981 } 982 983 //----------------------------------------------------------------------------// 984 985 void mlir::populateMathPolynomialApproximationPatterns( 986 RewritePatternSet &patterns, 987 const MathPolynomialApproximationOptions &options) { 988 patterns.add<TanhApproximation, LogApproximation, Log2Approximation, 989 Log1pApproximation, ErfPolynomialApproximation, ExpApproximation, 990 ExpM1Approximation, SinAndCosApproximation<true, math::SinOp>, 991 SinAndCosApproximation<false, math::CosOp>>( 992 patterns.getContext()); 993 if (options.enableAvx2) 994 patterns.add<RsqrtApproximation>(patterns.getContext()); 995 } 996