1 //===- PolynomialApproximation.cpp - Approximate math operations ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements expansion of math operations to fast approximations 10 // that do not rely on any of the library functions. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/Math/IR/Math.h" 16 #include "mlir/Dialect/Math/Transforms/Approximation.h" 17 #include "mlir/Dialect/Math/Transforms/Passes.h" 18 #include "mlir/Dialect/Vector/VectorOps.h" 19 #include "mlir/Dialect/X86Vector/X86VectorDialect.h" 20 #include "mlir/IR/Builders.h" 21 #include "mlir/IR/ImplicitLocOpBuilder.h" 22 #include "mlir/Transforms/Bufferize.h" 23 #include "mlir/Transforms/DialectConversion.h" 24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 25 #include "llvm/ADT/ArrayRef.h" 26 #include <climits> 27 #include <cstddef> 28 29 using namespace mlir; 30 using namespace mlir::math; 31 using namespace mlir::vector; 32 33 using TypePredicate = llvm::function_ref<bool(Type)>; 34 35 // Returns vector shape if the element type is matching the predicate (scalars 36 // that do match the predicate have shape equal to `{1}`). 37 static Optional<SmallVector<int64_t, 2>> vectorShape(Type type, 38 TypePredicate pred) { 39 // If the type matches the predicate then its shape is `{1}`. 40 if (pred(type)) 41 return SmallVector<int64_t, 2>{1}; 42 43 // Otherwise check if the type is a vector type. 44 auto vectorType = type.dyn_cast<VectorType>(); 45 if (vectorType && pred(vectorType.getElementType())) { 46 return llvm::to_vector<2>(vectorType.getShape()); 47 } 48 49 return llvm::None; 50 } 51 52 // Returns vector shape of the type. If the type is a scalar returns `1`. 53 static SmallVector<int64_t, 2> vectorShape(Type type) { 54 auto vectorType = type.dyn_cast<VectorType>(); 55 return vectorType ? llvm::to_vector<2>(vectorType.getShape()) 56 : SmallVector<int64_t, 2>{1}; 57 } 58 59 // Returns vector element type. If the type is a scalar returns the argument. 60 LLVM_ATTRIBUTE_UNUSED static Type elementType(Type type) { 61 auto vectorType = type.dyn_cast<VectorType>(); 62 return vectorType ? vectorType.getElementType() : type; 63 } 64 65 LLVM_ATTRIBUTE_UNUSED static bool isF32(Type type) { return type.isF32(); } 66 67 LLVM_ATTRIBUTE_UNUSED static bool isI32(Type type) { 68 return type.isInteger(32); 69 } 70 71 //----------------------------------------------------------------------------// 72 // Broadcast scalar types and values into vector types and values. 73 //----------------------------------------------------------------------------// 74 75 // Returns true if shape != {1}. 76 static bool isNonScalarShape(ArrayRef<int64_t> shape) { 77 return shape.size() > 1 || shape[0] > 1; 78 } 79 80 // Broadcasts scalar type into vector type (iff shape is non-scalar). 81 static Type broadcast(Type type, ArrayRef<int64_t> shape) { 82 assert(!type.isa<VectorType>() && "must be scalar type"); 83 return isNonScalarShape(shape) ? VectorType::get(shape, type) : type; 84 } 85 86 // Broadcasts scalar value into vector (iff shape is non-scalar). 87 static Value broadcast(ImplicitLocOpBuilder &builder, Value value, 88 ArrayRef<int64_t> shape) { 89 assert(!value.getType().isa<VectorType>() && "must be scalar value"); 90 auto type = broadcast(value.getType(), shape); 91 return isNonScalarShape(shape) ? builder.create<BroadcastOp>(type, value) 92 : value; 93 } 94 95 //----------------------------------------------------------------------------// 96 // Helper functions to create constants. 97 //----------------------------------------------------------------------------// 98 99 static Value f32Cst(ImplicitLocOpBuilder &builder, float value) { 100 return builder.create<arith::ConstantOp>(builder.getF32FloatAttr(value)); 101 } 102 103 static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value) { 104 return builder.create<arith::ConstantOp>(builder.getI32IntegerAttr(value)); 105 } 106 107 static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits) { 108 Value i32Value = i32Cst(builder, static_cast<int32_t>(bits)); 109 return builder.create<arith::BitcastOp>(builder.getF32Type(), i32Value); 110 } 111 112 //----------------------------------------------------------------------------// 113 // Helper functions to build math functions approximations. 114 //----------------------------------------------------------------------------// 115 116 static Value min(ImplicitLocOpBuilder &builder, Value a, Value b) { 117 return builder.create<SelectOp>( 118 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, a, b), a, b); 119 } 120 121 static Value max(ImplicitLocOpBuilder &builder, Value a, Value b) { 122 return builder.create<SelectOp>( 123 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, a, b), a, b); 124 } 125 126 static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, 127 Value upperBound) { 128 return max(builder, min(builder, value, upperBound), lowerBound); 129 } 130 131 // Decomposes given floating point value `arg` into a normalized fraction and 132 // an integral power of two (see std::frexp). Returned values have float type. 133 static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg, 134 bool is_positive = false) { 135 assert(isF32(elementType(arg.getType())) && "argument must be f32 type"); 136 137 auto shape = vectorShape(arg.getType()); 138 139 auto bcast = [&](Value value) -> Value { 140 return broadcast(builder, value, shape); 141 }; 142 143 auto i32 = builder.getIntegerType(32); 144 auto i32Vec = broadcast(i32, shape); 145 auto f32Vec = broadcast(builder.getF32Type(), shape); 146 147 Value cst126f = f32Cst(builder, 126.0f); 148 Value cstHalf = f32Cst(builder, 0.5f); 149 Value cstInvMantMask = f32FromBits(builder, ~0x7f800000u); 150 151 // Bitcast to i32 for bitwise operations. 152 Value i32Half = builder.create<arith::BitcastOp>(i32, cstHalf); 153 Value i32InvMantMask = builder.create<arith::BitcastOp>(i32, cstInvMantMask); 154 Value i32Arg = builder.create<arith::BitcastOp>(i32Vec, arg); 155 156 // Compute normalized fraction. 157 Value tmp0 = builder.create<arith::AndIOp>(i32Arg, bcast(i32InvMantMask)); 158 Value tmp1 = builder.create<arith::OrIOp>(tmp0, bcast(i32Half)); 159 Value normalizedFraction = builder.create<arith::BitcastOp>(f32Vec, tmp1); 160 161 // Compute exponent. 162 Value arg0 = is_positive ? arg : builder.create<math::AbsOp>(arg); 163 Value biasedExponentBits = builder.create<arith::ShRUIOp>( 164 builder.create<arith::BitcastOp>(i32Vec, arg0), 165 bcast(i32Cst(builder, 23))); 166 Value biasedExponent = 167 builder.create<arith::SIToFPOp>(f32Vec, biasedExponentBits); 168 Value exponent = 169 builder.create<arith::SubFOp>(biasedExponent, bcast(cst126f)); 170 171 return {normalizedFraction, exponent}; 172 } 173 174 // Computes exp2 for an i32 argument. 175 static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) { 176 assert(isI32(elementType(arg.getType())) && "argument must be i32 type"); 177 178 auto shape = vectorShape(arg.getType()); 179 180 auto bcast = [&](Value value) -> Value { 181 return broadcast(builder, value, shape); 182 }; 183 184 auto f32Vec = broadcast(builder.getF32Type(), shape); 185 // The exponent of f32 located at 23-bit. 186 auto exponetBitLocation = bcast(i32Cst(builder, 23)); 187 // Set the exponent bias to zero. 188 auto bias = bcast(i32Cst(builder, 127)); 189 190 Value biasedArg = builder.create<arith::AddIOp>(arg, bias); 191 Value exp2ValueInt = 192 builder.create<arith::ShLIOp>(biasedArg, exponetBitLocation); 193 Value exp2ValueF32 = builder.create<arith::BitcastOp>(f32Vec, exp2ValueInt); 194 195 return exp2ValueF32; 196 } 197 198 namespace { 199 Value makePolynomialCalculation(ImplicitLocOpBuilder &builder, 200 llvm::ArrayRef<Value> coeffs, Value x) { 201 auto shape = vectorShape(x.getType(), isF32); 202 if (coeffs.size() == 0) { 203 return broadcast(builder, f32Cst(builder, 0.0f), *shape); 204 } else if (coeffs.size() == 1) { 205 return coeffs[0]; 206 } 207 Value res = builder.create<math::FmaOp>(x, coeffs[coeffs.size() - 1], 208 coeffs[coeffs.size() - 2]); 209 for (auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) { 210 res = builder.create<math::FmaOp>(x, res, coeffs[i]); 211 } 212 return res; 213 } 214 } // namespace 215 216 //----------------------------------------------------------------------------// 217 // TanhOp approximation. 218 //----------------------------------------------------------------------------// 219 220 namespace { 221 struct TanhApproximation : public OpRewritePattern<math::TanhOp> { 222 public: 223 using OpRewritePattern::OpRewritePattern; 224 225 LogicalResult matchAndRewrite(math::TanhOp op, 226 PatternRewriter &rewriter) const final; 227 }; 228 } // namespace 229 230 LogicalResult 231 TanhApproximation::matchAndRewrite(math::TanhOp op, 232 PatternRewriter &rewriter) const { 233 auto shape = vectorShape(op.operand().getType(), isF32); 234 if (!shape.hasValue()) 235 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 236 237 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 238 auto bcast = [&](Value value) -> Value { 239 return broadcast(builder, value, *shape); 240 }; 241 242 // Clamp operand into [plusClamp, minusClamp] range. 243 Value minusClamp = bcast(f32Cst(builder, -7.99881172180175781f)); 244 Value plusClamp = bcast(f32Cst(builder, 7.99881172180175781f)); 245 Value x = clamp(builder, op.operand(), minusClamp, plusClamp); 246 247 // Mask for tiny values that are approximated with `operand`. 248 Value tiny = bcast(f32Cst(builder, 0.0004f)); 249 Value tinyMask = builder.create<arith::CmpFOp>( 250 arith::CmpFPredicate::OLT, builder.create<math::AbsOp>(op.operand()), 251 tiny); 252 253 // The monomial coefficients of the numerator polynomial (odd). 254 Value alpha1 = bcast(f32Cst(builder, 4.89352455891786e-03f)); 255 Value alpha3 = bcast(f32Cst(builder, 6.37261928875436e-04f)); 256 Value alpha5 = bcast(f32Cst(builder, 1.48572235717979e-05f)); 257 Value alpha7 = bcast(f32Cst(builder, 5.12229709037114e-08f)); 258 Value alpha9 = bcast(f32Cst(builder, -8.60467152213735e-11f)); 259 Value alpha11 = bcast(f32Cst(builder, 2.00018790482477e-13f)); 260 Value alpha13 = bcast(f32Cst(builder, -2.76076847742355e-16f)); 261 262 // The monomial coefficients of the denominator polynomial (even). 263 Value beta0 = bcast(f32Cst(builder, 4.89352518554385e-03f)); 264 Value beta2 = bcast(f32Cst(builder, 2.26843463243900e-03f)); 265 Value beta4 = bcast(f32Cst(builder, 1.18534705686654e-04f)); 266 Value beta6 = bcast(f32Cst(builder, 1.19825839466702e-06f)); 267 268 // Since the polynomials are odd/even, we need x^2. 269 Value x2 = builder.create<arith::MulFOp>(x, x); 270 271 // Evaluate the numerator polynomial p. 272 Value p = builder.create<math::FmaOp>(x2, alpha13, alpha11); 273 p = builder.create<math::FmaOp>(x2, p, alpha9); 274 p = builder.create<math::FmaOp>(x2, p, alpha7); 275 p = builder.create<math::FmaOp>(x2, p, alpha5); 276 p = builder.create<math::FmaOp>(x2, p, alpha3); 277 p = builder.create<math::FmaOp>(x2, p, alpha1); 278 p = builder.create<arith::MulFOp>(x, p); 279 280 // Evaluate the denominator polynomial q. 281 Value q = builder.create<math::FmaOp>(x2, beta6, beta4); 282 q = builder.create<math::FmaOp>(x2, q, beta2); 283 q = builder.create<math::FmaOp>(x2, q, beta0); 284 285 // Divide the numerator by the denominator. 286 Value res = builder.create<SelectOp>(tinyMask, x, 287 builder.create<arith::DivFOp>(p, q)); 288 289 rewriter.replaceOp(op, res); 290 291 return success(); 292 } 293 294 #define LN2_VALUE \ 295 0.693147180559945309417232121458176568075500134360255254120680009493393621L 296 #define LOG2E_VALUE \ 297 1.442695040888963407359924681001892137426645954152985934135449406931109219L 298 299 //----------------------------------------------------------------------------// 300 // LogOp and Log2Op approximation. 301 //----------------------------------------------------------------------------// 302 303 namespace { 304 template <typename Op> 305 struct LogApproximationBase : public OpRewritePattern<Op> { 306 using OpRewritePattern<Op>::OpRewritePattern; 307 308 /// Base 2 if 'base2' is set; natural logarithm (base e) otherwise. 309 LogicalResult logMatchAndRewrite(Op op, PatternRewriter &rewriter, 310 bool base2) const; 311 }; 312 } // namespace 313 314 // This approximation comes from Julien Pommier's SSE math library. 315 // Link: http://gruntthepeon.free.fr/ssemath 316 template <typename Op> 317 LogicalResult 318 LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter, 319 bool base2) const { 320 auto shape = vectorShape(op.operand().getType(), isF32); 321 if (!shape.hasValue()) 322 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 323 324 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 325 auto bcast = [&](Value value) -> Value { 326 return broadcast(builder, value, *shape); 327 }; 328 329 Value cstZero = bcast(f32Cst(builder, 0.0f)); 330 Value cstOne = bcast(f32Cst(builder, 1.0f)); 331 Value cstNegHalf = bcast(f32Cst(builder, -0.5f)); 332 333 // The smallest non denormalized float number. 334 Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u)); 335 Value cstMinusInf = bcast(f32FromBits(builder, 0xff800000u)); 336 Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u)); 337 Value cstNan = bcast(f32FromBits(builder, 0x7fc00000)); 338 339 // Polynomial coefficients. 340 Value cstCephesSQRTHF = bcast(f32Cst(builder, 0.707106781186547524f)); 341 Value cstCephesLogP0 = bcast(f32Cst(builder, 7.0376836292E-2f)); 342 Value cstCephesLogP1 = bcast(f32Cst(builder, -1.1514610310E-1f)); 343 Value cstCephesLogP2 = bcast(f32Cst(builder, 1.1676998740E-1f)); 344 Value cstCephesLogP3 = bcast(f32Cst(builder, -1.2420140846E-1f)); 345 Value cstCephesLogP4 = bcast(f32Cst(builder, +1.4249322787E-1f)); 346 Value cstCephesLogP5 = bcast(f32Cst(builder, -1.6668057665E-1f)); 347 Value cstCephesLogP6 = bcast(f32Cst(builder, +2.0000714765E-1f)); 348 Value cstCephesLogP7 = bcast(f32Cst(builder, -2.4999993993E-1f)); 349 Value cstCephesLogP8 = bcast(f32Cst(builder, +3.3333331174E-1f)); 350 351 Value x = op.operand(); 352 353 // Truncate input values to the minimum positive normal. 354 x = max(builder, x, cstMinNormPos); 355 356 // Extract significant in the range [0.5,1) and exponent. 357 std::pair<Value, Value> pair = frexp(builder, x, /*is_positive=*/true); 358 x = pair.first; 359 Value e = pair.second; 360 361 // Shift the inputs from the range [0.5,1) to [sqrt(1/2), sqrt(2)) and shift 362 // by -1.0. The values are then centered around 0, which improves the 363 // stability of the polynomial evaluation: 364 // 365 // if( x < SQRTHF ) { 366 // e -= 1; 367 // x = x + x - 1.0; 368 // } else { x = x - 1.0; } 369 Value mask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, 370 cstCephesSQRTHF); 371 Value tmp = builder.create<SelectOp>(mask, x, cstZero); 372 373 x = builder.create<arith::SubFOp>(x, cstOne); 374 e = builder.create<arith::SubFOp>( 375 e, builder.create<SelectOp>(mask, cstOne, cstZero)); 376 x = builder.create<arith::AddFOp>(x, tmp); 377 378 Value x2 = builder.create<arith::MulFOp>(x, x); 379 Value x3 = builder.create<arith::MulFOp>(x2, x); 380 381 // Evaluate the polynomial approximant of degree 8 in three parts. 382 Value y0, y1, y2; 383 y0 = builder.create<math::FmaOp>(cstCephesLogP0, x, cstCephesLogP1); 384 y1 = builder.create<math::FmaOp>(cstCephesLogP3, x, cstCephesLogP4); 385 y2 = builder.create<math::FmaOp>(cstCephesLogP6, x, cstCephesLogP7); 386 y0 = builder.create<math::FmaOp>(y0, x, cstCephesLogP2); 387 y1 = builder.create<math::FmaOp>(y1, x, cstCephesLogP5); 388 y2 = builder.create<math::FmaOp>(y2, x, cstCephesLogP8); 389 y0 = builder.create<math::FmaOp>(y0, x3, y1); 390 y0 = builder.create<math::FmaOp>(y0, x3, y2); 391 y0 = builder.create<arith::MulFOp>(y0, x3); 392 393 y0 = builder.create<math::FmaOp>(cstNegHalf, x2, y0); 394 x = builder.create<arith::AddFOp>(x, y0); 395 396 if (base2) { 397 Value cstLog2e = bcast(f32Cst(builder, static_cast<float>(LOG2E_VALUE))); 398 x = builder.create<math::FmaOp>(x, cstLog2e, e); 399 } else { 400 Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE))); 401 x = builder.create<math::FmaOp>(e, cstLn2, x); 402 } 403 404 Value invalidMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, 405 op.operand(), cstZero); 406 Value zeroMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, 407 op.operand(), cstZero); 408 Value posInfMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, 409 op.operand(), cstPosInf); 410 411 // Filter out invalid values: 412 // • x == 0 -> -INF 413 // • x < 0 -> NAN 414 // • x == +INF -> +INF 415 Value aproximation = builder.create<SelectOp>( 416 zeroMask, cstMinusInf, 417 builder.create<SelectOp>( 418 invalidMask, cstNan, 419 builder.create<SelectOp>(posInfMask, cstPosInf, x))); 420 421 rewriter.replaceOp(op, aproximation); 422 423 return success(); 424 } 425 426 namespace { 427 struct LogApproximation : public LogApproximationBase<math::LogOp> { 428 using LogApproximationBase::LogApproximationBase; 429 430 LogicalResult matchAndRewrite(math::LogOp op, 431 PatternRewriter &rewriter) const final { 432 return logMatchAndRewrite(op, rewriter, /*base2=*/false); 433 } 434 }; 435 } // namespace 436 437 namespace { 438 struct Log2Approximation : public LogApproximationBase<math::Log2Op> { 439 using LogApproximationBase::LogApproximationBase; 440 441 LogicalResult matchAndRewrite(math::Log2Op op, 442 PatternRewriter &rewriter) const final { 443 return logMatchAndRewrite(op, rewriter, /*base2=*/true); 444 } 445 }; 446 } // namespace 447 448 //----------------------------------------------------------------------------// 449 // Log1p approximation. 450 //----------------------------------------------------------------------------// 451 452 namespace { 453 struct Log1pApproximation : public OpRewritePattern<math::Log1pOp> { 454 public: 455 using OpRewritePattern::OpRewritePattern; 456 457 LogicalResult matchAndRewrite(math::Log1pOp op, 458 PatternRewriter &rewriter) const final; 459 }; 460 } // namespace 461 462 // Approximate log(1+x). 463 LogicalResult 464 Log1pApproximation::matchAndRewrite(math::Log1pOp op, 465 PatternRewriter &rewriter) const { 466 auto shape = vectorShape(op.operand().getType(), isF32); 467 if (!shape.hasValue()) 468 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 469 470 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 471 auto bcast = [&](Value value) -> Value { 472 return broadcast(builder, value, *shape); 473 }; 474 475 // Approximate log(1+x) using the following, due to W. Kahan: 476 // u = x + 1.0; 477 // if (u == 1.0 || u == inf) return x; 478 // return x * log(u) / (u - 1.0); 479 // ^^^^^^^^^^^^^^^^^^^^^^ 480 // "logLarge" below. 481 Value cstOne = bcast(f32Cst(builder, 1.0f)); 482 Value x = op.operand(); 483 Value u = builder.create<arith::AddFOp>(x, cstOne); 484 Value uSmall = 485 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne); 486 Value logU = builder.create<math::LogOp>(u); 487 Value uInf = 488 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, logU); 489 Value logLarge = builder.create<arith::MulFOp>( 490 x, builder.create<arith::DivFOp>( 491 logU, builder.create<arith::SubFOp>(u, cstOne))); 492 Value approximation = builder.create<SelectOp>( 493 builder.create<arith::OrIOp>(uSmall, uInf), x, logLarge); 494 rewriter.replaceOp(op, approximation); 495 return success(); 496 } 497 498 //----------------------------------------------------------------------------// 499 // Erf approximation. 500 //----------------------------------------------------------------------------// 501 502 // Approximates erf(x) with 503 // a - P(x)/Q(x) 504 // where P and Q are polynomials of degree 4. 505 // Different coefficients are chosen based on the value of x. 506 // The approximation error is ~2.5e-07. 507 // Boost's minimax tool that utilizes the Remez method was used to find the 508 // coefficients. 509 LogicalResult 510 ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op, 511 PatternRewriter &rewriter) const { 512 auto shape = vectorShape(op.operand().getType(), isF32); 513 if (!shape.hasValue()) 514 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 515 516 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 517 auto bcast = [&](Value value) -> Value { 518 return broadcast(builder, value, *shape); 519 }; 520 521 const int intervalsCount = 3; 522 const int polyDegree = 4; 523 524 Value zero = bcast(f32Cst(builder, 0)); 525 Value one = bcast(f32Cst(builder, 1)); 526 Value pp[intervalsCount][polyDegree + 1]; 527 pp[0][0] = bcast(f32Cst(builder, +0.00000000000000000e+00f)); 528 pp[0][1] = bcast(f32Cst(builder, +1.12837916222975858e+00f)); 529 pp[0][2] = bcast(f32Cst(builder, -5.23018562988006470e-01f)); 530 pp[0][3] = bcast(f32Cst(builder, +2.09741709609267072e-01f)); 531 pp[0][4] = bcast(f32Cst(builder, +2.58146801602987875e-02f)); 532 pp[1][0] = bcast(f32Cst(builder, +0.00000000000000000e+00f)); 533 pp[1][1] = bcast(f32Cst(builder, +1.12750687816789140e+00f)); 534 pp[1][2] = bcast(f32Cst(builder, -3.64721408487825775e-01f)); 535 pp[1][3] = bcast(f32Cst(builder, +1.18407396425136952e-01f)); 536 pp[1][4] = bcast(f32Cst(builder, +3.70645533056476558e-02f)); 537 pp[2][0] = bcast(f32Cst(builder, -3.30093071049483172e-03f)); 538 pp[2][1] = bcast(f32Cst(builder, +3.51961938357697011e-03f)); 539 pp[2][2] = bcast(f32Cst(builder, -1.41373622814988039e-03f)); 540 pp[2][3] = bcast(f32Cst(builder, +2.53447094961941348e-04f)); 541 pp[2][4] = bcast(f32Cst(builder, -1.71048029455037401e-05f)); 542 543 Value qq[intervalsCount][polyDegree + 1]; 544 qq[0][0] = bcast(f32Cst(builder, +1.000000000000000000e+00f)); 545 qq[0][1] = bcast(f32Cst(builder, -4.635138185962547255e-01f)); 546 qq[0][2] = bcast(f32Cst(builder, +5.192301327279782447e-01f)); 547 qq[0][3] = bcast(f32Cst(builder, -1.318089722204810087e-01f)); 548 qq[0][4] = bcast(f32Cst(builder, +7.397964654672315005e-02f)); 549 qq[1][0] = bcast(f32Cst(builder, +1.00000000000000000e+00f)); 550 qq[1][1] = bcast(f32Cst(builder, -3.27607011824493086e-01f)); 551 qq[1][2] = bcast(f32Cst(builder, +4.48369090658821977e-01f)); 552 qq[1][3] = bcast(f32Cst(builder, -8.83462621207857930e-02f)); 553 qq[1][4] = bcast(f32Cst(builder, +5.72442770283176093e-02f)); 554 qq[2][0] = bcast(f32Cst(builder, +1.00000000000000000e+00f)); 555 qq[2][1] = bcast(f32Cst(builder, -2.06069165953913769e+00f)); 556 qq[2][2] = bcast(f32Cst(builder, +1.62705939945477759e+00f)); 557 qq[2][3] = bcast(f32Cst(builder, -5.83389859211130017e-01f)); 558 qq[2][4] = bcast(f32Cst(builder, +8.21908939856640930e-02f)); 559 560 Value offsets[intervalsCount]; 561 offsets[0] = bcast(f32Cst(builder, 0.0f)); 562 offsets[1] = bcast(f32Cst(builder, 0.0f)); 563 offsets[2] = bcast(f32Cst(builder, 1.0f)); 564 565 Value bounds[intervalsCount]; 566 bounds[0] = bcast(f32Cst(builder, 0.8f)); 567 bounds[1] = bcast(f32Cst(builder, 2.0f)); 568 bounds[2] = bcast(f32Cst(builder, 3.75f)); 569 570 Value isNegativeArg = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, 571 op.operand(), zero); 572 Value negArg = builder.create<arith::NegFOp>(op.operand()); 573 Value x = builder.create<SelectOp>(isNegativeArg, negArg, op.operand()); 574 575 Value offset = offsets[0]; 576 Value p[polyDegree + 1]; 577 Value q[polyDegree + 1]; 578 for (int i = 0; i <= polyDegree; ++i) { 579 p[i] = pp[0][i]; 580 q[i] = qq[0][i]; 581 } 582 583 // TODO: maybe use vector stacking to reduce the number of selects. 584 Value isLessThanBound[intervalsCount]; 585 for (int j = 0; j < intervalsCount - 1; ++j) { 586 isLessThanBound[j] = 587 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, bounds[j]); 588 for (int i = 0; i <= polyDegree; ++i) { 589 p[i] = builder.create<SelectOp>(isLessThanBound[j], p[i], pp[j + 1][i]); 590 q[i] = builder.create<SelectOp>(isLessThanBound[j], q[i], qq[j + 1][i]); 591 } 592 offset = 593 builder.create<SelectOp>(isLessThanBound[j], offset, offsets[j + 1]); 594 } 595 isLessThanBound[intervalsCount - 1] = builder.create<arith::CmpFOp>( 596 arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]); 597 598 Value pPoly = makePolynomialCalculation(builder, p, x); 599 Value qPoly = makePolynomialCalculation(builder, q, x); 600 Value rationalPoly = builder.create<arith::DivFOp>(pPoly, qPoly); 601 Value formula = builder.create<arith::AddFOp>(offset, rationalPoly); 602 formula = builder.create<SelectOp>(isLessThanBound[intervalsCount - 1], 603 formula, one); 604 605 // erf is odd function: erf(x) = -erf(-x). 606 Value negFormula = builder.create<arith::NegFOp>(formula); 607 Value res = builder.create<SelectOp>(isNegativeArg, negFormula, formula); 608 609 rewriter.replaceOp(op, res); 610 611 return success(); 612 } 613 614 //----------------------------------------------------------------------------// 615 // Exp approximation. 616 //----------------------------------------------------------------------------// 617 618 namespace { 619 620 struct ExpApproximation : public OpRewritePattern<math::ExpOp> { 621 public: 622 using OpRewritePattern::OpRewritePattern; 623 624 LogicalResult matchAndRewrite(math::ExpOp op, 625 PatternRewriter &rewriter) const final; 626 }; 627 } // namespace 628 629 // Approximate exp(x) using its reduced range exp(y) where y is in the range 630 // [0, ln(2)], let y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2), exp(x) 631 // = exp(y) * 2^k. exp(y). 632 LogicalResult 633 ExpApproximation::matchAndRewrite(math::ExpOp op, 634 PatternRewriter &rewriter) const { 635 auto shape = vectorShape(op.operand().getType(), isF32); 636 if (!shape.hasValue()) 637 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 638 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 639 640 // TODO: Consider a common pattern rewriter with all methods below to 641 // write the approximations. 642 auto bcast = [&](Value value) -> Value { 643 return broadcast(builder, value, *shape); 644 }; 645 auto fmla = [&](Value a, Value b, Value c) { 646 return builder.create<math::FmaOp>(a, b, c); 647 }; 648 auto mul = [&](Value a, Value b) -> Value { 649 return builder.create<arith::MulFOp>(a, b); 650 }; 651 auto sub = [&](Value a, Value b) -> Value { 652 return builder.create<arith::SubFOp>(a, b); 653 }; 654 auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); }; 655 656 Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE))); 657 Value cstLog2E = bcast(f32Cst(builder, static_cast<float>(LOG2E_VALUE))); 658 659 // Polynomial coefficients. 660 Value cstCephesExpP0 = bcast(f32Cst(builder, 1.0)); 661 Value cstCephesExpP1 = bcast(f32Cst(builder, 1.0)); 662 Value cstCephesExpP2 = bcast(f32Cst(builder, 0.49970514590562437052f)); 663 Value cstCephesExpP3 = bcast(f32Cst(builder, 0.16873890085469545053f)); 664 Value cstCephesExpP4 = bcast(f32Cst(builder, 0.03668965196652099192f)); 665 Value cstCephesExpP5 = bcast(f32Cst(builder, 0.01314350012789660196f)); 666 667 Value x = op.operand(); 668 669 // Reduced y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2) 670 Value xL2Inv = mul(x, cstLog2E); 671 Value kF32 = floor(xL2Inv); 672 Value kLn2 = mul(kF32, cstLn2); 673 Value y = sub(x, kLn2); 674 675 // Use Estrin's evaluation scheme with 3 independent parts: 676 // P(y)^y : (c0 + c1 y) + (c2 + c3 y) y^2 + (c4 + c5 y) y^4 677 Value y2 = mul(y, y); 678 Value y4 = mul(y2, y2); 679 680 Value q0 = fmla(cstCephesExpP1, y, cstCephesExpP0); 681 Value q1 = fmla(cstCephesExpP3, y, cstCephesExpP2); 682 Value q2 = fmla(cstCephesExpP5, y, cstCephesExpP4); 683 Value expY = fmla(q1, y2, q0); 684 expY = fmla(q2, y4, expY); 685 686 auto i32Vec = broadcast(builder.getI32Type(), *shape); 687 688 // exp2(k) 689 Value k = builder.create<arith::FPToSIOp>(kF32, i32Vec); 690 Value exp2KValue = exp2I32(builder, k); 691 692 // exp(x) = exp(y) * exp2(k) 693 expY = mul(expY, exp2KValue); 694 695 // Handle overflow, inf and underflow of exp(x). exp(x) range is [0, inf], its 696 // partitioned as the following: 697 // exp(x) = 0, x <= -inf 698 // exp(x) = underflow (min_float), x <= -88 699 // exp(x) = inf (min_float), x >= 88 700 // Note: |k| = 127 is the value where the 8-bits exponent saturates. 701 Value zerof32Const = bcast(f32Cst(builder, 0)); 702 auto constPosInfinity = 703 bcast(f32Cst(builder, std::numeric_limits<float>::infinity())); 704 auto constNegIfinity = 705 bcast(f32Cst(builder, -std::numeric_limits<float>::infinity())); 706 auto underflow = bcast(f32Cst(builder, std::numeric_limits<float>::min())); 707 708 Value kMaxConst = bcast(i32Cst(builder, 127)); 709 Value kMaxNegConst = bcast(i32Cst(builder, -127)); 710 Value rightBound = 711 builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, k, kMaxConst); 712 Value leftBound = 713 builder.create<arith::CmpIOp>(arith::CmpIPredicate::sge, k, kMaxNegConst); 714 715 Value isNegInfinityX = builder.create<arith::CmpFOp>( 716 arith::CmpFPredicate::OEQ, x, constNegIfinity); 717 Value isPosInfinityX = builder.create<arith::CmpFOp>( 718 arith::CmpFPredicate::OEQ, x, constPosInfinity); 719 Value isPostiveX = 720 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zerof32Const); 721 Value isComputable = builder.create<arith::AndIOp>(rightBound, leftBound); 722 723 expY = builder.create<SelectOp>( 724 isNegInfinityX, zerof32Const, 725 builder.create<SelectOp>( 726 isPosInfinityX, constPosInfinity, 727 builder.create<SelectOp>(isComputable, expY, 728 builder.create<SelectOp>(isPostiveX, 729 constPosInfinity, 730 underflow)))); 731 732 rewriter.replaceOp(op, expY); 733 734 return success(); 735 } 736 737 //----------------------------------------------------------------------------// 738 // ExpM1 approximation. 739 //----------------------------------------------------------------------------// 740 741 namespace { 742 743 struct ExpM1Approximation : public OpRewritePattern<math::ExpM1Op> { 744 public: 745 using OpRewritePattern::OpRewritePattern; 746 747 LogicalResult matchAndRewrite(math::ExpM1Op op, 748 PatternRewriter &rewriter) const final; 749 }; 750 } // namespace 751 752 LogicalResult 753 ExpM1Approximation::matchAndRewrite(math::ExpM1Op op, 754 PatternRewriter &rewriter) const { 755 auto shape = vectorShape(op.operand().getType(), isF32); 756 if (!shape.hasValue()) 757 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 758 759 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 760 auto bcast = [&](Value value) -> Value { 761 return broadcast(builder, value, *shape); 762 }; 763 764 // expm1(x) = exp(x) - 1 = u - 1. 765 // We have to handle it carefully when x is near 0, i.e. u ~= 1, 766 // and when the input is ~= -inf, i.e. u - 1 ~= -1. 767 Value cstOne = bcast(f32Cst(builder, 1.0f)); 768 Value cstNegOne = bcast(f32Cst(builder, -1.0f)); 769 Value x = op.operand(); 770 Value u = builder.create<math::ExpOp>(x); 771 Value uEqOne = 772 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne); 773 Value uMinusOne = builder.create<arith::SubFOp>(u, cstOne); 774 Value uMinusOneEqNegOne = builder.create<arith::CmpFOp>( 775 arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne); 776 // logU = log(u) ~= x 777 Value logU = builder.create<math::LogOp>(u); 778 779 // Detect exp(x) = +inf; written this way to avoid having to form +inf. 780 Value isInf = 781 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, logU, u); 782 783 // (u - 1) * (x / ~x) 784 Value expm1 = builder.create<arith::MulFOp>( 785 uMinusOne, builder.create<arith::DivFOp>(x, logU)); 786 expm1 = builder.create<SelectOp>(isInf, u, expm1); 787 Value approximation = builder.create<SelectOp>( 788 uEqOne, x, builder.create<SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1)); 789 rewriter.replaceOp(op, approximation); 790 return success(); 791 } 792 793 //----------------------------------------------------------------------------// 794 // Sin and Cos approximation. 795 //----------------------------------------------------------------------------// 796 797 namespace { 798 799 template <bool isSine, typename OpTy> 800 struct SinAndCosApproximation : public OpRewritePattern<OpTy> { 801 public: 802 using OpRewritePattern<OpTy>::OpRewritePattern; 803 804 LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final; 805 }; 806 } // namespace 807 808 #define TWO_OVER_PI \ 809 0.6366197723675813430755350534900574481378385829618257949906693762L 810 #define PI_OVER_2 \ 811 1.5707963267948966192313216916397514420985846996875529104874722961L 812 813 // Approximates sin(x) or cos(x) by finding the best approximation polynomial in 814 // the reduced range [0, pi/2] for both sin(x) and cos(x). Then given y in the 815 // reduced range sin(x) will be computed as sin(y), -sin(y), cos(y) or -cos(y). 816 template <bool isSine, typename OpTy> 817 LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite( 818 OpTy op, PatternRewriter &rewriter) const { 819 static_assert( 820 llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value, 821 "SinAndCosApproximation pattern expects math::SinOp or math::CosOp"); 822 auto shape = vectorShape(op.operand().getType(), isF32); 823 if (!shape.hasValue()) 824 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 825 826 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 827 auto bcast = [&](Value value) -> Value { 828 return broadcast(builder, value, *shape); 829 }; 830 auto mul = [&](Value a, Value b) -> Value { 831 return builder.create<arith::MulFOp>(a, b); 832 }; 833 auto sub = [&](Value a, Value b) -> Value { 834 return builder.create<arith::SubFOp>(a, b); 835 }; 836 auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); }; 837 838 auto i32Vec = broadcast(builder.getI32Type(), *shape); 839 auto fPToSingedInteger = [&](Value a) -> Value { 840 return builder.create<arith::FPToSIOp>(a, i32Vec); 841 }; 842 843 auto modulo4 = [&](Value a) -> Value { 844 return builder.create<arith::AndIOp>(a, bcast(i32Cst(builder, 3))); 845 }; 846 847 auto isEqualTo = [&](Value a, Value b) -> Value { 848 return builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, a, b); 849 }; 850 851 auto isGreaterThan = [&](Value a, Value b) -> Value { 852 return builder.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, a, b); 853 }; 854 855 auto select = [&](Value cond, Value t, Value f) -> Value { 856 return builder.create<SelectOp>(cond, t, f); 857 }; 858 859 auto fmla = [&](Value a, Value b, Value c) { 860 return builder.create<math::FmaOp>(a, b, c); 861 }; 862 863 auto bitwiseOr = [&](Value a, Value b) { 864 return builder.create<arith::OrIOp>(a, b); 865 }; 866 867 Value twoOverPi = bcast(f32Cst(builder, TWO_OVER_PI)); 868 Value piOverTwo = bcast(f32Cst(builder, PI_OVER_2)); 869 870 Value x = op.operand(); 871 872 Value k = floor(mul(x, twoOverPi)); 873 874 Value y = sub(x, mul(k, piOverTwo)); 875 876 Value cstOne = bcast(f32Cst(builder, 1.0)); 877 Value cstNegativeOne = bcast(f32Cst(builder, -1.0)); 878 879 Value cstSC2 = bcast(f32Cst(builder, -0.16666667163372039794921875f)); 880 Value cstSC4 = bcast(f32Cst(builder, 8.333347737789154052734375e-3f)); 881 Value cstSC6 = bcast(f32Cst(builder, -1.9842604524455964565277099609375e-4f)); 882 Value cstSC8 = 883 bcast(f32Cst(builder, 2.760012648650445044040679931640625e-6f)); 884 Value cstSC10 = 885 bcast(f32Cst(builder, -2.50293279435709337121807038784027099609375e-8f)); 886 887 Value cstCC2 = bcast(f32Cst(builder, -0.5f)); 888 Value cstCC4 = bcast(f32Cst(builder, 4.166664183139801025390625e-2f)); 889 Value cstCC6 = bcast(f32Cst(builder, -1.388833043165504932403564453125e-3f)); 890 Value cstCC8 = bcast(f32Cst(builder, 2.47562347794882953166961669921875e-5f)); 891 Value cstCC10 = 892 bcast(f32Cst(builder, -2.59630184018533327616751194000244140625e-7f)); 893 894 Value kMod4 = modulo4(fPToSingedInteger(k)); 895 896 Value kR0 = isEqualTo(kMod4, bcast(i32Cst(builder, 0))); 897 Value kR1 = isEqualTo(kMod4, bcast(i32Cst(builder, 1))); 898 Value kR2 = isEqualTo(kMod4, bcast(i32Cst(builder, 2))); 899 Value kR3 = isEqualTo(kMod4, bcast(i32Cst(builder, 3))); 900 901 Value sinuseCos = isSine ? bitwiseOr(kR1, kR3) : bitwiseOr(kR0, kR2); 902 Value negativeRange = isSine ? isGreaterThan(kMod4, bcast(i32Cst(builder, 1))) 903 : bitwiseOr(kR1, kR2); 904 905 Value y2 = mul(y, y); 906 907 Value base = select(sinuseCos, cstOne, y); 908 Value cstC2 = select(sinuseCos, cstCC2, cstSC2); 909 Value cstC4 = select(sinuseCos, cstCC4, cstSC4); 910 Value cstC6 = select(sinuseCos, cstCC6, cstSC6); 911 Value cstC8 = select(sinuseCos, cstCC8, cstSC8); 912 Value cstC10 = select(sinuseCos, cstCC10, cstSC10); 913 914 Value v1 = fmla(y2, cstC10, cstC8); 915 Value v2 = fmla(y2, v1, cstC6); 916 Value v3 = fmla(y2, v2, cstC4); 917 Value v4 = fmla(y2, v3, cstC2); 918 Value v5 = fmla(y2, v4, cstOne); 919 Value v6 = mul(base, v5); 920 921 Value approximation = select(negativeRange, mul(cstNegativeOne, v6), v6); 922 923 rewriter.replaceOp(op, approximation); 924 925 return success(); 926 } 927 928 //----------------------------------------------------------------------------// 929 // Rsqrt approximation. 930 //----------------------------------------------------------------------------// 931 932 namespace { 933 struct RsqrtApproximation : public OpRewritePattern<math::RsqrtOp> { 934 using OpRewritePattern::OpRewritePattern; 935 936 LogicalResult matchAndRewrite(math::RsqrtOp op, 937 PatternRewriter &rewriter) const final; 938 }; 939 } // namespace 940 941 LogicalResult 942 RsqrtApproximation::matchAndRewrite(math::RsqrtOp op, 943 PatternRewriter &rewriter) const { 944 auto shape = vectorShape(op.operand().getType(), isF32); 945 // Only support already-vectorized rsqrt's. 946 if (!shape.hasValue() || (*shape)[0] != 8) 947 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 948 949 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 950 auto bcast = [&](Value value) -> Value { 951 return broadcast(builder, value, *shape); 952 }; 953 954 Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u)); 955 Value cstOnePointFive = bcast(f32Cst(builder, 1.5f)); 956 Value cstNegHalf = bcast(f32Cst(builder, -0.5f)); 957 Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u)); 958 959 Value negHalf = builder.create<arith::MulFOp>(op.operand(), cstNegHalf); 960 961 // Select only the inverse sqrt of positive normals (denormals are 962 // flushed to zero). 963 Value ltMinMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, 964 op.operand(), cstMinNormPos); 965 Value infMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, 966 op.operand(), cstPosInf); 967 Value notNormalFiniteMask = builder.create<arith::OrIOp>(ltMinMask, infMask); 968 969 // Compute an approximate result. 970 Value yApprox = builder.create<x86vector::RsqrtOp>(op.operand()); 971 972 // Do a single step of Newton-Raphson iteration to improve the approximation. 973 // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n). 974 // It is essential to evaluate the inner term like this because forming 975 // y_n^2 may over- or underflow. 976 Value inner = builder.create<arith::MulFOp>(negHalf, yApprox); 977 Value fma = builder.create<math::FmaOp>(yApprox, inner, cstOnePointFive); 978 Value yNewton = builder.create<arith::MulFOp>(yApprox, fma); 979 980 // Select the result of the Newton-Raphson step for positive normal arguments. 981 // For other arguments, choose the output of the intrinsic. This will 982 // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(x) = +inf if 983 // x is zero or a positive denormalized float (equivalent to flushing positive 984 // denormalized inputs to zero). 985 Value res = builder.create<SelectOp>(notNormalFiniteMask, yApprox, yNewton); 986 rewriter.replaceOp(op, res); 987 988 return success(); 989 } 990 991 //----------------------------------------------------------------------------// 992 993 void mlir::populateMathPolynomialApproximationPatterns( 994 RewritePatternSet &patterns, 995 const MathPolynomialApproximationOptions &options) { 996 patterns.add<TanhApproximation, LogApproximation, Log2Approximation, 997 Log1pApproximation, ErfPolynomialApproximation, ExpApproximation, 998 ExpM1Approximation, SinAndCosApproximation<true, math::SinOp>, 999 SinAndCosApproximation<false, math::CosOp>>( 1000 patterns.getContext()); 1001 if (options.enableAvx2) 1002 patterns.add<RsqrtApproximation>(patterns.getContext()); 1003 } 1004