1 //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" 10 11 #include <memory> 12 #include <type_traits> 13 14 #include "../PassDetail.h" 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/Complex/IR/Complex.h" 17 #include "mlir/Dialect/Math/IR/Math.h" 18 #include "mlir/IR/ImplicitLocOpBuilder.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 22 using namespace mlir; 23 24 namespace { 25 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> { 26 using OpConversionPattern<complex::AbsOp>::OpConversionPattern; 27 28 LogicalResult 29 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor, 30 ConversionPatternRewriter &rewriter) const override { 31 auto loc = op.getLoc(); 32 auto type = op.getType(); 33 34 Value real = 35 rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex()); 36 Value imag = 37 rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex()); 38 Value realSqr = rewriter.create<arith::MulFOp>(loc, real, real); 39 Value imagSqr = rewriter.create<arith::MulFOp>(loc, imag, imag); 40 Value sqNorm = rewriter.create<arith::AddFOp>(loc, realSqr, imagSqr); 41 42 rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm); 43 return success(); 44 } 45 }; 46 47 // atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2)) 48 struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> { 49 using OpConversionPattern<complex::Atan2Op>::OpConversionPattern; 50 51 LogicalResult 52 matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor, 53 ConversionPatternRewriter &rewriter) const override { 54 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 55 56 auto type = op.getType().cast<ComplexType>(); 57 Type elementType = type.getElementType(); 58 59 Value lhs = adaptor.getLhs(); 60 Value rhs = adaptor.getRhs(); 61 62 Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs); 63 Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs); 64 Value rhsSquaredPlusLhsSquared = 65 b.create<complex::AddOp>(type, rhsSquared, lhsSquared); 66 Value sqrtOfRhsSquaredPlusLhsSquared = 67 b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared); 68 69 Value zero = 70 b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType)); 71 Value one = b.create<arith::ConstantOp>(elementType, 72 b.getFloatAttr(elementType, 1)); 73 Value i = b.create<complex::CreateOp>(type, zero, one); 74 Value iTimesLhs = b.create<complex::MulOp>(i, lhs); 75 Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs); 76 77 Value divResult = 78 b.create<complex::DivOp>(rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared); 79 Value logResult = b.create<complex::LogOp>(divResult); 80 81 Value negativeOne = b.create<arith::ConstantOp>( 82 elementType, b.getFloatAttr(elementType, -1)); 83 Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne); 84 85 rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult); 86 return success(); 87 } 88 }; 89 90 template <typename ComparisonOp, arith::CmpFPredicate p> 91 struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> { 92 using OpConversionPattern<ComparisonOp>::OpConversionPattern; 93 using ResultCombiner = 94 std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value, 95 arith::AndIOp, arith::OrIOp>; 96 97 LogicalResult 98 matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor, 99 ConversionPatternRewriter &rewriter) const override { 100 auto loc = op.getLoc(); 101 auto type = adaptor.getLhs() 102 .getType() 103 .template cast<ComplexType>() 104 .getElementType(); 105 106 Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs()); 107 Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs()); 108 Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getRhs()); 109 Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getRhs()); 110 Value realComparison = 111 rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs); 112 Value imagComparison = 113 rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs); 114 115 rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison, 116 imagComparison); 117 return success(); 118 } 119 }; 120 121 // Default conversion which applies the BinaryStandardOp separately on the real 122 // and imaginary parts. Can for example be used for complex::AddOp and 123 // complex::SubOp. 124 template <typename BinaryComplexOp, typename BinaryStandardOp> 125 struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> { 126 using OpConversionPattern<BinaryComplexOp>::OpConversionPattern; 127 128 LogicalResult 129 matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor, 130 ConversionPatternRewriter &rewriter) const override { 131 auto type = adaptor.getLhs().getType().template cast<ComplexType>(); 132 auto elementType = type.getElementType().template cast<FloatType>(); 133 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 134 135 Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs()); 136 Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs()); 137 Value resultReal = 138 b.create<BinaryStandardOp>(elementType, realLhs, realRhs); 139 Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs()); 140 Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs()); 141 Value resultImag = 142 b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs); 143 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 144 resultImag); 145 return success(); 146 } 147 }; 148 149 template <typename TrigonometricOp> 150 struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> { 151 using OpAdaptor = typename OpConversionPattern<TrigonometricOp>::OpAdaptor; 152 153 using OpConversionPattern<TrigonometricOp>::OpConversionPattern; 154 155 LogicalResult 156 matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor, 157 ConversionPatternRewriter &rewriter) const override { 158 auto loc = op.getLoc(); 159 auto type = adaptor.getComplex().getType().template cast<ComplexType>(); 160 auto elementType = type.getElementType().template cast<FloatType>(); 161 162 Value real = 163 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); 164 Value imag = 165 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); 166 167 // Trigonometric ops use a set of common building blocks to convert to real 168 // ops. Here we create these building blocks and call into an op-specific 169 // implementation in the subclass to combine them. 170 Value half = rewriter.create<arith::ConstantOp>( 171 loc, elementType, rewriter.getFloatAttr(elementType, 0.5)); 172 Value exp = rewriter.create<math::ExpOp>(loc, imag); 173 Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp); 174 Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp); 175 Value sin = rewriter.create<math::SinOp>(loc, real); 176 Value cos = rewriter.create<math::CosOp>(loc, real); 177 178 auto resultPair = 179 combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter); 180 181 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first, 182 resultPair.second); 183 return success(); 184 } 185 186 virtual std::pair<Value, Value> 187 combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin, 188 Value cos, ConversionPatternRewriter &rewriter) const = 0; 189 }; 190 191 struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> { 192 using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion; 193 194 std::pair<Value, Value> 195 combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin, 196 Value cos, ConversionPatternRewriter &rewriter) const override { 197 // Complex cosine is defined as; 198 // cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy))) 199 // Plugging in: 200 // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x)) 201 // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x))) 202 // and defining t := exp(y) 203 // We get: 204 // Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x 205 // Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x 206 Value sum = rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp); 207 Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos); 208 Value diff = rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp); 209 Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin); 210 return {resultReal, resultImag}; 211 } 212 }; 213 214 struct DivOpConversion : public OpConversionPattern<complex::DivOp> { 215 using OpConversionPattern<complex::DivOp>::OpConversionPattern; 216 217 LogicalResult 218 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor, 219 ConversionPatternRewriter &rewriter) const override { 220 auto loc = op.getLoc(); 221 auto type = adaptor.getLhs().getType().cast<ComplexType>(); 222 auto elementType = type.getElementType().cast<FloatType>(); 223 224 Value lhsReal = 225 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs()); 226 Value lhsImag = 227 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getLhs()); 228 Value rhsReal = 229 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getRhs()); 230 Value rhsImag = 231 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getRhs()); 232 233 // Smith's algorithm to divide complex numbers. It is just a bit smarter 234 // way to compute the following formula: 235 // (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i) 236 // = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) / 237 // ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i)) 238 // = ((lhsReal * rhsReal + lhsImag * rhsImag) + 239 // (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2 240 // 241 // Depending on whether |rhsReal| < |rhsImag| we compute either 242 // rhsRealImagRatio = rhsReal / rhsImag 243 // rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio 244 // resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom 245 // resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom 246 // 247 // or 248 // 249 // rhsImagRealRatio = rhsImag / rhsReal 250 // rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio 251 // resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom 252 // resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom 253 // 254 // See https://dl.acm.org/citation.cfm?id=368661 for more details. 255 Value rhsRealImagRatio = 256 rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag); 257 Value rhsRealImagDenom = rewriter.create<arith::AddFOp>( 258 loc, rhsImag, 259 rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal)); 260 Value realNumerator1 = rewriter.create<arith::AddFOp>( 261 loc, rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio), 262 lhsImag); 263 Value resultReal1 = 264 rewriter.create<arith::DivFOp>(loc, realNumerator1, rhsRealImagDenom); 265 Value imagNumerator1 = rewriter.create<arith::SubFOp>( 266 loc, rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio), 267 lhsReal); 268 Value resultImag1 = 269 rewriter.create<arith::DivFOp>(loc, imagNumerator1, rhsRealImagDenom); 270 271 Value rhsImagRealRatio = 272 rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal); 273 Value rhsImagRealDenom = rewriter.create<arith::AddFOp>( 274 loc, rhsReal, 275 rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag)); 276 Value realNumerator2 = rewriter.create<arith::AddFOp>( 277 loc, lhsReal, 278 rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio)); 279 Value resultReal2 = 280 rewriter.create<arith::DivFOp>(loc, realNumerator2, rhsImagRealDenom); 281 Value imagNumerator2 = rewriter.create<arith::SubFOp>( 282 loc, lhsImag, 283 rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio)); 284 Value resultImag2 = 285 rewriter.create<arith::DivFOp>(loc, imagNumerator2, rhsImagRealDenom); 286 287 // Consider corner cases. 288 // Case 1. Zero denominator, numerator contains at most one NaN value. 289 Value zero = rewriter.create<arith::ConstantOp>( 290 loc, elementType, rewriter.getZeroAttr(elementType)); 291 Value rhsRealAbs = rewriter.create<math::AbsOp>(loc, rhsReal); 292 Value rhsRealIsZero = rewriter.create<arith::CmpFOp>( 293 loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero); 294 Value rhsImagAbs = rewriter.create<math::AbsOp>(loc, rhsImag); 295 Value rhsImagIsZero = rewriter.create<arith::CmpFOp>( 296 loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero); 297 Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>( 298 loc, arith::CmpFPredicate::ORD, lhsReal, zero); 299 Value lhsImagIsNotNaN = rewriter.create<arith::CmpFOp>( 300 loc, arith::CmpFPredicate::ORD, lhsImag, zero); 301 Value lhsContainsNotNaNValue = 302 rewriter.create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN); 303 Value resultIsInfinity = rewriter.create<arith::AndIOp>( 304 loc, lhsContainsNotNaNValue, 305 rewriter.create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero)); 306 Value inf = rewriter.create<arith::ConstantOp>( 307 loc, elementType, 308 rewriter.getFloatAttr( 309 elementType, APFloat::getInf(elementType.getFloatSemantics()))); 310 Value infWithSignOfRhsReal = 311 rewriter.create<math::CopySignOp>(loc, inf, rhsReal); 312 Value infinityResultReal = 313 rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal); 314 Value infinityResultImag = 315 rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag); 316 317 // Case 2. Infinite numerator, finite denominator. 318 Value rhsRealFinite = rewriter.create<arith::CmpFOp>( 319 loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf); 320 Value rhsImagFinite = rewriter.create<arith::CmpFOp>( 321 loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf); 322 Value rhsFinite = 323 rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite); 324 Value lhsRealAbs = rewriter.create<math::AbsOp>(loc, lhsReal); 325 Value lhsRealInfinite = rewriter.create<arith::CmpFOp>( 326 loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf); 327 Value lhsImagAbs = rewriter.create<math::AbsOp>(loc, lhsImag); 328 Value lhsImagInfinite = rewriter.create<arith::CmpFOp>( 329 loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf); 330 Value lhsInfinite = 331 rewriter.create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite); 332 Value infNumFiniteDenom = 333 rewriter.create<arith::AndIOp>(loc, lhsInfinite, rhsFinite); 334 Value one = rewriter.create<arith::ConstantOp>( 335 loc, elementType, rewriter.getFloatAttr(elementType, 1)); 336 Value lhsRealIsInfWithSign = rewriter.create<math::CopySignOp>( 337 loc, rewriter.create<arith::SelectOp>(loc, lhsRealInfinite, one, zero), 338 lhsReal); 339 Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>( 340 loc, rewriter.create<arith::SelectOp>(loc, lhsImagInfinite, one, zero), 341 lhsImag); 342 Value lhsRealIsInfWithSignTimesRhsReal = 343 rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal); 344 Value lhsImagIsInfWithSignTimesRhsImag = 345 rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag); 346 Value resultReal3 = rewriter.create<arith::MulFOp>( 347 loc, inf, 348 rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal, 349 lhsImagIsInfWithSignTimesRhsImag)); 350 Value lhsRealIsInfWithSignTimesRhsImag = 351 rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag); 352 Value lhsImagIsInfWithSignTimesRhsReal = 353 rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal); 354 Value resultImag3 = rewriter.create<arith::MulFOp>( 355 loc, inf, 356 rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal, 357 lhsRealIsInfWithSignTimesRhsImag)); 358 359 // Case 3: Finite numerator, infinite denominator. 360 Value lhsRealFinite = rewriter.create<arith::CmpFOp>( 361 loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf); 362 Value lhsImagFinite = rewriter.create<arith::CmpFOp>( 363 loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf); 364 Value lhsFinite = 365 rewriter.create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite); 366 Value rhsRealInfinite = rewriter.create<arith::CmpFOp>( 367 loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf); 368 Value rhsImagInfinite = rewriter.create<arith::CmpFOp>( 369 loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf); 370 Value rhsInfinite = 371 rewriter.create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite); 372 Value finiteNumInfiniteDenom = 373 rewriter.create<arith::AndIOp>(loc, lhsFinite, rhsInfinite); 374 Value rhsRealIsInfWithSign = rewriter.create<math::CopySignOp>( 375 loc, rewriter.create<arith::SelectOp>(loc, rhsRealInfinite, one, zero), 376 rhsReal); 377 Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>( 378 loc, rewriter.create<arith::SelectOp>(loc, rhsImagInfinite, one, zero), 379 rhsImag); 380 Value rhsRealIsInfWithSignTimesLhsReal = 381 rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign); 382 Value rhsImagIsInfWithSignTimesLhsImag = 383 rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign); 384 Value resultReal4 = rewriter.create<arith::MulFOp>( 385 loc, zero, 386 rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal, 387 rhsImagIsInfWithSignTimesLhsImag)); 388 Value rhsRealIsInfWithSignTimesLhsImag = 389 rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign); 390 Value rhsImagIsInfWithSignTimesLhsReal = 391 rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign); 392 Value resultImag4 = rewriter.create<arith::MulFOp>( 393 loc, zero, 394 rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag, 395 rhsImagIsInfWithSignTimesLhsReal)); 396 397 Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>( 398 loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs); 399 Value resultReal = rewriter.create<arith::SelectOp>( 400 loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2); 401 Value resultImag = rewriter.create<arith::SelectOp>( 402 loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2); 403 Value resultRealSpecialCase3 = rewriter.create<arith::SelectOp>( 404 loc, finiteNumInfiniteDenom, resultReal4, resultReal); 405 Value resultImagSpecialCase3 = rewriter.create<arith::SelectOp>( 406 loc, finiteNumInfiniteDenom, resultImag4, resultImag); 407 Value resultRealSpecialCase2 = rewriter.create<arith::SelectOp>( 408 loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3); 409 Value resultImagSpecialCase2 = rewriter.create<arith::SelectOp>( 410 loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3); 411 Value resultRealSpecialCase1 = rewriter.create<arith::SelectOp>( 412 loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2); 413 Value resultImagSpecialCase1 = rewriter.create<arith::SelectOp>( 414 loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2); 415 416 Value resultRealIsNaN = rewriter.create<arith::CmpFOp>( 417 loc, arith::CmpFPredicate::UNO, resultReal, zero); 418 Value resultImagIsNaN = rewriter.create<arith::CmpFOp>( 419 loc, arith::CmpFPredicate::UNO, resultImag, zero); 420 Value resultIsNaN = 421 rewriter.create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN); 422 Value resultRealWithSpecialCases = rewriter.create<arith::SelectOp>( 423 loc, resultIsNaN, resultRealSpecialCase1, resultReal); 424 Value resultImagWithSpecialCases = rewriter.create<arith::SelectOp>( 425 loc, resultIsNaN, resultImagSpecialCase1, resultImag); 426 427 rewriter.replaceOpWithNewOp<complex::CreateOp>( 428 op, type, resultRealWithSpecialCases, resultImagWithSpecialCases); 429 return success(); 430 } 431 }; 432 433 struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> { 434 using OpConversionPattern<complex::ExpOp>::OpConversionPattern; 435 436 LogicalResult 437 matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor, 438 ConversionPatternRewriter &rewriter) const override { 439 auto loc = op.getLoc(); 440 auto type = adaptor.getComplex().getType().cast<ComplexType>(); 441 auto elementType = type.getElementType().cast<FloatType>(); 442 443 Value real = 444 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); 445 Value imag = 446 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); 447 Value expReal = rewriter.create<math::ExpOp>(loc, real); 448 Value cosImag = rewriter.create<math::CosOp>(loc, imag); 449 Value resultReal = rewriter.create<arith::MulFOp>(loc, expReal, cosImag); 450 Value sinImag = rewriter.create<math::SinOp>(loc, imag); 451 Value resultImag = rewriter.create<arith::MulFOp>(loc, expReal, sinImag); 452 453 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 454 resultImag); 455 return success(); 456 } 457 }; 458 459 struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> { 460 using OpConversionPattern<complex::Expm1Op>::OpConversionPattern; 461 462 LogicalResult 463 matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor, 464 ConversionPatternRewriter &rewriter) const override { 465 auto type = adaptor.getComplex().getType().cast<ComplexType>(); 466 auto elementType = type.getElementType().cast<FloatType>(); 467 468 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 469 Value exp = b.create<complex::ExpOp>(adaptor.getComplex()); 470 471 Value real = b.create<complex::ReOp>(elementType, exp); 472 Value one = b.create<arith::ConstantOp>(elementType, 473 b.getFloatAttr(elementType, 1)); 474 Value realMinusOne = b.create<arith::SubFOp>(real, one); 475 Value imag = b.create<complex::ImOp>(elementType, exp); 476 477 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realMinusOne, 478 imag); 479 return success(); 480 } 481 }; 482 483 struct LogOpConversion : public OpConversionPattern<complex::LogOp> { 484 using OpConversionPattern<complex::LogOp>::OpConversionPattern; 485 486 LogicalResult 487 matchAndRewrite(complex::LogOp op, OpAdaptor adaptor, 488 ConversionPatternRewriter &rewriter) const override { 489 auto type = adaptor.getComplex().getType().cast<ComplexType>(); 490 auto elementType = type.getElementType().cast<FloatType>(); 491 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 492 493 Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex()); 494 Value resultReal = b.create<math::LogOp>(elementType, abs); 495 Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); 496 Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); 497 Value resultImag = b.create<math::Atan2Op>(elementType, imag, real); 498 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 499 resultImag); 500 return success(); 501 } 502 }; 503 504 struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> { 505 using OpConversionPattern<complex::Log1pOp>::OpConversionPattern; 506 507 LogicalResult 508 matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor, 509 ConversionPatternRewriter &rewriter) const override { 510 auto type = adaptor.getComplex().getType().cast<ComplexType>(); 511 auto elementType = type.getElementType().cast<FloatType>(); 512 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 513 514 Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); 515 Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); 516 Value one = b.create<arith::ConstantOp>(elementType, 517 b.getFloatAttr(elementType, 1)); 518 Value realPlusOne = b.create<arith::AddFOp>(real, one); 519 Value newComplex = b.create<complex::CreateOp>(type, realPlusOne, imag); 520 rewriter.replaceOpWithNewOp<complex::LogOp>(op, type, newComplex); 521 return success(); 522 } 523 }; 524 525 struct MulOpConversion : public OpConversionPattern<complex::MulOp> { 526 using OpConversionPattern<complex::MulOp>::OpConversionPattern; 527 528 LogicalResult 529 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor, 530 ConversionPatternRewriter &rewriter) const override { 531 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 532 auto type = adaptor.getLhs().getType().cast<ComplexType>(); 533 auto elementType = type.getElementType().cast<FloatType>(); 534 535 Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs()); 536 Value lhsRealAbs = b.create<math::AbsOp>(lhsReal); 537 Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs()); 538 Value lhsImagAbs = b.create<math::AbsOp>(lhsImag); 539 Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs()); 540 Value rhsRealAbs = b.create<math::AbsOp>(rhsReal); 541 Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs()); 542 Value rhsImagAbs = b.create<math::AbsOp>(rhsImag); 543 544 Value lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal); 545 Value lhsRealTimesRhsRealAbs = b.create<math::AbsOp>(lhsRealTimesRhsReal); 546 Value lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag); 547 Value lhsImagTimesRhsImagAbs = b.create<math::AbsOp>(lhsImagTimesRhsImag); 548 Value real = 549 b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag); 550 551 Value lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal); 552 Value lhsImagTimesRhsRealAbs = b.create<math::AbsOp>(lhsImagTimesRhsReal); 553 Value lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag); 554 Value lhsRealTimesRhsImagAbs = b.create<math::AbsOp>(lhsRealTimesRhsImag); 555 Value imag = 556 b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag); 557 558 // Handle cases where the "naive" calculation results in NaN values. 559 Value realIsNan = 560 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real); 561 Value imagIsNan = 562 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, imag, imag); 563 Value isNan = b.create<arith::AndIOp>(realIsNan, imagIsNan); 564 565 Value inf = b.create<arith::ConstantOp>( 566 elementType, 567 b.getFloatAttr(elementType, 568 APFloat::getInf(elementType.getFloatSemantics()))); 569 570 // Case 1. `lhsReal` or `lhsImag` are infinite. 571 Value lhsRealIsInf = 572 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsRealAbs, inf); 573 Value lhsImagIsInf = 574 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsImagAbs, inf); 575 Value lhsIsInf = b.create<arith::OrIOp>(lhsRealIsInf, lhsImagIsInf); 576 Value rhsRealIsNan = 577 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsReal, rhsReal); 578 Value rhsImagIsNan = 579 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsImag, rhsImag); 580 Value zero = 581 b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType)); 582 Value one = b.create<arith::ConstantOp>(elementType, 583 b.getFloatAttr(elementType, 1)); 584 Value lhsRealIsInfFloat = 585 b.create<arith::SelectOp>(lhsRealIsInf, one, zero); 586 lhsReal = b.create<arith::SelectOp>( 587 lhsIsInf, b.create<math::CopySignOp>(lhsRealIsInfFloat, lhsReal), 588 lhsReal); 589 Value lhsImagIsInfFloat = 590 b.create<arith::SelectOp>(lhsImagIsInf, one, zero); 591 lhsImag = b.create<arith::SelectOp>( 592 lhsIsInf, b.create<math::CopySignOp>(lhsImagIsInfFloat, lhsImag), 593 lhsImag); 594 Value lhsIsInfAndRhsRealIsNan = 595 b.create<arith::AndIOp>(lhsIsInf, rhsRealIsNan); 596 rhsReal = b.create<arith::SelectOp>( 597 lhsIsInfAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal), 598 rhsReal); 599 Value lhsIsInfAndRhsImagIsNan = 600 b.create<arith::AndIOp>(lhsIsInf, rhsImagIsNan); 601 rhsImag = b.create<arith::SelectOp>( 602 lhsIsInfAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag), 603 rhsImag); 604 605 // Case 2. `rhsReal` or `rhsImag` are infinite. 606 Value rhsRealIsInf = 607 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsRealAbs, inf); 608 Value rhsImagIsInf = 609 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsImagAbs, inf); 610 Value rhsIsInf = b.create<arith::OrIOp>(rhsRealIsInf, rhsImagIsInf); 611 Value lhsRealIsNan = 612 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsReal, lhsReal); 613 Value lhsImagIsNan = 614 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsImag, lhsImag); 615 Value rhsRealIsInfFloat = 616 b.create<arith::SelectOp>(rhsRealIsInf, one, zero); 617 rhsReal = b.create<arith::SelectOp>( 618 rhsIsInf, b.create<math::CopySignOp>(rhsRealIsInfFloat, rhsReal), 619 rhsReal); 620 Value rhsImagIsInfFloat = 621 b.create<arith::SelectOp>(rhsImagIsInf, one, zero); 622 rhsImag = b.create<arith::SelectOp>( 623 rhsIsInf, b.create<math::CopySignOp>(rhsImagIsInfFloat, rhsImag), 624 rhsImag); 625 Value rhsIsInfAndLhsRealIsNan = 626 b.create<arith::AndIOp>(rhsIsInf, lhsRealIsNan); 627 lhsReal = b.create<arith::SelectOp>( 628 rhsIsInfAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal), 629 lhsReal); 630 Value rhsIsInfAndLhsImagIsNan = 631 b.create<arith::AndIOp>(rhsIsInf, lhsImagIsNan); 632 lhsImag = b.create<arith::SelectOp>( 633 rhsIsInfAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag), 634 lhsImag); 635 Value recalc = b.create<arith::OrIOp>(lhsIsInf, rhsIsInf); 636 637 // Case 3. One of the pairwise products of left hand side with right hand 638 // side is infinite. 639 Value lhsRealTimesRhsRealIsInf = b.create<arith::CmpFOp>( 640 arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf); 641 Value lhsImagTimesRhsImagIsInf = b.create<arith::CmpFOp>( 642 arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf); 643 Value isSpecialCase = b.create<arith::OrIOp>(lhsRealTimesRhsRealIsInf, 644 lhsImagTimesRhsImagIsInf); 645 Value lhsRealTimesRhsImagIsInf = b.create<arith::CmpFOp>( 646 arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf); 647 isSpecialCase = 648 b.create<arith::OrIOp>(isSpecialCase, lhsRealTimesRhsImagIsInf); 649 Value lhsImagTimesRhsRealIsInf = b.create<arith::CmpFOp>( 650 arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf); 651 isSpecialCase = 652 b.create<arith::OrIOp>(isSpecialCase, lhsImagTimesRhsRealIsInf); 653 Type i1Type = b.getI1Type(); 654 Value notRecalc = b.create<arith::XOrIOp>( 655 recalc, 656 b.create<arith::ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1))); 657 isSpecialCase = b.create<arith::AndIOp>(isSpecialCase, notRecalc); 658 Value isSpecialCaseAndLhsRealIsNan = 659 b.create<arith::AndIOp>(isSpecialCase, lhsRealIsNan); 660 lhsReal = b.create<arith::SelectOp>( 661 isSpecialCaseAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal), 662 lhsReal); 663 Value isSpecialCaseAndLhsImagIsNan = 664 b.create<arith::AndIOp>(isSpecialCase, lhsImagIsNan); 665 lhsImag = b.create<arith::SelectOp>( 666 isSpecialCaseAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag), 667 lhsImag); 668 Value isSpecialCaseAndRhsRealIsNan = 669 b.create<arith::AndIOp>(isSpecialCase, rhsRealIsNan); 670 rhsReal = b.create<arith::SelectOp>( 671 isSpecialCaseAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal), 672 rhsReal); 673 Value isSpecialCaseAndRhsImagIsNan = 674 b.create<arith::AndIOp>(isSpecialCase, rhsImagIsNan); 675 rhsImag = b.create<arith::SelectOp>( 676 isSpecialCaseAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag), 677 rhsImag); 678 recalc = b.create<arith::OrIOp>(recalc, isSpecialCase); 679 recalc = b.create<arith::AndIOp>(isNan, recalc); 680 681 // Recalculate real part. 682 lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal); 683 lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag); 684 Value newReal = 685 b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag); 686 real = b.create<arith::SelectOp>( 687 recalc, b.create<arith::MulFOp>(inf, newReal), real); 688 689 // Recalculate imag part. 690 lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal); 691 lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag); 692 Value newImag = 693 b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag); 694 imag = b.create<arith::SelectOp>( 695 recalc, b.create<arith::MulFOp>(inf, newImag), imag); 696 697 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag); 698 return success(); 699 } 700 }; 701 702 struct NegOpConversion : public OpConversionPattern<complex::NegOp> { 703 using OpConversionPattern<complex::NegOp>::OpConversionPattern; 704 705 LogicalResult 706 matchAndRewrite(complex::NegOp op, OpAdaptor adaptor, 707 ConversionPatternRewriter &rewriter) const override { 708 auto loc = op.getLoc(); 709 auto type = adaptor.getComplex().getType().cast<ComplexType>(); 710 auto elementType = type.getElementType().cast<FloatType>(); 711 712 Value real = 713 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); 714 Value imag = 715 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); 716 Value negReal = rewriter.create<arith::NegFOp>(loc, real); 717 Value negImag = rewriter.create<arith::NegFOp>(loc, imag); 718 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag); 719 return success(); 720 } 721 }; 722 723 struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> { 724 using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion; 725 726 std::pair<Value, Value> 727 combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin, 728 Value cos, ConversionPatternRewriter &rewriter) const override { 729 // Complex sine is defined as; 730 // sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy))) 731 // Plugging in: 732 // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x)) 733 // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x))) 734 // and defining t := exp(y) 735 // We get: 736 // Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x 737 // Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x 738 Value sum = rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp); 739 Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin); 740 Value diff = rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp); 741 Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos); 742 return {resultReal, resultImag}; 743 } 744 }; 745 746 // The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780. 747 struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> { 748 using OpConversionPattern<complex::SqrtOp>::OpConversionPattern; 749 750 LogicalResult 751 matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor, 752 ConversionPatternRewriter &rewriter) const override { 753 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 754 755 auto type = op.getType().cast<ComplexType>(); 756 Type elementType = type.getElementType(); 757 Value arg = adaptor.getComplex(); 758 759 Value zero = 760 b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType)); 761 762 Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); 763 Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); 764 765 Value absLhs = b.create<math::AbsOp>(real); 766 Value absArg = b.create<complex::AbsOp>(elementType, arg); 767 Value addAbs = b.create<arith::AddFOp>(absLhs, absArg); 768 769 Value half = b.create<arith::ConstantOp>( 770 elementType, b.getFloatAttr(elementType, 0.5)); 771 Value halfAddAbs = b.create<arith::MulFOp>(addAbs, half); 772 Value sqrtAddAbs = b.create<math::SqrtOp>(halfAddAbs); 773 774 Value realIsNegative = 775 b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, real, zero); 776 Value imagIsNegative = 777 b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, imag, zero); 778 779 Value resultReal = sqrtAddAbs; 780 781 Value imagDivTwoResultReal = b.create<arith::DivFOp>( 782 imag, b.create<arith::AddFOp>(resultReal, resultReal)); 783 784 Value negativeResultReal = b.create<arith::NegFOp>(resultReal); 785 786 Value resultImag = b.create<arith::SelectOp>( 787 realIsNegative, 788 b.create<arith::SelectOp>(imagIsNegative, negativeResultReal, 789 resultReal), 790 imagDivTwoResultReal); 791 792 resultReal = b.create<arith::SelectOp>( 793 realIsNegative, 794 b.create<arith::DivFOp>( 795 imag, b.create<arith::AddFOp>(resultImag, resultImag)), 796 resultReal); 797 798 Value realIsZero = 799 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero); 800 Value imagIsZero = 801 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero); 802 Value argIsZero = b.create<arith::AndIOp>(realIsZero, imagIsZero); 803 804 resultReal = b.create<arith::SelectOp>(argIsZero, zero, resultReal); 805 resultImag = b.create<arith::SelectOp>(argIsZero, zero, resultImag); 806 807 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 808 resultImag); 809 return success(); 810 } 811 }; 812 813 struct SignOpConversion : public OpConversionPattern<complex::SignOp> { 814 using OpConversionPattern<complex::SignOp>::OpConversionPattern; 815 816 LogicalResult 817 matchAndRewrite(complex::SignOp op, OpAdaptor adaptor, 818 ConversionPatternRewriter &rewriter) const override { 819 auto type = adaptor.getComplex().getType().cast<ComplexType>(); 820 auto elementType = type.getElementType().cast<FloatType>(); 821 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 822 823 Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); 824 Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); 825 Value zero = 826 b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType)); 827 Value realIsZero = 828 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero); 829 Value imagIsZero = 830 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero); 831 Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero); 832 auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex()); 833 Value realSign = b.create<arith::DivFOp>(real, abs); 834 Value imagSign = b.create<arith::DivFOp>(imag, abs); 835 Value sign = b.create<complex::CreateOp>(type, realSign, imagSign); 836 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero, 837 adaptor.getComplex(), sign); 838 return success(); 839 } 840 }; 841 842 struct TanOpConversion : public OpConversionPattern<complex::TanOp> { 843 using OpConversionPattern<complex::TanOp>::OpConversionPattern; 844 845 LogicalResult 846 matchAndRewrite(complex::TanOp op, OpAdaptor adaptor, 847 ConversionPatternRewriter &rewriter) const override { 848 auto loc = op.getLoc(); 849 Value cos = rewriter.create<complex::CosOp>(loc, adaptor.getComplex()); 850 Value sin = rewriter.create<complex::SinOp>(loc, adaptor.getComplex()); 851 rewriter.replaceOpWithNewOp<complex::DivOp>(op, sin, cos); 852 return success(); 853 } 854 }; 855 856 struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> { 857 using OpConversionPattern<complex::TanhOp>::OpConversionPattern; 858 859 LogicalResult 860 matchAndRewrite(complex::TanhOp op, OpAdaptor adaptor, 861 ConversionPatternRewriter &rewriter) const override { 862 auto loc = op.getLoc(); 863 auto type = adaptor.getComplex().getType().cast<ComplexType>(); 864 auto elementType = type.getElementType().cast<FloatType>(); 865 866 // The hyperbolic tangent for complex number can be calculated as follows. 867 // tanh(x + i * y) = (tanh(x) + i * tan(y)) / (1 + tanh(x) * tan(y)) 868 // See: https://proofwiki.org/wiki/Hyperbolic_Tangent_of_Complex_Number 869 Value real = 870 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); 871 Value imag = 872 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); 873 Value tanhA = rewriter.create<math::TanhOp>(loc, real); 874 Value cosB = rewriter.create<math::CosOp>(loc, imag); 875 Value sinB = rewriter.create<math::SinOp>(loc, imag); 876 Value tanB = rewriter.create<arith::DivFOp>(loc, sinB, cosB); 877 Value numerator = 878 rewriter.create<complex::CreateOp>(loc, type, tanhA, tanB); 879 Value one = rewriter.create<arith::ConstantOp>( 880 loc, elementType, rewriter.getFloatAttr(elementType, 1)); 881 Value mul = rewriter.create<arith::MulFOp>(loc, tanhA, tanB); 882 Value denominator = rewriter.create<complex::CreateOp>(loc, type, one, mul); 883 rewriter.replaceOpWithNewOp<complex::DivOp>(op, numerator, denominator); 884 return success(); 885 } 886 }; 887 888 } // namespace 889 890 void mlir::populateComplexToStandardConversionPatterns( 891 RewritePatternSet &patterns) { 892 // clang-format off 893 patterns.add< 894 AbsOpConversion, 895 Atan2OpConversion, 896 ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>, 897 ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>, 898 BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>, 899 BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>, 900 CosOpConversion, 901 DivOpConversion, 902 ExpOpConversion, 903 Expm1OpConversion, 904 LogOpConversion, 905 Log1pOpConversion, 906 MulOpConversion, 907 NegOpConversion, 908 SignOpConversion, 909 SinOpConversion, 910 SqrtOpConversion, 911 TanOpConversion, 912 TanhOpConversion>(patterns.getContext()); 913 // clang-format on 914 } 915 916 namespace { 917 struct ConvertComplexToStandardPass 918 : public ConvertComplexToStandardBase<ConvertComplexToStandardPass> { 919 void runOnOperation() override; 920 }; 921 922 void ConvertComplexToStandardPass::runOnOperation() { 923 // Convert to the Standard dialect using the converter defined above. 924 RewritePatternSet patterns(&getContext()); 925 populateComplexToStandardConversionPatterns(patterns); 926 927 ConversionTarget target(getContext()); 928 target.addLegalDialect<arith::ArithmeticDialect, math::MathDialect>(); 929 target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>(); 930 if (failed( 931 applyPartialConversion(getOperation(), target, std::move(patterns)))) 932 signalPassFailure(); 933 } 934 } // namespace 935 936 std::unique_ptr<Pass> mlir::createConvertComplexToStandardPass() { 937 return std::make_unique<ConvertComplexToStandardPass>(); 938 } 939