1 //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" 10 11 #include <memory> 12 #include <type_traits> 13 14 #include "../PassDetail.h" 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/Complex/IR/Complex.h" 17 #include "mlir/Dialect/Math/IR/Math.h" 18 #include "mlir/IR/ImplicitLocOpBuilder.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 22 using namespace mlir; 23 24 namespace { 25 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> { 26 using OpConversionPattern<complex::AbsOp>::OpConversionPattern; 27 28 LogicalResult 29 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor, 30 ConversionPatternRewriter &rewriter) const override { 31 auto loc = op.getLoc(); 32 auto type = op.getType(); 33 34 Value real = 35 rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex()); 36 Value imag = 37 rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex()); 38 Value realSqr = rewriter.create<arith::MulFOp>(loc, real, real); 39 Value imagSqr = rewriter.create<arith::MulFOp>(loc, imag, imag); 40 Value sqNorm = rewriter.create<arith::AddFOp>(loc, realSqr, imagSqr); 41 42 rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm); 43 return success(); 44 } 45 }; 46 47 // atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2)) 48 struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> { 49 using OpConversionPattern<complex::Atan2Op>::OpConversionPattern; 50 51 LogicalResult 52 matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor, 53 ConversionPatternRewriter &rewriter) const override { 54 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 55 56 auto type = op.getType().cast<ComplexType>(); 57 Type elementType = type.getElementType(); 58 59 Value lhs = adaptor.getLhs(); 60 Value rhs = adaptor.getRhs(); 61 62 Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs); 63 Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs); 64 Value rhsSquaredPlusLhsSquared = 65 b.create<complex::AddOp>(type, rhsSquared, lhsSquared); 66 Value sqrtOfRhsSquaredPlusLhsSquared = 67 b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared); 68 69 Value zero = 70 b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType)); 71 Value one = b.create<arith::ConstantOp>(elementType, 72 b.getFloatAttr(elementType, 1)); 73 Value i = b.create<complex::CreateOp>(type, zero, one); 74 Value iTimesLhs = b.create<complex::MulOp>(i, lhs); 75 Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs); 76 77 Value divResult = 78 b.create<complex::DivOp>(rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared); 79 Value logResult = b.create<complex::LogOp>(divResult); 80 81 Value negativeOne = b.create<arith::ConstantOp>( 82 elementType, b.getFloatAttr(elementType, -1)); 83 Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne); 84 85 rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult); 86 return success(); 87 } 88 }; 89 90 template <typename ComparisonOp, arith::CmpFPredicate p> 91 struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> { 92 using OpConversionPattern<ComparisonOp>::OpConversionPattern; 93 using ResultCombiner = 94 std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value, 95 arith::AndIOp, arith::OrIOp>; 96 97 LogicalResult 98 matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor, 99 ConversionPatternRewriter &rewriter) const override { 100 auto loc = op.getLoc(); 101 auto type = adaptor.getLhs() 102 .getType() 103 .template cast<ComplexType>() 104 .getElementType(); 105 106 Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs()); 107 Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs()); 108 Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getRhs()); 109 Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getRhs()); 110 Value realComparison = 111 rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs); 112 Value imagComparison = 113 rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs); 114 115 rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison, 116 imagComparison); 117 return success(); 118 } 119 }; 120 121 // Default conversion which applies the BinaryStandardOp separately on the real 122 // and imaginary parts. Can for example be used for complex::AddOp and 123 // complex::SubOp. 124 template <typename BinaryComplexOp, typename BinaryStandardOp> 125 struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> { 126 using OpConversionPattern<BinaryComplexOp>::OpConversionPattern; 127 128 LogicalResult 129 matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor, 130 ConversionPatternRewriter &rewriter) const override { 131 auto type = adaptor.getLhs().getType().template cast<ComplexType>(); 132 auto elementType = type.getElementType().template cast<FloatType>(); 133 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 134 135 Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs()); 136 Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs()); 137 Value resultReal = 138 b.create<BinaryStandardOp>(elementType, realLhs, realRhs); 139 Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs()); 140 Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs()); 141 Value resultImag = 142 b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs); 143 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 144 resultImag); 145 return success(); 146 } 147 }; 148 149 template <typename TrigonometricOp> 150 struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> { 151 using OpAdaptor = typename OpConversionPattern<TrigonometricOp>::OpAdaptor; 152 153 using OpConversionPattern<TrigonometricOp>::OpConversionPattern; 154 155 LogicalResult 156 matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor, 157 ConversionPatternRewriter &rewriter) const override { 158 auto loc = op.getLoc(); 159 auto type = adaptor.getComplex().getType().template cast<ComplexType>(); 160 auto elementType = type.getElementType().template cast<FloatType>(); 161 162 Value real = 163 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); 164 Value imag = 165 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); 166 167 // Trigonometric ops use a set of common building blocks to convert to real 168 // ops. Here we create these building blocks and call into an op-specific 169 // implementation in the subclass to combine them. 170 Value half = rewriter.create<arith::ConstantOp>( 171 loc, elementType, rewriter.getFloatAttr(elementType, 0.5)); 172 Value exp = rewriter.create<math::ExpOp>(loc, imag); 173 Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp); 174 Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp); 175 Value sin = rewriter.create<math::SinOp>(loc, real); 176 Value cos = rewriter.create<math::CosOp>(loc, real); 177 178 auto resultPair = 179 combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter); 180 181 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first, 182 resultPair.second); 183 return success(); 184 } 185 186 virtual std::pair<Value, Value> 187 combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin, 188 Value cos, ConversionPatternRewriter &rewriter) const = 0; 189 }; 190 191 struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> { 192 using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion; 193 194 std::pair<Value, Value> 195 combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin, 196 Value cos, ConversionPatternRewriter &rewriter) const override { 197 // Complex cosine is defined as; 198 // cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy))) 199 // Plugging in: 200 // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x)) 201 // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x))) 202 // and defining t := exp(y) 203 // We get: 204 // Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x 205 // Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x 206 Value sum = rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp); 207 Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos); 208 Value diff = rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp); 209 Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin); 210 return {resultReal, resultImag}; 211 } 212 }; 213 214 struct DivOpConversion : public OpConversionPattern<complex::DivOp> { 215 using OpConversionPattern<complex::DivOp>::OpConversionPattern; 216 217 LogicalResult 218 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor, 219 ConversionPatternRewriter &rewriter) const override { 220 auto loc = op.getLoc(); 221 auto type = adaptor.getLhs().getType().cast<ComplexType>(); 222 auto elementType = type.getElementType().cast<FloatType>(); 223 224 Value lhsReal = 225 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs()); 226 Value lhsImag = 227 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getLhs()); 228 Value rhsReal = 229 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getRhs()); 230 Value rhsImag = 231 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getRhs()); 232 233 // Smith's algorithm to divide complex numbers. It is just a bit smarter 234 // way to compute the following formula: 235 // (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i) 236 // = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) / 237 // ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i)) 238 // = ((lhsReal * rhsReal + lhsImag * rhsImag) + 239 // (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2 240 // 241 // Depending on whether |rhsReal| < |rhsImag| we compute either 242 // rhsRealImagRatio = rhsReal / rhsImag 243 // rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio 244 // resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom 245 // resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom 246 // 247 // or 248 // 249 // rhsImagRealRatio = rhsImag / rhsReal 250 // rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio 251 // resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom 252 // resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom 253 // 254 // See https://dl.acm.org/citation.cfm?id=368661 for more details. 255 Value rhsRealImagRatio = 256 rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag); 257 Value rhsRealImagDenom = rewriter.create<arith::AddFOp>( 258 loc, rhsImag, 259 rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal)); 260 Value realNumerator1 = rewriter.create<arith::AddFOp>( 261 loc, rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio), 262 lhsImag); 263 Value resultReal1 = 264 rewriter.create<arith::DivFOp>(loc, realNumerator1, rhsRealImagDenom); 265 Value imagNumerator1 = rewriter.create<arith::SubFOp>( 266 loc, rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio), 267 lhsReal); 268 Value resultImag1 = 269 rewriter.create<arith::DivFOp>(loc, imagNumerator1, rhsRealImagDenom); 270 271 Value rhsImagRealRatio = 272 rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal); 273 Value rhsImagRealDenom = rewriter.create<arith::AddFOp>( 274 loc, rhsReal, 275 rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag)); 276 Value realNumerator2 = rewriter.create<arith::AddFOp>( 277 loc, lhsReal, 278 rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio)); 279 Value resultReal2 = 280 rewriter.create<arith::DivFOp>(loc, realNumerator2, rhsImagRealDenom); 281 Value imagNumerator2 = rewriter.create<arith::SubFOp>( 282 loc, lhsImag, 283 rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio)); 284 Value resultImag2 = 285 rewriter.create<arith::DivFOp>(loc, imagNumerator2, rhsImagRealDenom); 286 287 // Consider corner cases. 288 // Case 1. Zero denominator, numerator contains at most one NaN value. 289 Value zero = rewriter.create<arith::ConstantOp>( 290 loc, elementType, rewriter.getZeroAttr(elementType)); 291 Value rhsRealAbs = rewriter.create<math::AbsOp>(loc, rhsReal); 292 Value rhsRealIsZero = rewriter.create<arith::CmpFOp>( 293 loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero); 294 Value rhsImagAbs = rewriter.create<math::AbsOp>(loc, rhsImag); 295 Value rhsImagIsZero = rewriter.create<arith::CmpFOp>( 296 loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero); 297 Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>( 298 loc, arith::CmpFPredicate::ORD, lhsReal, zero); 299 Value lhsImagIsNotNaN = rewriter.create<arith::CmpFOp>( 300 loc, arith::CmpFPredicate::ORD, lhsImag, zero); 301 Value lhsContainsNotNaNValue = 302 rewriter.create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN); 303 Value resultIsInfinity = rewriter.create<arith::AndIOp>( 304 loc, lhsContainsNotNaNValue, 305 rewriter.create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero)); 306 Value inf = rewriter.create<arith::ConstantOp>( 307 loc, elementType, 308 rewriter.getFloatAttr( 309 elementType, APFloat::getInf(elementType.getFloatSemantics()))); 310 Value infWithSignOfRhsReal = 311 rewriter.create<math::CopySignOp>(loc, inf, rhsReal); 312 Value infinityResultReal = 313 rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal); 314 Value infinityResultImag = 315 rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag); 316 317 // Case 2. Infinite numerator, finite denominator. 318 Value rhsRealFinite = rewriter.create<arith::CmpFOp>( 319 loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf); 320 Value rhsImagFinite = rewriter.create<arith::CmpFOp>( 321 loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf); 322 Value rhsFinite = 323 rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite); 324 Value lhsRealAbs = rewriter.create<math::AbsOp>(loc, lhsReal); 325 Value lhsRealInfinite = rewriter.create<arith::CmpFOp>( 326 loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf); 327 Value lhsImagAbs = rewriter.create<math::AbsOp>(loc, lhsImag); 328 Value lhsImagInfinite = rewriter.create<arith::CmpFOp>( 329 loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf); 330 Value lhsInfinite = 331 rewriter.create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite); 332 Value infNumFiniteDenom = 333 rewriter.create<arith::AndIOp>(loc, lhsInfinite, rhsFinite); 334 Value one = rewriter.create<arith::ConstantOp>( 335 loc, elementType, rewriter.getFloatAttr(elementType, 1)); 336 Value lhsRealIsInfWithSign = rewriter.create<math::CopySignOp>( 337 loc, rewriter.create<arith::SelectOp>(loc, lhsRealInfinite, one, zero), 338 lhsReal); 339 Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>( 340 loc, rewriter.create<arith::SelectOp>(loc, lhsImagInfinite, one, zero), 341 lhsImag); 342 Value lhsRealIsInfWithSignTimesRhsReal = 343 rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal); 344 Value lhsImagIsInfWithSignTimesRhsImag = 345 rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag); 346 Value resultReal3 = rewriter.create<arith::MulFOp>( 347 loc, inf, 348 rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal, 349 lhsImagIsInfWithSignTimesRhsImag)); 350 Value lhsRealIsInfWithSignTimesRhsImag = 351 rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag); 352 Value lhsImagIsInfWithSignTimesRhsReal = 353 rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal); 354 Value resultImag3 = rewriter.create<arith::MulFOp>( 355 loc, inf, 356 rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal, 357 lhsRealIsInfWithSignTimesRhsImag)); 358 359 // Case 3: Finite numerator, infinite denominator. 360 Value lhsRealFinite = rewriter.create<arith::CmpFOp>( 361 loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf); 362 Value lhsImagFinite = rewriter.create<arith::CmpFOp>( 363 loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf); 364 Value lhsFinite = 365 rewriter.create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite); 366 Value rhsRealInfinite = rewriter.create<arith::CmpFOp>( 367 loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf); 368 Value rhsImagInfinite = rewriter.create<arith::CmpFOp>( 369 loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf); 370 Value rhsInfinite = 371 rewriter.create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite); 372 Value finiteNumInfiniteDenom = 373 rewriter.create<arith::AndIOp>(loc, lhsFinite, rhsInfinite); 374 Value rhsRealIsInfWithSign = rewriter.create<math::CopySignOp>( 375 loc, rewriter.create<arith::SelectOp>(loc, rhsRealInfinite, one, zero), 376 rhsReal); 377 Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>( 378 loc, rewriter.create<arith::SelectOp>(loc, rhsImagInfinite, one, zero), 379 rhsImag); 380 Value rhsRealIsInfWithSignTimesLhsReal = 381 rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign); 382 Value rhsImagIsInfWithSignTimesLhsImag = 383 rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign); 384 Value resultReal4 = rewriter.create<arith::MulFOp>( 385 loc, zero, 386 rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal, 387 rhsImagIsInfWithSignTimesLhsImag)); 388 Value rhsRealIsInfWithSignTimesLhsImag = 389 rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign); 390 Value rhsImagIsInfWithSignTimesLhsReal = 391 rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign); 392 Value resultImag4 = rewriter.create<arith::MulFOp>( 393 loc, zero, 394 rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag, 395 rhsImagIsInfWithSignTimesLhsReal)); 396 397 Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>( 398 loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs); 399 Value resultReal = rewriter.create<arith::SelectOp>( 400 loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2); 401 Value resultImag = rewriter.create<arith::SelectOp>( 402 loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2); 403 Value resultRealSpecialCase3 = rewriter.create<arith::SelectOp>( 404 loc, finiteNumInfiniteDenom, resultReal4, resultReal); 405 Value resultImagSpecialCase3 = rewriter.create<arith::SelectOp>( 406 loc, finiteNumInfiniteDenom, resultImag4, resultImag); 407 Value resultRealSpecialCase2 = rewriter.create<arith::SelectOp>( 408 loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3); 409 Value resultImagSpecialCase2 = rewriter.create<arith::SelectOp>( 410 loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3); 411 Value resultRealSpecialCase1 = rewriter.create<arith::SelectOp>( 412 loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2); 413 Value resultImagSpecialCase1 = rewriter.create<arith::SelectOp>( 414 loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2); 415 416 Value resultRealIsNaN = rewriter.create<arith::CmpFOp>( 417 loc, arith::CmpFPredicate::UNO, resultReal, zero); 418 Value resultImagIsNaN = rewriter.create<arith::CmpFOp>( 419 loc, arith::CmpFPredicate::UNO, resultImag, zero); 420 Value resultIsNaN = 421 rewriter.create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN); 422 Value resultRealWithSpecialCases = rewriter.create<arith::SelectOp>( 423 loc, resultIsNaN, resultRealSpecialCase1, resultReal); 424 Value resultImagWithSpecialCases = rewriter.create<arith::SelectOp>( 425 loc, resultIsNaN, resultImagSpecialCase1, resultImag); 426 427 rewriter.replaceOpWithNewOp<complex::CreateOp>( 428 op, type, resultRealWithSpecialCases, resultImagWithSpecialCases); 429 return success(); 430 } 431 }; 432 433 struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> { 434 using OpConversionPattern<complex::ExpOp>::OpConversionPattern; 435 436 LogicalResult 437 matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor, 438 ConversionPatternRewriter &rewriter) const override { 439 auto loc = op.getLoc(); 440 auto type = adaptor.getComplex().getType().cast<ComplexType>(); 441 auto elementType = type.getElementType().cast<FloatType>(); 442 443 Value real = 444 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); 445 Value imag = 446 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); 447 Value expReal = rewriter.create<math::ExpOp>(loc, real); 448 Value cosImag = rewriter.create<math::CosOp>(loc, imag); 449 Value resultReal = rewriter.create<arith::MulFOp>(loc, expReal, cosImag); 450 Value sinImag = rewriter.create<math::SinOp>(loc, imag); 451 Value resultImag = rewriter.create<arith::MulFOp>(loc, expReal, sinImag); 452 453 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 454 resultImag); 455 return success(); 456 } 457 }; 458 459 struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> { 460 using OpConversionPattern<complex::Expm1Op>::OpConversionPattern; 461 462 LogicalResult 463 matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor, 464 ConversionPatternRewriter &rewriter) const override { 465 auto type = adaptor.getComplex().getType().cast<ComplexType>(); 466 auto elementType = type.getElementType().cast<FloatType>(); 467 468 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 469 Value exp = b.create<complex::ExpOp>(adaptor.getComplex()); 470 471 Value real = b.create<complex::ReOp>(elementType, exp); 472 Value one = b.create<arith::ConstantOp>(elementType, 473 b.getFloatAttr(elementType, 1)); 474 Value realMinusOne = b.create<arith::SubFOp>(real, one); 475 Value imag = b.create<complex::ImOp>(elementType, exp); 476 477 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realMinusOne, 478 imag); 479 return success(); 480 } 481 }; 482 483 struct LogOpConversion : public OpConversionPattern<complex::LogOp> { 484 using OpConversionPattern<complex::LogOp>::OpConversionPattern; 485 486 LogicalResult 487 matchAndRewrite(complex::LogOp op, OpAdaptor adaptor, 488 ConversionPatternRewriter &rewriter) const override { 489 auto type = adaptor.getComplex().getType().cast<ComplexType>(); 490 auto elementType = type.getElementType().cast<FloatType>(); 491 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 492 493 Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex()); 494 Value resultReal = b.create<math::LogOp>(elementType, abs); 495 Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); 496 Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); 497 Value resultImag = b.create<math::Atan2Op>(elementType, imag, real); 498 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 499 resultImag); 500 return success(); 501 } 502 }; 503 504 struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> { 505 using OpConversionPattern<complex::Log1pOp>::OpConversionPattern; 506 507 LogicalResult 508 matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor, 509 ConversionPatternRewriter &rewriter) const override { 510 auto type = adaptor.getComplex().getType().cast<ComplexType>(); 511 auto elementType = type.getElementType().cast<FloatType>(); 512 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 513 514 Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); 515 Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); 516 Value one = b.create<arith::ConstantOp>(elementType, 517 b.getFloatAttr(elementType, 1)); 518 Value realPlusOne = b.create<arith::AddFOp>(real, one); 519 Value newComplex = b.create<complex::CreateOp>(type, realPlusOne, imag); 520 rewriter.replaceOpWithNewOp<complex::LogOp>(op, type, newComplex); 521 return success(); 522 } 523 }; 524 525 struct MulOpConversion : public OpConversionPattern<complex::MulOp> { 526 using OpConversionPattern<complex::MulOp>::OpConversionPattern; 527 528 LogicalResult 529 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor, 530 ConversionPatternRewriter &rewriter) const override { 531 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 532 auto type = adaptor.getLhs().getType().cast<ComplexType>(); 533 auto elementType = type.getElementType().cast<FloatType>(); 534 535 Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs()); 536 Value lhsRealAbs = b.create<math::AbsOp>(lhsReal); 537 Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs()); 538 Value lhsImagAbs = b.create<math::AbsOp>(lhsImag); 539 Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs()); 540 Value rhsRealAbs = b.create<math::AbsOp>(rhsReal); 541 Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs()); 542 Value rhsImagAbs = b.create<math::AbsOp>(rhsImag); 543 544 Value lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal); 545 Value lhsRealTimesRhsRealAbs = b.create<math::AbsOp>(lhsRealTimesRhsReal); 546 Value lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag); 547 Value lhsImagTimesRhsImagAbs = b.create<math::AbsOp>(lhsImagTimesRhsImag); 548 Value real = 549 b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag); 550 551 Value lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal); 552 Value lhsImagTimesRhsRealAbs = b.create<math::AbsOp>(lhsImagTimesRhsReal); 553 Value lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag); 554 Value lhsRealTimesRhsImagAbs = b.create<math::AbsOp>(lhsRealTimesRhsImag); 555 Value imag = 556 b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag); 557 558 // Handle cases where the "naive" calculation results in NaN values. 559 Value realIsNan = 560 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real); 561 Value imagIsNan = 562 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, imag, imag); 563 Value isNan = b.create<arith::AndIOp>(realIsNan, imagIsNan); 564 565 Value inf = b.create<arith::ConstantOp>( 566 elementType, 567 b.getFloatAttr(elementType, 568 APFloat::getInf(elementType.getFloatSemantics()))); 569 570 // Case 1. `lhsReal` or `lhsImag` are infinite. 571 Value lhsRealIsInf = 572 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsRealAbs, inf); 573 Value lhsImagIsInf = 574 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsImagAbs, inf); 575 Value lhsIsInf = b.create<arith::OrIOp>(lhsRealIsInf, lhsImagIsInf); 576 Value rhsRealIsNan = 577 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsReal, rhsReal); 578 Value rhsImagIsNan = 579 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsImag, rhsImag); 580 Value zero = 581 b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType)); 582 Value one = b.create<arith::ConstantOp>(elementType, 583 b.getFloatAttr(elementType, 1)); 584 Value lhsRealIsInfFloat = 585 b.create<arith::SelectOp>(lhsRealIsInf, one, zero); 586 lhsReal = b.create<arith::SelectOp>( 587 lhsIsInf, b.create<math::CopySignOp>(lhsRealIsInfFloat, lhsReal), 588 lhsReal); 589 Value lhsImagIsInfFloat = 590 b.create<arith::SelectOp>(lhsImagIsInf, one, zero); 591 lhsImag = b.create<arith::SelectOp>( 592 lhsIsInf, b.create<math::CopySignOp>(lhsImagIsInfFloat, lhsImag), 593 lhsImag); 594 Value lhsIsInfAndRhsRealIsNan = 595 b.create<arith::AndIOp>(lhsIsInf, rhsRealIsNan); 596 rhsReal = b.create<arith::SelectOp>( 597 lhsIsInfAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal), 598 rhsReal); 599 Value lhsIsInfAndRhsImagIsNan = 600 b.create<arith::AndIOp>(lhsIsInf, rhsImagIsNan); 601 rhsImag = b.create<arith::SelectOp>( 602 lhsIsInfAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag), 603 rhsImag); 604 605 // Case 2. `rhsReal` or `rhsImag` are infinite. 606 Value rhsRealIsInf = 607 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsRealAbs, inf); 608 Value rhsImagIsInf = 609 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsImagAbs, inf); 610 Value rhsIsInf = b.create<arith::OrIOp>(rhsRealIsInf, rhsImagIsInf); 611 Value lhsRealIsNan = 612 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsReal, lhsReal); 613 Value lhsImagIsNan = 614 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsImag, lhsImag); 615 Value rhsRealIsInfFloat = 616 b.create<arith::SelectOp>(rhsRealIsInf, one, zero); 617 rhsReal = b.create<arith::SelectOp>( 618 rhsIsInf, b.create<math::CopySignOp>(rhsRealIsInfFloat, rhsReal), 619 rhsReal); 620 Value rhsImagIsInfFloat = 621 b.create<arith::SelectOp>(rhsImagIsInf, one, zero); 622 rhsImag = b.create<arith::SelectOp>( 623 rhsIsInf, b.create<math::CopySignOp>(rhsImagIsInfFloat, rhsImag), 624 rhsImag); 625 Value rhsIsInfAndLhsRealIsNan = 626 b.create<arith::AndIOp>(rhsIsInf, lhsRealIsNan); 627 lhsReal = b.create<arith::SelectOp>( 628 rhsIsInfAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal), 629 lhsReal); 630 Value rhsIsInfAndLhsImagIsNan = 631 b.create<arith::AndIOp>(rhsIsInf, lhsImagIsNan); 632 lhsImag = b.create<arith::SelectOp>( 633 rhsIsInfAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag), 634 lhsImag); 635 Value recalc = b.create<arith::OrIOp>(lhsIsInf, rhsIsInf); 636 637 // Case 3. One of the pairwise products of left hand side with right hand 638 // side is infinite. 639 Value lhsRealTimesRhsRealIsInf = b.create<arith::CmpFOp>( 640 arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf); 641 Value lhsImagTimesRhsImagIsInf = b.create<arith::CmpFOp>( 642 arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf); 643 Value isSpecialCase = b.create<arith::OrIOp>(lhsRealTimesRhsRealIsInf, 644 lhsImagTimesRhsImagIsInf); 645 Value lhsRealTimesRhsImagIsInf = b.create<arith::CmpFOp>( 646 arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf); 647 isSpecialCase = 648 b.create<arith::OrIOp>(isSpecialCase, lhsRealTimesRhsImagIsInf); 649 Value lhsImagTimesRhsRealIsInf = b.create<arith::CmpFOp>( 650 arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf); 651 isSpecialCase = 652 b.create<arith::OrIOp>(isSpecialCase, lhsImagTimesRhsRealIsInf); 653 Type i1Type = b.getI1Type(); 654 Value notRecalc = b.create<arith::XOrIOp>( 655 recalc, 656 b.create<arith::ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1))); 657 isSpecialCase = b.create<arith::AndIOp>(isSpecialCase, notRecalc); 658 Value isSpecialCaseAndLhsRealIsNan = 659 b.create<arith::AndIOp>(isSpecialCase, lhsRealIsNan); 660 lhsReal = b.create<arith::SelectOp>( 661 isSpecialCaseAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal), 662 lhsReal); 663 Value isSpecialCaseAndLhsImagIsNan = 664 b.create<arith::AndIOp>(isSpecialCase, lhsImagIsNan); 665 lhsImag = b.create<arith::SelectOp>( 666 isSpecialCaseAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag), 667 lhsImag); 668 Value isSpecialCaseAndRhsRealIsNan = 669 b.create<arith::AndIOp>(isSpecialCase, rhsRealIsNan); 670 rhsReal = b.create<arith::SelectOp>( 671 isSpecialCaseAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal), 672 rhsReal); 673 Value isSpecialCaseAndRhsImagIsNan = 674 b.create<arith::AndIOp>(isSpecialCase, rhsImagIsNan); 675 rhsImag = b.create<arith::SelectOp>( 676 isSpecialCaseAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag), 677 rhsImag); 678 recalc = b.create<arith::OrIOp>(recalc, isSpecialCase); 679 recalc = b.create<arith::AndIOp>(isNan, recalc); 680 681 // Recalculate real part. 682 lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal); 683 lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag); 684 Value newReal = 685 b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag); 686 real = b.create<arith::SelectOp>( 687 recalc, b.create<arith::MulFOp>(inf, newReal), real); 688 689 // Recalculate imag part. 690 lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal); 691 lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag); 692 Value newImag = 693 b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag); 694 imag = b.create<arith::SelectOp>( 695 recalc, b.create<arith::MulFOp>(inf, newImag), imag); 696 697 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag); 698 return success(); 699 } 700 }; 701 702 struct NegOpConversion : public OpConversionPattern<complex::NegOp> { 703 using OpConversionPattern<complex::NegOp>::OpConversionPattern; 704 705 LogicalResult 706 matchAndRewrite(complex::NegOp op, OpAdaptor adaptor, 707 ConversionPatternRewriter &rewriter) const override { 708 auto loc = op.getLoc(); 709 auto type = adaptor.getComplex().getType().cast<ComplexType>(); 710 auto elementType = type.getElementType().cast<FloatType>(); 711 712 Value real = 713 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); 714 Value imag = 715 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); 716 Value negReal = rewriter.create<arith::NegFOp>(loc, real); 717 Value negImag = rewriter.create<arith::NegFOp>(loc, imag); 718 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag); 719 return success(); 720 } 721 }; 722 723 struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> { 724 using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion; 725 726 std::pair<Value, Value> 727 combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin, 728 Value cos, ConversionPatternRewriter &rewriter) const override { 729 // Complex sine is defined as; 730 // sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy))) 731 // Plugging in: 732 // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x)) 733 // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x))) 734 // and defining t := exp(y) 735 // We get: 736 // Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x 737 // Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x 738 Value sum = rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp); 739 Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin); 740 Value diff = rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp); 741 Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos); 742 return {resultReal, resultImag}; 743 } 744 }; 745 746 // The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780. 747 struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> { 748 using OpConversionPattern<complex::SqrtOp>::OpConversionPattern; 749 750 LogicalResult 751 matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor, 752 ConversionPatternRewriter &rewriter) const override { 753 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 754 755 auto type = op.getType().cast<ComplexType>(); 756 Type elementType = type.getElementType(); 757 Value arg = adaptor.getComplex(); 758 759 Value zero = 760 b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType)); 761 762 Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); 763 Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); 764 765 Value absLhs = b.create<math::AbsOp>(real); 766 Value absArg = b.create<complex::AbsOp>(elementType, arg); 767 Value addAbs = b.create<arith::AddFOp>(absLhs, absArg); 768 769 Value half = b.create<arith::ConstantOp>(elementType, 770 b.getFloatAttr(elementType, 0.5)); 771 Value halfAddAbs = b.create<arith::MulFOp>(addAbs, half); 772 Value sqrtAddAbs = b.create<math::SqrtOp>(halfAddAbs); 773 774 Value realIsNegative = 775 b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, real, zero); 776 Value imagIsNegative = 777 b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, imag, zero); 778 779 Value resultReal = sqrtAddAbs; 780 781 Value imagDivTwoResultReal = b.create<arith::DivFOp>( 782 imag, b.create<arith::AddFOp>(resultReal, resultReal)); 783 784 Value negativeResultReal = b.create<arith::NegFOp>(resultReal); 785 786 Value resultImag = b.create<arith::SelectOp>( 787 realIsNegative, 788 b.create<arith::SelectOp>(imagIsNegative, negativeResultReal, 789 resultReal), 790 imagDivTwoResultReal); 791 792 resultReal = b.create<arith::SelectOp>( 793 realIsNegative, 794 b.create<arith::DivFOp>( 795 imag, b.create<arith::AddFOp>(resultImag, resultImag)), 796 resultReal); 797 798 Value realIsZero = 799 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero); 800 Value imagIsZero = 801 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero); 802 Value argIsZero = b.create<arith::AndIOp>(realIsZero, imagIsZero); 803 804 resultReal = b.create<arith::SelectOp>(argIsZero, zero, resultReal); 805 resultImag = b.create<arith::SelectOp>(argIsZero, zero, resultImag); 806 807 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 808 resultImag); 809 return success(); 810 } 811 }; 812 813 struct SignOpConversion : public OpConversionPattern<complex::SignOp> { 814 using OpConversionPattern<complex::SignOp>::OpConversionPattern; 815 816 LogicalResult 817 matchAndRewrite(complex::SignOp op, OpAdaptor adaptor, 818 ConversionPatternRewriter &rewriter) const override { 819 auto type = adaptor.getComplex().getType().cast<ComplexType>(); 820 auto elementType = type.getElementType().cast<FloatType>(); 821 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 822 823 Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); 824 Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); 825 Value zero = 826 b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType)); 827 Value realIsZero = 828 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero); 829 Value imagIsZero = 830 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero); 831 Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero); 832 auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex()); 833 Value realSign = b.create<arith::DivFOp>(real, abs); 834 Value imagSign = b.create<arith::DivFOp>(imag, abs); 835 Value sign = b.create<complex::CreateOp>(type, realSign, imagSign); 836 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero, 837 adaptor.getComplex(), sign); 838 return success(); 839 } 840 }; 841 842 struct TanOpConversion : public OpConversionPattern<complex::TanOp> { 843 using OpConversionPattern<complex::TanOp>::OpConversionPattern; 844 845 LogicalResult 846 matchAndRewrite(complex::TanOp op, OpAdaptor adaptor, 847 ConversionPatternRewriter &rewriter) const override { 848 auto loc = op.getLoc(); 849 Value cos = rewriter.create<complex::CosOp>(loc, adaptor.getComplex()); 850 Value sin = rewriter.create<complex::SinOp>(loc, adaptor.getComplex()); 851 rewriter.replaceOpWithNewOp<complex::DivOp>(op, sin, cos); 852 return success(); 853 } 854 }; 855 856 struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> { 857 using OpConversionPattern<complex::TanhOp>::OpConversionPattern; 858 859 LogicalResult 860 matchAndRewrite(complex::TanhOp op, OpAdaptor adaptor, 861 ConversionPatternRewriter &rewriter) const override { 862 auto loc = op.getLoc(); 863 auto type = adaptor.getComplex().getType().cast<ComplexType>(); 864 auto elementType = type.getElementType().cast<FloatType>(); 865 866 // The hyperbolic tangent for complex number can be calculated as follows. 867 // tanh(x + i * y) = (tanh(x) + i * tan(y)) / (1 + tanh(x) * tan(y)) 868 // See: https://proofwiki.org/wiki/Hyperbolic_Tangent_of_Complex_Number 869 Value real = 870 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); 871 Value imag = 872 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); 873 Value tanhA = rewriter.create<math::TanhOp>(loc, real); 874 Value cosB = rewriter.create<math::CosOp>(loc, imag); 875 Value sinB = rewriter.create<math::SinOp>(loc, imag); 876 Value tanB = rewriter.create<arith::DivFOp>(loc, sinB, cosB); 877 Value numerator = 878 rewriter.create<complex::CreateOp>(loc, type, tanhA, tanB); 879 Value one = rewriter.create<arith::ConstantOp>( 880 loc, elementType, rewriter.getFloatAttr(elementType, 1)); 881 Value mul = rewriter.create<arith::MulFOp>(loc, tanhA, tanB); 882 Value denominator = rewriter.create<complex::CreateOp>(loc, type, one, mul); 883 rewriter.replaceOpWithNewOp<complex::DivOp>(op, numerator, denominator); 884 return success(); 885 } 886 }; 887 888 struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> { 889 using OpConversionPattern<complex::ConjOp>::OpConversionPattern; 890 891 LogicalResult 892 matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor, 893 ConversionPatternRewriter &rewriter) const override { 894 auto loc = op.getLoc(); 895 auto type = adaptor.getComplex().getType().cast<ComplexType>(); 896 auto elementType = type.getElementType().cast<FloatType>(); 897 Value real = 898 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); 899 Value imag = 900 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); 901 Value negImag = rewriter.create<arith::NegFOp>(loc, elementType, imag); 902 903 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, negImag); 904 905 return success(); 906 } 907 }; 908 909 /// Coverts x^y = (a+bi)^(c+di) to 910 /// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)), 911 /// where q = c*atan2(b,a)+0.5d*ln(a*a+b*b) 912 static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder, 913 ComplexType type, Value a, Value b, Value c, 914 Value d) { 915 auto elementType = type.getElementType().cast<FloatType>(); 916 917 // Compute (a*a+b*b)^(0.5c). 918 Value aaPbb = builder.create<arith::AddFOp>( 919 builder.create<arith::MulFOp>(a, a), builder.create<arith::MulFOp>(b, b)); 920 Value half = builder.create<arith::ConstantOp>( 921 elementType, builder.getFloatAttr(elementType, 0.5)); 922 Value halfC = builder.create<arith::MulFOp>(half, c); 923 Value aaPbbTohalfC = builder.create<math::PowFOp>(aaPbb, halfC); 924 925 // Compute exp(-d*atan2(b,a)). 926 Value negD = builder.create<arith::NegFOp>(d); 927 Value argX = builder.create<math::Atan2Op>(b, a); 928 Value negDArgX = builder.create<arith::MulFOp>(negD, argX); 929 Value eToNegDArgX = builder.create<math::ExpOp>(negDArgX); 930 931 // Compute (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)). 932 Value coeff = builder.create<arith::MulFOp>(aaPbbTohalfC, eToNegDArgX); 933 934 // Compute c*atan2(b,a)+0.5d*ln(a*a+b*b). 935 Value lnAaPbb = builder.create<math::LogOp>(aaPbb); 936 Value halfD = builder.create<arith::MulFOp>(half, d); 937 Value q = builder.create<arith::AddFOp>( 938 builder.create<arith::MulFOp>(c, argX), 939 builder.create<arith::MulFOp>(halfD, lnAaPbb)); 940 941 Value cosQ = builder.create<math::CosOp>(q); 942 Value sinQ = builder.create<math::SinOp>(q); 943 Value zero = builder.create<arith::ConstantOp>( 944 elementType, builder.getFloatAttr(elementType, 0)); 945 Value one = builder.create<arith::ConstantOp>( 946 elementType, builder.getFloatAttr(elementType, 1)); 947 948 Value xEqZero = 949 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, aaPbb, zero); 950 Value yGeZero = builder.create<arith::AndIOp>( 951 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, c, zero), 952 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, d, zero)); 953 Value cEqZero = 954 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, c, zero); 955 Value complexZero = builder.create<complex::CreateOp>(type, zero, zero); 956 Value complexOne = builder.create<complex::CreateOp>(type, one, zero); 957 Value complexOther = builder.create<complex::CreateOp>( 958 type, builder.create<arith::MulFOp>(coeff, cosQ), 959 builder.create<arith::MulFOp>(coeff, sinQ)); 960 961 // x^y is 0 if x is 0 and y > 0. 0^0 is defined to be 1.0, see 962 // Branch Cuts for Complex Elementary Functions or Much Ado About 963 // Nothing's Sign Bit, W. Kahan, Section 10. 964 return builder.create<arith::SelectOp>( 965 builder.create<arith::AndIOp>(xEqZero, yGeZero), 966 builder.create<arith::SelectOp>(cEqZero, complexOne, complexZero), 967 complexOther); 968 } 969 970 struct PowOpConversion : public OpConversionPattern<complex::PowOp> { 971 using OpConversionPattern<complex::PowOp>::OpConversionPattern; 972 973 LogicalResult 974 matchAndRewrite(complex::PowOp op, OpAdaptor adaptor, 975 ConversionPatternRewriter &rewriter) const override { 976 mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter); 977 auto type = adaptor.getLhs().getType().cast<ComplexType>(); 978 auto elementType = type.getElementType().cast<FloatType>(); 979 980 Value a = builder.create<complex::ReOp>(elementType, adaptor.getLhs()); 981 Value b = builder.create<complex::ImOp>(elementType, adaptor.getLhs()); 982 Value c = builder.create<complex::ReOp>(elementType, adaptor.getRhs()); 983 Value d = builder.create<complex::ImOp>(elementType, adaptor.getRhs()); 984 985 rewriter.replaceOp(op, {powOpConversionImpl(builder, type, a, b, c, d)}); 986 return success(); 987 } 988 }; 989 990 struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> { 991 using OpConversionPattern<complex::RsqrtOp>::OpConversionPattern; 992 993 LogicalResult 994 matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor, 995 ConversionPatternRewriter &rewriter) const override { 996 mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter); 997 auto type = adaptor.getComplex().getType().cast<ComplexType>(); 998 auto elementType = type.getElementType().cast<FloatType>(); 999 1000 Value a = builder.create<complex::ReOp>(elementType, adaptor.getComplex()); 1001 Value b = builder.create<complex::ImOp>(elementType, adaptor.getComplex()); 1002 Value c = builder.create<arith::ConstantOp>( 1003 elementType, builder.getFloatAttr(elementType, -0.5)); 1004 Value d = builder.create<arith::ConstantOp>( 1005 elementType, builder.getFloatAttr(elementType, 0)); 1006 1007 rewriter.replaceOp(op, {powOpConversionImpl(builder, type, a, b, c, d)}); 1008 return success(); 1009 } 1010 }; 1011 1012 struct AngleOpConversion : public OpConversionPattern<complex::AngleOp> { 1013 using OpConversionPattern<complex::AngleOp>::OpConversionPattern; 1014 1015 LogicalResult 1016 matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor, 1017 ConversionPatternRewriter &rewriter) const override { 1018 auto loc = op.getLoc(); 1019 auto type = op.getType(); 1020 1021 Value real = 1022 rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex()); 1023 Value imag = 1024 rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex()); 1025 1026 rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real); 1027 1028 return success(); 1029 } 1030 }; 1031 1032 } // namespace 1033 1034 void mlir::populateComplexToStandardConversionPatterns( 1035 RewritePatternSet &patterns) { 1036 // clang-format off 1037 patterns.add< 1038 AbsOpConversion, 1039 AngleOpConversion, 1040 Atan2OpConversion, 1041 BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>, 1042 BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>, 1043 ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>, 1044 ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>, 1045 ConjOpConversion, 1046 CosOpConversion, 1047 DivOpConversion, 1048 ExpOpConversion, 1049 Expm1OpConversion, 1050 Log1pOpConversion, 1051 LogOpConversion, 1052 MulOpConversion, 1053 NegOpConversion, 1054 SignOpConversion, 1055 SinOpConversion, 1056 SqrtOpConversion, 1057 TanOpConversion, 1058 TanhOpConversion, 1059 PowOpConversion, 1060 RsqrtOpConversion 1061 >(patterns.getContext()); 1062 // clang-format on 1063 } 1064 1065 namespace { 1066 struct ConvertComplexToStandardPass 1067 : public ConvertComplexToStandardBase<ConvertComplexToStandardPass> { 1068 void runOnOperation() override; 1069 }; 1070 1071 void ConvertComplexToStandardPass::runOnOperation() { 1072 // Convert to the Standard dialect using the converter defined above. 1073 RewritePatternSet patterns(&getContext()); 1074 populateComplexToStandardConversionPatterns(patterns); 1075 1076 ConversionTarget target(getContext()); 1077 target.addLegalDialect<arith::ArithmeticDialect, math::MathDialect>(); 1078 target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>(); 1079 if (failed( 1080 applyPartialConversion(getOperation(), target, std::move(patterns)))) 1081 signalPassFailure(); 1082 } 1083 } // namespace 1084 1085 std::unique_ptr<Pass> mlir::createConvertComplexToStandardPass() { 1086 return std::make_unique<ConvertComplexToStandardPass>(); 1087 } 1088