1 //===- PolynomialApproximation.cpp - Approximate math operations ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements expansion of math operations to fast approximations 10 // that do not rely on any of the library functions. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/Math/IR/Math.h" 16 #include "mlir/Dialect/Math/Transforms/Passes.h" 17 #include "mlir/Dialect/Vector/VectorOps.h" 18 #include "mlir/Dialect/X86Vector/X86VectorDialect.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/ImplicitLocOpBuilder.h" 21 #include "mlir/Transforms/Bufferize.h" 22 #include "mlir/Transforms/DialectConversion.h" 23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 24 #include <climits> 25 26 using namespace mlir; 27 using namespace mlir::vector; 28 29 using TypePredicate = llvm::function_ref<bool(Type)>; 30 31 // Returns vector width if the element type is matching the predicate (scalars 32 // that do match the predicate have width equal to `1`). 33 static Optional<int> vectorWidth(Type type, TypePredicate pred) { 34 // If the type matches the predicate then its width is `1`. 35 if (pred(type)) 36 return 1; 37 38 // Otherwise check if the type is a vector type. 39 auto vectorType = type.dyn_cast<VectorType>(); 40 if (vectorType && pred(vectorType.getElementType())) { 41 assert(vectorType.getRank() == 1 && "only 1d vectors are supported"); 42 return vectorType.getDimSize(0); 43 } 44 45 return llvm::None; 46 } 47 48 // Returns vector width of the type. If the type is a scalar returns `1`. 49 static int vectorWidth(Type type) { 50 auto vectorType = type.dyn_cast<VectorType>(); 51 return vectorType ? vectorType.getDimSize(0) : 1; 52 } 53 54 // Returns vector element type. If the type is a scalar returns the argument. 55 LLVM_ATTRIBUTE_UNUSED static Type elementType(Type type) { 56 auto vectorType = type.dyn_cast<VectorType>(); 57 return vectorType ? vectorType.getElementType() : type; 58 } 59 60 LLVM_ATTRIBUTE_UNUSED static bool isF32(Type type) { return type.isF32(); } 61 62 LLVM_ATTRIBUTE_UNUSED static bool isI32(Type type) { 63 return type.isInteger(32); 64 } 65 66 //----------------------------------------------------------------------------// 67 // Broadcast scalar types and values into vector types and values. 68 //----------------------------------------------------------------------------// 69 70 // Broadcasts scalar type into vector type (iff width is greater then 1). 71 static Type broadcast(Type type, int width) { 72 assert(!type.isa<VectorType>() && "must be scalar type"); 73 return width > 1 ? VectorType::get({width}, type) : type; 74 } 75 76 // Broadcasts scalar value into vector (iff width is greater then 1). 77 static Value broadcast(ImplicitLocOpBuilder &builder, Value value, int width) { 78 assert(!value.getType().isa<VectorType>() && "must be scalar value"); 79 auto type = broadcast(value.getType(), width); 80 return width > 1 ? builder.create<BroadcastOp>(type, value) : value; 81 } 82 83 //----------------------------------------------------------------------------// 84 // Helper functions to create constants. 85 //----------------------------------------------------------------------------// 86 87 static Value f32Cst(ImplicitLocOpBuilder &builder, float value) { 88 return builder.create<arith::ConstantOp>(builder.getF32FloatAttr(value)); 89 } 90 91 static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value) { 92 return builder.create<arith::ConstantOp>(builder.getI32IntegerAttr(value)); 93 } 94 95 static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits) { 96 Value i32Value = i32Cst(builder, static_cast<int32_t>(bits)); 97 return builder.create<arith::BitcastOp>(builder.getF32Type(), i32Value); 98 } 99 100 //----------------------------------------------------------------------------// 101 // Helper functions to build math functions approximations. 102 //----------------------------------------------------------------------------// 103 104 static Value min(ImplicitLocOpBuilder &builder, Value a, Value b) { 105 return builder.create<SelectOp>( 106 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, a, b), a, b); 107 } 108 109 static Value max(ImplicitLocOpBuilder &builder, Value a, Value b) { 110 return builder.create<SelectOp>( 111 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, a, b), a, b); 112 } 113 114 static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, 115 Value upperBound) { 116 return max(builder, min(builder, value, upperBound), lowerBound); 117 } 118 119 // Decomposes given floating point value `arg` into a normalized fraction and 120 // an integral power of two (see std::frexp). Returned values have float type. 121 static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg, 122 bool is_positive = false) { 123 assert(isF32(elementType(arg.getType())) && "argument must be f32 type"); 124 125 int width = vectorWidth(arg.getType()); 126 127 auto bcast = [&](Value value) -> Value { 128 return broadcast(builder, value, width); 129 }; 130 131 auto i32 = builder.getIntegerType(32); 132 auto i32Vec = broadcast(i32, width); 133 auto f32Vec = broadcast(builder.getF32Type(), width); 134 135 Value cst126f = f32Cst(builder, 126.0f); 136 Value cstHalf = f32Cst(builder, 0.5f); 137 Value cstInvMantMask = f32FromBits(builder, ~0x7f800000u); 138 139 // Bitcast to i32 for bitwise operations. 140 Value i32Half = builder.create<arith::BitcastOp>(i32, cstHalf); 141 Value i32InvMantMask = builder.create<arith::BitcastOp>(i32, cstInvMantMask); 142 Value i32Arg = builder.create<arith::BitcastOp>(i32Vec, arg); 143 144 // Compute normalized fraction. 145 Value tmp0 = builder.create<arith::AndIOp>(i32Arg, bcast(i32InvMantMask)); 146 Value tmp1 = builder.create<arith::OrIOp>(tmp0, bcast(i32Half)); 147 Value normalizedFraction = builder.create<arith::BitcastOp>(f32Vec, tmp1); 148 149 // Compute exponent. 150 Value arg0 = is_positive ? arg : builder.create<math::AbsOp>(arg); 151 Value biasedExponentBits = builder.create<arith::ShRUIOp>( 152 builder.create<arith::BitcastOp>(i32Vec, arg0), 153 bcast(i32Cst(builder, 23))); 154 Value biasedExponent = 155 builder.create<arith::SIToFPOp>(f32Vec, biasedExponentBits); 156 Value exponent = 157 builder.create<arith::SubFOp>(biasedExponent, bcast(cst126f)); 158 159 return {normalizedFraction, exponent}; 160 } 161 162 // Computes exp2 for an i32 argument. 163 static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) { 164 assert(isI32(elementType(arg.getType())) && "argument must be i32 type"); 165 166 int width = vectorWidth(arg.getType()); 167 168 auto bcast = [&](Value value) -> Value { 169 return broadcast(builder, value, width); 170 }; 171 172 auto f32Vec = broadcast(builder.getF32Type(), width); 173 // The exponent of f32 located at 23-bit. 174 auto exponetBitLocation = bcast(i32Cst(builder, 23)); 175 // Set the exponent bias to zero. 176 auto bias = bcast(i32Cst(builder, 127)); 177 178 Value biasedArg = builder.create<arith::AddIOp>(arg, bias); 179 Value exp2ValueInt = 180 builder.create<arith::ShLIOp>(biasedArg, exponetBitLocation); 181 Value exp2ValueF32 = builder.create<arith::BitcastOp>(f32Vec, exp2ValueInt); 182 183 return exp2ValueF32; 184 } 185 186 //----------------------------------------------------------------------------// 187 // TanhOp approximation. 188 //----------------------------------------------------------------------------// 189 190 namespace { 191 struct TanhApproximation : public OpRewritePattern<math::TanhOp> { 192 public: 193 using OpRewritePattern::OpRewritePattern; 194 195 LogicalResult matchAndRewrite(math::TanhOp op, 196 PatternRewriter &rewriter) const final; 197 }; 198 } // namespace 199 200 LogicalResult 201 TanhApproximation::matchAndRewrite(math::TanhOp op, 202 PatternRewriter &rewriter) const { 203 auto width = vectorWidth(op.operand().getType(), isF32); 204 if (!width.hasValue()) 205 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 206 207 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 208 auto bcast = [&](Value value) -> Value { 209 return broadcast(builder, value, *width); 210 }; 211 212 // Clamp operand into [plusClamp, minusClamp] range. 213 Value minusClamp = bcast(f32Cst(builder, -7.99881172180175781f)); 214 Value plusClamp = bcast(f32Cst(builder, 7.99881172180175781f)); 215 Value x = clamp(builder, op.operand(), minusClamp, plusClamp); 216 217 // Mask for tiny values that are approximated with `operand`. 218 Value tiny = bcast(f32Cst(builder, 0.0004f)); 219 Value tinyMask = builder.create<arith::CmpFOp>( 220 arith::CmpFPredicate::OLT, builder.create<math::AbsOp>(op.operand()), 221 tiny); 222 223 // The monomial coefficients of the numerator polynomial (odd). 224 Value alpha1 = bcast(f32Cst(builder, 4.89352455891786e-03f)); 225 Value alpha3 = bcast(f32Cst(builder, 6.37261928875436e-04f)); 226 Value alpha5 = bcast(f32Cst(builder, 1.48572235717979e-05f)); 227 Value alpha7 = bcast(f32Cst(builder, 5.12229709037114e-08f)); 228 Value alpha9 = bcast(f32Cst(builder, -8.60467152213735e-11f)); 229 Value alpha11 = bcast(f32Cst(builder, 2.00018790482477e-13f)); 230 Value alpha13 = bcast(f32Cst(builder, -2.76076847742355e-16f)); 231 232 // The monomial coefficients of the denominator polynomial (even). 233 Value beta0 = bcast(f32Cst(builder, 4.89352518554385e-03f)); 234 Value beta2 = bcast(f32Cst(builder, 2.26843463243900e-03f)); 235 Value beta4 = bcast(f32Cst(builder, 1.18534705686654e-04f)); 236 Value beta6 = bcast(f32Cst(builder, 1.19825839466702e-06f)); 237 238 // Since the polynomials are odd/even, we need x^2. 239 Value x2 = builder.create<arith::MulFOp>(x, x); 240 241 // Evaluate the numerator polynomial p. 242 Value p = builder.create<math::FmaOp>(x2, alpha13, alpha11); 243 p = builder.create<math::FmaOp>(x2, p, alpha9); 244 p = builder.create<math::FmaOp>(x2, p, alpha7); 245 p = builder.create<math::FmaOp>(x2, p, alpha5); 246 p = builder.create<math::FmaOp>(x2, p, alpha3); 247 p = builder.create<math::FmaOp>(x2, p, alpha1); 248 p = builder.create<arith::MulFOp>(x, p); 249 250 // Evaluate the denominator polynomial q. 251 Value q = builder.create<math::FmaOp>(x2, beta6, beta4); 252 q = builder.create<math::FmaOp>(x2, q, beta2); 253 q = builder.create<math::FmaOp>(x2, q, beta0); 254 255 // Divide the numerator by the denominator. 256 Value res = builder.create<SelectOp>(tinyMask, x, 257 builder.create<arith::DivFOp>(p, q)); 258 259 rewriter.replaceOp(op, res); 260 261 return success(); 262 } 263 264 #define LN2_VALUE \ 265 0.693147180559945309417232121458176568075500134360255254120680009493393621L 266 #define LOG2E_VALUE \ 267 1.442695040888963407359924681001892137426645954152985934135449406931109219L 268 269 //----------------------------------------------------------------------------// 270 // LogOp and Log2Op approximation. 271 //----------------------------------------------------------------------------// 272 273 namespace { 274 template <typename Op> 275 struct LogApproximationBase : public OpRewritePattern<Op> { 276 using OpRewritePattern<Op>::OpRewritePattern; 277 278 /// Base 2 if 'base2' is set; natural logarithm (base e) otherwise. 279 LogicalResult logMatchAndRewrite(Op op, PatternRewriter &rewriter, 280 bool base2) const; 281 }; 282 } // namespace 283 284 // This approximation comes from Julien Pommier's SSE math library. 285 // Link: http://gruntthepeon.free.fr/ssemath 286 template <typename Op> 287 LogicalResult 288 LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter, 289 bool base2) const { 290 auto width = vectorWidth(op.operand().getType(), isF32); 291 if (!width.hasValue()) 292 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 293 294 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 295 auto bcast = [&](Value value) -> Value { 296 return broadcast(builder, value, *width); 297 }; 298 299 Value cstZero = bcast(f32Cst(builder, 0.0f)); 300 Value cstOne = bcast(f32Cst(builder, 1.0f)); 301 Value cstNegHalf = bcast(f32Cst(builder, -0.5f)); 302 303 // The smallest non denormalized float number. 304 Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u)); 305 Value cstMinusInf = bcast(f32FromBits(builder, 0xff800000u)); 306 Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u)); 307 Value cstNan = bcast(f32FromBits(builder, 0x7fc00000)); 308 309 // Polynomial coefficients. 310 Value cstCephesSQRTHF = bcast(f32Cst(builder, 0.707106781186547524f)); 311 Value cstCephesLogP0 = bcast(f32Cst(builder, 7.0376836292E-2f)); 312 Value cstCephesLogP1 = bcast(f32Cst(builder, -1.1514610310E-1f)); 313 Value cstCephesLogP2 = bcast(f32Cst(builder, 1.1676998740E-1f)); 314 Value cstCephesLogP3 = bcast(f32Cst(builder, -1.2420140846E-1f)); 315 Value cstCephesLogP4 = bcast(f32Cst(builder, +1.4249322787E-1f)); 316 Value cstCephesLogP5 = bcast(f32Cst(builder, -1.6668057665E-1f)); 317 Value cstCephesLogP6 = bcast(f32Cst(builder, +2.0000714765E-1f)); 318 Value cstCephesLogP7 = bcast(f32Cst(builder, -2.4999993993E-1f)); 319 Value cstCephesLogP8 = bcast(f32Cst(builder, +3.3333331174E-1f)); 320 321 Value x = op.operand(); 322 323 // Truncate input values to the minimum positive normal. 324 x = max(builder, x, cstMinNormPos); 325 326 // Extract significant in the range [0.5,1) and exponent. 327 std::pair<Value, Value> pair = frexp(builder, x, /*is_positive=*/true); 328 x = pair.first; 329 Value e = pair.second; 330 331 // Shift the inputs from the range [0.5,1) to [sqrt(1/2), sqrt(2)) and shift 332 // by -1.0. The values are then centered around 0, which improves the 333 // stability of the polynomial evaluation: 334 // 335 // if( x < SQRTHF ) { 336 // e -= 1; 337 // x = x + x - 1.0; 338 // } else { x = x - 1.0; } 339 Value mask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, 340 cstCephesSQRTHF); 341 Value tmp = builder.create<SelectOp>(mask, x, cstZero); 342 343 x = builder.create<arith::SubFOp>(x, cstOne); 344 e = builder.create<arith::SubFOp>( 345 e, builder.create<SelectOp>(mask, cstOne, cstZero)); 346 x = builder.create<arith::AddFOp>(x, tmp); 347 348 Value x2 = builder.create<arith::MulFOp>(x, x); 349 Value x3 = builder.create<arith::MulFOp>(x2, x); 350 351 // Evaluate the polynomial approximant of degree 8 in three parts. 352 Value y0, y1, y2; 353 y0 = builder.create<math::FmaOp>(cstCephesLogP0, x, cstCephesLogP1); 354 y1 = builder.create<math::FmaOp>(cstCephesLogP3, x, cstCephesLogP4); 355 y2 = builder.create<math::FmaOp>(cstCephesLogP6, x, cstCephesLogP7); 356 y0 = builder.create<math::FmaOp>(y0, x, cstCephesLogP2); 357 y1 = builder.create<math::FmaOp>(y1, x, cstCephesLogP5); 358 y2 = builder.create<math::FmaOp>(y2, x, cstCephesLogP8); 359 y0 = builder.create<math::FmaOp>(y0, x3, y1); 360 y0 = builder.create<math::FmaOp>(y0, x3, y2); 361 y0 = builder.create<arith::MulFOp>(y0, x3); 362 363 y0 = builder.create<math::FmaOp>(cstNegHalf, x2, y0); 364 x = builder.create<arith::AddFOp>(x, y0); 365 366 if (base2) { 367 Value cstLog2e = bcast(f32Cst(builder, static_cast<float>(LOG2E_VALUE))); 368 x = builder.create<math::FmaOp>(x, cstLog2e, e); 369 } else { 370 Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE))); 371 x = builder.create<math::FmaOp>(e, cstLn2, x); 372 } 373 374 Value invalidMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, 375 op.operand(), cstZero); 376 Value zeroMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, 377 op.operand(), cstZero); 378 Value posInfMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, 379 op.operand(), cstPosInf); 380 381 // Filter out invalid values: 382 // • x == 0 -> -INF 383 // • x < 0 -> NAN 384 // • x == +INF -> +INF 385 Value aproximation = builder.create<SelectOp>( 386 zeroMask, cstMinusInf, 387 builder.create<SelectOp>( 388 invalidMask, cstNan, 389 builder.create<SelectOp>(posInfMask, cstPosInf, x))); 390 391 rewriter.replaceOp(op, aproximation); 392 393 return success(); 394 } 395 396 namespace { 397 struct LogApproximation : public LogApproximationBase<math::LogOp> { 398 using LogApproximationBase::LogApproximationBase; 399 400 LogicalResult matchAndRewrite(math::LogOp op, 401 PatternRewriter &rewriter) const final { 402 return logMatchAndRewrite(op, rewriter, /*base2=*/false); 403 } 404 }; 405 } // namespace 406 407 namespace { 408 struct Log2Approximation : public LogApproximationBase<math::Log2Op> { 409 using LogApproximationBase::LogApproximationBase; 410 411 LogicalResult matchAndRewrite(math::Log2Op op, 412 PatternRewriter &rewriter) const final { 413 return logMatchAndRewrite(op, rewriter, /*base2=*/true); 414 } 415 }; 416 } // namespace 417 418 //----------------------------------------------------------------------------// 419 // Log1p approximation. 420 //----------------------------------------------------------------------------// 421 422 namespace { 423 struct Log1pApproximation : public OpRewritePattern<math::Log1pOp> { 424 public: 425 using OpRewritePattern::OpRewritePattern; 426 427 LogicalResult matchAndRewrite(math::Log1pOp op, 428 PatternRewriter &rewriter) const final; 429 }; 430 } // namespace 431 432 // Approximate log(1+x). 433 LogicalResult 434 Log1pApproximation::matchAndRewrite(math::Log1pOp op, 435 PatternRewriter &rewriter) const { 436 auto width = vectorWidth(op.operand().getType(), isF32); 437 if (!width.hasValue()) 438 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 439 440 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 441 auto bcast = [&](Value value) -> Value { 442 return broadcast(builder, value, *width); 443 }; 444 445 // Approximate log(1+x) using the following, due to W. Kahan: 446 // u = x + 1.0; 447 // if (u == 1.0 || u == inf) return x; 448 // return x * log(u) / (u - 1.0); 449 // ^^^^^^^^^^^^^^^^^^^^^^ 450 // "logLarge" below. 451 Value cstOne = bcast(f32Cst(builder, 1.0f)); 452 Value x = op.operand(); 453 Value u = builder.create<arith::AddFOp>(x, cstOne); 454 Value uSmall = 455 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne); 456 Value logU = builder.create<math::LogOp>(u); 457 Value uInf = 458 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, logU); 459 Value logLarge = builder.create<arith::MulFOp>( 460 x, builder.create<arith::DivFOp>( 461 logU, builder.create<arith::SubFOp>(u, cstOne))); 462 Value approximation = builder.create<SelectOp>( 463 builder.create<arith::OrIOp>(uSmall, uInf), x, logLarge); 464 rewriter.replaceOp(op, approximation); 465 return success(); 466 } 467 468 //----------------------------------------------------------------------------// 469 // Exp approximation. 470 //----------------------------------------------------------------------------// 471 472 namespace { 473 474 struct ExpApproximation : public OpRewritePattern<math::ExpOp> { 475 public: 476 using OpRewritePattern::OpRewritePattern; 477 478 LogicalResult matchAndRewrite(math::ExpOp op, 479 PatternRewriter &rewriter) const final; 480 }; 481 } // namespace 482 483 // Approximate exp(x) using its reduced range exp(y) where y is in the range 484 // [0, ln(2)], let y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2), exp(x) 485 // = exp(y) * 2^k. exp(y). 486 LogicalResult 487 ExpApproximation::matchAndRewrite(math::ExpOp op, 488 PatternRewriter &rewriter) const { 489 auto width = vectorWidth(op.operand().getType(), isF32); 490 if (!width.hasValue()) 491 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 492 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 493 494 // TODO: Consider a common pattern rewriter with all methods below to 495 // write the approximations. 496 auto bcast = [&](Value value) -> Value { 497 return broadcast(builder, value, *width); 498 }; 499 auto fmla = [&](Value a, Value b, Value c) { 500 return builder.create<math::FmaOp>(a, b, c); 501 }; 502 auto mul = [&](Value a, Value b) -> Value { 503 return builder.create<arith::MulFOp>(a, b); 504 }; 505 auto sub = [&](Value a, Value b) -> Value { 506 return builder.create<arith::SubFOp>(a, b); 507 }; 508 auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); }; 509 510 Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE))); 511 Value cstLog2E = bcast(f32Cst(builder, static_cast<float>(LOG2E_VALUE))); 512 513 // Polynomial coefficients. 514 Value cstCephesExpP0 = bcast(f32Cst(builder, 1.0)); 515 Value cstCephesExpP1 = bcast(f32Cst(builder, 1.0)); 516 Value cstCephesExpP2 = bcast(f32Cst(builder, 0.49970514590562437052f)); 517 Value cstCephesExpP3 = bcast(f32Cst(builder, 0.16873890085469545053f)); 518 Value cstCephesExpP4 = bcast(f32Cst(builder, 0.03668965196652099192f)); 519 Value cstCephesExpP5 = bcast(f32Cst(builder, 0.01314350012789660196f)); 520 521 Value x = op.operand(); 522 523 // Reduced y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2) 524 Value xL2Inv = mul(x, cstLog2E); 525 Value kF32 = floor(xL2Inv); 526 Value kLn2 = mul(kF32, cstLn2); 527 Value y = sub(x, kLn2); 528 529 // Use Estrin's evaluation scheme with 3 independent parts: 530 // P(y)^y : (c0 + c1 y) + (c2 + c3 y) y^2 + (c4 + c5 y) y^4 531 Value y2 = mul(y, y); 532 Value y4 = mul(y2, y2); 533 534 Value q0 = fmla(cstCephesExpP1, y, cstCephesExpP0); 535 Value q1 = fmla(cstCephesExpP3, y, cstCephesExpP2); 536 Value q2 = fmla(cstCephesExpP5, y, cstCephesExpP4); 537 Value expY = fmla(q1, y2, q0); 538 expY = fmla(q2, y4, expY); 539 540 auto i32Vec = broadcast(builder.getI32Type(), *width); 541 542 // exp2(k) 543 Value k = builder.create<arith::FPToSIOp>(kF32, i32Vec); 544 Value exp2KValue = exp2I32(builder, k); 545 546 // exp(x) = exp(y) * exp2(k) 547 expY = mul(expY, exp2KValue); 548 549 // Handle overflow, inf and underflow of exp(x). exp(x) range is [0, inf], its 550 // partitioned as the following: 551 // exp(x) = 0, x <= -inf 552 // exp(x) = underflow (min_float), x <= -88 553 // exp(x) = inf (min_float), x >= 88 554 // Note: |k| = 127 is the value where the 8-bits exponent saturates. 555 Value zerof32Const = bcast(f32Cst(builder, 0)); 556 auto constPosInfinity = 557 bcast(f32Cst(builder, std::numeric_limits<float>::infinity())); 558 auto constNegIfinity = 559 bcast(f32Cst(builder, -std::numeric_limits<float>::infinity())); 560 auto underflow = bcast(f32Cst(builder, std::numeric_limits<float>::min())); 561 562 Value kMaxConst = bcast(i32Cst(builder, 127)); 563 Value kMaxNegConst = bcast(i32Cst(builder, -127)); 564 Value rightBound = 565 builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, k, kMaxConst); 566 Value leftBound = 567 builder.create<arith::CmpIOp>(arith::CmpIPredicate::sge, k, kMaxNegConst); 568 569 Value isNegInfinityX = builder.create<arith::CmpFOp>( 570 arith::CmpFPredicate::OEQ, x, constNegIfinity); 571 Value isPosInfinityX = builder.create<arith::CmpFOp>( 572 arith::CmpFPredicate::OEQ, x, constPosInfinity); 573 Value isPostiveX = 574 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zerof32Const); 575 Value isComputable = builder.create<arith::AndIOp>(rightBound, leftBound); 576 577 expY = builder.create<SelectOp>( 578 isNegInfinityX, zerof32Const, 579 builder.create<SelectOp>( 580 isPosInfinityX, constPosInfinity, 581 builder.create<SelectOp>(isComputable, expY, 582 builder.create<SelectOp>(isPostiveX, 583 constPosInfinity, 584 underflow)))); 585 586 rewriter.replaceOp(op, expY); 587 588 return success(); 589 } 590 591 //----------------------------------------------------------------------------// 592 // ExpM1 approximation. 593 //----------------------------------------------------------------------------// 594 595 namespace { 596 597 struct ExpM1Approximation : public OpRewritePattern<math::ExpM1Op> { 598 public: 599 using OpRewritePattern::OpRewritePattern; 600 601 LogicalResult matchAndRewrite(math::ExpM1Op op, 602 PatternRewriter &rewriter) const final; 603 }; 604 } // namespace 605 606 LogicalResult 607 ExpM1Approximation::matchAndRewrite(math::ExpM1Op op, 608 PatternRewriter &rewriter) const { 609 auto width = vectorWidth(op.operand().getType(), isF32); 610 if (!width.hasValue()) 611 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 612 613 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 614 auto bcast = [&](Value value) -> Value { 615 return broadcast(builder, value, *width); 616 }; 617 618 // expm1(x) = exp(x) - 1 = u - 1. 619 // We have to handle it carefully when x is near 0, i.e. u ~= 1, 620 // and when the input is ~= -inf, i.e. u - 1 ~= -1. 621 Value cstOne = bcast(f32Cst(builder, 1.0f)); 622 Value cstNegOne = bcast(f32Cst(builder, -1.0f)); 623 Value x = op.operand(); 624 Value u = builder.create<math::ExpOp>(x); 625 Value uEqOne = 626 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne); 627 Value uMinusOne = builder.create<arith::SubFOp>(u, cstOne); 628 Value uMinusOneEqNegOne = builder.create<arith::CmpFOp>( 629 arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne); 630 // logU = log(u) ~= x 631 Value logU = builder.create<math::LogOp>(u); 632 633 // Detect exp(x) = +inf; written this way to avoid having to form +inf. 634 Value isInf = 635 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, logU, u); 636 637 // (u - 1) * (x / ~x) 638 Value expm1 = builder.create<arith::MulFOp>( 639 uMinusOne, builder.create<arith::DivFOp>(x, logU)); 640 expm1 = builder.create<SelectOp>(isInf, u, expm1); 641 Value approximation = builder.create<SelectOp>( 642 uEqOne, x, builder.create<SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1)); 643 rewriter.replaceOp(op, approximation); 644 return success(); 645 } 646 647 //----------------------------------------------------------------------------// 648 // Sin and Cos approximation. 649 //----------------------------------------------------------------------------// 650 651 namespace { 652 653 template <bool isSine, typename OpTy> 654 struct SinAndCosApproximation : public OpRewritePattern<OpTy> { 655 public: 656 using OpRewritePattern<OpTy>::OpRewritePattern; 657 658 LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final; 659 }; 660 } // namespace 661 662 #define TWO_OVER_PI \ 663 0.6366197723675813430755350534900574481378385829618257949906693762L 664 #define PI_OVER_2 \ 665 1.5707963267948966192313216916397514420985846996875529104874722961L 666 667 // Approximates sin(x) or cos(x) by finding the best approximation polynomial in 668 // the reduced range [0, pi/2] for both sin(x) and cos(x). Then given y in the 669 // reduced range sin(x) will be computed as sin(y), -sin(y), cos(y) or -cos(y). 670 template <bool isSine, typename OpTy> 671 LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite( 672 OpTy op, PatternRewriter &rewriter) const { 673 static_assert( 674 llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value, 675 "SinAndCosApproximation pattern expects math::SinOp or math::CosOp"); 676 auto width = vectorWidth(op.operand().getType(), isF32); 677 if (!width.hasValue()) 678 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 679 680 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 681 auto bcast = [&](Value value) -> Value { 682 return broadcast(builder, value, *width); 683 }; 684 auto mul = [&](Value a, Value b) -> Value { 685 return builder.create<arith::MulFOp>(a, b); 686 }; 687 auto sub = [&](Value a, Value b) -> Value { 688 return builder.create<arith::SubFOp>(a, b); 689 }; 690 auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); }; 691 692 auto i32Vec = broadcast(builder.getI32Type(), *width); 693 auto fPToSingedInteger = [&](Value a) -> Value { 694 return builder.create<arith::FPToSIOp>(a, i32Vec); 695 }; 696 697 auto modulo4 = [&](Value a) -> Value { 698 return builder.create<arith::AndIOp>(a, bcast(i32Cst(builder, 3))); 699 }; 700 701 auto isEqualTo = [&](Value a, Value b) -> Value { 702 return builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, a, b); 703 }; 704 705 auto isGreaterThan = [&](Value a, Value b) -> Value { 706 return builder.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, a, b); 707 }; 708 709 auto select = [&](Value cond, Value t, Value f) -> Value { 710 return builder.create<SelectOp>(cond, t, f); 711 }; 712 713 auto fmla = [&](Value a, Value b, Value c) { 714 return builder.create<math::FmaOp>(a, b, c); 715 }; 716 717 auto bitwiseOr = [&](Value a, Value b) { 718 return builder.create<arith::OrIOp>(a, b); 719 }; 720 721 Value twoOverPi = bcast(f32Cst(builder, TWO_OVER_PI)); 722 Value piOverTwo = bcast(f32Cst(builder, PI_OVER_2)); 723 724 Value x = op.operand(); 725 726 Value k = floor(mul(x, twoOverPi)); 727 728 Value y = sub(x, mul(k, piOverTwo)); 729 730 Value cstOne = bcast(f32Cst(builder, 1.0)); 731 Value cstNegativeOne = bcast(f32Cst(builder, -1.0)); 732 733 Value cstSC2 = bcast(f32Cst(builder, -0.16666667163372039794921875f)); 734 Value cstSC4 = bcast(f32Cst(builder, 8.333347737789154052734375e-3f)); 735 Value cstSC6 = bcast(f32Cst(builder, -1.9842604524455964565277099609375e-4f)); 736 Value cstSC8 = 737 bcast(f32Cst(builder, 2.760012648650445044040679931640625e-6f)); 738 Value cstSC10 = 739 bcast(f32Cst(builder, -2.50293279435709337121807038784027099609375e-8f)); 740 741 Value cstCC2 = bcast(f32Cst(builder, -0.5f)); 742 Value cstCC4 = bcast(f32Cst(builder, 4.166664183139801025390625e-2f)); 743 Value cstCC6 = bcast(f32Cst(builder, -1.388833043165504932403564453125e-3f)); 744 Value cstCC8 = bcast(f32Cst(builder, 2.47562347794882953166961669921875e-5f)); 745 Value cstCC10 = 746 bcast(f32Cst(builder, -2.59630184018533327616751194000244140625e-7f)); 747 748 Value kMod4 = modulo4(fPToSingedInteger(k)); 749 750 Value kR0 = isEqualTo(kMod4, bcast(i32Cst(builder, 0))); 751 Value kR1 = isEqualTo(kMod4, bcast(i32Cst(builder, 1))); 752 Value kR2 = isEqualTo(kMod4, bcast(i32Cst(builder, 2))); 753 Value kR3 = isEqualTo(kMod4, bcast(i32Cst(builder, 3))); 754 755 Value sinuseCos = isSine ? bitwiseOr(kR1, kR3) : bitwiseOr(kR0, kR2); 756 Value negativeRange = isSine ? isGreaterThan(kMod4, bcast(i32Cst(builder, 1))) 757 : bitwiseOr(kR1, kR2); 758 759 Value y2 = mul(y, y); 760 761 Value base = select(sinuseCos, cstOne, y); 762 Value cstC2 = select(sinuseCos, cstCC2, cstSC2); 763 Value cstC4 = select(sinuseCos, cstCC4, cstSC4); 764 Value cstC6 = select(sinuseCos, cstCC6, cstSC6); 765 Value cstC8 = select(sinuseCos, cstCC8, cstSC8); 766 Value cstC10 = select(sinuseCos, cstCC10, cstSC10); 767 768 Value v1 = fmla(y2, cstC10, cstC8); 769 Value v2 = fmla(y2, v1, cstC6); 770 Value v3 = fmla(y2, v2, cstC4); 771 Value v4 = fmla(y2, v3, cstC2); 772 Value v5 = fmla(y2, v4, cstOne); 773 Value v6 = mul(base, v5); 774 775 Value approximation = select(negativeRange, mul(cstNegativeOne, v6), v6); 776 777 rewriter.replaceOp(op, approximation); 778 779 return success(); 780 } 781 782 //----------------------------------------------------------------------------// 783 // Rsqrt approximation. 784 //----------------------------------------------------------------------------// 785 786 namespace { 787 struct RsqrtApproximation : public OpRewritePattern<math::RsqrtOp> { 788 using OpRewritePattern::OpRewritePattern; 789 790 LogicalResult matchAndRewrite(math::RsqrtOp op, 791 PatternRewriter &rewriter) const final; 792 }; 793 } // namespace 794 795 LogicalResult 796 RsqrtApproximation::matchAndRewrite(math::RsqrtOp op, 797 PatternRewriter &rewriter) const { 798 auto width = vectorWidth(op.operand().getType(), isF32); 799 // Only support already-vectorized rsqrt's. 800 if (!width.hasValue() || *width != 8) 801 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 802 803 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 804 auto bcast = [&](Value value) -> Value { 805 return broadcast(builder, value, *width); 806 }; 807 808 Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u)); 809 Value cstOnePointFive = bcast(f32Cst(builder, 1.5f)); 810 Value cstNegHalf = bcast(f32Cst(builder, -0.5f)); 811 Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u)); 812 813 Value negHalf = builder.create<arith::MulFOp>(op.operand(), cstNegHalf); 814 815 // Select only the inverse sqrt of positive normals (denormals are 816 // flushed to zero). 817 Value ltMinMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, 818 op.operand(), cstMinNormPos); 819 Value infMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, 820 op.operand(), cstPosInf); 821 Value notNormalFiniteMask = builder.create<arith::OrIOp>(ltMinMask, infMask); 822 823 // Compute an approximate result. 824 Value yApprox = builder.create<x86vector::RsqrtOp>(op.operand()); 825 826 // Do a single step of Newton-Raphson iteration to improve the approximation. 827 // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n). 828 // It is essential to evaluate the inner term like this because forming 829 // y_n^2 may over- or underflow. 830 Value inner = builder.create<arith::MulFOp>(negHalf, yApprox); 831 Value fma = builder.create<math::FmaOp>(yApprox, inner, cstOnePointFive); 832 Value yNewton = builder.create<arith::MulFOp>(yApprox, fma); 833 834 // Select the result of the Newton-Raphson step for positive normal arguments. 835 // For other arguments, choose the output of the intrinsic. This will 836 // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(x) = +inf if 837 // x is zero or a positive denormalized float (equivalent to flushing positive 838 // denormalized inputs to zero). 839 Value res = builder.create<SelectOp>(notNormalFiniteMask, yApprox, yNewton); 840 rewriter.replaceOp(op, res); 841 842 return success(); 843 } 844 845 //----------------------------------------------------------------------------// 846 847 void mlir::populateMathPolynomialApproximationPatterns( 848 RewritePatternSet &patterns, 849 const MathPolynomialApproximationOptions &options) { 850 patterns.add<TanhApproximation, LogApproximation, Log2Approximation, 851 Log1pApproximation, ExpApproximation, ExpM1Approximation, 852 SinAndCosApproximation<true, math::SinOp>, 853 SinAndCosApproximation<false, math::CosOp>>( 854 patterns.getContext()); 855 if (options.enableAvx2) 856 patterns.add<RsqrtApproximation>(patterns.getContext()); 857 } 858