1 //===- PolynomialApproximation.cpp - Approximate math operations ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements expansion of math operations to fast approximations 10 // that do not rely on any of the library functions. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include <climits> 15 #include <cstddef> 16 17 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 18 #include "mlir/Dialect/Math/IR/Math.h" 19 #include "mlir/Dialect/Math/Transforms/Approximation.h" 20 #include "mlir/Dialect/Math/Transforms/Passes.h" 21 #include "mlir/Dialect/Vector/VectorOps.h" 22 #include "mlir/Dialect/Vector/VectorUtils.h" 23 #include "mlir/Dialect/X86Vector/X86VectorDialect.h" 24 #include "mlir/IR/Builders.h" 25 #include "mlir/IR/ImplicitLocOpBuilder.h" 26 #include "mlir/IR/TypeUtilities.h" 27 #include "mlir/Transforms/DialectConversion.h" 28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 29 #include "llvm/ADT/ArrayRef.h" 30 31 using namespace mlir; 32 using namespace mlir::math; 33 using namespace mlir::vector; 34 35 // Returns vector shape if the type is a vector. Returns an empty shape if it is 36 // not a vector. 37 static ArrayRef<int64_t> vectorShape(Type type) { 38 auto vectorType = type.dyn_cast<VectorType>(); 39 return vectorType ? vectorType.getShape() : ArrayRef<int64_t>(); 40 } 41 42 static ArrayRef<int64_t> vectorShape(Value value) { 43 return vectorShape(value.getType()); 44 } 45 46 //----------------------------------------------------------------------------// 47 // Broadcast scalar types and values into vector types and values. 48 //----------------------------------------------------------------------------// 49 50 // Broadcasts scalar type into vector type (iff shape is non-scalar). 51 static Type broadcast(Type type, ArrayRef<int64_t> shape) { 52 assert(!type.isa<VectorType>() && "must be scalar type"); 53 return !shape.empty() ? VectorType::get(shape, type) : type; 54 } 55 56 // Broadcasts scalar value into vector (iff shape is non-scalar). 57 static Value broadcast(ImplicitLocOpBuilder &builder, Value value, 58 ArrayRef<int64_t> shape) { 59 assert(!value.getType().isa<VectorType>() && "must be scalar value"); 60 auto type = broadcast(value.getType(), shape); 61 return !shape.empty() ? builder.create<BroadcastOp>(type, value) : value; 62 } 63 64 //----------------------------------------------------------------------------// 65 // Helper function to handle n-D vectors with 1-D operations. 66 //----------------------------------------------------------------------------// 67 68 // Expands and unrolls n-D vector operands into multiple fixed size 1-D vectors 69 // and calls the compute function with 1-D vector operands. Stitches back all 70 // results into the original n-D vector result. 71 // 72 // Examples: vectorWidth = 8 73 // - vector<4x8xf32> unrolled 4 times 74 // - vector<16xf32> expanded to vector<2x8xf32> and unrolled 2 times 75 // - vector<4x16xf32> expanded to vector<4x2x8xf32> and unrolled 4*2 times 76 // 77 // Some math approximations rely on ISA-specific operations that only accept 78 // fixed size 1-D vectors (e.g. AVX expects vectors of width 8). 79 // 80 // It is the caller's responsibility to verify that the inner dimension is 81 // divisible by the vectorWidth, and that all operands have the same vector 82 // shape. 83 static Value 84 handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, 85 ValueRange operands, int64_t vectorWidth, 86 std::function<Value(ValueRange)> compute) { 87 assert(!operands.empty() && "operands must be not empty"); 88 assert(vectorWidth > 0 && "vector width must be larger than 0"); 89 90 VectorType inputType = operands[0].getType().cast<VectorType>(); 91 ArrayRef<int64_t> inputShape = inputType.getShape(); 92 93 // If input shape matches target vector width, we can just call the 94 // user-provided compute function with the operands. 95 if (inputShape == llvm::makeArrayRef(vectorWidth)) 96 return compute(operands); 97 98 // Check if the inner dimension has to be expanded, or we can directly iterate 99 // over the outer dimensions of the vector. 100 int64_t innerDim = inputShape.back(); 101 int64_t expansionDim = innerDim / vectorWidth; 102 assert((innerDim % vectorWidth == 0) && "invalid inner dimension size"); 103 104 // Maybe expand operands to the higher rank vector shape that we'll use to 105 // iterate over and extract one dimensional vectors. 106 SmallVector<int64_t> expandedShape(inputShape.begin(), inputShape.end()); 107 SmallVector<Value> expandedOperands(operands); 108 109 if (expansionDim > 1) { 110 // Expand shape from [..., innerDim] to [..., expansionDim, vectorWidth]. 111 expandedShape.insert(expandedShape.end() - 1, expansionDim); 112 expandedShape.back() = vectorWidth; 113 114 for (unsigned i = 0; i < operands.size(); ++i) { 115 auto operand = operands[i]; 116 auto eltType = operand.getType().cast<VectorType>().getElementType(); 117 auto expandedType = VectorType::get(expandedShape, eltType); 118 expandedOperands[i] = 119 builder.create<vector::ShapeCastOp>(expandedType, operand); 120 } 121 } 122 123 // Iterate over all outer dimensions of the compute shape vector type. 124 auto iterationDims = ArrayRef<int64_t>(expandedShape).drop_back(); 125 int64_t maxLinearIndex = computeMaxLinearIndex(iterationDims); 126 127 SmallVector<int64_t> ones(iterationDims.size(), 1); 128 auto strides = computeStrides(iterationDims, ones); 129 130 // Compute results for each one dimensional vector. 131 SmallVector<Value> results(maxLinearIndex); 132 133 for (int64_t i = 0; i < maxLinearIndex; ++i) { 134 auto offsets = delinearize(strides, i); 135 136 SmallVector<Value> extracted(expandedOperands.size()); 137 for (auto tuple : llvm::enumerate(expandedOperands)) 138 extracted[tuple.index()] = 139 builder.create<vector::ExtractOp>(tuple.value(), offsets); 140 141 results[i] = compute(extracted); 142 } 143 144 // Stitch results together into one large vector. 145 Type resultEltType = results[0].getType().cast<VectorType>().getElementType(); 146 Type resultExpandedType = VectorType::get(expandedShape, resultEltType); 147 Value result = builder.create<ConstantOp>( 148 resultExpandedType, builder.getZeroAttr(resultExpandedType)); 149 150 for (int64_t i = 0; i < maxLinearIndex; ++i) 151 result = builder.create<vector::InsertOp>(results[i], result, 152 delinearize(strides, i)); 153 154 // Reshape back to the original vector shape. 155 return builder.create<vector::ShapeCastOp>( 156 VectorType::get(inputShape, resultEltType), result); 157 } 158 159 //----------------------------------------------------------------------------// 160 // Helper functions to create constants. 161 //----------------------------------------------------------------------------// 162 163 static Value f32Cst(ImplicitLocOpBuilder &builder, float value) { 164 return builder.create<arith::ConstantOp>(builder.getF32FloatAttr(value)); 165 } 166 167 static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value) { 168 return builder.create<arith::ConstantOp>(builder.getI32IntegerAttr(value)); 169 } 170 171 static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits) { 172 Value i32Value = i32Cst(builder, static_cast<int32_t>(bits)); 173 return builder.create<arith::BitcastOp>(builder.getF32Type(), i32Value); 174 } 175 176 //----------------------------------------------------------------------------// 177 // Helper functions to build math functions approximations. 178 //----------------------------------------------------------------------------// 179 180 static Value min(ImplicitLocOpBuilder &builder, Value a, Value b) { 181 return builder.create<SelectOp>( 182 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, a, b), a, b); 183 } 184 185 static Value max(ImplicitLocOpBuilder &builder, Value a, Value b) { 186 return builder.create<SelectOp>( 187 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, a, b), a, b); 188 } 189 190 static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, 191 Value upperBound) { 192 return max(builder, min(builder, value, upperBound), lowerBound); 193 } 194 195 // Decomposes given floating point value `arg` into a normalized fraction and 196 // an integral power of two (see std::frexp). Returned values have float type. 197 static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg, 198 bool is_positive = false) { 199 assert(getElementTypeOrSelf(arg).isF32() && "arg must be f32 type"); 200 ArrayRef<int64_t> shape = vectorShape(arg); 201 202 auto bcast = [&](Value value) -> Value { 203 return broadcast(builder, value, shape); 204 }; 205 206 auto i32 = builder.getIntegerType(32); 207 auto i32Vec = broadcast(i32, shape); 208 auto f32Vec = broadcast(builder.getF32Type(), shape); 209 210 Value cst126f = f32Cst(builder, 126.0f); 211 Value cstHalf = f32Cst(builder, 0.5f); 212 Value cstInvMantMask = f32FromBits(builder, ~0x7f800000u); 213 214 // Bitcast to i32 for bitwise operations. 215 Value i32Half = builder.create<arith::BitcastOp>(i32, cstHalf); 216 Value i32InvMantMask = builder.create<arith::BitcastOp>(i32, cstInvMantMask); 217 Value i32Arg = builder.create<arith::BitcastOp>(i32Vec, arg); 218 219 // Compute normalized fraction. 220 Value tmp0 = builder.create<arith::AndIOp>(i32Arg, bcast(i32InvMantMask)); 221 Value tmp1 = builder.create<arith::OrIOp>(tmp0, bcast(i32Half)); 222 Value normalizedFraction = builder.create<arith::BitcastOp>(f32Vec, tmp1); 223 224 // Compute exponent. 225 Value arg0 = is_positive ? arg : builder.create<math::AbsOp>(arg); 226 Value biasedExponentBits = builder.create<arith::ShRUIOp>( 227 builder.create<arith::BitcastOp>(i32Vec, arg0), 228 bcast(i32Cst(builder, 23))); 229 Value biasedExponent = 230 builder.create<arith::SIToFPOp>(f32Vec, biasedExponentBits); 231 Value exponent = 232 builder.create<arith::SubFOp>(biasedExponent, bcast(cst126f)); 233 234 return {normalizedFraction, exponent}; 235 } 236 237 // Computes exp2 for an i32 argument. 238 static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) { 239 assert(getElementTypeOrSelf(arg).isInteger(32) && "arg must be i32 type"); 240 ArrayRef<int64_t> shape = vectorShape(arg); 241 242 auto bcast = [&](Value value) -> Value { 243 return broadcast(builder, value, shape); 244 }; 245 246 auto f32Vec = broadcast(builder.getF32Type(), shape); 247 // The exponent of f32 located at 23-bit. 248 auto exponetBitLocation = bcast(i32Cst(builder, 23)); 249 // Set the exponent bias to zero. 250 auto bias = bcast(i32Cst(builder, 127)); 251 252 Value biasedArg = builder.create<arith::AddIOp>(arg, bias); 253 Value exp2ValueInt = 254 builder.create<arith::ShLIOp>(biasedArg, exponetBitLocation); 255 Value exp2ValueF32 = builder.create<arith::BitcastOp>(f32Vec, exp2ValueInt); 256 257 return exp2ValueF32; 258 } 259 260 namespace { 261 Value makePolynomialCalculation(ImplicitLocOpBuilder &builder, 262 llvm::ArrayRef<Value> coeffs, Value x) { 263 assert(getElementTypeOrSelf(x).isF32() && "x must be f32 type"); 264 ArrayRef<int64_t> shape = vectorShape(x); 265 266 if (coeffs.empty()) 267 return broadcast(builder, f32Cst(builder, 0.0f), shape); 268 269 if (coeffs.size() == 1) 270 return coeffs[0]; 271 272 Value res = builder.create<math::FmaOp>(x, coeffs[coeffs.size() - 1], 273 coeffs[coeffs.size() - 2]); 274 for (auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) { 275 res = builder.create<math::FmaOp>(x, res, coeffs[i]); 276 } 277 return res; 278 } 279 } // namespace 280 281 //----------------------------------------------------------------------------// 282 // TanhOp approximation. 283 //----------------------------------------------------------------------------// 284 285 namespace { 286 struct TanhApproximation : public OpRewritePattern<math::TanhOp> { 287 public: 288 using OpRewritePattern::OpRewritePattern; 289 290 LogicalResult matchAndRewrite(math::TanhOp op, 291 PatternRewriter &rewriter) const final; 292 }; 293 } // namespace 294 295 LogicalResult 296 TanhApproximation::matchAndRewrite(math::TanhOp op, 297 PatternRewriter &rewriter) const { 298 if (!getElementTypeOrSelf(op.getOperand()).isF32()) 299 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 300 301 ArrayRef<int64_t> shape = vectorShape(op.getOperand()); 302 303 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 304 auto bcast = [&](Value value) -> Value { 305 return broadcast(builder, value, shape); 306 }; 307 308 // Clamp operand into [plusClamp, minusClamp] range. 309 Value minusClamp = bcast(f32Cst(builder, -7.99881172180175781f)); 310 Value plusClamp = bcast(f32Cst(builder, 7.99881172180175781f)); 311 Value x = clamp(builder, op.getOperand(), minusClamp, plusClamp); 312 313 // Mask for tiny values that are approximated with `operand`. 314 Value tiny = bcast(f32Cst(builder, 0.0004f)); 315 Value tinyMask = builder.create<arith::CmpFOp>( 316 arith::CmpFPredicate::OLT, builder.create<math::AbsOp>(op.getOperand()), 317 tiny); 318 319 // The monomial coefficients of the numerator polynomial (odd). 320 Value alpha1 = bcast(f32Cst(builder, 4.89352455891786e-03f)); 321 Value alpha3 = bcast(f32Cst(builder, 6.37261928875436e-04f)); 322 Value alpha5 = bcast(f32Cst(builder, 1.48572235717979e-05f)); 323 Value alpha7 = bcast(f32Cst(builder, 5.12229709037114e-08f)); 324 Value alpha9 = bcast(f32Cst(builder, -8.60467152213735e-11f)); 325 Value alpha11 = bcast(f32Cst(builder, 2.00018790482477e-13f)); 326 Value alpha13 = bcast(f32Cst(builder, -2.76076847742355e-16f)); 327 328 // The monomial coefficients of the denominator polynomial (even). 329 Value beta0 = bcast(f32Cst(builder, 4.89352518554385e-03f)); 330 Value beta2 = bcast(f32Cst(builder, 2.26843463243900e-03f)); 331 Value beta4 = bcast(f32Cst(builder, 1.18534705686654e-04f)); 332 Value beta6 = bcast(f32Cst(builder, 1.19825839466702e-06f)); 333 334 // Since the polynomials are odd/even, we need x^2. 335 Value x2 = builder.create<arith::MulFOp>(x, x); 336 337 // Evaluate the numerator polynomial p. 338 Value p = builder.create<math::FmaOp>(x2, alpha13, alpha11); 339 p = builder.create<math::FmaOp>(x2, p, alpha9); 340 p = builder.create<math::FmaOp>(x2, p, alpha7); 341 p = builder.create<math::FmaOp>(x2, p, alpha5); 342 p = builder.create<math::FmaOp>(x2, p, alpha3); 343 p = builder.create<math::FmaOp>(x2, p, alpha1); 344 p = builder.create<arith::MulFOp>(x, p); 345 346 // Evaluate the denominator polynomial q. 347 Value q = builder.create<math::FmaOp>(x2, beta6, beta4); 348 q = builder.create<math::FmaOp>(x2, q, beta2); 349 q = builder.create<math::FmaOp>(x2, q, beta0); 350 351 // Divide the numerator by the denominator. 352 Value res = builder.create<SelectOp>(tinyMask, x, 353 builder.create<arith::DivFOp>(p, q)); 354 355 rewriter.replaceOp(op, res); 356 357 return success(); 358 } 359 360 #define LN2_VALUE \ 361 0.693147180559945309417232121458176568075500134360255254120680009493393621L 362 #define LOG2E_VALUE \ 363 1.442695040888963407359924681001892137426645954152985934135449406931109219L 364 365 //----------------------------------------------------------------------------// 366 // LogOp and Log2Op approximation. 367 //----------------------------------------------------------------------------// 368 369 namespace { 370 template <typename Op> 371 struct LogApproximationBase : public OpRewritePattern<Op> { 372 using OpRewritePattern<Op>::OpRewritePattern; 373 374 /// Base 2 if 'base2' is set; natural logarithm (base e) otherwise. 375 LogicalResult logMatchAndRewrite(Op op, PatternRewriter &rewriter, 376 bool base2) const; 377 }; 378 } // namespace 379 380 // This approximation comes from Julien Pommier's SSE math library. 381 // Link: http://gruntthepeon.free.fr/ssemath 382 template <typename Op> 383 LogicalResult 384 LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter, 385 bool base2) const { 386 if (!getElementTypeOrSelf(op.getOperand()).isF32()) 387 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 388 389 ArrayRef<int64_t> shape = vectorShape(op.getOperand()); 390 391 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 392 auto bcast = [&](Value value) -> Value { 393 return broadcast(builder, value, shape); 394 }; 395 396 Value cstZero = bcast(f32Cst(builder, 0.0f)); 397 Value cstOne = bcast(f32Cst(builder, 1.0f)); 398 Value cstNegHalf = bcast(f32Cst(builder, -0.5f)); 399 400 // The smallest non denormalized float number. 401 Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u)); 402 Value cstMinusInf = bcast(f32FromBits(builder, 0xff800000u)); 403 Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u)); 404 Value cstNan = bcast(f32FromBits(builder, 0x7fc00000)); 405 406 // Polynomial coefficients. 407 Value cstCephesSQRTHF = bcast(f32Cst(builder, 0.707106781186547524f)); 408 Value cstCephesLogP0 = bcast(f32Cst(builder, 7.0376836292E-2f)); 409 Value cstCephesLogP1 = bcast(f32Cst(builder, -1.1514610310E-1f)); 410 Value cstCephesLogP2 = bcast(f32Cst(builder, 1.1676998740E-1f)); 411 Value cstCephesLogP3 = bcast(f32Cst(builder, -1.2420140846E-1f)); 412 Value cstCephesLogP4 = bcast(f32Cst(builder, +1.4249322787E-1f)); 413 Value cstCephesLogP5 = bcast(f32Cst(builder, -1.6668057665E-1f)); 414 Value cstCephesLogP6 = bcast(f32Cst(builder, +2.0000714765E-1f)); 415 Value cstCephesLogP7 = bcast(f32Cst(builder, -2.4999993993E-1f)); 416 Value cstCephesLogP8 = bcast(f32Cst(builder, +3.3333331174E-1f)); 417 418 Value x = op.getOperand(); 419 420 // Truncate input values to the minimum positive normal. 421 x = max(builder, x, cstMinNormPos); 422 423 // Extract significant in the range [0.5,1) and exponent. 424 std::pair<Value, Value> pair = frexp(builder, x, /*is_positive=*/true); 425 x = pair.first; 426 Value e = pair.second; 427 428 // Shift the inputs from the range [0.5,1) to [sqrt(1/2), sqrt(2)) and shift 429 // by -1.0. The values are then centered around 0, which improves the 430 // stability of the polynomial evaluation: 431 // 432 // if( x < SQRTHF ) { 433 // e -= 1; 434 // x = x + x - 1.0; 435 // } else { x = x - 1.0; } 436 Value mask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, 437 cstCephesSQRTHF); 438 Value tmp = builder.create<SelectOp>(mask, x, cstZero); 439 440 x = builder.create<arith::SubFOp>(x, cstOne); 441 e = builder.create<arith::SubFOp>( 442 e, builder.create<SelectOp>(mask, cstOne, cstZero)); 443 x = builder.create<arith::AddFOp>(x, tmp); 444 445 Value x2 = builder.create<arith::MulFOp>(x, x); 446 Value x3 = builder.create<arith::MulFOp>(x2, x); 447 448 // Evaluate the polynomial approximant of degree 8 in three parts. 449 Value y0, y1, y2; 450 y0 = builder.create<math::FmaOp>(cstCephesLogP0, x, cstCephesLogP1); 451 y1 = builder.create<math::FmaOp>(cstCephesLogP3, x, cstCephesLogP4); 452 y2 = builder.create<math::FmaOp>(cstCephesLogP6, x, cstCephesLogP7); 453 y0 = builder.create<math::FmaOp>(y0, x, cstCephesLogP2); 454 y1 = builder.create<math::FmaOp>(y1, x, cstCephesLogP5); 455 y2 = builder.create<math::FmaOp>(y2, x, cstCephesLogP8); 456 y0 = builder.create<math::FmaOp>(y0, x3, y1); 457 y0 = builder.create<math::FmaOp>(y0, x3, y2); 458 y0 = builder.create<arith::MulFOp>(y0, x3); 459 460 y0 = builder.create<math::FmaOp>(cstNegHalf, x2, y0); 461 x = builder.create<arith::AddFOp>(x, y0); 462 463 if (base2) { 464 Value cstLog2e = bcast(f32Cst(builder, static_cast<float>(LOG2E_VALUE))); 465 x = builder.create<math::FmaOp>(x, cstLog2e, e); 466 } else { 467 Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE))); 468 x = builder.create<math::FmaOp>(e, cstLn2, x); 469 } 470 471 Value invalidMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, 472 op.getOperand(), cstZero); 473 Value zeroMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, 474 op.getOperand(), cstZero); 475 Value posInfMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, 476 op.getOperand(), cstPosInf); 477 478 // Filter out invalid values: 479 // • x == 0 -> -INF 480 // • x < 0 -> NAN 481 // • x == +INF -> +INF 482 Value aproximation = builder.create<SelectOp>( 483 zeroMask, cstMinusInf, 484 builder.create<SelectOp>( 485 invalidMask, cstNan, 486 builder.create<SelectOp>(posInfMask, cstPosInf, x))); 487 488 rewriter.replaceOp(op, aproximation); 489 490 return success(); 491 } 492 493 namespace { 494 struct LogApproximation : public LogApproximationBase<math::LogOp> { 495 using LogApproximationBase::LogApproximationBase; 496 497 LogicalResult matchAndRewrite(math::LogOp op, 498 PatternRewriter &rewriter) const final { 499 return logMatchAndRewrite(op, rewriter, /*base2=*/false); 500 } 501 }; 502 } // namespace 503 504 namespace { 505 struct Log2Approximation : public LogApproximationBase<math::Log2Op> { 506 using LogApproximationBase::LogApproximationBase; 507 508 LogicalResult matchAndRewrite(math::Log2Op op, 509 PatternRewriter &rewriter) const final { 510 return logMatchAndRewrite(op, rewriter, /*base2=*/true); 511 } 512 }; 513 } // namespace 514 515 //----------------------------------------------------------------------------// 516 // Log1p approximation. 517 //----------------------------------------------------------------------------// 518 519 namespace { 520 struct Log1pApproximation : public OpRewritePattern<math::Log1pOp> { 521 public: 522 using OpRewritePattern::OpRewritePattern; 523 524 LogicalResult matchAndRewrite(math::Log1pOp op, 525 PatternRewriter &rewriter) const final; 526 }; 527 } // namespace 528 529 // Approximate log(1+x). 530 LogicalResult 531 Log1pApproximation::matchAndRewrite(math::Log1pOp op, 532 PatternRewriter &rewriter) const { 533 if (!getElementTypeOrSelf(op.getOperand()).isF32()) 534 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 535 536 ArrayRef<int64_t> shape = vectorShape(op.getOperand()); 537 538 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 539 auto bcast = [&](Value value) -> Value { 540 return broadcast(builder, value, shape); 541 }; 542 543 // Approximate log(1+x) using the following, due to W. Kahan: 544 // u = x + 1.0; 545 // if (u == 1.0 || u == inf) return x; 546 // return x * log(u) / (u - 1.0); 547 // ^^^^^^^^^^^^^^^^^^^^^^ 548 // "logLarge" below. 549 Value cstOne = bcast(f32Cst(builder, 1.0f)); 550 Value x = op.getOperand(); 551 Value u = builder.create<arith::AddFOp>(x, cstOne); 552 Value uSmall = 553 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne); 554 Value logU = builder.create<math::LogOp>(u); 555 Value uInf = 556 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, logU); 557 Value logLarge = builder.create<arith::MulFOp>( 558 x, builder.create<arith::DivFOp>( 559 logU, builder.create<arith::SubFOp>(u, cstOne))); 560 Value approximation = builder.create<SelectOp>( 561 builder.create<arith::OrIOp>(uSmall, uInf), x, logLarge); 562 rewriter.replaceOp(op, approximation); 563 return success(); 564 } 565 566 //----------------------------------------------------------------------------// 567 // Erf approximation. 568 //----------------------------------------------------------------------------// 569 570 // Approximates erf(x) with 571 // a - P(x)/Q(x) 572 // where P and Q are polynomials of degree 4. 573 // Different coefficients are chosen based on the value of x. 574 // The approximation error is ~2.5e-07. 575 // Boost's minimax tool that utilizes the Remez method was used to find the 576 // coefficients. 577 LogicalResult 578 ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op, 579 PatternRewriter &rewriter) const { 580 if (!getElementTypeOrSelf(op.getOperand()).isF32()) 581 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 582 583 ArrayRef<int64_t> shape = vectorShape(op.getOperand()); 584 585 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 586 auto bcast = [&](Value value) -> Value { 587 return broadcast(builder, value, shape); 588 }; 589 590 const int intervalsCount = 3; 591 const int polyDegree = 4; 592 593 Value zero = bcast(f32Cst(builder, 0)); 594 Value one = bcast(f32Cst(builder, 1)); 595 Value pp[intervalsCount][polyDegree + 1]; 596 pp[0][0] = bcast(f32Cst(builder, +0.00000000000000000e+00f)); 597 pp[0][1] = bcast(f32Cst(builder, +1.12837916222975858e+00f)); 598 pp[0][2] = bcast(f32Cst(builder, -5.23018562988006470e-01f)); 599 pp[0][3] = bcast(f32Cst(builder, +2.09741709609267072e-01f)); 600 pp[0][4] = bcast(f32Cst(builder, +2.58146801602987875e-02f)); 601 pp[1][0] = bcast(f32Cst(builder, +0.00000000000000000e+00f)); 602 pp[1][1] = bcast(f32Cst(builder, +1.12750687816789140e+00f)); 603 pp[1][2] = bcast(f32Cst(builder, -3.64721408487825775e-01f)); 604 pp[1][3] = bcast(f32Cst(builder, +1.18407396425136952e-01f)); 605 pp[1][4] = bcast(f32Cst(builder, +3.70645533056476558e-02f)); 606 pp[2][0] = bcast(f32Cst(builder, -3.30093071049483172e-03f)); 607 pp[2][1] = bcast(f32Cst(builder, +3.51961938357697011e-03f)); 608 pp[2][2] = bcast(f32Cst(builder, -1.41373622814988039e-03f)); 609 pp[2][3] = bcast(f32Cst(builder, +2.53447094961941348e-04f)); 610 pp[2][4] = bcast(f32Cst(builder, -1.71048029455037401e-05f)); 611 612 Value qq[intervalsCount][polyDegree + 1]; 613 qq[0][0] = bcast(f32Cst(builder, +1.000000000000000000e+00f)); 614 qq[0][1] = bcast(f32Cst(builder, -4.635138185962547255e-01f)); 615 qq[0][2] = bcast(f32Cst(builder, +5.192301327279782447e-01f)); 616 qq[0][3] = bcast(f32Cst(builder, -1.318089722204810087e-01f)); 617 qq[0][4] = bcast(f32Cst(builder, +7.397964654672315005e-02f)); 618 qq[1][0] = bcast(f32Cst(builder, +1.00000000000000000e+00f)); 619 qq[1][1] = bcast(f32Cst(builder, -3.27607011824493086e-01f)); 620 qq[1][2] = bcast(f32Cst(builder, +4.48369090658821977e-01f)); 621 qq[1][3] = bcast(f32Cst(builder, -8.83462621207857930e-02f)); 622 qq[1][4] = bcast(f32Cst(builder, +5.72442770283176093e-02f)); 623 qq[2][0] = bcast(f32Cst(builder, +1.00000000000000000e+00f)); 624 qq[2][1] = bcast(f32Cst(builder, -2.06069165953913769e+00f)); 625 qq[2][2] = bcast(f32Cst(builder, +1.62705939945477759e+00f)); 626 qq[2][3] = bcast(f32Cst(builder, -5.83389859211130017e-01f)); 627 qq[2][4] = bcast(f32Cst(builder, +8.21908939856640930e-02f)); 628 629 Value offsets[intervalsCount]; 630 offsets[0] = bcast(f32Cst(builder, 0.0f)); 631 offsets[1] = bcast(f32Cst(builder, 0.0f)); 632 offsets[2] = bcast(f32Cst(builder, 1.0f)); 633 634 Value bounds[intervalsCount]; 635 bounds[0] = bcast(f32Cst(builder, 0.8f)); 636 bounds[1] = bcast(f32Cst(builder, 2.0f)); 637 bounds[2] = bcast(f32Cst(builder, 3.75f)); 638 639 Value isNegativeArg = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, 640 op.getOperand(), zero); 641 Value negArg = builder.create<arith::NegFOp>(op.getOperand()); 642 Value x = builder.create<SelectOp>(isNegativeArg, negArg, op.getOperand()); 643 644 Value offset = offsets[0]; 645 Value p[polyDegree + 1]; 646 Value q[polyDegree + 1]; 647 for (int i = 0; i <= polyDegree; ++i) { 648 p[i] = pp[0][i]; 649 q[i] = qq[0][i]; 650 } 651 652 // TODO: maybe use vector stacking to reduce the number of selects. 653 Value isLessThanBound[intervalsCount]; 654 for (int j = 0; j < intervalsCount - 1; ++j) { 655 isLessThanBound[j] = 656 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, bounds[j]); 657 for (int i = 0; i <= polyDegree; ++i) { 658 p[i] = builder.create<SelectOp>(isLessThanBound[j], p[i], pp[j + 1][i]); 659 q[i] = builder.create<SelectOp>(isLessThanBound[j], q[i], qq[j + 1][i]); 660 } 661 offset = 662 builder.create<SelectOp>(isLessThanBound[j], offset, offsets[j + 1]); 663 } 664 isLessThanBound[intervalsCount - 1] = builder.create<arith::CmpFOp>( 665 arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]); 666 667 Value pPoly = makePolynomialCalculation(builder, p, x); 668 Value qPoly = makePolynomialCalculation(builder, q, x); 669 Value rationalPoly = builder.create<arith::DivFOp>(pPoly, qPoly); 670 Value formula = builder.create<arith::AddFOp>(offset, rationalPoly); 671 formula = builder.create<SelectOp>(isLessThanBound[intervalsCount - 1], 672 formula, one); 673 674 // erf is odd function: erf(x) = -erf(-x). 675 Value negFormula = builder.create<arith::NegFOp>(formula); 676 Value res = builder.create<SelectOp>(isNegativeArg, negFormula, formula); 677 678 rewriter.replaceOp(op, res); 679 680 return success(); 681 } 682 683 //----------------------------------------------------------------------------// 684 // Exp approximation. 685 //----------------------------------------------------------------------------// 686 687 namespace { 688 689 struct ExpApproximation : public OpRewritePattern<math::ExpOp> { 690 public: 691 using OpRewritePattern::OpRewritePattern; 692 693 LogicalResult matchAndRewrite(math::ExpOp op, 694 PatternRewriter &rewriter) const final; 695 }; 696 } // namespace 697 698 // Approximate exp(x) using its reduced range exp(y) where y is in the range 699 // [0, ln(2)], let y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2), exp(x) 700 // = exp(y) * 2^k. exp(y). 701 LogicalResult 702 ExpApproximation::matchAndRewrite(math::ExpOp op, 703 PatternRewriter &rewriter) const { 704 if (!getElementTypeOrSelf(op.getOperand()).isF32()) 705 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 706 707 ArrayRef<int64_t> shape = vectorShape(op.getOperand()); 708 709 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 710 711 // TODO: Consider a common pattern rewriter with all methods below to 712 // write the approximations. 713 auto bcast = [&](Value value) -> Value { 714 return broadcast(builder, value, shape); 715 }; 716 auto fmla = [&](Value a, Value b, Value c) { 717 return builder.create<math::FmaOp>(a, b, c); 718 }; 719 auto mul = [&](Value a, Value b) -> Value { 720 return builder.create<arith::MulFOp>(a, b); 721 }; 722 auto sub = [&](Value a, Value b) -> Value { 723 return builder.create<arith::SubFOp>(a, b); 724 }; 725 auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); }; 726 727 Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE))); 728 Value cstLog2E = bcast(f32Cst(builder, static_cast<float>(LOG2E_VALUE))); 729 730 // Polynomial coefficients. 731 Value cstCephesExpP0 = bcast(f32Cst(builder, 1.0)); 732 Value cstCephesExpP1 = bcast(f32Cst(builder, 1.0)); 733 Value cstCephesExpP2 = bcast(f32Cst(builder, 0.49970514590562437052f)); 734 Value cstCephesExpP3 = bcast(f32Cst(builder, 0.16873890085469545053f)); 735 Value cstCephesExpP4 = bcast(f32Cst(builder, 0.03668965196652099192f)); 736 Value cstCephesExpP5 = bcast(f32Cst(builder, 0.01314350012789660196f)); 737 738 Value x = op.getOperand(); 739 740 // Reduced y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2) 741 Value xL2Inv = mul(x, cstLog2E); 742 Value kF32 = floor(xL2Inv); 743 Value kLn2 = mul(kF32, cstLn2); 744 Value y = sub(x, kLn2); 745 746 // Use Estrin's evaluation scheme with 3 independent parts: 747 // P(y)^y : (c0 + c1 y) + (c2 + c3 y) y^2 + (c4 + c5 y) y^4 748 Value y2 = mul(y, y); 749 Value y4 = mul(y2, y2); 750 751 Value q0 = fmla(cstCephesExpP1, y, cstCephesExpP0); 752 Value q1 = fmla(cstCephesExpP3, y, cstCephesExpP2); 753 Value q2 = fmla(cstCephesExpP5, y, cstCephesExpP4); 754 Value expY = fmla(q1, y2, q0); 755 expY = fmla(q2, y4, expY); 756 757 auto i32Vec = broadcast(builder.getI32Type(), shape); 758 759 // exp2(k) 760 Value k = builder.create<arith::FPToSIOp>(kF32, i32Vec); 761 Value exp2KValue = exp2I32(builder, k); 762 763 // exp(x) = exp(y) * exp2(k) 764 expY = mul(expY, exp2KValue); 765 766 // Handle overflow, inf and underflow of exp(x). exp(x) range is [0, inf], its 767 // partitioned as the following: 768 // exp(x) = 0, x <= -inf 769 // exp(x) = underflow (min_float), x <= -88 770 // exp(x) = inf (min_float), x >= 88 771 // Note: |k| = 127 is the value where the 8-bits exponent saturates. 772 Value zerof32Const = bcast(f32Cst(builder, 0)); 773 auto constPosInfinity = 774 bcast(f32Cst(builder, std::numeric_limits<float>::infinity())); 775 auto constNegIfinity = 776 bcast(f32Cst(builder, -std::numeric_limits<float>::infinity())); 777 auto underflow = bcast(f32Cst(builder, std::numeric_limits<float>::min())); 778 779 Value kMaxConst = bcast(i32Cst(builder, 127)); 780 Value kMaxNegConst = bcast(i32Cst(builder, -127)); 781 Value rightBound = 782 builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, k, kMaxConst); 783 Value leftBound = 784 builder.create<arith::CmpIOp>(arith::CmpIPredicate::sge, k, kMaxNegConst); 785 786 Value isNegInfinityX = builder.create<arith::CmpFOp>( 787 arith::CmpFPredicate::OEQ, x, constNegIfinity); 788 Value isPosInfinityX = builder.create<arith::CmpFOp>( 789 arith::CmpFPredicate::OEQ, x, constPosInfinity); 790 Value isPostiveX = 791 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zerof32Const); 792 Value isComputable = builder.create<arith::AndIOp>(rightBound, leftBound); 793 794 expY = builder.create<SelectOp>( 795 isNegInfinityX, zerof32Const, 796 builder.create<SelectOp>( 797 isPosInfinityX, constPosInfinity, 798 builder.create<SelectOp>(isComputable, expY, 799 builder.create<SelectOp>(isPostiveX, 800 constPosInfinity, 801 underflow)))); 802 803 rewriter.replaceOp(op, expY); 804 805 return success(); 806 } 807 808 //----------------------------------------------------------------------------// 809 // ExpM1 approximation. 810 //----------------------------------------------------------------------------// 811 812 namespace { 813 814 struct ExpM1Approximation : public OpRewritePattern<math::ExpM1Op> { 815 public: 816 using OpRewritePattern::OpRewritePattern; 817 818 LogicalResult matchAndRewrite(math::ExpM1Op op, 819 PatternRewriter &rewriter) const final; 820 }; 821 } // namespace 822 823 LogicalResult 824 ExpM1Approximation::matchAndRewrite(math::ExpM1Op op, 825 PatternRewriter &rewriter) const { 826 if (!getElementTypeOrSelf(op.getOperand()).isF32()) 827 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 828 829 ArrayRef<int64_t> shape = vectorShape(op.getOperand()); 830 831 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 832 auto bcast = [&](Value value) -> Value { 833 return broadcast(builder, value, shape); 834 }; 835 836 // expm1(x) = exp(x) - 1 = u - 1. 837 // We have to handle it carefully when x is near 0, i.e. u ~= 1, 838 // and when the input is ~= -inf, i.e. u - 1 ~= -1. 839 Value cstOne = bcast(f32Cst(builder, 1.0f)); 840 Value cstNegOne = bcast(f32Cst(builder, -1.0f)); 841 Value x = op.getOperand(); 842 Value u = builder.create<math::ExpOp>(x); 843 Value uEqOne = 844 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne); 845 Value uMinusOne = builder.create<arith::SubFOp>(u, cstOne); 846 Value uMinusOneEqNegOne = builder.create<arith::CmpFOp>( 847 arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne); 848 // logU = log(u) ~= x 849 Value logU = builder.create<math::LogOp>(u); 850 851 // Detect exp(x) = +inf; written this way to avoid having to form +inf. 852 Value isInf = 853 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, logU, u); 854 855 // (u - 1) * (x / ~x) 856 Value expm1 = builder.create<arith::MulFOp>( 857 uMinusOne, builder.create<arith::DivFOp>(x, logU)); 858 expm1 = builder.create<SelectOp>(isInf, u, expm1); 859 Value approximation = builder.create<SelectOp>( 860 uEqOne, x, builder.create<SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1)); 861 rewriter.replaceOp(op, approximation); 862 return success(); 863 } 864 865 //----------------------------------------------------------------------------// 866 // Sin and Cos approximation. 867 //----------------------------------------------------------------------------// 868 869 namespace { 870 871 template <bool isSine, typename OpTy> 872 struct SinAndCosApproximation : public OpRewritePattern<OpTy> { 873 public: 874 using OpRewritePattern<OpTy>::OpRewritePattern; 875 876 LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final; 877 }; 878 } // namespace 879 880 #define TWO_OVER_PI \ 881 0.6366197723675813430755350534900574481378385829618257949906693762L 882 #define PI_OVER_2 \ 883 1.5707963267948966192313216916397514420985846996875529104874722961L 884 885 // Approximates sin(x) or cos(x) by finding the best approximation polynomial in 886 // the reduced range [0, pi/2] for both sin(x) and cos(x). Then given y in the 887 // reduced range sin(x) will be computed as sin(y), -sin(y), cos(y) or -cos(y). 888 template <bool isSine, typename OpTy> 889 LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite( 890 OpTy op, PatternRewriter &rewriter) const { 891 static_assert( 892 llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value, 893 "SinAndCosApproximation pattern expects math::SinOp or math::CosOp"); 894 895 if (!getElementTypeOrSelf(op.getOperand()).isF32()) 896 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 897 898 ArrayRef<int64_t> shape = vectorShape(op.getOperand()); 899 900 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 901 auto bcast = [&](Value value) -> Value { 902 return broadcast(builder, value, shape); 903 }; 904 auto mul = [&](Value a, Value b) -> Value { 905 return builder.create<arith::MulFOp>(a, b); 906 }; 907 auto sub = [&](Value a, Value b) -> Value { 908 return builder.create<arith::SubFOp>(a, b); 909 }; 910 auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); }; 911 912 auto i32Vec = broadcast(builder.getI32Type(), shape); 913 auto fPToSingedInteger = [&](Value a) -> Value { 914 return builder.create<arith::FPToSIOp>(a, i32Vec); 915 }; 916 917 auto modulo4 = [&](Value a) -> Value { 918 return builder.create<arith::AndIOp>(a, bcast(i32Cst(builder, 3))); 919 }; 920 921 auto isEqualTo = [&](Value a, Value b) -> Value { 922 return builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, a, b); 923 }; 924 925 auto isGreaterThan = [&](Value a, Value b) -> Value { 926 return builder.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, a, b); 927 }; 928 929 auto select = [&](Value cond, Value t, Value f) -> Value { 930 return builder.create<SelectOp>(cond, t, f); 931 }; 932 933 auto fmla = [&](Value a, Value b, Value c) { 934 return builder.create<math::FmaOp>(a, b, c); 935 }; 936 937 auto bitwiseOr = [&](Value a, Value b) { 938 return builder.create<arith::OrIOp>(a, b); 939 }; 940 941 Value twoOverPi = bcast(f32Cst(builder, TWO_OVER_PI)); 942 Value piOverTwo = bcast(f32Cst(builder, PI_OVER_2)); 943 944 Value x = op.getOperand(); 945 946 Value k = floor(mul(x, twoOverPi)); 947 948 Value y = sub(x, mul(k, piOverTwo)); 949 950 Value cstOne = bcast(f32Cst(builder, 1.0)); 951 Value cstNegativeOne = bcast(f32Cst(builder, -1.0)); 952 953 Value cstSC2 = bcast(f32Cst(builder, -0.16666667163372039794921875f)); 954 Value cstSC4 = bcast(f32Cst(builder, 8.333347737789154052734375e-3f)); 955 Value cstSC6 = bcast(f32Cst(builder, -1.9842604524455964565277099609375e-4f)); 956 Value cstSC8 = 957 bcast(f32Cst(builder, 2.760012648650445044040679931640625e-6f)); 958 Value cstSC10 = 959 bcast(f32Cst(builder, -2.50293279435709337121807038784027099609375e-8f)); 960 961 Value cstCC2 = bcast(f32Cst(builder, -0.5f)); 962 Value cstCC4 = bcast(f32Cst(builder, 4.166664183139801025390625e-2f)); 963 Value cstCC6 = bcast(f32Cst(builder, -1.388833043165504932403564453125e-3f)); 964 Value cstCC8 = bcast(f32Cst(builder, 2.47562347794882953166961669921875e-5f)); 965 Value cstCC10 = 966 bcast(f32Cst(builder, -2.59630184018533327616751194000244140625e-7f)); 967 968 Value kMod4 = modulo4(fPToSingedInteger(k)); 969 970 Value kR0 = isEqualTo(kMod4, bcast(i32Cst(builder, 0))); 971 Value kR1 = isEqualTo(kMod4, bcast(i32Cst(builder, 1))); 972 Value kR2 = isEqualTo(kMod4, bcast(i32Cst(builder, 2))); 973 Value kR3 = isEqualTo(kMod4, bcast(i32Cst(builder, 3))); 974 975 Value sinuseCos = isSine ? bitwiseOr(kR1, kR3) : bitwiseOr(kR0, kR2); 976 Value negativeRange = isSine ? isGreaterThan(kMod4, bcast(i32Cst(builder, 1))) 977 : bitwiseOr(kR1, kR2); 978 979 Value y2 = mul(y, y); 980 981 Value base = select(sinuseCos, cstOne, y); 982 Value cstC2 = select(sinuseCos, cstCC2, cstSC2); 983 Value cstC4 = select(sinuseCos, cstCC4, cstSC4); 984 Value cstC6 = select(sinuseCos, cstCC6, cstSC6); 985 Value cstC8 = select(sinuseCos, cstCC8, cstSC8); 986 Value cstC10 = select(sinuseCos, cstCC10, cstSC10); 987 988 Value v1 = fmla(y2, cstC10, cstC8); 989 Value v2 = fmla(y2, v1, cstC6); 990 Value v3 = fmla(y2, v2, cstC4); 991 Value v4 = fmla(y2, v3, cstC2); 992 Value v5 = fmla(y2, v4, cstOne); 993 Value v6 = mul(base, v5); 994 995 Value approximation = select(negativeRange, mul(cstNegativeOne, v6), v6); 996 997 rewriter.replaceOp(op, approximation); 998 999 return success(); 1000 } 1001 1002 //----------------------------------------------------------------------------// 1003 // Rsqrt approximation. 1004 //----------------------------------------------------------------------------// 1005 1006 namespace { 1007 struct RsqrtApproximation : public OpRewritePattern<math::RsqrtOp> { 1008 using OpRewritePattern::OpRewritePattern; 1009 1010 LogicalResult matchAndRewrite(math::RsqrtOp op, 1011 PatternRewriter &rewriter) const final; 1012 }; 1013 } // namespace 1014 1015 LogicalResult 1016 RsqrtApproximation::matchAndRewrite(math::RsqrtOp op, 1017 PatternRewriter &rewriter) const { 1018 if (!getElementTypeOrSelf(op.getOperand()).isF32()) 1019 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 1020 1021 ArrayRef<int64_t> shape = vectorShape(op.getOperand()); 1022 1023 // Only support already-vectorized rsqrt's. 1024 if (shape.empty() || shape.back() % 8 != 0) 1025 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 1026 1027 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 1028 auto bcast = [&](Value value) -> Value { 1029 return broadcast(builder, value, shape); 1030 }; 1031 1032 Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u)); 1033 Value cstOnePointFive = bcast(f32Cst(builder, 1.5f)); 1034 Value cstNegHalf = bcast(f32Cst(builder, -0.5f)); 1035 Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u)); 1036 1037 Value negHalf = builder.create<arith::MulFOp>(op.getOperand(), cstNegHalf); 1038 1039 // Select only the inverse sqrt of positive normals (denormals are 1040 // flushed to zero). 1041 Value ltMinMask = builder.create<arith::CmpFOp>( 1042 arith::CmpFPredicate::OLT, op.getOperand(), cstMinNormPos); 1043 Value infMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, 1044 op.getOperand(), cstPosInf); 1045 Value notNormalFiniteMask = builder.create<arith::OrIOp>(ltMinMask, infMask); 1046 1047 // Compute an approximate result. 1048 Value yApprox = handleMultidimensionalVectors( 1049 builder, op->getOperands(), 8, [&builder](ValueRange operands) -> Value { 1050 return builder.create<x86vector::RsqrtOp>(operands); 1051 }); 1052 1053 // Do a single step of Newton-Raphson iteration to improve the approximation. 1054 // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n). 1055 // It is essential to evaluate the inner term like this because forming 1056 // y_n^2 may over- or underflow. 1057 Value inner = builder.create<arith::MulFOp>(negHalf, yApprox); 1058 Value fma = builder.create<math::FmaOp>(yApprox, inner, cstOnePointFive); 1059 Value yNewton = builder.create<arith::MulFOp>(yApprox, fma); 1060 1061 // Select the result of the Newton-Raphson step for positive normal arguments. 1062 // For other arguments, choose the output of the intrinsic. This will 1063 // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(x) = +inf if 1064 // x is zero or a positive denormalized float (equivalent to flushing positive 1065 // denormalized inputs to zero). 1066 Value res = builder.create<SelectOp>(notNormalFiniteMask, yApprox, yNewton); 1067 rewriter.replaceOp(op, res); 1068 1069 return success(); 1070 } 1071 1072 //----------------------------------------------------------------------------// 1073 1074 void mlir::populateMathPolynomialApproximationPatterns( 1075 RewritePatternSet &patterns, 1076 const MathPolynomialApproximationOptions &options) { 1077 patterns.add<TanhApproximation, LogApproximation, Log2Approximation, 1078 Log1pApproximation, ErfPolynomialApproximation, ExpApproximation, 1079 ExpM1Approximation, SinAndCosApproximation<true, math::SinOp>, 1080 SinAndCosApproximation<false, math::CosOp>>( 1081 patterns.getContext()); 1082 if (options.enableAvx2) 1083 patterns.add<RsqrtApproximation>(patterns.getContext()); 1084 } 1085