1 //===- PolynomialApproximation.cpp - Approximate math operations ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements expansion of math operations to fast approximations 10 // that do not rely on any of the library functions. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include <climits> 15 #include <cstddef> 16 17 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 18 #include "mlir/Dialect/Math/IR/Math.h" 19 #include "mlir/Dialect/Math/Transforms/Approximation.h" 20 #include "mlir/Dialect/Math/Transforms/Passes.h" 21 #include "mlir/Dialect/Vector/VectorOps.h" 22 #include "mlir/Dialect/Vector/VectorUtils.h" 23 #include "mlir/Dialect/X86Vector/X86VectorDialect.h" 24 #include "mlir/IR/Builders.h" 25 #include "mlir/IR/ImplicitLocOpBuilder.h" 26 #include "mlir/IR/TypeUtilities.h" 27 #include "mlir/Transforms/DialectConversion.h" 28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 29 #include "llvm/ADT/ArrayRef.h" 30 31 using namespace mlir; 32 using namespace mlir::math; 33 using namespace mlir::vector; 34 35 // Returns vector shape if the type is a vector. Returns an empty shape if it is 36 // not a vector. 37 static ArrayRef<int64_t> vectorShape(Type type) { 38 auto vectorType = type.dyn_cast<VectorType>(); 39 return vectorType ? vectorType.getShape() : ArrayRef<int64_t>(); 40 } 41 42 static ArrayRef<int64_t> vectorShape(Value value) { 43 return vectorShape(value.getType()); 44 } 45 46 //----------------------------------------------------------------------------// 47 // Broadcast scalar types and values into vector types and values. 48 //----------------------------------------------------------------------------// 49 50 // Broadcasts scalar type into vector type (iff shape is non-scalar). 51 static Type broadcast(Type type, ArrayRef<int64_t> shape) { 52 assert(!type.isa<VectorType>() && "must be scalar type"); 53 return !shape.empty() ? VectorType::get(shape, type) : type; 54 } 55 56 // Broadcasts scalar value into vector (iff shape is non-scalar). 57 static Value broadcast(ImplicitLocOpBuilder &builder, Value value, 58 ArrayRef<int64_t> shape) { 59 assert(!value.getType().isa<VectorType>() && "must be scalar value"); 60 auto type = broadcast(value.getType(), shape); 61 return !shape.empty() ? builder.create<BroadcastOp>(type, value) : value; 62 } 63 64 //----------------------------------------------------------------------------// 65 // Helper function to handle n-D vectors with 1-D operations. 66 //----------------------------------------------------------------------------// 67 68 // Expands and unrolls n-D vector operands into multiple fixed size 1-D vectors 69 // and calls the compute function with 1-D vector operands. Stitches back all 70 // results into the original n-D vector result. 71 // 72 // Examples: vectorWidth = 8 73 // - vector<4x8xf32> unrolled 4 times 74 // - vector<16xf32> expanded to vector<2x8xf32> and unrolled 2 times 75 // - vector<4x16xf32> expanded to vector<4x2x8xf32> and unrolled 4*2 times 76 // 77 // Some math approximations rely on ISA-specific operations that only accept 78 // fixed size 1-D vectors (e.g. AVX expects vectors of width 8). 79 // 80 // It is the caller's responsibility to verify that the inner dimension is 81 // divisible by the vectorWidth, and that all operands have the same vector 82 // shape. 83 static Value 84 handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, 85 ValueRange operands, int64_t vectorWidth, 86 llvm::function_ref<Value(ValueRange)> compute) { 87 assert(!operands.empty() && "operands must be not empty"); 88 assert(vectorWidth > 0 && "vector width must be larger than 0"); 89 90 VectorType inputType = operands[0].getType().cast<VectorType>(); 91 ArrayRef<int64_t> inputShape = inputType.getShape(); 92 93 // If input shape matches target vector width, we can just call the 94 // user-provided compute function with the operands. 95 if (inputShape == llvm::makeArrayRef(vectorWidth)) 96 return compute(operands); 97 98 // Check if the inner dimension has to be expanded, or we can directly iterate 99 // over the outer dimensions of the vector. 100 int64_t innerDim = inputShape.back(); 101 int64_t expansionDim = innerDim / vectorWidth; 102 assert((innerDim % vectorWidth == 0) && "invalid inner dimension size"); 103 104 // Maybe expand operands to the higher rank vector shape that we'll use to 105 // iterate over and extract one dimensional vectors. 106 SmallVector<int64_t> expandedShape(inputShape.begin(), inputShape.end()); 107 SmallVector<Value> expandedOperands(operands); 108 109 if (expansionDim > 1) { 110 // Expand shape from [..., innerDim] to [..., expansionDim, vectorWidth]. 111 expandedShape.insert(expandedShape.end() - 1, expansionDim); 112 expandedShape.back() = vectorWidth; 113 114 for (unsigned i = 0; i < operands.size(); ++i) { 115 auto operand = operands[i]; 116 auto eltType = operand.getType().cast<VectorType>().getElementType(); 117 auto expandedType = VectorType::get(expandedShape, eltType); 118 expandedOperands[i] = 119 builder.create<vector::ShapeCastOp>(expandedType, operand); 120 } 121 } 122 123 // Iterate over all outer dimensions of the compute shape vector type. 124 auto iterationDims = ArrayRef<int64_t>(expandedShape).drop_back(); 125 int64_t maxLinearIndex = computeMaxLinearIndex(iterationDims); 126 127 SmallVector<int64_t> ones(iterationDims.size(), 1); 128 auto strides = computeStrides(iterationDims, ones); 129 130 // Compute results for each one dimensional vector. 131 SmallVector<Value> results(maxLinearIndex); 132 133 for (int64_t i = 0; i < maxLinearIndex; ++i) { 134 auto offsets = delinearize(strides, i); 135 136 SmallVector<Value> extracted(expandedOperands.size()); 137 for (const auto &tuple : llvm::enumerate(expandedOperands)) 138 extracted[tuple.index()] = 139 builder.create<vector::ExtractOp>(tuple.value(), offsets); 140 141 results[i] = compute(extracted); 142 } 143 144 // Stitch results together into one large vector. 145 Type resultEltType = results[0].getType().cast<VectorType>().getElementType(); 146 Type resultExpandedType = VectorType::get(expandedShape, resultEltType); 147 Value result = builder.create<ConstantOp>( 148 resultExpandedType, builder.getZeroAttr(resultExpandedType)); 149 150 for (int64_t i = 0; i < maxLinearIndex; ++i) 151 result = builder.create<vector::InsertOp>(results[i], result, 152 delinearize(strides, i)); 153 154 // Reshape back to the original vector shape. 155 return builder.create<vector::ShapeCastOp>( 156 VectorType::get(inputShape, resultEltType), result); 157 } 158 159 //----------------------------------------------------------------------------// 160 // Helper functions to create constants. 161 //----------------------------------------------------------------------------// 162 163 static Value f32Cst(ImplicitLocOpBuilder &builder, float value) { 164 return builder.create<arith::ConstantOp>(builder.getF32FloatAttr(value)); 165 } 166 167 static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value) { 168 return builder.create<arith::ConstantOp>(builder.getI32IntegerAttr(value)); 169 } 170 171 static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits) { 172 Value i32Value = i32Cst(builder, static_cast<int32_t>(bits)); 173 return builder.create<arith::BitcastOp>(builder.getF32Type(), i32Value); 174 } 175 176 //----------------------------------------------------------------------------// 177 // Helper functions to build math functions approximations. 178 //----------------------------------------------------------------------------// 179 180 static Value min(ImplicitLocOpBuilder &builder, Value a, Value b) { 181 return builder.create<SelectOp>( 182 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, a, b), a, b); 183 } 184 185 static Value max(ImplicitLocOpBuilder &builder, Value a, Value b) { 186 return builder.create<SelectOp>( 187 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, a, b), a, b); 188 } 189 190 static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, 191 Value upperBound) { 192 return max(builder, min(builder, value, upperBound), lowerBound); 193 } 194 195 // Decomposes given floating point value `arg` into a normalized fraction and 196 // an integral power of two (see std::frexp). Returned values have float type. 197 static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg, 198 bool isPositive = false) { 199 assert(getElementTypeOrSelf(arg).isF32() && "arg must be f32 type"); 200 ArrayRef<int64_t> shape = vectorShape(arg); 201 202 auto bcast = [&](Value value) -> Value { 203 return broadcast(builder, value, shape); 204 }; 205 206 auto i32 = builder.getIntegerType(32); 207 auto i32Vec = broadcast(i32, shape); 208 auto f32Vec = broadcast(builder.getF32Type(), shape); 209 210 Value cst126f = f32Cst(builder, 126.0f); 211 Value cstHalf = f32Cst(builder, 0.5f); 212 Value cstInvMantMask = f32FromBits(builder, ~0x7f800000u); 213 214 // Bitcast to i32 for bitwise operations. 215 Value i32Half = builder.create<arith::BitcastOp>(i32, cstHalf); 216 Value i32InvMantMask = builder.create<arith::BitcastOp>(i32, cstInvMantMask); 217 Value i32Arg = builder.create<arith::BitcastOp>(i32Vec, arg); 218 219 // Compute normalized fraction. 220 Value tmp0 = builder.create<arith::AndIOp>(i32Arg, bcast(i32InvMantMask)); 221 Value tmp1 = builder.create<arith::OrIOp>(tmp0, bcast(i32Half)); 222 Value normalizedFraction = builder.create<arith::BitcastOp>(f32Vec, tmp1); 223 224 // Compute exponent. 225 Value arg0 = isPositive ? arg : builder.create<math::AbsOp>(arg); 226 Value biasedExponentBits = builder.create<arith::ShRUIOp>( 227 builder.create<arith::BitcastOp>(i32Vec, arg0), 228 bcast(i32Cst(builder, 23))); 229 Value biasedExponent = 230 builder.create<arith::SIToFPOp>(f32Vec, biasedExponentBits); 231 Value exponent = 232 builder.create<arith::SubFOp>(biasedExponent, bcast(cst126f)); 233 234 return {normalizedFraction, exponent}; 235 } 236 237 // Computes exp2 for an i32 argument. 238 static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) { 239 assert(getElementTypeOrSelf(arg).isInteger(32) && "arg must be i32 type"); 240 ArrayRef<int64_t> shape = vectorShape(arg); 241 242 auto bcast = [&](Value value) -> Value { 243 return broadcast(builder, value, shape); 244 }; 245 246 auto f32Vec = broadcast(builder.getF32Type(), shape); 247 // The exponent of f32 located at 23-bit. 248 auto exponetBitLocation = bcast(i32Cst(builder, 23)); 249 // Set the exponent bias to zero. 250 auto bias = bcast(i32Cst(builder, 127)); 251 252 Value biasedArg = builder.create<arith::AddIOp>(arg, bias); 253 Value exp2ValueInt = 254 builder.create<arith::ShLIOp>(biasedArg, exponetBitLocation); 255 Value exp2ValueF32 = builder.create<arith::BitcastOp>(f32Vec, exp2ValueInt); 256 257 return exp2ValueF32; 258 } 259 260 namespace { 261 Value makePolynomialCalculation(ImplicitLocOpBuilder &builder, 262 llvm::ArrayRef<Value> coeffs, Value x) { 263 assert(getElementTypeOrSelf(x).isF32() && "x must be f32 type"); 264 ArrayRef<int64_t> shape = vectorShape(x); 265 266 if (coeffs.empty()) 267 return broadcast(builder, f32Cst(builder, 0.0f), shape); 268 269 if (coeffs.size() == 1) 270 return coeffs[0]; 271 272 Value res = builder.create<math::FmaOp>(x, coeffs[coeffs.size() - 1], 273 coeffs[coeffs.size() - 2]); 274 for (auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) { 275 res = builder.create<math::FmaOp>(x, res, coeffs[i]); 276 } 277 return res; 278 } 279 } // namespace 280 281 //----------------------------------------------------------------------------// 282 // AtanOp approximation. 283 //----------------------------------------------------------------------------// 284 285 namespace { 286 struct AtanApproximation : public OpRewritePattern<math::AtanOp> { 287 public: 288 using OpRewritePattern::OpRewritePattern; 289 290 LogicalResult matchAndRewrite(math::AtanOp op, 291 PatternRewriter &rewriter) const final; 292 }; 293 } // namespace 294 295 LogicalResult 296 AtanApproximation::matchAndRewrite(math::AtanOp op, 297 PatternRewriter &rewriter) const { 298 auto operand = op.getOperand(); 299 if (!getElementTypeOrSelf(operand).isF32()) 300 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 301 302 ArrayRef<int64_t> shape = vectorShape(op.getOperand()); 303 304 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 305 auto one = broadcast(builder, f32Cst(builder, 1.0f), shape); 306 307 // Remap the problem over [0.0, 1.0] by looking at the absolute value and the 308 // handling symmetry. 309 Value abs = builder.create<math::AbsOp>(operand); 310 Value reciprocal = builder.create<arith::DivFOp>(one, abs); 311 Value compare = 312 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, abs, reciprocal); 313 Value x = builder.create<SelectOp>(compare, abs, reciprocal); 314 315 // Perform the Taylor series approximation for atan over the range 316 // [-1.0, 1.0]. 317 auto n1 = broadcast(builder, f32Cst(builder, 0.14418283), shape); 318 auto n2 = broadcast(builder, f32Cst(builder, -0.34999234), shape); 319 auto n3 = broadcast(builder, f32Cst(builder, -0.01067831), shape); 320 auto n4 = broadcast(builder, f32Cst(builder, 1.00209986), shape); 321 322 Value p = builder.create<math::FmaOp>(x, n1, n2); 323 p = builder.create<math::FmaOp>(x, p, n3); 324 p = builder.create<math::FmaOp>(x, p, n4); 325 p = builder.create<arith::MulFOp>(x, p); 326 327 // Remap the solution for over [0.0, 1.0] to [0.0, inf] 328 auto halfPi = broadcast(builder, f32Cst(builder, 1.57079632679f), shape); 329 Value sub = builder.create<arith::SubFOp>(halfPi, p); 330 Value select = builder.create<SelectOp>(compare, p, sub); 331 332 // Correct for signing of the input. 333 rewriter.replaceOpWithNewOp<math::CopySignOp>(op, select, operand); 334 return success(); 335 } 336 337 //----------------------------------------------------------------------------// 338 // AtanOp approximation. 339 //----------------------------------------------------------------------------// 340 341 namespace { 342 struct Atan2Approximation : public OpRewritePattern<math::Atan2Op> { 343 public: 344 using OpRewritePattern::OpRewritePattern; 345 346 LogicalResult matchAndRewrite(math::Atan2Op op, 347 PatternRewriter &rewriter) const final; 348 }; 349 } // namespace 350 351 LogicalResult 352 Atan2Approximation::matchAndRewrite(math::Atan2Op op, 353 PatternRewriter &rewriter) const { 354 auto y = op.getOperand(0); 355 auto x = op.getOperand(1); 356 if (!getElementTypeOrSelf(x).isF32()) 357 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 358 359 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 360 ArrayRef<int64_t> shape = vectorShape(op.getResult()); 361 362 // Compute atan in the valid range. 363 auto div = builder.create<arith::DivFOp>(y, x); 364 auto atan = builder.create<math::AtanOp>(div); 365 366 // Determine what the atan would be for a 180 degree rotation. 367 auto zero = broadcast(builder, f32Cst(builder, 0.0f), shape); 368 auto pi = broadcast(builder, f32Cst(builder, 3.14159265359f), shape); 369 auto addPi = builder.create<arith::AddFOp>(atan, pi); 370 auto subPi = builder.create<arith::SubFOp>(atan, pi); 371 auto atanGt = 372 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, atan, zero); 373 auto flippedAtan = builder.create<SelectOp>(atanGt, subPi, addPi); 374 375 // Determine whether to directly use atan or use the 180 degree flip 376 auto xGt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zero); 377 Value result = builder.create<SelectOp>(xGt, atan, flippedAtan); 378 379 // Handle x = 0, y > 0 380 Value xZero = 381 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, x, zero); 382 Value yGt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, y, zero); 383 Value isHalfPi = builder.create<arith::AndIOp>(xZero, yGt); 384 auto halfPi = broadcast(builder, f32Cst(builder, 1.57079632679f), shape); 385 result = builder.create<SelectOp>(isHalfPi, halfPi, result); 386 387 // Handle x = 0, y < 0 388 Value yLt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, y, zero); 389 Value isNegativeHalfPiPi = builder.create<arith::AndIOp>(xZero, yLt); 390 auto negativeHalfPiPi = 391 broadcast(builder, f32Cst(builder, -1.57079632679), shape); 392 result = 393 builder.create<SelectOp>(isNegativeHalfPiPi, negativeHalfPiPi, result); 394 395 // Handle x = 0, y = 0; 396 Value yZero = 397 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, y, zero); 398 Value isNan = builder.create<arith::AndIOp>(xZero, yZero); 399 Value cstNan = broadcast(builder, f32FromBits(builder, 0x7fc00000), shape); 400 result = builder.create<SelectOp>(isNan, cstNan, result); 401 402 rewriter.replaceOp(op, result); 403 return success(); 404 } 405 406 //----------------------------------------------------------------------------// 407 // TanhOp approximation. 408 //----------------------------------------------------------------------------// 409 410 namespace { 411 struct TanhApproximation : public OpRewritePattern<math::TanhOp> { 412 public: 413 using OpRewritePattern::OpRewritePattern; 414 415 LogicalResult matchAndRewrite(math::TanhOp op, 416 PatternRewriter &rewriter) const final; 417 }; 418 } // namespace 419 420 LogicalResult 421 TanhApproximation::matchAndRewrite(math::TanhOp op, 422 PatternRewriter &rewriter) const { 423 if (!getElementTypeOrSelf(op.getOperand()).isF32()) 424 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 425 426 ArrayRef<int64_t> shape = vectorShape(op.getOperand()); 427 428 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 429 auto bcast = [&](Value value) -> Value { 430 return broadcast(builder, value, shape); 431 }; 432 433 // Clamp operand into [plusClamp, minusClamp] range. 434 Value minusClamp = bcast(f32Cst(builder, -7.99881172180175781f)); 435 Value plusClamp = bcast(f32Cst(builder, 7.99881172180175781f)); 436 Value x = clamp(builder, op.getOperand(), minusClamp, plusClamp); 437 438 // Mask for tiny values that are approximated with `operand`. 439 Value tiny = bcast(f32Cst(builder, 0.0004f)); 440 Value tinyMask = builder.create<arith::CmpFOp>( 441 arith::CmpFPredicate::OLT, builder.create<math::AbsOp>(op.getOperand()), 442 tiny); 443 444 // The monomial coefficients of the numerator polynomial (odd). 445 Value alpha1 = bcast(f32Cst(builder, 4.89352455891786e-03f)); 446 Value alpha3 = bcast(f32Cst(builder, 6.37261928875436e-04f)); 447 Value alpha5 = bcast(f32Cst(builder, 1.48572235717979e-05f)); 448 Value alpha7 = bcast(f32Cst(builder, 5.12229709037114e-08f)); 449 Value alpha9 = bcast(f32Cst(builder, -8.60467152213735e-11f)); 450 Value alpha11 = bcast(f32Cst(builder, 2.00018790482477e-13f)); 451 Value alpha13 = bcast(f32Cst(builder, -2.76076847742355e-16f)); 452 453 // The monomial coefficients of the denominator polynomial (even). 454 Value beta0 = bcast(f32Cst(builder, 4.89352518554385e-03f)); 455 Value beta2 = bcast(f32Cst(builder, 2.26843463243900e-03f)); 456 Value beta4 = bcast(f32Cst(builder, 1.18534705686654e-04f)); 457 Value beta6 = bcast(f32Cst(builder, 1.19825839466702e-06f)); 458 459 // Since the polynomials are odd/even, we need x^2. 460 Value x2 = builder.create<arith::MulFOp>(x, x); 461 462 // Evaluate the numerator polynomial p. 463 Value p = builder.create<math::FmaOp>(x2, alpha13, alpha11); 464 p = builder.create<math::FmaOp>(x2, p, alpha9); 465 p = builder.create<math::FmaOp>(x2, p, alpha7); 466 p = builder.create<math::FmaOp>(x2, p, alpha5); 467 p = builder.create<math::FmaOp>(x2, p, alpha3); 468 p = builder.create<math::FmaOp>(x2, p, alpha1); 469 p = builder.create<arith::MulFOp>(x, p); 470 471 // Evaluate the denominator polynomial q. 472 Value q = builder.create<math::FmaOp>(x2, beta6, beta4); 473 q = builder.create<math::FmaOp>(x2, q, beta2); 474 q = builder.create<math::FmaOp>(x2, q, beta0); 475 476 // Divide the numerator by the denominator. 477 Value res = builder.create<SelectOp>(tinyMask, x, 478 builder.create<arith::DivFOp>(p, q)); 479 480 rewriter.replaceOp(op, res); 481 482 return success(); 483 } 484 485 #define LN2_VALUE \ 486 0.693147180559945309417232121458176568075500134360255254120680009493393621L 487 #define LOG2E_VALUE \ 488 1.442695040888963407359924681001892137426645954152985934135449406931109219L 489 490 //----------------------------------------------------------------------------// 491 // LogOp and Log2Op approximation. 492 //----------------------------------------------------------------------------// 493 494 namespace { 495 template <typename Op> 496 struct LogApproximationBase : public OpRewritePattern<Op> { 497 using OpRewritePattern<Op>::OpRewritePattern; 498 499 /// Base 2 if 'base2' is set; natural logarithm (base e) otherwise. 500 LogicalResult logMatchAndRewrite(Op op, PatternRewriter &rewriter, 501 bool base2) const; 502 }; 503 } // namespace 504 505 // This approximation comes from Julien Pommier's SSE math library. 506 // Link: http://gruntthepeon.free.fr/ssemath 507 template <typename Op> 508 LogicalResult 509 LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter, 510 bool base2) const { 511 if (!getElementTypeOrSelf(op.getOperand()).isF32()) 512 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 513 514 ArrayRef<int64_t> shape = vectorShape(op.getOperand()); 515 516 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 517 auto bcast = [&](Value value) -> Value { 518 return broadcast(builder, value, shape); 519 }; 520 521 Value cstZero = bcast(f32Cst(builder, 0.0f)); 522 Value cstOne = bcast(f32Cst(builder, 1.0f)); 523 Value cstNegHalf = bcast(f32Cst(builder, -0.5f)); 524 525 // The smallest non denormalized float number. 526 Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u)); 527 Value cstMinusInf = bcast(f32FromBits(builder, 0xff800000u)); 528 Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u)); 529 Value cstNan = bcast(f32FromBits(builder, 0x7fc00000)); 530 531 // Polynomial coefficients. 532 Value cstCephesSQRTHF = bcast(f32Cst(builder, 0.707106781186547524f)); 533 Value cstCephesLogP0 = bcast(f32Cst(builder, 7.0376836292E-2f)); 534 Value cstCephesLogP1 = bcast(f32Cst(builder, -1.1514610310E-1f)); 535 Value cstCephesLogP2 = bcast(f32Cst(builder, 1.1676998740E-1f)); 536 Value cstCephesLogP3 = bcast(f32Cst(builder, -1.2420140846E-1f)); 537 Value cstCephesLogP4 = bcast(f32Cst(builder, +1.4249322787E-1f)); 538 Value cstCephesLogP5 = bcast(f32Cst(builder, -1.6668057665E-1f)); 539 Value cstCephesLogP6 = bcast(f32Cst(builder, +2.0000714765E-1f)); 540 Value cstCephesLogP7 = bcast(f32Cst(builder, -2.4999993993E-1f)); 541 Value cstCephesLogP8 = bcast(f32Cst(builder, +3.3333331174E-1f)); 542 543 Value x = op.getOperand(); 544 545 // Truncate input values to the minimum positive normal. 546 x = max(builder, x, cstMinNormPos); 547 548 // Extract significant in the range [0.5,1) and exponent. 549 std::pair<Value, Value> pair = frexp(builder, x, /*isPositive=*/true); 550 x = pair.first; 551 Value e = pair.second; 552 553 // Shift the inputs from the range [0.5,1) to [sqrt(1/2), sqrt(2)) and shift 554 // by -1.0. The values are then centered around 0, which improves the 555 // stability of the polynomial evaluation: 556 // 557 // if( x < SQRTHF ) { 558 // e -= 1; 559 // x = x + x - 1.0; 560 // } else { x = x - 1.0; } 561 Value mask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, 562 cstCephesSQRTHF); 563 Value tmp = builder.create<SelectOp>(mask, x, cstZero); 564 565 x = builder.create<arith::SubFOp>(x, cstOne); 566 e = builder.create<arith::SubFOp>( 567 e, builder.create<SelectOp>(mask, cstOne, cstZero)); 568 x = builder.create<arith::AddFOp>(x, tmp); 569 570 Value x2 = builder.create<arith::MulFOp>(x, x); 571 Value x3 = builder.create<arith::MulFOp>(x2, x); 572 573 // Evaluate the polynomial approximant of degree 8 in three parts. 574 Value y0, y1, y2; 575 y0 = builder.create<math::FmaOp>(cstCephesLogP0, x, cstCephesLogP1); 576 y1 = builder.create<math::FmaOp>(cstCephesLogP3, x, cstCephesLogP4); 577 y2 = builder.create<math::FmaOp>(cstCephesLogP6, x, cstCephesLogP7); 578 y0 = builder.create<math::FmaOp>(y0, x, cstCephesLogP2); 579 y1 = builder.create<math::FmaOp>(y1, x, cstCephesLogP5); 580 y2 = builder.create<math::FmaOp>(y2, x, cstCephesLogP8); 581 y0 = builder.create<math::FmaOp>(y0, x3, y1); 582 y0 = builder.create<math::FmaOp>(y0, x3, y2); 583 y0 = builder.create<arith::MulFOp>(y0, x3); 584 585 y0 = builder.create<math::FmaOp>(cstNegHalf, x2, y0); 586 x = builder.create<arith::AddFOp>(x, y0); 587 588 if (base2) { 589 Value cstLog2e = bcast(f32Cst(builder, static_cast<float>(LOG2E_VALUE))); 590 x = builder.create<math::FmaOp>(x, cstLog2e, e); 591 } else { 592 Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE))); 593 x = builder.create<math::FmaOp>(e, cstLn2, x); 594 } 595 596 Value invalidMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, 597 op.getOperand(), cstZero); 598 Value zeroMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, 599 op.getOperand(), cstZero); 600 Value posInfMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, 601 op.getOperand(), cstPosInf); 602 603 // Filter out invalid values: 604 // • x == 0 -> -INF 605 // • x < 0 -> NAN 606 // • x == +INF -> +INF 607 Value aproximation = builder.create<SelectOp>( 608 zeroMask, cstMinusInf, 609 builder.create<SelectOp>( 610 invalidMask, cstNan, 611 builder.create<SelectOp>(posInfMask, cstPosInf, x))); 612 613 rewriter.replaceOp(op, aproximation); 614 615 return success(); 616 } 617 618 namespace { 619 struct LogApproximation : public LogApproximationBase<math::LogOp> { 620 using LogApproximationBase::LogApproximationBase; 621 622 LogicalResult matchAndRewrite(math::LogOp op, 623 PatternRewriter &rewriter) const final { 624 return logMatchAndRewrite(op, rewriter, /*base2=*/false); 625 } 626 }; 627 } // namespace 628 629 namespace { 630 struct Log2Approximation : public LogApproximationBase<math::Log2Op> { 631 using LogApproximationBase::LogApproximationBase; 632 633 LogicalResult matchAndRewrite(math::Log2Op op, 634 PatternRewriter &rewriter) const final { 635 return logMatchAndRewrite(op, rewriter, /*base2=*/true); 636 } 637 }; 638 } // namespace 639 640 //----------------------------------------------------------------------------// 641 // Log1p approximation. 642 //----------------------------------------------------------------------------// 643 644 namespace { 645 struct Log1pApproximation : public OpRewritePattern<math::Log1pOp> { 646 public: 647 using OpRewritePattern::OpRewritePattern; 648 649 LogicalResult matchAndRewrite(math::Log1pOp op, 650 PatternRewriter &rewriter) const final; 651 }; 652 } // namespace 653 654 // Approximate log(1+x). 655 LogicalResult 656 Log1pApproximation::matchAndRewrite(math::Log1pOp op, 657 PatternRewriter &rewriter) const { 658 if (!getElementTypeOrSelf(op.getOperand()).isF32()) 659 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 660 661 ArrayRef<int64_t> shape = vectorShape(op.getOperand()); 662 663 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 664 auto bcast = [&](Value value) -> Value { 665 return broadcast(builder, value, shape); 666 }; 667 668 // Approximate log(1+x) using the following, due to W. Kahan: 669 // u = x + 1.0; 670 // if (u == 1.0 || u == inf) return x; 671 // return x * log(u) / (u - 1.0); 672 // ^^^^^^^^^^^^^^^^^^^^^^ 673 // "logLarge" below. 674 Value cstOne = bcast(f32Cst(builder, 1.0f)); 675 Value x = op.getOperand(); 676 Value u = builder.create<arith::AddFOp>(x, cstOne); 677 Value uSmall = 678 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne); 679 Value logU = builder.create<math::LogOp>(u); 680 Value uInf = 681 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, logU); 682 Value logLarge = builder.create<arith::MulFOp>( 683 x, builder.create<arith::DivFOp>( 684 logU, builder.create<arith::SubFOp>(u, cstOne))); 685 Value approximation = builder.create<SelectOp>( 686 builder.create<arith::OrIOp>(uSmall, uInf), x, logLarge); 687 rewriter.replaceOp(op, approximation); 688 return success(); 689 } 690 691 //----------------------------------------------------------------------------// 692 // Erf approximation. 693 //----------------------------------------------------------------------------// 694 695 // Approximates erf(x) with 696 // a - P(x)/Q(x) 697 // where P and Q are polynomials of degree 4. 698 // Different coefficients are chosen based on the value of x. 699 // The approximation error is ~2.5e-07. 700 // Boost's minimax tool that utilizes the Remez method was used to find the 701 // coefficients. 702 LogicalResult 703 ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op, 704 PatternRewriter &rewriter) const { 705 if (!getElementTypeOrSelf(op.getOperand()).isF32()) 706 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 707 708 ArrayRef<int64_t> shape = vectorShape(op.getOperand()); 709 710 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 711 auto bcast = [&](Value value) -> Value { 712 return broadcast(builder, value, shape); 713 }; 714 715 const int intervalsCount = 3; 716 const int polyDegree = 4; 717 718 Value zero = bcast(f32Cst(builder, 0)); 719 Value one = bcast(f32Cst(builder, 1)); 720 Value pp[intervalsCount][polyDegree + 1]; 721 pp[0][0] = bcast(f32Cst(builder, +0.00000000000000000e+00f)); 722 pp[0][1] = bcast(f32Cst(builder, +1.12837916222975858e+00f)); 723 pp[0][2] = bcast(f32Cst(builder, -5.23018562988006470e-01f)); 724 pp[0][3] = bcast(f32Cst(builder, +2.09741709609267072e-01f)); 725 pp[0][4] = bcast(f32Cst(builder, +2.58146801602987875e-02f)); 726 pp[1][0] = bcast(f32Cst(builder, +0.00000000000000000e+00f)); 727 pp[1][1] = bcast(f32Cst(builder, +1.12750687816789140e+00f)); 728 pp[1][2] = bcast(f32Cst(builder, -3.64721408487825775e-01f)); 729 pp[1][3] = bcast(f32Cst(builder, +1.18407396425136952e-01f)); 730 pp[1][4] = bcast(f32Cst(builder, +3.70645533056476558e-02f)); 731 pp[2][0] = bcast(f32Cst(builder, -3.30093071049483172e-03f)); 732 pp[2][1] = bcast(f32Cst(builder, +3.51961938357697011e-03f)); 733 pp[2][2] = bcast(f32Cst(builder, -1.41373622814988039e-03f)); 734 pp[2][3] = bcast(f32Cst(builder, +2.53447094961941348e-04f)); 735 pp[2][4] = bcast(f32Cst(builder, -1.71048029455037401e-05f)); 736 737 Value qq[intervalsCount][polyDegree + 1]; 738 qq[0][0] = bcast(f32Cst(builder, +1.000000000000000000e+00f)); 739 qq[0][1] = bcast(f32Cst(builder, -4.635138185962547255e-01f)); 740 qq[0][2] = bcast(f32Cst(builder, +5.192301327279782447e-01f)); 741 qq[0][3] = bcast(f32Cst(builder, -1.318089722204810087e-01f)); 742 qq[0][4] = bcast(f32Cst(builder, +7.397964654672315005e-02f)); 743 qq[1][0] = bcast(f32Cst(builder, +1.00000000000000000e+00f)); 744 qq[1][1] = bcast(f32Cst(builder, -3.27607011824493086e-01f)); 745 qq[1][2] = bcast(f32Cst(builder, +4.48369090658821977e-01f)); 746 qq[1][3] = bcast(f32Cst(builder, -8.83462621207857930e-02f)); 747 qq[1][4] = bcast(f32Cst(builder, +5.72442770283176093e-02f)); 748 qq[2][0] = bcast(f32Cst(builder, +1.00000000000000000e+00f)); 749 qq[2][1] = bcast(f32Cst(builder, -2.06069165953913769e+00f)); 750 qq[2][2] = bcast(f32Cst(builder, +1.62705939945477759e+00f)); 751 qq[2][3] = bcast(f32Cst(builder, -5.83389859211130017e-01f)); 752 qq[2][4] = bcast(f32Cst(builder, +8.21908939856640930e-02f)); 753 754 Value offsets[intervalsCount]; 755 offsets[0] = bcast(f32Cst(builder, 0.0f)); 756 offsets[1] = bcast(f32Cst(builder, 0.0f)); 757 offsets[2] = bcast(f32Cst(builder, 1.0f)); 758 759 Value bounds[intervalsCount]; 760 bounds[0] = bcast(f32Cst(builder, 0.8f)); 761 bounds[1] = bcast(f32Cst(builder, 2.0f)); 762 bounds[2] = bcast(f32Cst(builder, 3.75f)); 763 764 Value isNegativeArg = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, 765 op.getOperand(), zero); 766 Value negArg = builder.create<arith::NegFOp>(op.getOperand()); 767 Value x = builder.create<SelectOp>(isNegativeArg, negArg, op.getOperand()); 768 769 Value offset = offsets[0]; 770 Value p[polyDegree + 1]; 771 Value q[polyDegree + 1]; 772 for (int i = 0; i <= polyDegree; ++i) { 773 p[i] = pp[0][i]; 774 q[i] = qq[0][i]; 775 } 776 777 // TODO: maybe use vector stacking to reduce the number of selects. 778 Value isLessThanBound[intervalsCount]; 779 for (int j = 0; j < intervalsCount - 1; ++j) { 780 isLessThanBound[j] = 781 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, bounds[j]); 782 for (int i = 0; i <= polyDegree; ++i) { 783 p[i] = builder.create<SelectOp>(isLessThanBound[j], p[i], pp[j + 1][i]); 784 q[i] = builder.create<SelectOp>(isLessThanBound[j], q[i], qq[j + 1][i]); 785 } 786 offset = 787 builder.create<SelectOp>(isLessThanBound[j], offset, offsets[j + 1]); 788 } 789 isLessThanBound[intervalsCount - 1] = builder.create<arith::CmpFOp>( 790 arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]); 791 792 Value pPoly = makePolynomialCalculation(builder, p, x); 793 Value qPoly = makePolynomialCalculation(builder, q, x); 794 Value rationalPoly = builder.create<arith::DivFOp>(pPoly, qPoly); 795 Value formula = builder.create<arith::AddFOp>(offset, rationalPoly); 796 formula = builder.create<SelectOp>(isLessThanBound[intervalsCount - 1], 797 formula, one); 798 799 // erf is odd function: erf(x) = -erf(-x). 800 Value negFormula = builder.create<arith::NegFOp>(formula); 801 Value res = builder.create<SelectOp>(isNegativeArg, negFormula, formula); 802 803 rewriter.replaceOp(op, res); 804 805 return success(); 806 } 807 808 //----------------------------------------------------------------------------// 809 // Exp approximation. 810 //----------------------------------------------------------------------------// 811 812 namespace { 813 814 struct ExpApproximation : public OpRewritePattern<math::ExpOp> { 815 public: 816 using OpRewritePattern::OpRewritePattern; 817 818 LogicalResult matchAndRewrite(math::ExpOp op, 819 PatternRewriter &rewriter) const final; 820 }; 821 } // namespace 822 823 // Approximate exp(x) using its reduced range exp(y) where y is in the range 824 // [0, ln(2)], let y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2), exp(x) 825 // = exp(y) * 2^k. exp(y). 826 LogicalResult 827 ExpApproximation::matchAndRewrite(math::ExpOp op, 828 PatternRewriter &rewriter) const { 829 if (!getElementTypeOrSelf(op.getOperand()).isF32()) 830 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 831 832 ArrayRef<int64_t> shape = vectorShape(op.getOperand()); 833 834 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 835 836 // TODO: Consider a common pattern rewriter with all methods below to 837 // write the approximations. 838 auto bcast = [&](Value value) -> Value { 839 return broadcast(builder, value, shape); 840 }; 841 auto fmla = [&](Value a, Value b, Value c) { 842 return builder.create<math::FmaOp>(a, b, c); 843 }; 844 auto mul = [&](Value a, Value b) -> Value { 845 return builder.create<arith::MulFOp>(a, b); 846 }; 847 auto sub = [&](Value a, Value b) -> Value { 848 return builder.create<arith::SubFOp>(a, b); 849 }; 850 auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); }; 851 852 Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE))); 853 Value cstLog2E = bcast(f32Cst(builder, static_cast<float>(LOG2E_VALUE))); 854 855 // Polynomial coefficients. 856 Value cstCephesExpP0 = bcast(f32Cst(builder, 1.0)); 857 Value cstCephesExpP1 = bcast(f32Cst(builder, 1.0)); 858 Value cstCephesExpP2 = bcast(f32Cst(builder, 0.49970514590562437052f)); 859 Value cstCephesExpP3 = bcast(f32Cst(builder, 0.16873890085469545053f)); 860 Value cstCephesExpP4 = bcast(f32Cst(builder, 0.03668965196652099192f)); 861 Value cstCephesExpP5 = bcast(f32Cst(builder, 0.01314350012789660196f)); 862 863 Value x = op.getOperand(); 864 865 // Reduced y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2) 866 Value xL2Inv = mul(x, cstLog2E); 867 Value kF32 = floor(xL2Inv); 868 Value kLn2 = mul(kF32, cstLn2); 869 Value y = sub(x, kLn2); 870 871 // Use Estrin's evaluation scheme with 3 independent parts: 872 // P(y)^y : (c0 + c1 y) + (c2 + c3 y) y^2 + (c4 + c5 y) y^4 873 Value y2 = mul(y, y); 874 Value y4 = mul(y2, y2); 875 876 Value q0 = fmla(cstCephesExpP1, y, cstCephesExpP0); 877 Value q1 = fmla(cstCephesExpP3, y, cstCephesExpP2); 878 Value q2 = fmla(cstCephesExpP5, y, cstCephesExpP4); 879 Value expY = fmla(q1, y2, q0); 880 expY = fmla(q2, y4, expY); 881 882 auto i32Vec = broadcast(builder.getI32Type(), shape); 883 884 // exp2(k) 885 Value k = builder.create<arith::FPToSIOp>(kF32, i32Vec); 886 Value exp2KValue = exp2I32(builder, k); 887 888 // exp(x) = exp(y) * exp2(k) 889 expY = mul(expY, exp2KValue); 890 891 // Handle overflow, inf and underflow of exp(x). exp(x) range is [0, inf], its 892 // partitioned as the following: 893 // exp(x) = 0, x <= -inf 894 // exp(x) = underflow (min_float), x <= -88 895 // exp(x) = inf (min_float), x >= 88 896 // Note: |k| = 127 is the value where the 8-bits exponent saturates. 897 Value zerof32Const = bcast(f32Cst(builder, 0)); 898 auto constPosInfinity = 899 bcast(f32Cst(builder, std::numeric_limits<float>::infinity())); 900 auto constNegIfinity = 901 bcast(f32Cst(builder, -std::numeric_limits<float>::infinity())); 902 auto underflow = bcast(f32Cst(builder, std::numeric_limits<float>::min())); 903 904 Value kMaxConst = bcast(i32Cst(builder, 127)); 905 Value kMaxNegConst = bcast(i32Cst(builder, -127)); 906 Value rightBound = 907 builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, k, kMaxConst); 908 Value leftBound = 909 builder.create<arith::CmpIOp>(arith::CmpIPredicate::sge, k, kMaxNegConst); 910 911 Value isNegInfinityX = builder.create<arith::CmpFOp>( 912 arith::CmpFPredicate::OEQ, x, constNegIfinity); 913 Value isPosInfinityX = builder.create<arith::CmpFOp>( 914 arith::CmpFPredicate::OEQ, x, constPosInfinity); 915 Value isPostiveX = 916 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zerof32Const); 917 Value isComputable = builder.create<arith::AndIOp>(rightBound, leftBound); 918 919 expY = builder.create<SelectOp>( 920 isNegInfinityX, zerof32Const, 921 builder.create<SelectOp>( 922 isPosInfinityX, constPosInfinity, 923 builder.create<SelectOp>(isComputable, expY, 924 builder.create<SelectOp>(isPostiveX, 925 constPosInfinity, 926 underflow)))); 927 928 rewriter.replaceOp(op, expY); 929 930 return success(); 931 } 932 933 //----------------------------------------------------------------------------// 934 // ExpM1 approximation. 935 //----------------------------------------------------------------------------// 936 937 namespace { 938 939 struct ExpM1Approximation : public OpRewritePattern<math::ExpM1Op> { 940 public: 941 using OpRewritePattern::OpRewritePattern; 942 943 LogicalResult matchAndRewrite(math::ExpM1Op op, 944 PatternRewriter &rewriter) const final; 945 }; 946 } // namespace 947 948 LogicalResult 949 ExpM1Approximation::matchAndRewrite(math::ExpM1Op op, 950 PatternRewriter &rewriter) const { 951 if (!getElementTypeOrSelf(op.getOperand()).isF32()) 952 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 953 954 ArrayRef<int64_t> shape = vectorShape(op.getOperand()); 955 956 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 957 auto bcast = [&](Value value) -> Value { 958 return broadcast(builder, value, shape); 959 }; 960 961 // expm1(x) = exp(x) - 1 = u - 1. 962 // We have to handle it carefully when x is near 0, i.e. u ~= 1, 963 // and when the input is ~= -inf, i.e. u - 1 ~= -1. 964 Value cstOne = bcast(f32Cst(builder, 1.0f)); 965 Value cstNegOne = bcast(f32Cst(builder, -1.0f)); 966 Value x = op.getOperand(); 967 Value u = builder.create<math::ExpOp>(x); 968 Value uEqOne = 969 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne); 970 Value uMinusOne = builder.create<arith::SubFOp>(u, cstOne); 971 Value uMinusOneEqNegOne = builder.create<arith::CmpFOp>( 972 arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne); 973 // logU = log(u) ~= x 974 Value logU = builder.create<math::LogOp>(u); 975 976 // Detect exp(x) = +inf; written this way to avoid having to form +inf. 977 Value isInf = 978 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, logU, u); 979 980 // (u - 1) * (x / ~x) 981 Value expm1 = builder.create<arith::MulFOp>( 982 uMinusOne, builder.create<arith::DivFOp>(x, logU)); 983 expm1 = builder.create<SelectOp>(isInf, u, expm1); 984 Value approximation = builder.create<SelectOp>( 985 uEqOne, x, builder.create<SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1)); 986 rewriter.replaceOp(op, approximation); 987 return success(); 988 } 989 990 //----------------------------------------------------------------------------// 991 // Sin and Cos approximation. 992 //----------------------------------------------------------------------------// 993 994 namespace { 995 996 template <bool isSine, typename OpTy> 997 struct SinAndCosApproximation : public OpRewritePattern<OpTy> { 998 public: 999 using OpRewritePattern<OpTy>::OpRewritePattern; 1000 1001 LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final; 1002 }; 1003 } // namespace 1004 1005 #define TWO_OVER_PI \ 1006 0.6366197723675813430755350534900574481378385829618257949906693762L 1007 #define PI_OVER_2 \ 1008 1.5707963267948966192313216916397514420985846996875529104874722961L 1009 1010 // Approximates sin(x) or cos(x) by finding the best approximation polynomial in 1011 // the reduced range [0, pi/2] for both sin(x) and cos(x). Then given y in the 1012 // reduced range sin(x) will be computed as sin(y), -sin(y), cos(y) or -cos(y). 1013 template <bool isSine, typename OpTy> 1014 LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite( 1015 OpTy op, PatternRewriter &rewriter) const { 1016 static_assert( 1017 llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value, 1018 "SinAndCosApproximation pattern expects math::SinOp or math::CosOp"); 1019 1020 if (!getElementTypeOrSelf(op.getOperand()).isF32()) 1021 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 1022 1023 ArrayRef<int64_t> shape = vectorShape(op.getOperand()); 1024 1025 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 1026 auto bcast = [&](Value value) -> Value { 1027 return broadcast(builder, value, shape); 1028 }; 1029 auto mul = [&](Value a, Value b) -> Value { 1030 return builder.create<arith::MulFOp>(a, b); 1031 }; 1032 auto sub = [&](Value a, Value b) -> Value { 1033 return builder.create<arith::SubFOp>(a, b); 1034 }; 1035 auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); }; 1036 1037 auto i32Vec = broadcast(builder.getI32Type(), shape); 1038 auto fPToSingedInteger = [&](Value a) -> Value { 1039 return builder.create<arith::FPToSIOp>(a, i32Vec); 1040 }; 1041 1042 auto modulo4 = [&](Value a) -> Value { 1043 return builder.create<arith::AndIOp>(a, bcast(i32Cst(builder, 3))); 1044 }; 1045 1046 auto isEqualTo = [&](Value a, Value b) -> Value { 1047 return builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, a, b); 1048 }; 1049 1050 auto isGreaterThan = [&](Value a, Value b) -> Value { 1051 return builder.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, a, b); 1052 }; 1053 1054 auto select = [&](Value cond, Value t, Value f) -> Value { 1055 return builder.create<SelectOp>(cond, t, f); 1056 }; 1057 1058 auto fmla = [&](Value a, Value b, Value c) { 1059 return builder.create<math::FmaOp>(a, b, c); 1060 }; 1061 1062 auto bitwiseOr = [&](Value a, Value b) { 1063 return builder.create<arith::OrIOp>(a, b); 1064 }; 1065 1066 Value twoOverPi = bcast(f32Cst(builder, TWO_OVER_PI)); 1067 Value piOverTwo = bcast(f32Cst(builder, PI_OVER_2)); 1068 1069 Value x = op.getOperand(); 1070 1071 Value k = floor(mul(x, twoOverPi)); 1072 1073 Value y = sub(x, mul(k, piOverTwo)); 1074 1075 Value cstOne = bcast(f32Cst(builder, 1.0)); 1076 Value cstNegativeOne = bcast(f32Cst(builder, -1.0)); 1077 1078 Value cstSC2 = bcast(f32Cst(builder, -0.16666667163372039794921875f)); 1079 Value cstSC4 = bcast(f32Cst(builder, 8.333347737789154052734375e-3f)); 1080 Value cstSC6 = bcast(f32Cst(builder, -1.9842604524455964565277099609375e-4f)); 1081 Value cstSC8 = 1082 bcast(f32Cst(builder, 2.760012648650445044040679931640625e-6f)); 1083 Value cstSC10 = 1084 bcast(f32Cst(builder, -2.50293279435709337121807038784027099609375e-8f)); 1085 1086 Value cstCC2 = bcast(f32Cst(builder, -0.5f)); 1087 Value cstCC4 = bcast(f32Cst(builder, 4.166664183139801025390625e-2f)); 1088 Value cstCC6 = bcast(f32Cst(builder, -1.388833043165504932403564453125e-3f)); 1089 Value cstCC8 = bcast(f32Cst(builder, 2.47562347794882953166961669921875e-5f)); 1090 Value cstCC10 = 1091 bcast(f32Cst(builder, -2.59630184018533327616751194000244140625e-7f)); 1092 1093 Value kMod4 = modulo4(fPToSingedInteger(k)); 1094 1095 Value kR0 = isEqualTo(kMod4, bcast(i32Cst(builder, 0))); 1096 Value kR1 = isEqualTo(kMod4, bcast(i32Cst(builder, 1))); 1097 Value kR2 = isEqualTo(kMod4, bcast(i32Cst(builder, 2))); 1098 Value kR3 = isEqualTo(kMod4, bcast(i32Cst(builder, 3))); 1099 1100 Value sinuseCos = isSine ? bitwiseOr(kR1, kR3) : bitwiseOr(kR0, kR2); 1101 Value negativeRange = isSine ? isGreaterThan(kMod4, bcast(i32Cst(builder, 1))) 1102 : bitwiseOr(kR1, kR2); 1103 1104 Value y2 = mul(y, y); 1105 1106 Value base = select(sinuseCos, cstOne, y); 1107 Value cstC2 = select(sinuseCos, cstCC2, cstSC2); 1108 Value cstC4 = select(sinuseCos, cstCC4, cstSC4); 1109 Value cstC6 = select(sinuseCos, cstCC6, cstSC6); 1110 Value cstC8 = select(sinuseCos, cstCC8, cstSC8); 1111 Value cstC10 = select(sinuseCos, cstCC10, cstSC10); 1112 1113 Value v1 = fmla(y2, cstC10, cstC8); 1114 Value v2 = fmla(y2, v1, cstC6); 1115 Value v3 = fmla(y2, v2, cstC4); 1116 Value v4 = fmla(y2, v3, cstC2); 1117 Value v5 = fmla(y2, v4, cstOne); 1118 Value v6 = mul(base, v5); 1119 1120 Value approximation = select(negativeRange, mul(cstNegativeOne, v6), v6); 1121 1122 rewriter.replaceOp(op, approximation); 1123 1124 return success(); 1125 } 1126 1127 //----------------------------------------------------------------------------// 1128 // Rsqrt approximation. 1129 //----------------------------------------------------------------------------// 1130 1131 namespace { 1132 struct RsqrtApproximation : public OpRewritePattern<math::RsqrtOp> { 1133 using OpRewritePattern::OpRewritePattern; 1134 1135 LogicalResult matchAndRewrite(math::RsqrtOp op, 1136 PatternRewriter &rewriter) const final; 1137 }; 1138 } // namespace 1139 1140 LogicalResult 1141 RsqrtApproximation::matchAndRewrite(math::RsqrtOp op, 1142 PatternRewriter &rewriter) const { 1143 if (!getElementTypeOrSelf(op.getOperand()).isF32()) 1144 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 1145 1146 ArrayRef<int64_t> shape = vectorShape(op.getOperand()); 1147 1148 // Only support already-vectorized rsqrt's. 1149 if (shape.empty() || shape.back() % 8 != 0) 1150 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 1151 1152 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 1153 auto bcast = [&](Value value) -> Value { 1154 return broadcast(builder, value, shape); 1155 }; 1156 1157 Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u)); 1158 Value cstOnePointFive = bcast(f32Cst(builder, 1.5f)); 1159 Value cstNegHalf = bcast(f32Cst(builder, -0.5f)); 1160 Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u)); 1161 1162 Value negHalf = builder.create<arith::MulFOp>(op.getOperand(), cstNegHalf); 1163 1164 // Select only the inverse sqrt of positive normals (denormals are 1165 // flushed to zero). 1166 Value ltMinMask = builder.create<arith::CmpFOp>( 1167 arith::CmpFPredicate::OLT, op.getOperand(), cstMinNormPos); 1168 Value infMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, 1169 op.getOperand(), cstPosInf); 1170 Value notNormalFiniteMask = builder.create<arith::OrIOp>(ltMinMask, infMask); 1171 1172 // Compute an approximate result. 1173 Value yApprox = handleMultidimensionalVectors( 1174 builder, op->getOperands(), 8, [&builder](ValueRange operands) -> Value { 1175 return builder.create<x86vector::RsqrtOp>(operands); 1176 }); 1177 1178 // Do a single step of Newton-Raphson iteration to improve the approximation. 1179 // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n). 1180 // It is essential to evaluate the inner term like this because forming 1181 // y_n^2 may over- or underflow. 1182 Value inner = builder.create<arith::MulFOp>(negHalf, yApprox); 1183 Value fma = builder.create<math::FmaOp>(yApprox, inner, cstOnePointFive); 1184 Value yNewton = builder.create<arith::MulFOp>(yApprox, fma); 1185 1186 // Select the result of the Newton-Raphson step for positive normal arguments. 1187 // For other arguments, choose the output of the intrinsic. This will 1188 // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(x) = +inf if 1189 // x is zero or a positive denormalized float (equivalent to flushing positive 1190 // denormalized inputs to zero). 1191 Value res = builder.create<SelectOp>(notNormalFiniteMask, yApprox, yNewton); 1192 rewriter.replaceOp(op, res); 1193 1194 return success(); 1195 } 1196 1197 //----------------------------------------------------------------------------// 1198 1199 void mlir::populateMathPolynomialApproximationPatterns( 1200 RewritePatternSet &patterns, 1201 const MathPolynomialApproximationOptions &options) { 1202 patterns.add<AtanApproximation, Atan2Approximation, TanhApproximation, 1203 LogApproximation, Log2Approximation, Log1pApproximation, 1204 ErfPolynomialApproximation, ExpApproximation, ExpM1Approximation, 1205 SinAndCosApproximation<true, math::SinOp>, 1206 SinAndCosApproximation<false, math::CosOp>>( 1207 patterns.getContext()); 1208 if (options.enableAvx2) 1209 patterns.add<RsqrtApproximation>(patterns.getContext()); 1210 } 1211