1 //===- PolynomialApproximation.cpp - Approximate math operations ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements expansion of math operations to fast approximations 10 // that do not rely on any of the library functions. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/Math/IR/Math.h" 16 #include "mlir/Dialect/Math/Transforms/Approximation.h" 17 #include "mlir/Dialect/Math/Transforms/Passes.h" 18 #include "mlir/Dialect/Vector/VectorOps.h" 19 #include "mlir/Dialect/Vector/VectorUtils.h" 20 #include "mlir/Dialect/X86Vector/X86VectorDialect.h" 21 #include "mlir/IR/Builders.h" 22 #include "mlir/IR/ImplicitLocOpBuilder.h" 23 #include "mlir/Transforms/Bufferize.h" 24 #include "mlir/Transforms/DialectConversion.h" 25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 26 #include "llvm/ADT/ArrayRef.h" 27 #include <climits> 28 #include <cstddef> 29 30 using namespace mlir; 31 using namespace mlir::math; 32 using namespace mlir::vector; 33 34 using TypePredicate = llvm::function_ref<bool(Type)>; 35 36 // Returns vector shape if the element type is matching the predicate (scalars 37 // that do match the predicate have shape equal to `{1}`). 38 static Optional<SmallVector<int64_t, 2>> vectorShape(Type type, 39 TypePredicate pred) { 40 // If the type matches the predicate then its shape is `{1}`. 41 if (pred(type)) 42 return SmallVector<int64_t, 2>{1}; 43 44 // Otherwise check if the type is a vector type. 45 auto vectorType = type.dyn_cast<VectorType>(); 46 if (vectorType && pred(vectorType.getElementType())) { 47 return llvm::to_vector<2>(vectorType.getShape()); 48 } 49 50 return llvm::None; 51 } 52 53 // Returns vector shape of the type. If the type is a scalar returns `1`. 54 static SmallVector<int64_t, 2> vectorShape(Type type) { 55 auto vectorType = type.dyn_cast<VectorType>(); 56 return vectorType ? llvm::to_vector<2>(vectorType.getShape()) 57 : SmallVector<int64_t, 2>{1}; 58 } 59 60 // Returns vector element type. If the type is a scalar returns the argument. 61 LLVM_ATTRIBUTE_UNUSED static Type elementType(Type type) { 62 auto vectorType = type.dyn_cast<VectorType>(); 63 return vectorType ? vectorType.getElementType() : type; 64 } 65 66 LLVM_ATTRIBUTE_UNUSED static bool isF32(Type type) { return type.isF32(); } 67 68 LLVM_ATTRIBUTE_UNUSED static bool isI32(Type type) { 69 return type.isInteger(32); 70 } 71 72 //----------------------------------------------------------------------------// 73 // Broadcast scalar types and values into vector types and values. 74 //----------------------------------------------------------------------------// 75 76 // Returns true if shape != {1}. 77 static bool isNonScalarShape(ArrayRef<int64_t> shape) { 78 return shape.size() > 1 || shape[0] > 1; 79 } 80 81 // Broadcasts scalar type into vector type (iff shape is non-scalar). 82 static Type broadcast(Type type, ArrayRef<int64_t> shape) { 83 assert(!type.isa<VectorType>() && "must be scalar type"); 84 return isNonScalarShape(shape) ? VectorType::get(shape, type) : type; 85 } 86 87 // Broadcasts scalar value into vector (iff shape is non-scalar). 88 static Value broadcast(ImplicitLocOpBuilder &builder, Value value, 89 ArrayRef<int64_t> shape) { 90 assert(!value.getType().isa<VectorType>() && "must be scalar value"); 91 auto type = broadcast(value.getType(), shape); 92 return isNonScalarShape(shape) ? builder.create<BroadcastOp>(type, value) 93 : value; 94 } 95 96 //----------------------------------------------------------------------------// 97 // Helper function to handle n-D vectors with 1-D operations. 98 //----------------------------------------------------------------------------// 99 100 // Expands and unrolls n-D vector operands into multiple fixed size 1-D vectors 101 // and calls the compute function with 1-D vector operands. Stitches back all 102 // results into the original n-D vector result. 103 // 104 // Examples: vectorWidth = 8 105 // - vector<4x8xf32> unrolled 4 times 106 // - vector<16xf32> expanded to vector<2x8xf32> and unrolled 2 times 107 // - vector<4x16xf32> expanded to vector<4x2x8xf32> and unrolled 4*2 times 108 // 109 // Some math approximations rely on ISA-specific operations that only accept 110 // fixed size 1-D vectors (e.g. AVX expects vectors of width 8). 111 // 112 // It is the caller's responsibility to verify that the inner dimension is 113 // divisible by the vectorWidth, and that all operands have the same vector 114 // shape. 115 static Value 116 handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, 117 ValueRange operands, int64_t vectorWidth, 118 std::function<Value(ValueRange)> compute) { 119 assert(!operands.empty() && "operands must be not empty"); 120 assert(vectorWidth > 0 && "vector width must be larger than 0"); 121 122 VectorType inputType = operands[0].getType().cast<VectorType>(); 123 ArrayRef<int64_t> inputShape = inputType.getShape(); 124 125 // If input shape matches target vector width, we can just call the 126 // user-provided compute function with the operands. 127 if (inputShape == llvm::makeArrayRef(vectorWidth)) 128 return compute(operands); 129 130 // Check if the inner dimension has to be expanded, or we can directly iterate 131 // over the outer dimensions of the vector. 132 int64_t innerDim = inputShape.back(); 133 int64_t expansionDim = innerDim / vectorWidth; 134 assert((innerDim % vectorWidth == 0) && "invalid inner dimension size"); 135 136 // Maybe expand operands to the higher rank vector shape that we'll use to 137 // iterate over and extract one dimensional vectors. 138 SmallVector<int64_t> expandedShape(inputShape.begin(), inputShape.end()); 139 SmallVector<Value> expandedOperands(operands); 140 141 if (expansionDim > 1) { 142 // Expand shape from [..., innerDim] to [..., expansionDim, vectorWidth]. 143 expandedShape.insert(expandedShape.end() - 1, expansionDim); 144 expandedShape.back() = vectorWidth; 145 146 for (unsigned i = 0; i < operands.size(); ++i) { 147 auto operand = operands[i]; 148 auto eltType = operand.getType().cast<VectorType>().getElementType(); 149 auto expandedType = VectorType::get(expandedShape, eltType); 150 expandedOperands[i] = 151 builder.create<vector::ShapeCastOp>(expandedType, operand); 152 } 153 } 154 155 // Iterate over all outer dimensions of the compute shape vector type. 156 auto iterationDims = ArrayRef<int64_t>(expandedShape).drop_back(); 157 int64_t maxLinearIndex = computeMaxLinearIndex(iterationDims); 158 159 SmallVector<int64_t> ones(iterationDims.size(), 1); 160 auto strides = computeStrides(iterationDims, ones); 161 162 // Compute results for each one dimensional vector. 163 SmallVector<Value> results(maxLinearIndex); 164 165 for (int64_t i = 0; i < maxLinearIndex; ++i) { 166 auto offsets = delinearize(strides, i); 167 168 SmallVector<Value> extracted(expandedOperands.size()); 169 for (auto tuple : llvm::enumerate(expandedOperands)) 170 extracted[tuple.index()] = 171 builder.create<vector::ExtractOp>(tuple.value(), offsets); 172 173 results[i] = compute(extracted); 174 } 175 176 // Stitch results together into one large vector. 177 Type resultEltType = results[0].getType().cast<VectorType>().getElementType(); 178 Type resultExpandedType = VectorType::get(expandedShape, resultEltType); 179 Value result = builder.create<ConstantOp>( 180 resultExpandedType, builder.getZeroAttr(resultExpandedType)); 181 182 for (int64_t i = 0; i < maxLinearIndex; ++i) 183 result = builder.create<vector::InsertOp>(results[i], result, 184 delinearize(strides, i)); 185 186 // Reshape back to the original vector shape. 187 return builder.create<vector::ShapeCastOp>( 188 VectorType::get(inputShape, resultEltType), result); 189 } 190 191 //----------------------------------------------------------------------------// 192 // Helper functions to create constants. 193 //----------------------------------------------------------------------------// 194 195 static Value f32Cst(ImplicitLocOpBuilder &builder, float value) { 196 return builder.create<arith::ConstantOp>(builder.getF32FloatAttr(value)); 197 } 198 199 static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value) { 200 return builder.create<arith::ConstantOp>(builder.getI32IntegerAttr(value)); 201 } 202 203 static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits) { 204 Value i32Value = i32Cst(builder, static_cast<int32_t>(bits)); 205 return builder.create<arith::BitcastOp>(builder.getF32Type(), i32Value); 206 } 207 208 //----------------------------------------------------------------------------// 209 // Helper functions to build math functions approximations. 210 //----------------------------------------------------------------------------// 211 212 static Value min(ImplicitLocOpBuilder &builder, Value a, Value b) { 213 return builder.create<SelectOp>( 214 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, a, b), a, b); 215 } 216 217 static Value max(ImplicitLocOpBuilder &builder, Value a, Value b) { 218 return builder.create<SelectOp>( 219 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, a, b), a, b); 220 } 221 222 static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, 223 Value upperBound) { 224 return max(builder, min(builder, value, upperBound), lowerBound); 225 } 226 227 // Decomposes given floating point value `arg` into a normalized fraction and 228 // an integral power of two (see std::frexp). Returned values have float type. 229 static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg, 230 bool is_positive = false) { 231 assert(isF32(elementType(arg.getType())) && "argument must be f32 type"); 232 233 auto shape = vectorShape(arg.getType()); 234 235 auto bcast = [&](Value value) -> Value { 236 return broadcast(builder, value, shape); 237 }; 238 239 auto i32 = builder.getIntegerType(32); 240 auto i32Vec = broadcast(i32, shape); 241 auto f32Vec = broadcast(builder.getF32Type(), shape); 242 243 Value cst126f = f32Cst(builder, 126.0f); 244 Value cstHalf = f32Cst(builder, 0.5f); 245 Value cstInvMantMask = f32FromBits(builder, ~0x7f800000u); 246 247 // Bitcast to i32 for bitwise operations. 248 Value i32Half = builder.create<arith::BitcastOp>(i32, cstHalf); 249 Value i32InvMantMask = builder.create<arith::BitcastOp>(i32, cstInvMantMask); 250 Value i32Arg = builder.create<arith::BitcastOp>(i32Vec, arg); 251 252 // Compute normalized fraction. 253 Value tmp0 = builder.create<arith::AndIOp>(i32Arg, bcast(i32InvMantMask)); 254 Value tmp1 = builder.create<arith::OrIOp>(tmp0, bcast(i32Half)); 255 Value normalizedFraction = builder.create<arith::BitcastOp>(f32Vec, tmp1); 256 257 // Compute exponent. 258 Value arg0 = is_positive ? arg : builder.create<math::AbsOp>(arg); 259 Value biasedExponentBits = builder.create<arith::ShRUIOp>( 260 builder.create<arith::BitcastOp>(i32Vec, arg0), 261 bcast(i32Cst(builder, 23))); 262 Value biasedExponent = 263 builder.create<arith::SIToFPOp>(f32Vec, biasedExponentBits); 264 Value exponent = 265 builder.create<arith::SubFOp>(biasedExponent, bcast(cst126f)); 266 267 return {normalizedFraction, exponent}; 268 } 269 270 // Computes exp2 for an i32 argument. 271 static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) { 272 assert(isI32(elementType(arg.getType())) && "argument must be i32 type"); 273 274 auto shape = vectorShape(arg.getType()); 275 276 auto bcast = [&](Value value) -> Value { 277 return broadcast(builder, value, shape); 278 }; 279 280 auto f32Vec = broadcast(builder.getF32Type(), shape); 281 // The exponent of f32 located at 23-bit. 282 auto exponetBitLocation = bcast(i32Cst(builder, 23)); 283 // Set the exponent bias to zero. 284 auto bias = bcast(i32Cst(builder, 127)); 285 286 Value biasedArg = builder.create<arith::AddIOp>(arg, bias); 287 Value exp2ValueInt = 288 builder.create<arith::ShLIOp>(biasedArg, exponetBitLocation); 289 Value exp2ValueF32 = builder.create<arith::BitcastOp>(f32Vec, exp2ValueInt); 290 291 return exp2ValueF32; 292 } 293 294 namespace { 295 Value makePolynomialCalculation(ImplicitLocOpBuilder &builder, 296 llvm::ArrayRef<Value> coeffs, Value x) { 297 auto shape = vectorShape(x.getType(), isF32); 298 if (coeffs.size() == 0) { 299 return broadcast(builder, f32Cst(builder, 0.0f), *shape); 300 } else if (coeffs.size() == 1) { 301 return coeffs[0]; 302 } 303 Value res = builder.create<math::FmaOp>(x, coeffs[coeffs.size() - 1], 304 coeffs[coeffs.size() - 2]); 305 for (auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) { 306 res = builder.create<math::FmaOp>(x, res, coeffs[i]); 307 } 308 return res; 309 } 310 } // namespace 311 312 //----------------------------------------------------------------------------// 313 // TanhOp approximation. 314 //----------------------------------------------------------------------------// 315 316 namespace { 317 struct TanhApproximation : public OpRewritePattern<math::TanhOp> { 318 public: 319 using OpRewritePattern::OpRewritePattern; 320 321 LogicalResult matchAndRewrite(math::TanhOp op, 322 PatternRewriter &rewriter) const final; 323 }; 324 } // namespace 325 326 LogicalResult 327 TanhApproximation::matchAndRewrite(math::TanhOp op, 328 PatternRewriter &rewriter) const { 329 auto shape = vectorShape(op.operand().getType(), isF32); 330 if (!shape.hasValue()) 331 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 332 333 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 334 auto bcast = [&](Value value) -> Value { 335 return broadcast(builder, value, *shape); 336 }; 337 338 // Clamp operand into [plusClamp, minusClamp] range. 339 Value minusClamp = bcast(f32Cst(builder, -7.99881172180175781f)); 340 Value plusClamp = bcast(f32Cst(builder, 7.99881172180175781f)); 341 Value x = clamp(builder, op.operand(), minusClamp, plusClamp); 342 343 // Mask for tiny values that are approximated with `operand`. 344 Value tiny = bcast(f32Cst(builder, 0.0004f)); 345 Value tinyMask = builder.create<arith::CmpFOp>( 346 arith::CmpFPredicate::OLT, builder.create<math::AbsOp>(op.operand()), 347 tiny); 348 349 // The monomial coefficients of the numerator polynomial (odd). 350 Value alpha1 = bcast(f32Cst(builder, 4.89352455891786e-03f)); 351 Value alpha3 = bcast(f32Cst(builder, 6.37261928875436e-04f)); 352 Value alpha5 = bcast(f32Cst(builder, 1.48572235717979e-05f)); 353 Value alpha7 = bcast(f32Cst(builder, 5.12229709037114e-08f)); 354 Value alpha9 = bcast(f32Cst(builder, -8.60467152213735e-11f)); 355 Value alpha11 = bcast(f32Cst(builder, 2.00018790482477e-13f)); 356 Value alpha13 = bcast(f32Cst(builder, -2.76076847742355e-16f)); 357 358 // The monomial coefficients of the denominator polynomial (even). 359 Value beta0 = bcast(f32Cst(builder, 4.89352518554385e-03f)); 360 Value beta2 = bcast(f32Cst(builder, 2.26843463243900e-03f)); 361 Value beta4 = bcast(f32Cst(builder, 1.18534705686654e-04f)); 362 Value beta6 = bcast(f32Cst(builder, 1.19825839466702e-06f)); 363 364 // Since the polynomials are odd/even, we need x^2. 365 Value x2 = builder.create<arith::MulFOp>(x, x); 366 367 // Evaluate the numerator polynomial p. 368 Value p = builder.create<math::FmaOp>(x2, alpha13, alpha11); 369 p = builder.create<math::FmaOp>(x2, p, alpha9); 370 p = builder.create<math::FmaOp>(x2, p, alpha7); 371 p = builder.create<math::FmaOp>(x2, p, alpha5); 372 p = builder.create<math::FmaOp>(x2, p, alpha3); 373 p = builder.create<math::FmaOp>(x2, p, alpha1); 374 p = builder.create<arith::MulFOp>(x, p); 375 376 // Evaluate the denominator polynomial q. 377 Value q = builder.create<math::FmaOp>(x2, beta6, beta4); 378 q = builder.create<math::FmaOp>(x2, q, beta2); 379 q = builder.create<math::FmaOp>(x2, q, beta0); 380 381 // Divide the numerator by the denominator. 382 Value res = builder.create<SelectOp>(tinyMask, x, 383 builder.create<arith::DivFOp>(p, q)); 384 385 rewriter.replaceOp(op, res); 386 387 return success(); 388 } 389 390 #define LN2_VALUE \ 391 0.693147180559945309417232121458176568075500134360255254120680009493393621L 392 #define LOG2E_VALUE \ 393 1.442695040888963407359924681001892137426645954152985934135449406931109219L 394 395 //----------------------------------------------------------------------------// 396 // LogOp and Log2Op approximation. 397 //----------------------------------------------------------------------------// 398 399 namespace { 400 template <typename Op> 401 struct LogApproximationBase : public OpRewritePattern<Op> { 402 using OpRewritePattern<Op>::OpRewritePattern; 403 404 /// Base 2 if 'base2' is set; natural logarithm (base e) otherwise. 405 LogicalResult logMatchAndRewrite(Op op, PatternRewriter &rewriter, 406 bool base2) const; 407 }; 408 } // namespace 409 410 // This approximation comes from Julien Pommier's SSE math library. 411 // Link: http://gruntthepeon.free.fr/ssemath 412 template <typename Op> 413 LogicalResult 414 LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter, 415 bool base2) const { 416 auto shape = vectorShape(op.operand().getType(), isF32); 417 if (!shape.hasValue()) 418 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 419 420 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 421 auto bcast = [&](Value value) -> Value { 422 return broadcast(builder, value, *shape); 423 }; 424 425 Value cstZero = bcast(f32Cst(builder, 0.0f)); 426 Value cstOne = bcast(f32Cst(builder, 1.0f)); 427 Value cstNegHalf = bcast(f32Cst(builder, -0.5f)); 428 429 // The smallest non denormalized float number. 430 Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u)); 431 Value cstMinusInf = bcast(f32FromBits(builder, 0xff800000u)); 432 Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u)); 433 Value cstNan = bcast(f32FromBits(builder, 0x7fc00000)); 434 435 // Polynomial coefficients. 436 Value cstCephesSQRTHF = bcast(f32Cst(builder, 0.707106781186547524f)); 437 Value cstCephesLogP0 = bcast(f32Cst(builder, 7.0376836292E-2f)); 438 Value cstCephesLogP1 = bcast(f32Cst(builder, -1.1514610310E-1f)); 439 Value cstCephesLogP2 = bcast(f32Cst(builder, 1.1676998740E-1f)); 440 Value cstCephesLogP3 = bcast(f32Cst(builder, -1.2420140846E-1f)); 441 Value cstCephesLogP4 = bcast(f32Cst(builder, +1.4249322787E-1f)); 442 Value cstCephesLogP5 = bcast(f32Cst(builder, -1.6668057665E-1f)); 443 Value cstCephesLogP6 = bcast(f32Cst(builder, +2.0000714765E-1f)); 444 Value cstCephesLogP7 = bcast(f32Cst(builder, -2.4999993993E-1f)); 445 Value cstCephesLogP8 = bcast(f32Cst(builder, +3.3333331174E-1f)); 446 447 Value x = op.operand(); 448 449 // Truncate input values to the minimum positive normal. 450 x = max(builder, x, cstMinNormPos); 451 452 // Extract significant in the range [0.5,1) and exponent. 453 std::pair<Value, Value> pair = frexp(builder, x, /*is_positive=*/true); 454 x = pair.first; 455 Value e = pair.second; 456 457 // Shift the inputs from the range [0.5,1) to [sqrt(1/2), sqrt(2)) and shift 458 // by -1.0. The values are then centered around 0, which improves the 459 // stability of the polynomial evaluation: 460 // 461 // if( x < SQRTHF ) { 462 // e -= 1; 463 // x = x + x - 1.0; 464 // } else { x = x - 1.0; } 465 Value mask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, 466 cstCephesSQRTHF); 467 Value tmp = builder.create<SelectOp>(mask, x, cstZero); 468 469 x = builder.create<arith::SubFOp>(x, cstOne); 470 e = builder.create<arith::SubFOp>( 471 e, builder.create<SelectOp>(mask, cstOne, cstZero)); 472 x = builder.create<arith::AddFOp>(x, tmp); 473 474 Value x2 = builder.create<arith::MulFOp>(x, x); 475 Value x3 = builder.create<arith::MulFOp>(x2, x); 476 477 // Evaluate the polynomial approximant of degree 8 in three parts. 478 Value y0, y1, y2; 479 y0 = builder.create<math::FmaOp>(cstCephesLogP0, x, cstCephesLogP1); 480 y1 = builder.create<math::FmaOp>(cstCephesLogP3, x, cstCephesLogP4); 481 y2 = builder.create<math::FmaOp>(cstCephesLogP6, x, cstCephesLogP7); 482 y0 = builder.create<math::FmaOp>(y0, x, cstCephesLogP2); 483 y1 = builder.create<math::FmaOp>(y1, x, cstCephesLogP5); 484 y2 = builder.create<math::FmaOp>(y2, x, cstCephesLogP8); 485 y0 = builder.create<math::FmaOp>(y0, x3, y1); 486 y0 = builder.create<math::FmaOp>(y0, x3, y2); 487 y0 = builder.create<arith::MulFOp>(y0, x3); 488 489 y0 = builder.create<math::FmaOp>(cstNegHalf, x2, y0); 490 x = builder.create<arith::AddFOp>(x, y0); 491 492 if (base2) { 493 Value cstLog2e = bcast(f32Cst(builder, static_cast<float>(LOG2E_VALUE))); 494 x = builder.create<math::FmaOp>(x, cstLog2e, e); 495 } else { 496 Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE))); 497 x = builder.create<math::FmaOp>(e, cstLn2, x); 498 } 499 500 Value invalidMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, 501 op.operand(), cstZero); 502 Value zeroMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, 503 op.operand(), cstZero); 504 Value posInfMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, 505 op.operand(), cstPosInf); 506 507 // Filter out invalid values: 508 // • x == 0 -> -INF 509 // • x < 0 -> NAN 510 // • x == +INF -> +INF 511 Value aproximation = builder.create<SelectOp>( 512 zeroMask, cstMinusInf, 513 builder.create<SelectOp>( 514 invalidMask, cstNan, 515 builder.create<SelectOp>(posInfMask, cstPosInf, x))); 516 517 rewriter.replaceOp(op, aproximation); 518 519 return success(); 520 } 521 522 namespace { 523 struct LogApproximation : public LogApproximationBase<math::LogOp> { 524 using LogApproximationBase::LogApproximationBase; 525 526 LogicalResult matchAndRewrite(math::LogOp op, 527 PatternRewriter &rewriter) const final { 528 return logMatchAndRewrite(op, rewriter, /*base2=*/false); 529 } 530 }; 531 } // namespace 532 533 namespace { 534 struct Log2Approximation : public LogApproximationBase<math::Log2Op> { 535 using LogApproximationBase::LogApproximationBase; 536 537 LogicalResult matchAndRewrite(math::Log2Op op, 538 PatternRewriter &rewriter) const final { 539 return logMatchAndRewrite(op, rewriter, /*base2=*/true); 540 } 541 }; 542 } // namespace 543 544 //----------------------------------------------------------------------------// 545 // Log1p approximation. 546 //----------------------------------------------------------------------------// 547 548 namespace { 549 struct Log1pApproximation : public OpRewritePattern<math::Log1pOp> { 550 public: 551 using OpRewritePattern::OpRewritePattern; 552 553 LogicalResult matchAndRewrite(math::Log1pOp op, 554 PatternRewriter &rewriter) const final; 555 }; 556 } // namespace 557 558 // Approximate log(1+x). 559 LogicalResult 560 Log1pApproximation::matchAndRewrite(math::Log1pOp op, 561 PatternRewriter &rewriter) const { 562 auto shape = vectorShape(op.operand().getType(), isF32); 563 if (!shape.hasValue()) 564 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 565 566 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 567 auto bcast = [&](Value value) -> Value { 568 return broadcast(builder, value, *shape); 569 }; 570 571 // Approximate log(1+x) using the following, due to W. Kahan: 572 // u = x + 1.0; 573 // if (u == 1.0 || u == inf) return x; 574 // return x * log(u) / (u - 1.0); 575 // ^^^^^^^^^^^^^^^^^^^^^^ 576 // "logLarge" below. 577 Value cstOne = bcast(f32Cst(builder, 1.0f)); 578 Value x = op.operand(); 579 Value u = builder.create<arith::AddFOp>(x, cstOne); 580 Value uSmall = 581 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne); 582 Value logU = builder.create<math::LogOp>(u); 583 Value uInf = 584 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, logU); 585 Value logLarge = builder.create<arith::MulFOp>( 586 x, builder.create<arith::DivFOp>( 587 logU, builder.create<arith::SubFOp>(u, cstOne))); 588 Value approximation = builder.create<SelectOp>( 589 builder.create<arith::OrIOp>(uSmall, uInf), x, logLarge); 590 rewriter.replaceOp(op, approximation); 591 return success(); 592 } 593 594 //----------------------------------------------------------------------------// 595 // Erf approximation. 596 //----------------------------------------------------------------------------// 597 598 // Approximates erf(x) with 599 // a - P(x)/Q(x) 600 // where P and Q are polynomials of degree 4. 601 // Different coefficients are chosen based on the value of x. 602 // The approximation error is ~2.5e-07. 603 // Boost's minimax tool that utilizes the Remez method was used to find the 604 // coefficients. 605 LogicalResult 606 ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op, 607 PatternRewriter &rewriter) const { 608 auto shape = vectorShape(op.operand().getType(), isF32); 609 if (!shape.hasValue()) 610 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 611 612 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 613 auto bcast = [&](Value value) -> Value { 614 return broadcast(builder, value, *shape); 615 }; 616 617 const int intervalsCount = 3; 618 const int polyDegree = 4; 619 620 Value zero = bcast(f32Cst(builder, 0)); 621 Value one = bcast(f32Cst(builder, 1)); 622 Value pp[intervalsCount][polyDegree + 1]; 623 pp[0][0] = bcast(f32Cst(builder, +0.00000000000000000e+00f)); 624 pp[0][1] = bcast(f32Cst(builder, +1.12837916222975858e+00f)); 625 pp[0][2] = bcast(f32Cst(builder, -5.23018562988006470e-01f)); 626 pp[0][3] = bcast(f32Cst(builder, +2.09741709609267072e-01f)); 627 pp[0][4] = bcast(f32Cst(builder, +2.58146801602987875e-02f)); 628 pp[1][0] = bcast(f32Cst(builder, +0.00000000000000000e+00f)); 629 pp[1][1] = bcast(f32Cst(builder, +1.12750687816789140e+00f)); 630 pp[1][2] = bcast(f32Cst(builder, -3.64721408487825775e-01f)); 631 pp[1][3] = bcast(f32Cst(builder, +1.18407396425136952e-01f)); 632 pp[1][4] = bcast(f32Cst(builder, +3.70645533056476558e-02f)); 633 pp[2][0] = bcast(f32Cst(builder, -3.30093071049483172e-03f)); 634 pp[2][1] = bcast(f32Cst(builder, +3.51961938357697011e-03f)); 635 pp[2][2] = bcast(f32Cst(builder, -1.41373622814988039e-03f)); 636 pp[2][3] = bcast(f32Cst(builder, +2.53447094961941348e-04f)); 637 pp[2][4] = bcast(f32Cst(builder, -1.71048029455037401e-05f)); 638 639 Value qq[intervalsCount][polyDegree + 1]; 640 qq[0][0] = bcast(f32Cst(builder, +1.000000000000000000e+00f)); 641 qq[0][1] = bcast(f32Cst(builder, -4.635138185962547255e-01f)); 642 qq[0][2] = bcast(f32Cst(builder, +5.192301327279782447e-01f)); 643 qq[0][3] = bcast(f32Cst(builder, -1.318089722204810087e-01f)); 644 qq[0][4] = bcast(f32Cst(builder, +7.397964654672315005e-02f)); 645 qq[1][0] = bcast(f32Cst(builder, +1.00000000000000000e+00f)); 646 qq[1][1] = bcast(f32Cst(builder, -3.27607011824493086e-01f)); 647 qq[1][2] = bcast(f32Cst(builder, +4.48369090658821977e-01f)); 648 qq[1][3] = bcast(f32Cst(builder, -8.83462621207857930e-02f)); 649 qq[1][4] = bcast(f32Cst(builder, +5.72442770283176093e-02f)); 650 qq[2][0] = bcast(f32Cst(builder, +1.00000000000000000e+00f)); 651 qq[2][1] = bcast(f32Cst(builder, -2.06069165953913769e+00f)); 652 qq[2][2] = bcast(f32Cst(builder, +1.62705939945477759e+00f)); 653 qq[2][3] = bcast(f32Cst(builder, -5.83389859211130017e-01f)); 654 qq[2][4] = bcast(f32Cst(builder, +8.21908939856640930e-02f)); 655 656 Value offsets[intervalsCount]; 657 offsets[0] = bcast(f32Cst(builder, 0.0f)); 658 offsets[1] = bcast(f32Cst(builder, 0.0f)); 659 offsets[2] = bcast(f32Cst(builder, 1.0f)); 660 661 Value bounds[intervalsCount]; 662 bounds[0] = bcast(f32Cst(builder, 0.8f)); 663 bounds[1] = bcast(f32Cst(builder, 2.0f)); 664 bounds[2] = bcast(f32Cst(builder, 3.75f)); 665 666 Value isNegativeArg = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, 667 op.operand(), zero); 668 Value negArg = builder.create<arith::NegFOp>(op.operand()); 669 Value x = builder.create<SelectOp>(isNegativeArg, negArg, op.operand()); 670 671 Value offset = offsets[0]; 672 Value p[polyDegree + 1]; 673 Value q[polyDegree + 1]; 674 for (int i = 0; i <= polyDegree; ++i) { 675 p[i] = pp[0][i]; 676 q[i] = qq[0][i]; 677 } 678 679 // TODO: maybe use vector stacking to reduce the number of selects. 680 Value isLessThanBound[intervalsCount]; 681 for (int j = 0; j < intervalsCount - 1; ++j) { 682 isLessThanBound[j] = 683 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, bounds[j]); 684 for (int i = 0; i <= polyDegree; ++i) { 685 p[i] = builder.create<SelectOp>(isLessThanBound[j], p[i], pp[j + 1][i]); 686 q[i] = builder.create<SelectOp>(isLessThanBound[j], q[i], qq[j + 1][i]); 687 } 688 offset = 689 builder.create<SelectOp>(isLessThanBound[j], offset, offsets[j + 1]); 690 } 691 isLessThanBound[intervalsCount - 1] = builder.create<arith::CmpFOp>( 692 arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]); 693 694 Value pPoly = makePolynomialCalculation(builder, p, x); 695 Value qPoly = makePolynomialCalculation(builder, q, x); 696 Value rationalPoly = builder.create<arith::DivFOp>(pPoly, qPoly); 697 Value formula = builder.create<arith::AddFOp>(offset, rationalPoly); 698 formula = builder.create<SelectOp>(isLessThanBound[intervalsCount - 1], 699 formula, one); 700 701 // erf is odd function: erf(x) = -erf(-x). 702 Value negFormula = builder.create<arith::NegFOp>(formula); 703 Value res = builder.create<SelectOp>(isNegativeArg, negFormula, formula); 704 705 rewriter.replaceOp(op, res); 706 707 return success(); 708 } 709 710 //----------------------------------------------------------------------------// 711 // Exp approximation. 712 //----------------------------------------------------------------------------// 713 714 namespace { 715 716 struct ExpApproximation : public OpRewritePattern<math::ExpOp> { 717 public: 718 using OpRewritePattern::OpRewritePattern; 719 720 LogicalResult matchAndRewrite(math::ExpOp op, 721 PatternRewriter &rewriter) const final; 722 }; 723 } // namespace 724 725 // Approximate exp(x) using its reduced range exp(y) where y is in the range 726 // [0, ln(2)], let y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2), exp(x) 727 // = exp(y) * 2^k. exp(y). 728 LogicalResult 729 ExpApproximation::matchAndRewrite(math::ExpOp op, 730 PatternRewriter &rewriter) const { 731 auto shape = vectorShape(op.operand().getType(), isF32); 732 if (!shape.hasValue()) 733 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 734 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 735 736 // TODO: Consider a common pattern rewriter with all methods below to 737 // write the approximations. 738 auto bcast = [&](Value value) -> Value { 739 return broadcast(builder, value, *shape); 740 }; 741 auto fmla = [&](Value a, Value b, Value c) { 742 return builder.create<math::FmaOp>(a, b, c); 743 }; 744 auto mul = [&](Value a, Value b) -> Value { 745 return builder.create<arith::MulFOp>(a, b); 746 }; 747 auto sub = [&](Value a, Value b) -> Value { 748 return builder.create<arith::SubFOp>(a, b); 749 }; 750 auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); }; 751 752 Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE))); 753 Value cstLog2E = bcast(f32Cst(builder, static_cast<float>(LOG2E_VALUE))); 754 755 // Polynomial coefficients. 756 Value cstCephesExpP0 = bcast(f32Cst(builder, 1.0)); 757 Value cstCephesExpP1 = bcast(f32Cst(builder, 1.0)); 758 Value cstCephesExpP2 = bcast(f32Cst(builder, 0.49970514590562437052f)); 759 Value cstCephesExpP3 = bcast(f32Cst(builder, 0.16873890085469545053f)); 760 Value cstCephesExpP4 = bcast(f32Cst(builder, 0.03668965196652099192f)); 761 Value cstCephesExpP5 = bcast(f32Cst(builder, 0.01314350012789660196f)); 762 763 Value x = op.operand(); 764 765 // Reduced y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2) 766 Value xL2Inv = mul(x, cstLog2E); 767 Value kF32 = floor(xL2Inv); 768 Value kLn2 = mul(kF32, cstLn2); 769 Value y = sub(x, kLn2); 770 771 // Use Estrin's evaluation scheme with 3 independent parts: 772 // P(y)^y : (c0 + c1 y) + (c2 + c3 y) y^2 + (c4 + c5 y) y^4 773 Value y2 = mul(y, y); 774 Value y4 = mul(y2, y2); 775 776 Value q0 = fmla(cstCephesExpP1, y, cstCephesExpP0); 777 Value q1 = fmla(cstCephesExpP3, y, cstCephesExpP2); 778 Value q2 = fmla(cstCephesExpP5, y, cstCephesExpP4); 779 Value expY = fmla(q1, y2, q0); 780 expY = fmla(q2, y4, expY); 781 782 auto i32Vec = broadcast(builder.getI32Type(), *shape); 783 784 // exp2(k) 785 Value k = builder.create<arith::FPToSIOp>(kF32, i32Vec); 786 Value exp2KValue = exp2I32(builder, k); 787 788 // exp(x) = exp(y) * exp2(k) 789 expY = mul(expY, exp2KValue); 790 791 // Handle overflow, inf and underflow of exp(x). exp(x) range is [0, inf], its 792 // partitioned as the following: 793 // exp(x) = 0, x <= -inf 794 // exp(x) = underflow (min_float), x <= -88 795 // exp(x) = inf (min_float), x >= 88 796 // Note: |k| = 127 is the value where the 8-bits exponent saturates. 797 Value zerof32Const = bcast(f32Cst(builder, 0)); 798 auto constPosInfinity = 799 bcast(f32Cst(builder, std::numeric_limits<float>::infinity())); 800 auto constNegIfinity = 801 bcast(f32Cst(builder, -std::numeric_limits<float>::infinity())); 802 auto underflow = bcast(f32Cst(builder, std::numeric_limits<float>::min())); 803 804 Value kMaxConst = bcast(i32Cst(builder, 127)); 805 Value kMaxNegConst = bcast(i32Cst(builder, -127)); 806 Value rightBound = 807 builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, k, kMaxConst); 808 Value leftBound = 809 builder.create<arith::CmpIOp>(arith::CmpIPredicate::sge, k, kMaxNegConst); 810 811 Value isNegInfinityX = builder.create<arith::CmpFOp>( 812 arith::CmpFPredicate::OEQ, x, constNegIfinity); 813 Value isPosInfinityX = builder.create<arith::CmpFOp>( 814 arith::CmpFPredicate::OEQ, x, constPosInfinity); 815 Value isPostiveX = 816 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zerof32Const); 817 Value isComputable = builder.create<arith::AndIOp>(rightBound, leftBound); 818 819 expY = builder.create<SelectOp>( 820 isNegInfinityX, zerof32Const, 821 builder.create<SelectOp>( 822 isPosInfinityX, constPosInfinity, 823 builder.create<SelectOp>(isComputable, expY, 824 builder.create<SelectOp>(isPostiveX, 825 constPosInfinity, 826 underflow)))); 827 828 rewriter.replaceOp(op, expY); 829 830 return success(); 831 } 832 833 //----------------------------------------------------------------------------// 834 // ExpM1 approximation. 835 //----------------------------------------------------------------------------// 836 837 namespace { 838 839 struct ExpM1Approximation : public OpRewritePattern<math::ExpM1Op> { 840 public: 841 using OpRewritePattern::OpRewritePattern; 842 843 LogicalResult matchAndRewrite(math::ExpM1Op op, 844 PatternRewriter &rewriter) const final; 845 }; 846 } // namespace 847 848 LogicalResult 849 ExpM1Approximation::matchAndRewrite(math::ExpM1Op op, 850 PatternRewriter &rewriter) const { 851 auto shape = vectorShape(op.operand().getType(), isF32); 852 if (!shape.hasValue()) 853 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 854 855 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 856 auto bcast = [&](Value value) -> Value { 857 return broadcast(builder, value, *shape); 858 }; 859 860 // expm1(x) = exp(x) - 1 = u - 1. 861 // We have to handle it carefully when x is near 0, i.e. u ~= 1, 862 // and when the input is ~= -inf, i.e. u - 1 ~= -1. 863 Value cstOne = bcast(f32Cst(builder, 1.0f)); 864 Value cstNegOne = bcast(f32Cst(builder, -1.0f)); 865 Value x = op.operand(); 866 Value u = builder.create<math::ExpOp>(x); 867 Value uEqOne = 868 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne); 869 Value uMinusOne = builder.create<arith::SubFOp>(u, cstOne); 870 Value uMinusOneEqNegOne = builder.create<arith::CmpFOp>( 871 arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne); 872 // logU = log(u) ~= x 873 Value logU = builder.create<math::LogOp>(u); 874 875 // Detect exp(x) = +inf; written this way to avoid having to form +inf. 876 Value isInf = 877 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, logU, u); 878 879 // (u - 1) * (x / ~x) 880 Value expm1 = builder.create<arith::MulFOp>( 881 uMinusOne, builder.create<arith::DivFOp>(x, logU)); 882 expm1 = builder.create<SelectOp>(isInf, u, expm1); 883 Value approximation = builder.create<SelectOp>( 884 uEqOne, x, builder.create<SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1)); 885 rewriter.replaceOp(op, approximation); 886 return success(); 887 } 888 889 //----------------------------------------------------------------------------// 890 // Sin and Cos approximation. 891 //----------------------------------------------------------------------------// 892 893 namespace { 894 895 template <bool isSine, typename OpTy> 896 struct SinAndCosApproximation : public OpRewritePattern<OpTy> { 897 public: 898 using OpRewritePattern<OpTy>::OpRewritePattern; 899 900 LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final; 901 }; 902 } // namespace 903 904 #define TWO_OVER_PI \ 905 0.6366197723675813430755350534900574481378385829618257949906693762L 906 #define PI_OVER_2 \ 907 1.5707963267948966192313216916397514420985846996875529104874722961L 908 909 // Approximates sin(x) or cos(x) by finding the best approximation polynomial in 910 // the reduced range [0, pi/2] for both sin(x) and cos(x). Then given y in the 911 // reduced range sin(x) will be computed as sin(y), -sin(y), cos(y) or -cos(y). 912 template <bool isSine, typename OpTy> 913 LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite( 914 OpTy op, PatternRewriter &rewriter) const { 915 static_assert( 916 llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value, 917 "SinAndCosApproximation pattern expects math::SinOp or math::CosOp"); 918 auto shape = vectorShape(op.operand().getType(), isF32); 919 if (!shape.hasValue()) 920 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 921 922 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 923 auto bcast = [&](Value value) -> Value { 924 return broadcast(builder, value, *shape); 925 }; 926 auto mul = [&](Value a, Value b) -> Value { 927 return builder.create<arith::MulFOp>(a, b); 928 }; 929 auto sub = [&](Value a, Value b) -> Value { 930 return builder.create<arith::SubFOp>(a, b); 931 }; 932 auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); }; 933 934 auto i32Vec = broadcast(builder.getI32Type(), *shape); 935 auto fPToSingedInteger = [&](Value a) -> Value { 936 return builder.create<arith::FPToSIOp>(a, i32Vec); 937 }; 938 939 auto modulo4 = [&](Value a) -> Value { 940 return builder.create<arith::AndIOp>(a, bcast(i32Cst(builder, 3))); 941 }; 942 943 auto isEqualTo = [&](Value a, Value b) -> Value { 944 return builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, a, b); 945 }; 946 947 auto isGreaterThan = [&](Value a, Value b) -> Value { 948 return builder.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, a, b); 949 }; 950 951 auto select = [&](Value cond, Value t, Value f) -> Value { 952 return builder.create<SelectOp>(cond, t, f); 953 }; 954 955 auto fmla = [&](Value a, Value b, Value c) { 956 return builder.create<math::FmaOp>(a, b, c); 957 }; 958 959 auto bitwiseOr = [&](Value a, Value b) { 960 return builder.create<arith::OrIOp>(a, b); 961 }; 962 963 Value twoOverPi = bcast(f32Cst(builder, TWO_OVER_PI)); 964 Value piOverTwo = bcast(f32Cst(builder, PI_OVER_2)); 965 966 Value x = op.operand(); 967 968 Value k = floor(mul(x, twoOverPi)); 969 970 Value y = sub(x, mul(k, piOverTwo)); 971 972 Value cstOne = bcast(f32Cst(builder, 1.0)); 973 Value cstNegativeOne = bcast(f32Cst(builder, -1.0)); 974 975 Value cstSC2 = bcast(f32Cst(builder, -0.16666667163372039794921875f)); 976 Value cstSC4 = bcast(f32Cst(builder, 8.333347737789154052734375e-3f)); 977 Value cstSC6 = bcast(f32Cst(builder, -1.9842604524455964565277099609375e-4f)); 978 Value cstSC8 = 979 bcast(f32Cst(builder, 2.760012648650445044040679931640625e-6f)); 980 Value cstSC10 = 981 bcast(f32Cst(builder, -2.50293279435709337121807038784027099609375e-8f)); 982 983 Value cstCC2 = bcast(f32Cst(builder, -0.5f)); 984 Value cstCC4 = bcast(f32Cst(builder, 4.166664183139801025390625e-2f)); 985 Value cstCC6 = bcast(f32Cst(builder, -1.388833043165504932403564453125e-3f)); 986 Value cstCC8 = bcast(f32Cst(builder, 2.47562347794882953166961669921875e-5f)); 987 Value cstCC10 = 988 bcast(f32Cst(builder, -2.59630184018533327616751194000244140625e-7f)); 989 990 Value kMod4 = modulo4(fPToSingedInteger(k)); 991 992 Value kR0 = isEqualTo(kMod4, bcast(i32Cst(builder, 0))); 993 Value kR1 = isEqualTo(kMod4, bcast(i32Cst(builder, 1))); 994 Value kR2 = isEqualTo(kMod4, bcast(i32Cst(builder, 2))); 995 Value kR3 = isEqualTo(kMod4, bcast(i32Cst(builder, 3))); 996 997 Value sinuseCos = isSine ? bitwiseOr(kR1, kR3) : bitwiseOr(kR0, kR2); 998 Value negativeRange = isSine ? isGreaterThan(kMod4, bcast(i32Cst(builder, 1))) 999 : bitwiseOr(kR1, kR2); 1000 1001 Value y2 = mul(y, y); 1002 1003 Value base = select(sinuseCos, cstOne, y); 1004 Value cstC2 = select(sinuseCos, cstCC2, cstSC2); 1005 Value cstC4 = select(sinuseCos, cstCC4, cstSC4); 1006 Value cstC6 = select(sinuseCos, cstCC6, cstSC6); 1007 Value cstC8 = select(sinuseCos, cstCC8, cstSC8); 1008 Value cstC10 = select(sinuseCos, cstCC10, cstSC10); 1009 1010 Value v1 = fmla(y2, cstC10, cstC8); 1011 Value v2 = fmla(y2, v1, cstC6); 1012 Value v3 = fmla(y2, v2, cstC4); 1013 Value v4 = fmla(y2, v3, cstC2); 1014 Value v5 = fmla(y2, v4, cstOne); 1015 Value v6 = mul(base, v5); 1016 1017 Value approximation = select(negativeRange, mul(cstNegativeOne, v6), v6); 1018 1019 rewriter.replaceOp(op, approximation); 1020 1021 return success(); 1022 } 1023 1024 //----------------------------------------------------------------------------// 1025 // Rsqrt approximation. 1026 //----------------------------------------------------------------------------// 1027 1028 namespace { 1029 struct RsqrtApproximation : public OpRewritePattern<math::RsqrtOp> { 1030 using OpRewritePattern::OpRewritePattern; 1031 1032 LogicalResult matchAndRewrite(math::RsqrtOp op, 1033 PatternRewriter &rewriter) const final; 1034 }; 1035 } // namespace 1036 1037 LogicalResult 1038 RsqrtApproximation::matchAndRewrite(math::RsqrtOp op, 1039 PatternRewriter &rewriter) const { 1040 auto shape = vectorShape(op.operand().getType(), isF32); 1041 // Only support already-vectorized rsqrt's. 1042 if (!shape.hasValue() || shape->back() % 8 != 0) 1043 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 1044 1045 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 1046 auto bcast = [&](Value value) -> Value { 1047 return broadcast(builder, value, *shape); 1048 }; 1049 1050 Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u)); 1051 Value cstOnePointFive = bcast(f32Cst(builder, 1.5f)); 1052 Value cstNegHalf = bcast(f32Cst(builder, -0.5f)); 1053 Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u)); 1054 1055 Value negHalf = builder.create<arith::MulFOp>(op.operand(), cstNegHalf); 1056 1057 // Select only the inverse sqrt of positive normals (denormals are 1058 // flushed to zero). 1059 Value ltMinMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, 1060 op.operand(), cstMinNormPos); 1061 Value infMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, 1062 op.operand(), cstPosInf); 1063 Value notNormalFiniteMask = builder.create<arith::OrIOp>(ltMinMask, infMask); 1064 1065 // Compute an approximate result. 1066 Value yApprox = handleMultidimensionalVectors( 1067 builder, op->getOperands(), 8, [&builder](ValueRange operands) -> Value { 1068 return builder.create<x86vector::RsqrtOp>(operands); 1069 }); 1070 1071 // Do a single step of Newton-Raphson iteration to improve the approximation. 1072 // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n). 1073 // It is essential to evaluate the inner term like this because forming 1074 // y_n^2 may over- or underflow. 1075 Value inner = builder.create<arith::MulFOp>(negHalf, yApprox); 1076 Value fma = builder.create<math::FmaOp>(yApprox, inner, cstOnePointFive); 1077 Value yNewton = builder.create<arith::MulFOp>(yApprox, fma); 1078 1079 // Select the result of the Newton-Raphson step for positive normal arguments. 1080 // For other arguments, choose the output of the intrinsic. This will 1081 // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(x) = +inf if 1082 // x is zero or a positive denormalized float (equivalent to flushing positive 1083 // denormalized inputs to zero). 1084 Value res = builder.create<SelectOp>(notNormalFiniteMask, yApprox, yNewton); 1085 rewriter.replaceOp(op, res); 1086 1087 return success(); 1088 } 1089 1090 //----------------------------------------------------------------------------// 1091 1092 void mlir::populateMathPolynomialApproximationPatterns( 1093 RewritePatternSet &patterns, 1094 const MathPolynomialApproximationOptions &options) { 1095 patterns.add<TanhApproximation, LogApproximation, Log2Approximation, 1096 Log1pApproximation, ErfPolynomialApproximation, ExpApproximation, 1097 ExpM1Approximation, SinAndCosApproximation<true, math::SinOp>, 1098 SinAndCosApproximation<false, math::CosOp>>( 1099 patterns.getContext()); 1100 if (options.enableAvx2) 1101 patterns.add<RsqrtApproximation>(patterns.getContext()); 1102 } 1103