1 //===- PolynomialApproximation.cpp - Approximate math operations ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements expansion of math operations to fast approximations 10 // that do not rely on any of the library functions. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include <climits> 15 #include <cstddef> 16 17 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 18 #include "mlir/Dialect/Math/IR/Math.h" 19 #include "mlir/Dialect/Math/Transforms/Approximation.h" 20 #include "mlir/Dialect/Math/Transforms/Passes.h" 21 #include "mlir/Dialect/Utils/IndexingUtils.h" 22 #include "mlir/Dialect/Vector/IR/VectorOps.h" 23 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 24 #include "mlir/Dialect/X86Vector/X86VectorDialect.h" 25 #include "mlir/IR/Builders.h" 26 #include "mlir/IR/BuiltinTypes.h" 27 #include "mlir/IR/ImplicitLocOpBuilder.h" 28 #include "mlir/IR/OpDefinition.h" 29 #include "mlir/IR/PatternMatch.h" 30 #include "mlir/IR/TypeUtilities.h" 31 #include "mlir/Transforms/DialectConversion.h" 32 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 33 #include "llvm/ADT/ArrayRef.h" 34 #include "llvm/ADT/STLExtras.h" 35 36 using namespace mlir; 37 using namespace mlir::math; 38 using namespace mlir::vector; 39 40 // Returns vector shape if the type is a vector. Returns an empty shape if it is 41 // not a vector. 42 static ArrayRef<int64_t> vectorShape(Type type) { 43 auto vectorType = type.dyn_cast<VectorType>(); 44 return vectorType ? vectorType.getShape() : ArrayRef<int64_t>(); 45 } 46 47 static ArrayRef<int64_t> vectorShape(Value value) { 48 return vectorShape(value.getType()); 49 } 50 51 //----------------------------------------------------------------------------// 52 // Broadcast scalar types and values into vector types and values. 53 //----------------------------------------------------------------------------// 54 55 // Broadcasts scalar type into vector type (iff shape is non-scalar). 56 static Type broadcast(Type type, ArrayRef<int64_t> shape) { 57 assert(!type.isa<VectorType>() && "must be scalar type"); 58 return !shape.empty() ? VectorType::get(shape, type) : type; 59 } 60 61 // Broadcasts scalar value into vector (iff shape is non-scalar). 62 static Value broadcast(ImplicitLocOpBuilder &builder, Value value, 63 ArrayRef<int64_t> shape) { 64 assert(!value.getType().isa<VectorType>() && "must be scalar value"); 65 auto type = broadcast(value.getType(), shape); 66 return !shape.empty() ? builder.create<BroadcastOp>(type, value) : value; 67 } 68 69 //----------------------------------------------------------------------------// 70 // Helper function to handle n-D vectors with 1-D operations. 71 //----------------------------------------------------------------------------// 72 73 // Expands and unrolls n-D vector operands into multiple fixed size 1-D vectors 74 // and calls the compute function with 1-D vector operands. Stitches back all 75 // results into the original n-D vector result. 76 // 77 // Examples: vectorWidth = 8 78 // - vector<4x8xf32> unrolled 4 times 79 // - vector<16xf32> expanded to vector<2x8xf32> and unrolled 2 times 80 // - vector<4x16xf32> expanded to vector<4x2x8xf32> and unrolled 4*2 times 81 // 82 // Some math approximations rely on ISA-specific operations that only accept 83 // fixed size 1-D vectors (e.g. AVX expects vectors of width 8). 84 // 85 // It is the caller's responsibility to verify that the inner dimension is 86 // divisible by the vectorWidth, and that all operands have the same vector 87 // shape. 88 static Value 89 handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, 90 ValueRange operands, int64_t vectorWidth, 91 llvm::function_ref<Value(ValueRange)> compute) { 92 assert(!operands.empty() && "operands must be not empty"); 93 assert(vectorWidth > 0 && "vector width must be larger than 0"); 94 95 VectorType inputType = operands[0].getType().cast<VectorType>(); 96 ArrayRef<int64_t> inputShape = inputType.getShape(); 97 98 // If input shape matches target vector width, we can just call the 99 // user-provided compute function with the operands. 100 if (inputShape == llvm::makeArrayRef(vectorWidth)) 101 return compute(operands); 102 103 // Check if the inner dimension has to be expanded, or we can directly iterate 104 // over the outer dimensions of the vector. 105 int64_t innerDim = inputShape.back(); 106 int64_t expansionDim = innerDim / vectorWidth; 107 assert((innerDim % vectorWidth == 0) && "invalid inner dimension size"); 108 109 // Maybe expand operands to the higher rank vector shape that we'll use to 110 // iterate over and extract one dimensional vectors. 111 SmallVector<int64_t> expandedShape(inputShape.begin(), inputShape.end()); 112 SmallVector<Value> expandedOperands(operands); 113 114 if (expansionDim > 1) { 115 // Expand shape from [..., innerDim] to [..., expansionDim, vectorWidth]. 116 expandedShape.insert(expandedShape.end() - 1, expansionDim); 117 expandedShape.back() = vectorWidth; 118 119 for (unsigned i = 0; i < operands.size(); ++i) { 120 auto operand = operands[i]; 121 auto eltType = operand.getType().cast<VectorType>().getElementType(); 122 auto expandedType = VectorType::get(expandedShape, eltType); 123 expandedOperands[i] = 124 builder.create<vector::ShapeCastOp>(expandedType, operand); 125 } 126 } 127 128 // Iterate over all outer dimensions of the compute shape vector type. 129 auto iterationDims = ArrayRef<int64_t>(expandedShape).drop_back(); 130 int64_t maxLinearIndex = computeMaxLinearIndex(iterationDims); 131 132 SmallVector<int64_t> ones(iterationDims.size(), 1); 133 auto strides = computeStrides(iterationDims, ones); 134 135 // Compute results for each one dimensional vector. 136 SmallVector<Value> results(maxLinearIndex); 137 138 for (int64_t i = 0; i < maxLinearIndex; ++i) { 139 auto offsets = delinearize(strides, i); 140 141 SmallVector<Value> extracted(expandedOperands.size()); 142 for (const auto &tuple : llvm::enumerate(expandedOperands)) 143 extracted[tuple.index()] = 144 builder.create<vector::ExtractOp>(tuple.value(), offsets); 145 146 results[i] = compute(extracted); 147 } 148 149 // Stitch results together into one large vector. 150 Type resultEltType = results[0].getType().cast<VectorType>().getElementType(); 151 Type resultExpandedType = VectorType::get(expandedShape, resultEltType); 152 Value result = builder.create<arith::ConstantOp>( 153 resultExpandedType, builder.getZeroAttr(resultExpandedType)); 154 155 for (int64_t i = 0; i < maxLinearIndex; ++i) 156 result = builder.create<vector::InsertOp>(results[i], result, 157 delinearize(strides, i)); 158 159 // Reshape back to the original vector shape. 160 return builder.create<vector::ShapeCastOp>( 161 VectorType::get(inputShape, resultEltType), result); 162 } 163 164 //----------------------------------------------------------------------------// 165 // Helper functions to create constants. 166 //----------------------------------------------------------------------------// 167 168 static Value f32Cst(ImplicitLocOpBuilder &builder, float value) { 169 return builder.create<arith::ConstantOp>(builder.getF32FloatAttr(value)); 170 } 171 172 static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value) { 173 return builder.create<arith::ConstantOp>(builder.getI32IntegerAttr(value)); 174 } 175 176 static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits) { 177 Value i32Value = i32Cst(builder, static_cast<int32_t>(bits)); 178 return builder.create<arith::BitcastOp>(builder.getF32Type(), i32Value); 179 } 180 181 //----------------------------------------------------------------------------// 182 // Helper functions to build math functions approximations. 183 //----------------------------------------------------------------------------// 184 185 // Return the minimum of the two values or NaN if value is NaN 186 static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound) { 187 return builder.create<arith::SelectOp>( 188 builder.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, value, bound), 189 value, bound); 190 } 191 192 // Return the maximum of the two values or NaN if value is NaN 193 static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound) { 194 return builder.create<arith::SelectOp>( 195 builder.create<arith::CmpFOp>(arith::CmpFPredicate::UGT, value, bound), 196 value, bound); 197 } 198 199 // Return the clamped value or NaN if value is NaN 200 static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, 201 Value upperBound) { 202 return max(builder, min(builder, value, upperBound), lowerBound); 203 } 204 205 // Decomposes given floating point value `arg` into a normalized fraction and 206 // an integral power of two (see std::frexp). Returned values have float type. 207 static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg, 208 bool isPositive = false) { 209 assert(getElementTypeOrSelf(arg).isF32() && "arg must be f32 type"); 210 ArrayRef<int64_t> shape = vectorShape(arg); 211 212 auto bcast = [&](Value value) -> Value { 213 return broadcast(builder, value, shape); 214 }; 215 216 auto i32 = builder.getIntegerType(32); 217 auto i32Vec = broadcast(i32, shape); 218 auto f32Vec = broadcast(builder.getF32Type(), shape); 219 220 Value cst126f = f32Cst(builder, 126.0f); 221 Value cstHalf = f32Cst(builder, 0.5f); 222 Value cstInvMantMask = f32FromBits(builder, ~0x7f800000u); 223 224 // Bitcast to i32 for bitwise operations. 225 Value i32Half = builder.create<arith::BitcastOp>(i32, cstHalf); 226 Value i32InvMantMask = builder.create<arith::BitcastOp>(i32, cstInvMantMask); 227 Value i32Arg = builder.create<arith::BitcastOp>(i32Vec, arg); 228 229 // Compute normalized fraction. 230 Value tmp0 = builder.create<arith::AndIOp>(i32Arg, bcast(i32InvMantMask)); 231 Value tmp1 = builder.create<arith::OrIOp>(tmp0, bcast(i32Half)); 232 Value normalizedFraction = builder.create<arith::BitcastOp>(f32Vec, tmp1); 233 234 // Compute exponent. 235 Value arg0 = isPositive ? arg : builder.create<math::AbsOp>(arg); 236 Value biasedExponentBits = builder.create<arith::ShRUIOp>( 237 builder.create<arith::BitcastOp>(i32Vec, arg0), 238 bcast(i32Cst(builder, 23))); 239 Value biasedExponent = 240 builder.create<arith::SIToFPOp>(f32Vec, biasedExponentBits); 241 Value exponent = 242 builder.create<arith::SubFOp>(biasedExponent, bcast(cst126f)); 243 244 return {normalizedFraction, exponent}; 245 } 246 247 // Computes exp2 for an i32 argument. 248 static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) { 249 assert(getElementTypeOrSelf(arg).isInteger(32) && "arg must be i32 type"); 250 ArrayRef<int64_t> shape = vectorShape(arg); 251 252 auto bcast = [&](Value value) -> Value { 253 return broadcast(builder, value, shape); 254 }; 255 256 auto f32Vec = broadcast(builder.getF32Type(), shape); 257 // The exponent of f32 located at 23-bit. 258 auto exponetBitLocation = bcast(i32Cst(builder, 23)); 259 // Set the exponent bias to zero. 260 auto bias = bcast(i32Cst(builder, 127)); 261 262 Value biasedArg = builder.create<arith::AddIOp>(arg, bias); 263 Value exp2ValueInt = 264 builder.create<arith::ShLIOp>(biasedArg, exponetBitLocation); 265 Value exp2ValueF32 = builder.create<arith::BitcastOp>(f32Vec, exp2ValueInt); 266 267 return exp2ValueF32; 268 } 269 270 namespace { 271 Value makePolynomialCalculation(ImplicitLocOpBuilder &builder, 272 llvm::ArrayRef<Value> coeffs, Value x) { 273 assert(getElementTypeOrSelf(x).isF32() && "x must be f32 type"); 274 ArrayRef<int64_t> shape = vectorShape(x); 275 276 if (coeffs.empty()) 277 return broadcast(builder, f32Cst(builder, 0.0f), shape); 278 279 if (coeffs.size() == 1) 280 return coeffs[0]; 281 282 Value res = builder.create<math::FmaOp>(x, coeffs[coeffs.size() - 1], 283 coeffs[coeffs.size() - 2]); 284 for (auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) { 285 res = builder.create<math::FmaOp>(x, res, coeffs[i]); 286 } 287 return res; 288 } 289 } // namespace 290 291 //----------------------------------------------------------------------------// 292 // Helper function/pattern to insert casts for reusing F32 bit expansion. 293 //----------------------------------------------------------------------------// 294 295 template <typename T> 296 LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter) { 297 // Conservatively only allow where the operand and result types are exactly 1. 298 Type origType = op->getResultTypes().front(); 299 for (Type t : llvm::drop_begin(op->getResultTypes())) 300 if (origType != t) 301 return rewriter.notifyMatchFailure(op, "required all types to match"); 302 for (Type t : op->getOperandTypes()) 303 if (origType != t) 304 return rewriter.notifyMatchFailure(op, "required all types to match"); 305 306 // Skip if already F32 or larger than 32 bits. 307 if (getElementTypeOrSelf(origType).isF32() || 308 getElementTypeOrSelf(origType).getIntOrFloatBitWidth() > 32) 309 return failure(); 310 311 // Create F32 equivalent type. 312 Type newType; 313 if (auto shaped = origType.dyn_cast<ShapedType>()) { 314 newType = shaped.clone(rewriter.getF32Type()); 315 } else if (origType.isa<FloatType>()) { 316 newType = rewriter.getF32Type(); 317 } else { 318 return rewriter.notifyMatchFailure(op, 319 "unable to find F32 equivalent type"); 320 } 321 322 Location loc = op->getLoc(); 323 SmallVector<Value> operands; 324 for (auto operand : op->getOperands()) 325 operands.push_back(rewriter.create<arith::ExtFOp>(loc, newType, operand)); 326 auto result = rewriter.create<math::Atan2Op>(loc, newType, operands); 327 rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, origType, result); 328 return success(); 329 } 330 331 namespace { 332 // Pattern to cast to F32 to reuse F32 expansion as fallback for single-result 333 // op. 334 // TODO: Consider revising to avoid adding multiple casts for a subgraph that is 335 // all in lower precision. Currently this is only fallback support and performs 336 // simplistic casting. 337 template <typename T> 338 struct ReuseF32Expansion : public OpRewritePattern<T> { 339 public: 340 using OpRewritePattern<T>::OpRewritePattern; 341 LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter) const final { 342 static_assert( 343 T::template hasTrait<mlir::OpTrait::SameOperandsAndResultType>(), 344 "requires same operands and result types"); 345 return insertCasts<T>(op, rewriter); 346 } 347 }; 348 } // namespace 349 350 //----------------------------------------------------------------------------// 351 // AtanOp approximation. 352 //----------------------------------------------------------------------------// 353 354 namespace { 355 struct AtanApproximation : public OpRewritePattern<math::AtanOp> { 356 public: 357 using OpRewritePattern::OpRewritePattern; 358 359 LogicalResult matchAndRewrite(math::AtanOp op, 360 PatternRewriter &rewriter) const final; 361 }; 362 } // namespace 363 364 LogicalResult 365 AtanApproximation::matchAndRewrite(math::AtanOp op, 366 PatternRewriter &rewriter) const { 367 auto operand = op.getOperand(); 368 if (!getElementTypeOrSelf(operand).isF32()) 369 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 370 371 ArrayRef<int64_t> shape = vectorShape(op.getOperand()); 372 373 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 374 auto one = broadcast(builder, f32Cst(builder, 1.0f), shape); 375 376 // Remap the problem over [0.0, 1.0] by looking at the absolute value and the 377 // handling symmetry. 378 Value abs = builder.create<math::AbsOp>(operand); 379 Value reciprocal = builder.create<arith::DivFOp>(one, abs); 380 Value compare = 381 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, abs, reciprocal); 382 Value x = builder.create<arith::SelectOp>(compare, abs, reciprocal); 383 384 // Perform the Taylor series approximation for atan over the range 385 // [-1.0, 1.0]. 386 auto n1 = broadcast(builder, f32Cst(builder, 0.14418283f), shape); 387 auto n2 = broadcast(builder, f32Cst(builder, -0.34999234f), shape); 388 auto n3 = broadcast(builder, f32Cst(builder, -0.01067831f), shape); 389 auto n4 = broadcast(builder, f32Cst(builder, 1.00209986f), shape); 390 391 Value p = builder.create<math::FmaOp>(x, n1, n2); 392 p = builder.create<math::FmaOp>(x, p, n3); 393 p = builder.create<math::FmaOp>(x, p, n4); 394 p = builder.create<arith::MulFOp>(x, p); 395 396 // Remap the solution for over [0.0, 1.0] to [0.0, inf] 397 auto halfPi = broadcast(builder, f32Cst(builder, 1.57079632679f), shape); 398 Value sub = builder.create<arith::SubFOp>(halfPi, p); 399 Value select = builder.create<arith::SelectOp>(compare, p, sub); 400 401 // Correct for signing of the input. 402 rewriter.replaceOpWithNewOp<math::CopySignOp>(op, select, operand); 403 return success(); 404 } 405 406 //----------------------------------------------------------------------------// 407 // AtanOp approximation. 408 //----------------------------------------------------------------------------// 409 410 namespace { 411 struct Atan2Approximation : public OpRewritePattern<math::Atan2Op> { 412 public: 413 using OpRewritePattern::OpRewritePattern; 414 415 LogicalResult matchAndRewrite(math::Atan2Op op, 416 PatternRewriter &rewriter) const final; 417 }; 418 } // namespace 419 420 LogicalResult 421 Atan2Approximation::matchAndRewrite(math::Atan2Op op, 422 PatternRewriter &rewriter) const { 423 auto y = op.getOperand(0); 424 auto x = op.getOperand(1); 425 if (!getElementTypeOrSelf(x).isF32()) 426 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 427 428 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 429 ArrayRef<int64_t> shape = vectorShape(op.getResult()); 430 431 // Compute atan in the valid range. 432 auto div = builder.create<arith::DivFOp>(y, x); 433 auto atan = builder.create<math::AtanOp>(div); 434 435 // Determine what the atan would be for a 180 degree rotation. 436 auto zero = broadcast(builder, f32Cst(builder, 0.0f), shape); 437 auto pi = broadcast(builder, f32Cst(builder, 3.14159265359f), shape); 438 auto addPi = builder.create<arith::AddFOp>(atan, pi); 439 auto subPi = builder.create<arith::SubFOp>(atan, pi); 440 auto atanGt = 441 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, atan, zero); 442 auto flippedAtan = builder.create<arith::SelectOp>(atanGt, subPi, addPi); 443 444 // Determine whether to directly use atan or use the 180 degree flip 445 auto xGt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zero); 446 Value result = builder.create<arith::SelectOp>(xGt, atan, flippedAtan); 447 448 // Handle x = 0, y > 0 449 Value xZero = 450 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, x, zero); 451 Value yGt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, y, zero); 452 Value isHalfPi = builder.create<arith::AndIOp>(xZero, yGt); 453 auto halfPi = broadcast(builder, f32Cst(builder, 1.57079632679f), shape); 454 result = builder.create<arith::SelectOp>(isHalfPi, halfPi, result); 455 456 // Handle x = 0, y < 0 457 Value yLt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, y, zero); 458 Value isNegativeHalfPiPi = builder.create<arith::AndIOp>(xZero, yLt); 459 auto negativeHalfPiPi = 460 broadcast(builder, f32Cst(builder, -1.57079632679f), shape); 461 result = builder.create<arith::SelectOp>(isNegativeHalfPiPi, negativeHalfPiPi, 462 result); 463 464 // Handle x = 0, y = 0; 465 Value yZero = 466 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, y, zero); 467 Value isNan = builder.create<arith::AndIOp>(xZero, yZero); 468 Value cstNan = broadcast(builder, f32FromBits(builder, 0x7fc00000), shape); 469 result = builder.create<arith::SelectOp>(isNan, cstNan, result); 470 471 rewriter.replaceOp(op, result); 472 return success(); 473 } 474 475 //----------------------------------------------------------------------------// 476 // TanhOp approximation. 477 //----------------------------------------------------------------------------// 478 479 namespace { 480 struct TanhApproximation : public OpRewritePattern<math::TanhOp> { 481 public: 482 using OpRewritePattern::OpRewritePattern; 483 484 LogicalResult matchAndRewrite(math::TanhOp op, 485 PatternRewriter &rewriter) const final; 486 }; 487 } // namespace 488 489 LogicalResult 490 TanhApproximation::matchAndRewrite(math::TanhOp op, 491 PatternRewriter &rewriter) const { 492 if (!getElementTypeOrSelf(op.getOperand()).isF32()) 493 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 494 495 ArrayRef<int64_t> shape = vectorShape(op.getOperand()); 496 497 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 498 auto bcast = [&](Value value) -> Value { 499 return broadcast(builder, value, shape); 500 }; 501 502 // Clamp operand into [plusClamp, minusClamp] range. 503 Value minusClamp = bcast(f32Cst(builder, -7.99881172180175781f)); 504 Value plusClamp = bcast(f32Cst(builder, 7.99881172180175781f)); 505 Value x = clamp(builder, op.getOperand(), minusClamp, plusClamp); 506 507 // Mask for tiny values that are approximated with `operand`. 508 Value tiny = bcast(f32Cst(builder, 0.0004f)); 509 Value tinyMask = builder.create<arith::CmpFOp>( 510 arith::CmpFPredicate::OLT, builder.create<math::AbsOp>(op.getOperand()), 511 tiny); 512 513 // The monomial coefficients of the numerator polynomial (odd). 514 Value alpha1 = bcast(f32Cst(builder, 4.89352455891786e-03f)); 515 Value alpha3 = bcast(f32Cst(builder, 6.37261928875436e-04f)); 516 Value alpha5 = bcast(f32Cst(builder, 1.48572235717979e-05f)); 517 Value alpha7 = bcast(f32Cst(builder, 5.12229709037114e-08f)); 518 Value alpha9 = bcast(f32Cst(builder, -8.60467152213735e-11f)); 519 Value alpha11 = bcast(f32Cst(builder, 2.00018790482477e-13f)); 520 Value alpha13 = bcast(f32Cst(builder, -2.76076847742355e-16f)); 521 522 // The monomial coefficients of the denominator polynomial (even). 523 Value beta0 = bcast(f32Cst(builder, 4.89352518554385e-03f)); 524 Value beta2 = bcast(f32Cst(builder, 2.26843463243900e-03f)); 525 Value beta4 = bcast(f32Cst(builder, 1.18534705686654e-04f)); 526 Value beta6 = bcast(f32Cst(builder, 1.19825839466702e-06f)); 527 528 // Since the polynomials are odd/even, we need x^2. 529 Value x2 = builder.create<arith::MulFOp>(x, x); 530 531 // Evaluate the numerator polynomial p. 532 Value p = builder.create<math::FmaOp>(x2, alpha13, alpha11); 533 p = builder.create<math::FmaOp>(x2, p, alpha9); 534 p = builder.create<math::FmaOp>(x2, p, alpha7); 535 p = builder.create<math::FmaOp>(x2, p, alpha5); 536 p = builder.create<math::FmaOp>(x2, p, alpha3); 537 p = builder.create<math::FmaOp>(x2, p, alpha1); 538 p = builder.create<arith::MulFOp>(x, p); 539 540 // Evaluate the denominator polynomial q. 541 Value q = builder.create<math::FmaOp>(x2, beta6, beta4); 542 q = builder.create<math::FmaOp>(x2, q, beta2); 543 q = builder.create<math::FmaOp>(x2, q, beta0); 544 545 // Divide the numerator by the denominator. 546 Value res = builder.create<arith::SelectOp>( 547 tinyMask, x, builder.create<arith::DivFOp>(p, q)); 548 549 rewriter.replaceOp(op, res); 550 551 return success(); 552 } 553 554 #define LN2_VALUE \ 555 0.693147180559945309417232121458176568075500134360255254120680009493393621L 556 #define LOG2E_VALUE \ 557 1.442695040888963407359924681001892137426645954152985934135449406931109219L 558 559 //----------------------------------------------------------------------------// 560 // LogOp and Log2Op approximation. 561 //----------------------------------------------------------------------------// 562 563 namespace { 564 template <typename Op> 565 struct LogApproximationBase : public OpRewritePattern<Op> { 566 using OpRewritePattern<Op>::OpRewritePattern; 567 568 /// Base 2 if 'base2' is set; natural logarithm (base e) otherwise. 569 LogicalResult logMatchAndRewrite(Op op, PatternRewriter &rewriter, 570 bool base2) const; 571 }; 572 } // namespace 573 574 // This approximation comes from Julien Pommier's SSE math library. 575 // Link: http://gruntthepeon.free.fr/ssemath 576 template <typename Op> 577 LogicalResult 578 LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter, 579 bool base2) const { 580 if (!getElementTypeOrSelf(op.getOperand()).isF32()) 581 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 582 583 ArrayRef<int64_t> shape = vectorShape(op.getOperand()); 584 585 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 586 auto bcast = [&](Value value) -> Value { 587 return broadcast(builder, value, shape); 588 }; 589 590 Value cstZero = bcast(f32Cst(builder, 0.0f)); 591 Value cstOne = bcast(f32Cst(builder, 1.0f)); 592 Value cstNegHalf = bcast(f32Cst(builder, -0.5f)); 593 594 // The smallest non denormalized float number. 595 Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u)); 596 Value cstMinusInf = bcast(f32FromBits(builder, 0xff800000u)); 597 Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u)); 598 Value cstNan = bcast(f32FromBits(builder, 0x7fc00000)); 599 600 // Polynomial coefficients. 601 Value cstCephesSQRTHF = bcast(f32Cst(builder, 0.707106781186547524f)); 602 Value cstCephesLogP0 = bcast(f32Cst(builder, 7.0376836292E-2f)); 603 Value cstCephesLogP1 = bcast(f32Cst(builder, -1.1514610310E-1f)); 604 Value cstCephesLogP2 = bcast(f32Cst(builder, 1.1676998740E-1f)); 605 Value cstCephesLogP3 = bcast(f32Cst(builder, -1.2420140846E-1f)); 606 Value cstCephesLogP4 = bcast(f32Cst(builder, +1.4249322787E-1f)); 607 Value cstCephesLogP5 = bcast(f32Cst(builder, -1.6668057665E-1f)); 608 Value cstCephesLogP6 = bcast(f32Cst(builder, +2.0000714765E-1f)); 609 Value cstCephesLogP7 = bcast(f32Cst(builder, -2.4999993993E-1f)); 610 Value cstCephesLogP8 = bcast(f32Cst(builder, +3.3333331174E-1f)); 611 612 Value x = op.getOperand(); 613 614 // Truncate input values to the minimum positive normal. 615 x = max(builder, x, cstMinNormPos); 616 617 // Extract significant in the range [0.5,1) and exponent. 618 std::pair<Value, Value> pair = frexp(builder, x, /*isPositive=*/true); 619 x = pair.first; 620 Value e = pair.second; 621 622 // Shift the inputs from the range [0.5,1) to [sqrt(1/2), sqrt(2)) and shift 623 // by -1.0. The values are then centered around 0, which improves the 624 // stability of the polynomial evaluation: 625 // 626 // if( x < SQRTHF ) { 627 // e -= 1; 628 // x = x + x - 1.0; 629 // } else { x = x - 1.0; } 630 Value mask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, 631 cstCephesSQRTHF); 632 Value tmp = builder.create<arith::SelectOp>(mask, x, cstZero); 633 634 x = builder.create<arith::SubFOp>(x, cstOne); 635 e = builder.create<arith::SubFOp>( 636 e, builder.create<arith::SelectOp>(mask, cstOne, cstZero)); 637 x = builder.create<arith::AddFOp>(x, tmp); 638 639 Value x2 = builder.create<arith::MulFOp>(x, x); 640 Value x3 = builder.create<arith::MulFOp>(x2, x); 641 642 // Evaluate the polynomial approximant of degree 8 in three parts. 643 Value y0, y1, y2; 644 y0 = builder.create<math::FmaOp>(cstCephesLogP0, x, cstCephesLogP1); 645 y1 = builder.create<math::FmaOp>(cstCephesLogP3, x, cstCephesLogP4); 646 y2 = builder.create<math::FmaOp>(cstCephesLogP6, x, cstCephesLogP7); 647 y0 = builder.create<math::FmaOp>(y0, x, cstCephesLogP2); 648 y1 = builder.create<math::FmaOp>(y1, x, cstCephesLogP5); 649 y2 = builder.create<math::FmaOp>(y2, x, cstCephesLogP8); 650 y0 = builder.create<math::FmaOp>(y0, x3, y1); 651 y0 = builder.create<math::FmaOp>(y0, x3, y2); 652 y0 = builder.create<arith::MulFOp>(y0, x3); 653 654 y0 = builder.create<math::FmaOp>(cstNegHalf, x2, y0); 655 x = builder.create<arith::AddFOp>(x, y0); 656 657 if (base2) { 658 Value cstLog2e = bcast(f32Cst(builder, static_cast<float>(LOG2E_VALUE))); 659 x = builder.create<math::FmaOp>(x, cstLog2e, e); 660 } else { 661 Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE))); 662 x = builder.create<math::FmaOp>(e, cstLn2, x); 663 } 664 665 Value invalidMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, 666 op.getOperand(), cstZero); 667 Value zeroMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, 668 op.getOperand(), cstZero); 669 Value posInfMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, 670 op.getOperand(), cstPosInf); 671 672 // Filter out invalid values: 673 // • x == 0 -> -INF 674 // • x < 0 -> NAN 675 // • x == +INF -> +INF 676 Value aproximation = builder.create<arith::SelectOp>( 677 zeroMask, cstMinusInf, 678 builder.create<arith::SelectOp>( 679 invalidMask, cstNan, 680 builder.create<arith::SelectOp>(posInfMask, cstPosInf, x))); 681 682 rewriter.replaceOp(op, aproximation); 683 684 return success(); 685 } 686 687 namespace { 688 struct LogApproximation : public LogApproximationBase<math::LogOp> { 689 using LogApproximationBase::LogApproximationBase; 690 691 LogicalResult matchAndRewrite(math::LogOp op, 692 PatternRewriter &rewriter) const final { 693 return logMatchAndRewrite(op, rewriter, /*base2=*/false); 694 } 695 }; 696 } // namespace 697 698 namespace { 699 struct Log2Approximation : public LogApproximationBase<math::Log2Op> { 700 using LogApproximationBase::LogApproximationBase; 701 702 LogicalResult matchAndRewrite(math::Log2Op op, 703 PatternRewriter &rewriter) const final { 704 return logMatchAndRewrite(op, rewriter, /*base2=*/true); 705 } 706 }; 707 } // namespace 708 709 //----------------------------------------------------------------------------// 710 // Log1p approximation. 711 //----------------------------------------------------------------------------// 712 713 namespace { 714 struct Log1pApproximation : public OpRewritePattern<math::Log1pOp> { 715 public: 716 using OpRewritePattern::OpRewritePattern; 717 718 LogicalResult matchAndRewrite(math::Log1pOp op, 719 PatternRewriter &rewriter) const final; 720 }; 721 } // namespace 722 723 // Approximate log(1+x). 724 LogicalResult 725 Log1pApproximation::matchAndRewrite(math::Log1pOp op, 726 PatternRewriter &rewriter) const { 727 if (!getElementTypeOrSelf(op.getOperand()).isF32()) 728 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 729 730 ArrayRef<int64_t> shape = vectorShape(op.getOperand()); 731 732 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 733 auto bcast = [&](Value value) -> Value { 734 return broadcast(builder, value, shape); 735 }; 736 737 // Approximate log(1+x) using the following, due to W. Kahan: 738 // u = x + 1.0; 739 // if (u == 1.0 || u == inf) return x; 740 // return x * log(u) / (u - 1.0); 741 // ^^^^^^^^^^^^^^^^^^^^^^ 742 // "logLarge" below. 743 Value cstOne = bcast(f32Cst(builder, 1.0f)); 744 Value x = op.getOperand(); 745 Value u = builder.create<arith::AddFOp>(x, cstOne); 746 Value uSmall = 747 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne); 748 Value logU = builder.create<math::LogOp>(u); 749 Value uInf = 750 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, logU); 751 Value logLarge = builder.create<arith::MulFOp>( 752 x, builder.create<arith::DivFOp>( 753 logU, builder.create<arith::SubFOp>(u, cstOne))); 754 Value approximation = builder.create<arith::SelectOp>( 755 builder.create<arith::OrIOp>(uSmall, uInf), x, logLarge); 756 rewriter.replaceOp(op, approximation); 757 return success(); 758 } 759 760 //----------------------------------------------------------------------------// 761 // Erf approximation. 762 //----------------------------------------------------------------------------// 763 764 // Approximates erf(x) with 765 // a - P(x)/Q(x) 766 // where P and Q are polynomials of degree 4. 767 // Different coefficients are chosen based on the value of x. 768 // The approximation error is ~2.5e-07. 769 // Boost's minimax tool that utilizes the Remez method was used to find the 770 // coefficients. 771 LogicalResult 772 ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op, 773 PatternRewriter &rewriter) const { 774 if (!getElementTypeOrSelf(op.getOperand()).isF32()) 775 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 776 777 ArrayRef<int64_t> shape = vectorShape(op.getOperand()); 778 779 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 780 auto bcast = [&](Value value) -> Value { 781 return broadcast(builder, value, shape); 782 }; 783 784 const int intervalsCount = 3; 785 const int polyDegree = 4; 786 787 Value zero = bcast(f32Cst(builder, 0)); 788 Value one = bcast(f32Cst(builder, 1)); 789 Value pp[intervalsCount][polyDegree + 1]; 790 pp[0][0] = bcast(f32Cst(builder, +0.00000000000000000e+00f)); 791 pp[0][1] = bcast(f32Cst(builder, +1.12837916222975858e+00f)); 792 pp[0][2] = bcast(f32Cst(builder, -5.23018562988006470e-01f)); 793 pp[0][3] = bcast(f32Cst(builder, +2.09741709609267072e-01f)); 794 pp[0][4] = bcast(f32Cst(builder, +2.58146801602987875e-02f)); 795 pp[1][0] = bcast(f32Cst(builder, +0.00000000000000000e+00f)); 796 pp[1][1] = bcast(f32Cst(builder, +1.12750687816789140e+00f)); 797 pp[1][2] = bcast(f32Cst(builder, -3.64721408487825775e-01f)); 798 pp[1][3] = bcast(f32Cst(builder, +1.18407396425136952e-01f)); 799 pp[1][4] = bcast(f32Cst(builder, +3.70645533056476558e-02f)); 800 pp[2][0] = bcast(f32Cst(builder, -3.30093071049483172e-03f)); 801 pp[2][1] = bcast(f32Cst(builder, +3.51961938357697011e-03f)); 802 pp[2][2] = bcast(f32Cst(builder, -1.41373622814988039e-03f)); 803 pp[2][3] = bcast(f32Cst(builder, +2.53447094961941348e-04f)); 804 pp[2][4] = bcast(f32Cst(builder, -1.71048029455037401e-05f)); 805 806 Value qq[intervalsCount][polyDegree + 1]; 807 qq[0][0] = bcast(f32Cst(builder, +1.000000000000000000e+00f)); 808 qq[0][1] = bcast(f32Cst(builder, -4.635138185962547255e-01f)); 809 qq[0][2] = bcast(f32Cst(builder, +5.192301327279782447e-01f)); 810 qq[0][3] = bcast(f32Cst(builder, -1.318089722204810087e-01f)); 811 qq[0][4] = bcast(f32Cst(builder, +7.397964654672315005e-02f)); 812 qq[1][0] = bcast(f32Cst(builder, +1.00000000000000000e+00f)); 813 qq[1][1] = bcast(f32Cst(builder, -3.27607011824493086e-01f)); 814 qq[1][2] = bcast(f32Cst(builder, +4.48369090658821977e-01f)); 815 qq[1][3] = bcast(f32Cst(builder, -8.83462621207857930e-02f)); 816 qq[1][4] = bcast(f32Cst(builder, +5.72442770283176093e-02f)); 817 qq[2][0] = bcast(f32Cst(builder, +1.00000000000000000e+00f)); 818 qq[2][1] = bcast(f32Cst(builder, -2.06069165953913769e+00f)); 819 qq[2][2] = bcast(f32Cst(builder, +1.62705939945477759e+00f)); 820 qq[2][3] = bcast(f32Cst(builder, -5.83389859211130017e-01f)); 821 qq[2][4] = bcast(f32Cst(builder, +8.21908939856640930e-02f)); 822 823 Value offsets[intervalsCount]; 824 offsets[0] = bcast(f32Cst(builder, 0.0f)); 825 offsets[1] = bcast(f32Cst(builder, 0.0f)); 826 offsets[2] = bcast(f32Cst(builder, 1.0f)); 827 828 Value bounds[intervalsCount]; 829 bounds[0] = bcast(f32Cst(builder, 0.8f)); 830 bounds[1] = bcast(f32Cst(builder, 2.0f)); 831 bounds[2] = bcast(f32Cst(builder, 3.75f)); 832 833 Value isNegativeArg = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, 834 op.getOperand(), zero); 835 Value negArg = builder.create<arith::NegFOp>(op.getOperand()); 836 Value x = 837 builder.create<arith::SelectOp>(isNegativeArg, negArg, op.getOperand()); 838 839 Value offset = offsets[0]; 840 Value p[polyDegree + 1]; 841 Value q[polyDegree + 1]; 842 for (int i = 0; i <= polyDegree; ++i) { 843 p[i] = pp[0][i]; 844 q[i] = qq[0][i]; 845 } 846 847 // TODO: maybe use vector stacking to reduce the number of selects. 848 Value isLessThanBound[intervalsCount]; 849 for (int j = 0; j < intervalsCount - 1; ++j) { 850 isLessThanBound[j] = 851 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, bounds[j]); 852 for (int i = 0; i <= polyDegree; ++i) { 853 p[i] = builder.create<arith::SelectOp>(isLessThanBound[j], p[i], 854 pp[j + 1][i]); 855 q[i] = builder.create<arith::SelectOp>(isLessThanBound[j], q[i], 856 qq[j + 1][i]); 857 } 858 offset = builder.create<arith::SelectOp>(isLessThanBound[j], offset, 859 offsets[j + 1]); 860 } 861 isLessThanBound[intervalsCount - 1] = builder.create<arith::CmpFOp>( 862 arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]); 863 864 Value pPoly = makePolynomialCalculation(builder, p, x); 865 Value qPoly = makePolynomialCalculation(builder, q, x); 866 Value rationalPoly = builder.create<arith::DivFOp>(pPoly, qPoly); 867 Value formula = builder.create<arith::AddFOp>(offset, rationalPoly); 868 formula = builder.create<arith::SelectOp>(isLessThanBound[intervalsCount - 1], 869 formula, one); 870 871 // erf is odd function: erf(x) = -erf(-x). 872 Value negFormula = builder.create<arith::NegFOp>(formula); 873 Value res = 874 builder.create<arith::SelectOp>(isNegativeArg, negFormula, formula); 875 876 rewriter.replaceOp(op, res); 877 878 return success(); 879 } 880 881 //----------------------------------------------------------------------------// 882 // Exp approximation. 883 //----------------------------------------------------------------------------// 884 885 namespace { 886 887 struct ExpApproximation : public OpRewritePattern<math::ExpOp> { 888 public: 889 using OpRewritePattern::OpRewritePattern; 890 891 LogicalResult matchAndRewrite(math::ExpOp op, 892 PatternRewriter &rewriter) const final; 893 }; 894 } // namespace 895 896 // Approximate exp(x) using its reduced range exp(y) where y is in the range 897 // [0, ln(2)], let y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2), exp(x) 898 // = exp(y) * 2^k. exp(y). 899 LogicalResult 900 ExpApproximation::matchAndRewrite(math::ExpOp op, 901 PatternRewriter &rewriter) const { 902 if (!getElementTypeOrSelf(op.getOperand()).isF32()) 903 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 904 905 ArrayRef<int64_t> shape = vectorShape(op.getOperand()); 906 907 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 908 909 // TODO: Consider a common pattern rewriter with all methods below to 910 // write the approximations. 911 auto bcast = [&](Value value) -> Value { 912 return broadcast(builder, value, shape); 913 }; 914 auto fmla = [&](Value a, Value b, Value c) { 915 return builder.create<math::FmaOp>(a, b, c); 916 }; 917 auto mul = [&](Value a, Value b) -> Value { 918 return builder.create<arith::MulFOp>(a, b); 919 }; 920 auto sub = [&](Value a, Value b) -> Value { 921 return builder.create<arith::SubFOp>(a, b); 922 }; 923 auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); }; 924 925 Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE))); 926 Value cstLog2E = bcast(f32Cst(builder, static_cast<float>(LOG2E_VALUE))); 927 928 // Polynomial coefficients. 929 Value cstCephesExpP0 = bcast(f32Cst(builder, 1.0)); 930 Value cstCephesExpP1 = bcast(f32Cst(builder, 1.0)); 931 Value cstCephesExpP2 = bcast(f32Cst(builder, 0.49970514590562437052f)); 932 Value cstCephesExpP3 = bcast(f32Cst(builder, 0.16873890085469545053f)); 933 Value cstCephesExpP4 = bcast(f32Cst(builder, 0.03668965196652099192f)); 934 Value cstCephesExpP5 = bcast(f32Cst(builder, 0.01314350012789660196f)); 935 936 Value x = op.getOperand(); 937 938 Value isNan = builder.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, x, x); 939 940 // Reduced y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2) 941 Value xL2Inv = mul(x, cstLog2E); 942 Value kF32 = floor(xL2Inv); 943 Value kLn2 = mul(kF32, cstLn2); 944 Value y = sub(x, kLn2); 945 946 // Use Estrin's evaluation scheme with 3 independent parts: 947 // P(y)^y : (c0 + c1 y) + (c2 + c3 y) y^2 + (c4 + c5 y) y^4 948 Value y2 = mul(y, y); 949 Value y4 = mul(y2, y2); 950 951 Value q0 = fmla(cstCephesExpP1, y, cstCephesExpP0); 952 Value q1 = fmla(cstCephesExpP3, y, cstCephesExpP2); 953 Value q2 = fmla(cstCephesExpP5, y, cstCephesExpP4); 954 Value expY = fmla(q1, y2, q0); 955 expY = fmla(q2, y4, expY); 956 957 auto i32Vec = broadcast(builder.getI32Type(), shape); 958 959 // exp2(k) 960 Value k = builder.create<arith::FPToSIOp>(i32Vec, kF32); 961 Value exp2KValue = exp2I32(builder, k); 962 963 // exp(x) = exp(y) * exp2(k) 964 expY = mul(expY, exp2KValue); 965 966 // Handle overflow, inf and underflow of exp(x). exp(x) range is [0, inf], its 967 // partitioned as the following: 968 // exp(x) = 0, x <= -inf 969 // exp(x) = underflow (min_float), x <= -88 970 // exp(x) = inf (min_float), x >= 88 971 // Note: |k| = 127 is the value where the 8-bits exponent saturates. 972 Value zerof32Const = bcast(f32Cst(builder, 0)); 973 auto constPosInfinity = 974 bcast(f32Cst(builder, std::numeric_limits<float>::infinity())); 975 auto constNegIfinity = 976 bcast(f32Cst(builder, -std::numeric_limits<float>::infinity())); 977 auto underflow = bcast(f32Cst(builder, std::numeric_limits<float>::min())); 978 979 Value kMaxConst = bcast(i32Cst(builder, 127)); 980 Value kMaxNegConst = bcast(i32Cst(builder, -127)); 981 Value rightBound = 982 builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, k, kMaxConst); 983 Value leftBound = 984 builder.create<arith::CmpIOp>(arith::CmpIPredicate::sge, k, kMaxNegConst); 985 986 Value isNegInfinityX = builder.create<arith::CmpFOp>( 987 arith::CmpFPredicate::OEQ, x, constNegIfinity); 988 Value isPosInfinityX = builder.create<arith::CmpFOp>( 989 arith::CmpFPredicate::OEQ, x, constPosInfinity); 990 Value isPostiveX = 991 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zerof32Const); 992 Value isComputable = builder.create<arith::AndIOp>(rightBound, leftBound); 993 994 expY = builder.create<arith::SelectOp>( 995 isNan, x, 996 builder.create<arith::SelectOp>( 997 isNegInfinityX, zerof32Const, 998 builder.create<arith::SelectOp>( 999 isPosInfinityX, constPosInfinity, 1000 builder.create<arith::SelectOp>( 1001 isComputable, expY, 1002 builder.create<arith::SelectOp>(isPostiveX, constPosInfinity, 1003 underflow))))); 1004 1005 rewriter.replaceOp(op, expY); 1006 1007 return success(); 1008 } 1009 1010 //----------------------------------------------------------------------------// 1011 // ExpM1 approximation. 1012 //----------------------------------------------------------------------------// 1013 1014 namespace { 1015 1016 struct ExpM1Approximation : public OpRewritePattern<math::ExpM1Op> { 1017 public: 1018 using OpRewritePattern::OpRewritePattern; 1019 1020 LogicalResult matchAndRewrite(math::ExpM1Op op, 1021 PatternRewriter &rewriter) const final; 1022 }; 1023 } // namespace 1024 1025 LogicalResult 1026 ExpM1Approximation::matchAndRewrite(math::ExpM1Op op, 1027 PatternRewriter &rewriter) const { 1028 if (!getElementTypeOrSelf(op.getOperand()).isF32()) 1029 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 1030 1031 ArrayRef<int64_t> shape = vectorShape(op.getOperand()); 1032 1033 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 1034 auto bcast = [&](Value value) -> Value { 1035 return broadcast(builder, value, shape); 1036 }; 1037 1038 // expm1(x) = exp(x) - 1 = u - 1. 1039 // We have to handle it carefully when x is near 0, i.e. u ~= 1, 1040 // and when the input is ~= -inf, i.e. u - 1 ~= -1. 1041 Value cstOne = bcast(f32Cst(builder, 1.0f)); 1042 Value cstNegOne = bcast(f32Cst(builder, -1.0f)); 1043 Value x = op.getOperand(); 1044 Value u = builder.create<math::ExpOp>(x); 1045 Value uEqOneOrNaN = 1046 builder.create<arith::CmpFOp>(arith::CmpFPredicate::UEQ, u, cstOne); 1047 Value uMinusOne = builder.create<arith::SubFOp>(u, cstOne); 1048 Value uMinusOneEqNegOne = builder.create<arith::CmpFOp>( 1049 arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne); 1050 // logU = log(u) ~= x 1051 Value logU = builder.create<math::LogOp>(u); 1052 1053 // Detect exp(x) = +inf; written this way to avoid having to form +inf. 1054 Value isInf = 1055 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, logU, u); 1056 1057 // (u - 1) * (x / ~x) 1058 Value expm1 = builder.create<arith::MulFOp>( 1059 uMinusOne, builder.create<arith::DivFOp>(x, logU)); 1060 expm1 = builder.create<arith::SelectOp>(isInf, u, expm1); 1061 Value approximation = builder.create<arith::SelectOp>( 1062 uEqOneOrNaN, x, 1063 builder.create<arith::SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1)); 1064 rewriter.replaceOp(op, approximation); 1065 return success(); 1066 } 1067 1068 //----------------------------------------------------------------------------// 1069 // Sin and Cos approximation. 1070 //----------------------------------------------------------------------------// 1071 1072 namespace { 1073 1074 template <bool isSine, typename OpTy> 1075 struct SinAndCosApproximation : public OpRewritePattern<OpTy> { 1076 public: 1077 using OpRewritePattern<OpTy>::OpRewritePattern; 1078 1079 LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final; 1080 }; 1081 } // namespace 1082 1083 #define TWO_OVER_PI \ 1084 0.6366197723675813430755350534900574481378385829618257949906693762L 1085 #define PI_OVER_2 \ 1086 1.5707963267948966192313216916397514420985846996875529104874722961L 1087 1088 // Approximates sin(x) or cos(x) by finding the best approximation polynomial in 1089 // the reduced range [0, pi/2] for both sin(x) and cos(x). Then given y in the 1090 // reduced range sin(x) will be computed as sin(y), -sin(y), cos(y) or -cos(y). 1091 template <bool isSine, typename OpTy> 1092 LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite( 1093 OpTy op, PatternRewriter &rewriter) const { 1094 static_assert( 1095 llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value, 1096 "SinAndCosApproximation pattern expects math::SinOp or math::CosOp"); 1097 1098 if (!getElementTypeOrSelf(op.getOperand()).isF32()) 1099 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 1100 1101 ArrayRef<int64_t> shape = vectorShape(op.getOperand()); 1102 1103 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 1104 auto bcast = [&](Value value) -> Value { 1105 return broadcast(builder, value, shape); 1106 }; 1107 auto mul = [&](Value a, Value b) -> Value { 1108 return builder.create<arith::MulFOp>(a, b); 1109 }; 1110 auto sub = [&](Value a, Value b) -> Value { 1111 return builder.create<arith::SubFOp>(a, b); 1112 }; 1113 auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); }; 1114 1115 auto i32Vec = broadcast(builder.getI32Type(), shape); 1116 auto fPToSingedInteger = [&](Value a) -> Value { 1117 return builder.create<arith::FPToSIOp>(i32Vec, a); 1118 }; 1119 1120 auto modulo4 = [&](Value a) -> Value { 1121 return builder.create<arith::AndIOp>(a, bcast(i32Cst(builder, 3))); 1122 }; 1123 1124 auto isEqualTo = [&](Value a, Value b) -> Value { 1125 return builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, a, b); 1126 }; 1127 1128 auto isGreaterThan = [&](Value a, Value b) -> Value { 1129 return builder.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, a, b); 1130 }; 1131 1132 auto select = [&](Value cond, Value t, Value f) -> Value { 1133 return builder.create<arith::SelectOp>(cond, t, f); 1134 }; 1135 1136 auto fmla = [&](Value a, Value b, Value c) { 1137 return builder.create<math::FmaOp>(a, b, c); 1138 }; 1139 1140 auto bitwiseOr = [&](Value a, Value b) { 1141 return builder.create<arith::OrIOp>(a, b); 1142 }; 1143 1144 Value twoOverPi = bcast(f32Cst(builder, (float)TWO_OVER_PI)); 1145 Value piOverTwo = bcast(f32Cst(builder, (float)PI_OVER_2)); 1146 1147 Value x = op.getOperand(); 1148 1149 Value k = floor(mul(x, twoOverPi)); 1150 1151 Value y = sub(x, mul(k, piOverTwo)); 1152 1153 Value cstOne = bcast(f32Cst(builder, 1.0)); 1154 Value cstNegativeOne = bcast(f32Cst(builder, -1.0)); 1155 1156 Value cstSC2 = bcast(f32Cst(builder, -0.16666667163372039794921875f)); 1157 Value cstSC4 = bcast(f32Cst(builder, 8.333347737789154052734375e-3f)); 1158 Value cstSC6 = bcast(f32Cst(builder, -1.9842604524455964565277099609375e-4f)); 1159 Value cstSC8 = 1160 bcast(f32Cst(builder, 2.760012648650445044040679931640625e-6f)); 1161 Value cstSC10 = 1162 bcast(f32Cst(builder, -2.50293279435709337121807038784027099609375e-8f)); 1163 1164 Value cstCC2 = bcast(f32Cst(builder, -0.5f)); 1165 Value cstCC4 = bcast(f32Cst(builder, 4.166664183139801025390625e-2f)); 1166 Value cstCC6 = bcast(f32Cst(builder, -1.388833043165504932403564453125e-3f)); 1167 Value cstCC8 = bcast(f32Cst(builder, 2.47562347794882953166961669921875e-5f)); 1168 Value cstCC10 = 1169 bcast(f32Cst(builder, -2.59630184018533327616751194000244140625e-7f)); 1170 1171 Value kMod4 = modulo4(fPToSingedInteger(k)); 1172 1173 Value kR0 = isEqualTo(kMod4, bcast(i32Cst(builder, 0))); 1174 Value kR1 = isEqualTo(kMod4, bcast(i32Cst(builder, 1))); 1175 Value kR2 = isEqualTo(kMod4, bcast(i32Cst(builder, 2))); 1176 Value kR3 = isEqualTo(kMod4, bcast(i32Cst(builder, 3))); 1177 1178 Value sinuseCos = isSine ? bitwiseOr(kR1, kR3) : bitwiseOr(kR0, kR2); 1179 Value negativeRange = isSine ? isGreaterThan(kMod4, bcast(i32Cst(builder, 1))) 1180 : bitwiseOr(kR1, kR2); 1181 1182 Value y2 = mul(y, y); 1183 1184 Value base = select(sinuseCos, cstOne, y); 1185 Value cstC2 = select(sinuseCos, cstCC2, cstSC2); 1186 Value cstC4 = select(sinuseCos, cstCC4, cstSC4); 1187 Value cstC6 = select(sinuseCos, cstCC6, cstSC6); 1188 Value cstC8 = select(sinuseCos, cstCC8, cstSC8); 1189 Value cstC10 = select(sinuseCos, cstCC10, cstSC10); 1190 1191 Value v1 = fmla(y2, cstC10, cstC8); 1192 Value v2 = fmla(y2, v1, cstC6); 1193 Value v3 = fmla(y2, v2, cstC4); 1194 Value v4 = fmla(y2, v3, cstC2); 1195 Value v5 = fmla(y2, v4, cstOne); 1196 Value v6 = mul(base, v5); 1197 1198 Value approximation = select(negativeRange, mul(cstNegativeOne, v6), v6); 1199 1200 rewriter.replaceOp(op, approximation); 1201 1202 return success(); 1203 } 1204 1205 //----------------------------------------------------------------------------// 1206 // Rsqrt approximation. 1207 //----------------------------------------------------------------------------// 1208 1209 namespace { 1210 struct RsqrtApproximation : public OpRewritePattern<math::RsqrtOp> { 1211 using OpRewritePattern::OpRewritePattern; 1212 1213 LogicalResult matchAndRewrite(math::RsqrtOp op, 1214 PatternRewriter &rewriter) const final; 1215 }; 1216 } // namespace 1217 1218 LogicalResult 1219 RsqrtApproximation::matchAndRewrite(math::RsqrtOp op, 1220 PatternRewriter &rewriter) const { 1221 if (!getElementTypeOrSelf(op.getOperand()).isF32()) 1222 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 1223 1224 ArrayRef<int64_t> shape = vectorShape(op.getOperand()); 1225 1226 // Only support already-vectorized rsqrt's. 1227 if (shape.empty() || shape.back() % 8 != 0) 1228 return rewriter.notifyMatchFailure(op, "unsupported operand type"); 1229 1230 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 1231 auto bcast = [&](Value value) -> Value { 1232 return broadcast(builder, value, shape); 1233 }; 1234 1235 Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u)); 1236 Value cstOnePointFive = bcast(f32Cst(builder, 1.5f)); 1237 Value cstNegHalf = bcast(f32Cst(builder, -0.5f)); 1238 Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u)); 1239 1240 Value negHalf = builder.create<arith::MulFOp>(op.getOperand(), cstNegHalf); 1241 1242 // Select only the inverse sqrt of positive normals (denormals are 1243 // flushed to zero). 1244 Value ltMinMask = builder.create<arith::CmpFOp>( 1245 arith::CmpFPredicate::OLT, op.getOperand(), cstMinNormPos); 1246 Value infMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, 1247 op.getOperand(), cstPosInf); 1248 Value notNormalFiniteMask = builder.create<arith::OrIOp>(ltMinMask, infMask); 1249 1250 // Compute an approximate result. 1251 Value yApprox = handleMultidimensionalVectors( 1252 builder, op->getOperands(), 8, [&builder](ValueRange operands) -> Value { 1253 return builder.create<x86vector::RsqrtOp>(operands); 1254 }); 1255 1256 // Do a single step of Newton-Raphson iteration to improve the approximation. 1257 // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n). 1258 // It is essential to evaluate the inner term like this because forming 1259 // y_n^2 may over- or underflow. 1260 Value inner = builder.create<arith::MulFOp>(negHalf, yApprox); 1261 Value fma = builder.create<math::FmaOp>(yApprox, inner, cstOnePointFive); 1262 Value yNewton = builder.create<arith::MulFOp>(yApprox, fma); 1263 1264 // Select the result of the Newton-Raphson step for positive normal arguments. 1265 // For other arguments, choose the output of the intrinsic. This will 1266 // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(x) = +inf if 1267 // x is zero or a positive denormalized float (equivalent to flushing positive 1268 // denormalized inputs to zero). 1269 Value res = 1270 builder.create<arith::SelectOp>(notNormalFiniteMask, yApprox, yNewton); 1271 rewriter.replaceOp(op, res); 1272 1273 return success(); 1274 } 1275 1276 //----------------------------------------------------------------------------// 1277 1278 void mlir::populateMathPolynomialApproximationPatterns( 1279 RewritePatternSet &patterns, 1280 const MathPolynomialApproximationOptions &options) { 1281 patterns.add<AtanApproximation, Atan2Approximation, TanhApproximation, 1282 LogApproximation, Log2Approximation, Log1pApproximation, 1283 ErfPolynomialApproximation, ExpApproximation, ExpM1Approximation, 1284 ReuseF32Expansion<math::Atan2Op>, 1285 SinAndCosApproximation<true, math::SinOp>, 1286 SinAndCosApproximation<false, math::CosOp>>( 1287 patterns.getContext()); 1288 if (options.enableAvx2) 1289 patterns.add<RsqrtApproximation>(patterns.getContext()); 1290 } 1291