1 //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" 10 11 #include <memory> 12 #include <type_traits> 13 14 #include "../PassDetail.h" 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/Complex/IR/Complex.h" 17 #include "mlir/Dialect/Math/IR/Math.h" 18 #include "mlir/IR/ImplicitLocOpBuilder.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 22 using namespace mlir; 23 24 namespace { 25 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> { 26 using OpConversionPattern<complex::AbsOp>::OpConversionPattern; 27 28 LogicalResult 29 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor, 30 ConversionPatternRewriter &rewriter) const override { 31 auto loc = op.getLoc(); 32 auto type = op.getType(); 33 34 Value real = 35 rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex()); 36 Value imag = 37 rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex()); 38 Value realSqr = rewriter.create<arith::MulFOp>(loc, real, real); 39 Value imagSqr = rewriter.create<arith::MulFOp>(loc, imag, imag); 40 Value sqNorm = rewriter.create<arith::AddFOp>(loc, realSqr, imagSqr); 41 42 rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm); 43 return success(); 44 } 45 }; 46 47 template <typename ComparisonOp, arith::CmpFPredicate p> 48 struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> { 49 using OpConversionPattern<ComparisonOp>::OpConversionPattern; 50 using ResultCombiner = 51 std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value, 52 arith::AndIOp, arith::OrIOp>; 53 54 LogicalResult 55 matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor, 56 ConversionPatternRewriter &rewriter) const override { 57 auto loc = op.getLoc(); 58 auto type = adaptor.getLhs() 59 .getType() 60 .template cast<ComplexType>() 61 .getElementType(); 62 63 Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs()); 64 Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs()); 65 Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getRhs()); 66 Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getRhs()); 67 Value realComparison = 68 rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs); 69 Value imagComparison = 70 rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs); 71 72 rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison, 73 imagComparison); 74 return success(); 75 } 76 }; 77 78 // Default conversion which applies the BinaryStandardOp separately on the real 79 // and imaginary parts. Can for example be used for complex::AddOp and 80 // complex::SubOp. 81 template <typename BinaryComplexOp, typename BinaryStandardOp> 82 struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> { 83 using OpConversionPattern<BinaryComplexOp>::OpConversionPattern; 84 85 LogicalResult 86 matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor, 87 ConversionPatternRewriter &rewriter) const override { 88 auto type = adaptor.getLhs().getType().template cast<ComplexType>(); 89 auto elementType = type.getElementType().template cast<FloatType>(); 90 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 91 92 Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs()); 93 Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs()); 94 Value resultReal = 95 b.create<BinaryStandardOp>(elementType, realLhs, realRhs); 96 Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs()); 97 Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs()); 98 Value resultImag = 99 b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs); 100 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 101 resultImag); 102 return success(); 103 } 104 }; 105 106 template <typename TrigonometricOp> 107 struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> { 108 using OpAdaptor = typename OpConversionPattern<TrigonometricOp>::OpAdaptor; 109 110 using OpConversionPattern<TrigonometricOp>::OpConversionPattern; 111 112 LogicalResult 113 matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor, 114 ConversionPatternRewriter &rewriter) const override { 115 auto loc = op.getLoc(); 116 auto type = adaptor.getComplex().getType().template cast<ComplexType>(); 117 auto elementType = type.getElementType().template cast<FloatType>(); 118 119 Value real = 120 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); 121 Value imag = 122 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); 123 124 // Trigonometric ops use a set of common building blocks to convert to real 125 // ops. Here we create these building blocks and call into an op-specific 126 // implementation in the subclass to combine them. 127 Value half = rewriter.create<arith::ConstantOp>( 128 loc, elementType, rewriter.getFloatAttr(elementType, 0.5)); 129 Value exp = rewriter.create<math::ExpOp>(loc, imag); 130 Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp); 131 Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp); 132 Value sin = rewriter.create<math::SinOp>(loc, real); 133 Value cos = rewriter.create<math::CosOp>(loc, real); 134 135 auto resultPair = 136 combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter); 137 138 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first, 139 resultPair.second); 140 return success(); 141 } 142 143 virtual std::pair<Value, Value> 144 combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin, 145 Value cos, ConversionPatternRewriter &rewriter) const = 0; 146 }; 147 148 struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> { 149 using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion; 150 151 std::pair<Value, Value> 152 combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin, 153 Value cos, ConversionPatternRewriter &rewriter) const override { 154 // Complex cosine is defined as; 155 // cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy))) 156 // Plugging in: 157 // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x)) 158 // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x))) 159 // and defining t := exp(y) 160 // We get: 161 // Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x 162 // Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x 163 Value sum = rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp); 164 Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos); 165 Value diff = rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp); 166 Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin); 167 return {resultReal, resultImag}; 168 } 169 }; 170 171 struct DivOpConversion : public OpConversionPattern<complex::DivOp> { 172 using OpConversionPattern<complex::DivOp>::OpConversionPattern; 173 174 LogicalResult 175 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor, 176 ConversionPatternRewriter &rewriter) const override { 177 auto loc = op.getLoc(); 178 auto type = adaptor.getLhs().getType().cast<ComplexType>(); 179 auto elementType = type.getElementType().cast<FloatType>(); 180 181 Value lhsReal = 182 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs()); 183 Value lhsImag = 184 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getLhs()); 185 Value rhsReal = 186 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getRhs()); 187 Value rhsImag = 188 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getRhs()); 189 190 // Smith's algorithm to divide complex numbers. It is just a bit smarter 191 // way to compute the following formula: 192 // (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i) 193 // = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) / 194 // ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i)) 195 // = ((lhsReal * rhsReal + lhsImag * rhsImag) + 196 // (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2 197 // 198 // Depending on whether |rhsReal| < |rhsImag| we compute either 199 // rhsRealImagRatio = rhsReal / rhsImag 200 // rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio 201 // resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom 202 // resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom 203 // 204 // or 205 // 206 // rhsImagRealRatio = rhsImag / rhsReal 207 // rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio 208 // resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom 209 // resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom 210 // 211 // See https://dl.acm.org/citation.cfm?id=368661 for more details. 212 Value rhsRealImagRatio = 213 rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag); 214 Value rhsRealImagDenom = rewriter.create<arith::AddFOp>( 215 loc, rhsImag, 216 rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal)); 217 Value realNumerator1 = rewriter.create<arith::AddFOp>( 218 loc, rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio), 219 lhsImag); 220 Value resultReal1 = 221 rewriter.create<arith::DivFOp>(loc, realNumerator1, rhsRealImagDenom); 222 Value imagNumerator1 = rewriter.create<arith::SubFOp>( 223 loc, rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio), 224 lhsReal); 225 Value resultImag1 = 226 rewriter.create<arith::DivFOp>(loc, imagNumerator1, rhsRealImagDenom); 227 228 Value rhsImagRealRatio = 229 rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal); 230 Value rhsImagRealDenom = rewriter.create<arith::AddFOp>( 231 loc, rhsReal, 232 rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag)); 233 Value realNumerator2 = rewriter.create<arith::AddFOp>( 234 loc, lhsReal, 235 rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio)); 236 Value resultReal2 = 237 rewriter.create<arith::DivFOp>(loc, realNumerator2, rhsImagRealDenom); 238 Value imagNumerator2 = rewriter.create<arith::SubFOp>( 239 loc, lhsImag, 240 rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio)); 241 Value resultImag2 = 242 rewriter.create<arith::DivFOp>(loc, imagNumerator2, rhsImagRealDenom); 243 244 // Consider corner cases. 245 // Case 1. Zero denominator, numerator contains at most one NaN value. 246 Value zero = rewriter.create<arith::ConstantOp>( 247 loc, elementType, rewriter.getZeroAttr(elementType)); 248 Value rhsRealAbs = rewriter.create<math::AbsOp>(loc, rhsReal); 249 Value rhsRealIsZero = rewriter.create<arith::CmpFOp>( 250 loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero); 251 Value rhsImagAbs = rewriter.create<math::AbsOp>(loc, rhsImag); 252 Value rhsImagIsZero = rewriter.create<arith::CmpFOp>( 253 loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero); 254 Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>( 255 loc, arith::CmpFPredicate::ORD, lhsReal, zero); 256 Value lhsImagIsNotNaN = rewriter.create<arith::CmpFOp>( 257 loc, arith::CmpFPredicate::ORD, lhsImag, zero); 258 Value lhsContainsNotNaNValue = 259 rewriter.create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN); 260 Value resultIsInfinity = rewriter.create<arith::AndIOp>( 261 loc, lhsContainsNotNaNValue, 262 rewriter.create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero)); 263 Value inf = rewriter.create<arith::ConstantOp>( 264 loc, elementType, 265 rewriter.getFloatAttr( 266 elementType, APFloat::getInf(elementType.getFloatSemantics()))); 267 Value infWithSignOfRhsReal = 268 rewriter.create<math::CopySignOp>(loc, inf, rhsReal); 269 Value infinityResultReal = 270 rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal); 271 Value infinityResultImag = 272 rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag); 273 274 // Case 2. Infinite numerator, finite denominator. 275 Value rhsRealFinite = rewriter.create<arith::CmpFOp>( 276 loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf); 277 Value rhsImagFinite = rewriter.create<arith::CmpFOp>( 278 loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf); 279 Value rhsFinite = 280 rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite); 281 Value lhsRealAbs = rewriter.create<math::AbsOp>(loc, lhsReal); 282 Value lhsRealInfinite = rewriter.create<arith::CmpFOp>( 283 loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf); 284 Value lhsImagAbs = rewriter.create<math::AbsOp>(loc, lhsImag); 285 Value lhsImagInfinite = rewriter.create<arith::CmpFOp>( 286 loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf); 287 Value lhsInfinite = 288 rewriter.create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite); 289 Value infNumFiniteDenom = 290 rewriter.create<arith::AndIOp>(loc, lhsInfinite, rhsFinite); 291 Value one = rewriter.create<arith::ConstantOp>( 292 loc, elementType, rewriter.getFloatAttr(elementType, 1)); 293 Value lhsRealIsInfWithSign = rewriter.create<math::CopySignOp>( 294 loc, rewriter.create<arith::SelectOp>(loc, lhsRealInfinite, one, zero), 295 lhsReal); 296 Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>( 297 loc, rewriter.create<arith::SelectOp>(loc, lhsImagInfinite, one, zero), 298 lhsImag); 299 Value lhsRealIsInfWithSignTimesRhsReal = 300 rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal); 301 Value lhsImagIsInfWithSignTimesRhsImag = 302 rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag); 303 Value resultReal3 = rewriter.create<arith::MulFOp>( 304 loc, inf, 305 rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal, 306 lhsImagIsInfWithSignTimesRhsImag)); 307 Value lhsRealIsInfWithSignTimesRhsImag = 308 rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag); 309 Value lhsImagIsInfWithSignTimesRhsReal = 310 rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal); 311 Value resultImag3 = rewriter.create<arith::MulFOp>( 312 loc, inf, 313 rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal, 314 lhsRealIsInfWithSignTimesRhsImag)); 315 316 // Case 3: Finite numerator, infinite denominator. 317 Value lhsRealFinite = rewriter.create<arith::CmpFOp>( 318 loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf); 319 Value lhsImagFinite = rewriter.create<arith::CmpFOp>( 320 loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf); 321 Value lhsFinite = 322 rewriter.create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite); 323 Value rhsRealInfinite = rewriter.create<arith::CmpFOp>( 324 loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf); 325 Value rhsImagInfinite = rewriter.create<arith::CmpFOp>( 326 loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf); 327 Value rhsInfinite = 328 rewriter.create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite); 329 Value finiteNumInfiniteDenom = 330 rewriter.create<arith::AndIOp>(loc, lhsFinite, rhsInfinite); 331 Value rhsRealIsInfWithSign = rewriter.create<math::CopySignOp>( 332 loc, rewriter.create<arith::SelectOp>(loc, rhsRealInfinite, one, zero), 333 rhsReal); 334 Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>( 335 loc, rewriter.create<arith::SelectOp>(loc, rhsImagInfinite, one, zero), 336 rhsImag); 337 Value rhsRealIsInfWithSignTimesLhsReal = 338 rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign); 339 Value rhsImagIsInfWithSignTimesLhsImag = 340 rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign); 341 Value resultReal4 = rewriter.create<arith::MulFOp>( 342 loc, zero, 343 rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal, 344 rhsImagIsInfWithSignTimesLhsImag)); 345 Value rhsRealIsInfWithSignTimesLhsImag = 346 rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign); 347 Value rhsImagIsInfWithSignTimesLhsReal = 348 rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign); 349 Value resultImag4 = rewriter.create<arith::MulFOp>( 350 loc, zero, 351 rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag, 352 rhsImagIsInfWithSignTimesLhsReal)); 353 354 Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>( 355 loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs); 356 Value resultReal = rewriter.create<arith::SelectOp>( 357 loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2); 358 Value resultImag = rewriter.create<arith::SelectOp>( 359 loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2); 360 Value resultRealSpecialCase3 = rewriter.create<arith::SelectOp>( 361 loc, finiteNumInfiniteDenom, resultReal4, resultReal); 362 Value resultImagSpecialCase3 = rewriter.create<arith::SelectOp>( 363 loc, finiteNumInfiniteDenom, resultImag4, resultImag); 364 Value resultRealSpecialCase2 = rewriter.create<arith::SelectOp>( 365 loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3); 366 Value resultImagSpecialCase2 = rewriter.create<arith::SelectOp>( 367 loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3); 368 Value resultRealSpecialCase1 = rewriter.create<arith::SelectOp>( 369 loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2); 370 Value resultImagSpecialCase1 = rewriter.create<arith::SelectOp>( 371 loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2); 372 373 Value resultRealIsNaN = rewriter.create<arith::CmpFOp>( 374 loc, arith::CmpFPredicate::UNO, resultReal, zero); 375 Value resultImagIsNaN = rewriter.create<arith::CmpFOp>( 376 loc, arith::CmpFPredicate::UNO, resultImag, zero); 377 Value resultIsNaN = 378 rewriter.create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN); 379 Value resultRealWithSpecialCases = rewriter.create<arith::SelectOp>( 380 loc, resultIsNaN, resultRealSpecialCase1, resultReal); 381 Value resultImagWithSpecialCases = rewriter.create<arith::SelectOp>( 382 loc, resultIsNaN, resultImagSpecialCase1, resultImag); 383 384 rewriter.replaceOpWithNewOp<complex::CreateOp>( 385 op, type, resultRealWithSpecialCases, resultImagWithSpecialCases); 386 return success(); 387 } 388 }; 389 390 struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> { 391 using OpConversionPattern<complex::ExpOp>::OpConversionPattern; 392 393 LogicalResult 394 matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor, 395 ConversionPatternRewriter &rewriter) const override { 396 auto loc = op.getLoc(); 397 auto type = adaptor.getComplex().getType().cast<ComplexType>(); 398 auto elementType = type.getElementType().cast<FloatType>(); 399 400 Value real = 401 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); 402 Value imag = 403 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); 404 Value expReal = rewriter.create<math::ExpOp>(loc, real); 405 Value cosImag = rewriter.create<math::CosOp>(loc, imag); 406 Value resultReal = rewriter.create<arith::MulFOp>(loc, expReal, cosImag); 407 Value sinImag = rewriter.create<math::SinOp>(loc, imag); 408 Value resultImag = rewriter.create<arith::MulFOp>(loc, expReal, sinImag); 409 410 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 411 resultImag); 412 return success(); 413 } 414 }; 415 416 struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> { 417 using OpConversionPattern<complex::Expm1Op>::OpConversionPattern; 418 419 LogicalResult 420 matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor, 421 ConversionPatternRewriter &rewriter) const override { 422 auto type = adaptor.getComplex().getType().cast<ComplexType>(); 423 auto elementType = type.getElementType().cast<FloatType>(); 424 425 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 426 Value exp = b.create<complex::ExpOp>(adaptor.getComplex()); 427 428 Value real = b.create<complex::ReOp>(elementType, exp); 429 Value one = b.create<arith::ConstantOp>(elementType, 430 b.getFloatAttr(elementType, 1)); 431 Value realMinusOne = b.create<arith::SubFOp>(real, one); 432 Value imag = b.create<complex::ImOp>(elementType, exp); 433 434 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realMinusOne, 435 imag); 436 return success(); 437 } 438 }; 439 440 struct LogOpConversion : public OpConversionPattern<complex::LogOp> { 441 using OpConversionPattern<complex::LogOp>::OpConversionPattern; 442 443 LogicalResult 444 matchAndRewrite(complex::LogOp op, OpAdaptor adaptor, 445 ConversionPatternRewriter &rewriter) const override { 446 auto type = adaptor.getComplex().getType().cast<ComplexType>(); 447 auto elementType = type.getElementType().cast<FloatType>(); 448 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 449 450 Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex()); 451 Value resultReal = b.create<math::LogOp>(elementType, abs); 452 Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); 453 Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); 454 Value resultImag = b.create<math::Atan2Op>(elementType, imag, real); 455 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 456 resultImag); 457 return success(); 458 } 459 }; 460 461 struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> { 462 using OpConversionPattern<complex::Log1pOp>::OpConversionPattern; 463 464 LogicalResult 465 matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor, 466 ConversionPatternRewriter &rewriter) const override { 467 auto type = adaptor.getComplex().getType().cast<ComplexType>(); 468 auto elementType = type.getElementType().cast<FloatType>(); 469 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 470 471 Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); 472 Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); 473 Value one = b.create<arith::ConstantOp>(elementType, 474 b.getFloatAttr(elementType, 1)); 475 Value realPlusOne = b.create<arith::AddFOp>(real, one); 476 Value newComplex = b.create<complex::CreateOp>(type, realPlusOne, imag); 477 rewriter.replaceOpWithNewOp<complex::LogOp>(op, type, newComplex); 478 return success(); 479 } 480 }; 481 482 struct MulOpConversion : public OpConversionPattern<complex::MulOp> { 483 using OpConversionPattern<complex::MulOp>::OpConversionPattern; 484 485 LogicalResult 486 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor, 487 ConversionPatternRewriter &rewriter) const override { 488 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 489 auto type = adaptor.getLhs().getType().cast<ComplexType>(); 490 auto elementType = type.getElementType().cast<FloatType>(); 491 492 Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs()); 493 Value lhsRealAbs = b.create<math::AbsOp>(lhsReal); 494 Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs()); 495 Value lhsImagAbs = b.create<math::AbsOp>(lhsImag); 496 Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs()); 497 Value rhsRealAbs = b.create<math::AbsOp>(rhsReal); 498 Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs()); 499 Value rhsImagAbs = b.create<math::AbsOp>(rhsImag); 500 501 Value lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal); 502 Value lhsRealTimesRhsRealAbs = b.create<math::AbsOp>(lhsRealTimesRhsReal); 503 Value lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag); 504 Value lhsImagTimesRhsImagAbs = b.create<math::AbsOp>(lhsImagTimesRhsImag); 505 Value real = 506 b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag); 507 508 Value lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal); 509 Value lhsImagTimesRhsRealAbs = b.create<math::AbsOp>(lhsImagTimesRhsReal); 510 Value lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag); 511 Value lhsRealTimesRhsImagAbs = b.create<math::AbsOp>(lhsRealTimesRhsImag); 512 Value imag = 513 b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag); 514 515 // Handle cases where the "naive" calculation results in NaN values. 516 Value realIsNan = 517 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real); 518 Value imagIsNan = 519 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, imag, imag); 520 Value isNan = b.create<arith::AndIOp>(realIsNan, imagIsNan); 521 522 Value inf = b.create<arith::ConstantOp>( 523 elementType, 524 b.getFloatAttr(elementType, 525 APFloat::getInf(elementType.getFloatSemantics()))); 526 527 // Case 1. `lhsReal` or `lhsImag` are infinite. 528 Value lhsRealIsInf = 529 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsRealAbs, inf); 530 Value lhsImagIsInf = 531 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsImagAbs, inf); 532 Value lhsIsInf = b.create<arith::OrIOp>(lhsRealIsInf, lhsImagIsInf); 533 Value rhsRealIsNan = 534 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsReal, rhsReal); 535 Value rhsImagIsNan = 536 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsImag, rhsImag); 537 Value zero = 538 b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType)); 539 Value one = b.create<arith::ConstantOp>(elementType, 540 b.getFloatAttr(elementType, 1)); 541 Value lhsRealIsInfFloat = 542 b.create<arith::SelectOp>(lhsRealIsInf, one, zero); 543 lhsReal = b.create<arith::SelectOp>( 544 lhsIsInf, b.create<math::CopySignOp>(lhsRealIsInfFloat, lhsReal), 545 lhsReal); 546 Value lhsImagIsInfFloat = 547 b.create<arith::SelectOp>(lhsImagIsInf, one, zero); 548 lhsImag = b.create<arith::SelectOp>( 549 lhsIsInf, b.create<math::CopySignOp>(lhsImagIsInfFloat, lhsImag), 550 lhsImag); 551 Value lhsIsInfAndRhsRealIsNan = 552 b.create<arith::AndIOp>(lhsIsInf, rhsRealIsNan); 553 rhsReal = b.create<arith::SelectOp>( 554 lhsIsInfAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal), 555 rhsReal); 556 Value lhsIsInfAndRhsImagIsNan = 557 b.create<arith::AndIOp>(lhsIsInf, rhsImagIsNan); 558 rhsImag = b.create<arith::SelectOp>( 559 lhsIsInfAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag), 560 rhsImag); 561 562 // Case 2. `rhsReal` or `rhsImag` are infinite. 563 Value rhsRealIsInf = 564 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsRealAbs, inf); 565 Value rhsImagIsInf = 566 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsImagAbs, inf); 567 Value rhsIsInf = b.create<arith::OrIOp>(rhsRealIsInf, rhsImagIsInf); 568 Value lhsRealIsNan = 569 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsReal, lhsReal); 570 Value lhsImagIsNan = 571 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsImag, lhsImag); 572 Value rhsRealIsInfFloat = 573 b.create<arith::SelectOp>(rhsRealIsInf, one, zero); 574 rhsReal = b.create<arith::SelectOp>( 575 rhsIsInf, b.create<math::CopySignOp>(rhsRealIsInfFloat, rhsReal), 576 rhsReal); 577 Value rhsImagIsInfFloat = 578 b.create<arith::SelectOp>(rhsImagIsInf, one, zero); 579 rhsImag = b.create<arith::SelectOp>( 580 rhsIsInf, b.create<math::CopySignOp>(rhsImagIsInfFloat, rhsImag), 581 rhsImag); 582 Value rhsIsInfAndLhsRealIsNan = 583 b.create<arith::AndIOp>(rhsIsInf, lhsRealIsNan); 584 lhsReal = b.create<arith::SelectOp>( 585 rhsIsInfAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal), 586 lhsReal); 587 Value rhsIsInfAndLhsImagIsNan = 588 b.create<arith::AndIOp>(rhsIsInf, lhsImagIsNan); 589 lhsImag = b.create<arith::SelectOp>( 590 rhsIsInfAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag), 591 lhsImag); 592 Value recalc = b.create<arith::OrIOp>(lhsIsInf, rhsIsInf); 593 594 // Case 3. One of the pairwise products of left hand side with right hand 595 // side is infinite. 596 Value lhsRealTimesRhsRealIsInf = b.create<arith::CmpFOp>( 597 arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf); 598 Value lhsImagTimesRhsImagIsInf = b.create<arith::CmpFOp>( 599 arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf); 600 Value isSpecialCase = b.create<arith::OrIOp>(lhsRealTimesRhsRealIsInf, 601 lhsImagTimesRhsImagIsInf); 602 Value lhsRealTimesRhsImagIsInf = b.create<arith::CmpFOp>( 603 arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf); 604 isSpecialCase = 605 b.create<arith::OrIOp>(isSpecialCase, lhsRealTimesRhsImagIsInf); 606 Value lhsImagTimesRhsRealIsInf = b.create<arith::CmpFOp>( 607 arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf); 608 isSpecialCase = 609 b.create<arith::OrIOp>(isSpecialCase, lhsImagTimesRhsRealIsInf); 610 Type i1Type = b.getI1Type(); 611 Value notRecalc = b.create<arith::XOrIOp>( 612 recalc, 613 b.create<arith::ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1))); 614 isSpecialCase = b.create<arith::AndIOp>(isSpecialCase, notRecalc); 615 Value isSpecialCaseAndLhsRealIsNan = 616 b.create<arith::AndIOp>(isSpecialCase, lhsRealIsNan); 617 lhsReal = b.create<arith::SelectOp>( 618 isSpecialCaseAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal), 619 lhsReal); 620 Value isSpecialCaseAndLhsImagIsNan = 621 b.create<arith::AndIOp>(isSpecialCase, lhsImagIsNan); 622 lhsImag = b.create<arith::SelectOp>( 623 isSpecialCaseAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag), 624 lhsImag); 625 Value isSpecialCaseAndRhsRealIsNan = 626 b.create<arith::AndIOp>(isSpecialCase, rhsRealIsNan); 627 rhsReal = b.create<arith::SelectOp>( 628 isSpecialCaseAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal), 629 rhsReal); 630 Value isSpecialCaseAndRhsImagIsNan = 631 b.create<arith::AndIOp>(isSpecialCase, rhsImagIsNan); 632 rhsImag = b.create<arith::SelectOp>( 633 isSpecialCaseAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag), 634 rhsImag); 635 recalc = b.create<arith::OrIOp>(recalc, isSpecialCase); 636 recalc = b.create<arith::AndIOp>(isNan, recalc); 637 638 // Recalculate real part. 639 lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal); 640 lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag); 641 Value newReal = 642 b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag); 643 real = b.create<arith::SelectOp>( 644 recalc, b.create<arith::MulFOp>(inf, newReal), real); 645 646 // Recalculate imag part. 647 lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal); 648 lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag); 649 Value newImag = 650 b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag); 651 imag = b.create<arith::SelectOp>( 652 recalc, b.create<arith::MulFOp>(inf, newImag), imag); 653 654 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag); 655 return success(); 656 } 657 }; 658 659 struct NegOpConversion : public OpConversionPattern<complex::NegOp> { 660 using OpConversionPattern<complex::NegOp>::OpConversionPattern; 661 662 LogicalResult 663 matchAndRewrite(complex::NegOp op, OpAdaptor adaptor, 664 ConversionPatternRewriter &rewriter) const override { 665 auto loc = op.getLoc(); 666 auto type = adaptor.getComplex().getType().cast<ComplexType>(); 667 auto elementType = type.getElementType().cast<FloatType>(); 668 669 Value real = 670 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); 671 Value imag = 672 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); 673 Value negReal = rewriter.create<arith::NegFOp>(loc, real); 674 Value negImag = rewriter.create<arith::NegFOp>(loc, imag); 675 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag); 676 return success(); 677 } 678 }; 679 680 struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> { 681 using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion; 682 683 std::pair<Value, Value> 684 combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin, 685 Value cos, ConversionPatternRewriter &rewriter) const override { 686 // Complex sine is defined as; 687 // sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy))) 688 // Plugging in: 689 // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x)) 690 // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x))) 691 // and defining t := exp(y) 692 // We get: 693 // Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x 694 // Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x 695 Value sum = rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp); 696 Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin); 697 Value diff = rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp); 698 Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos); 699 return {resultReal, resultImag}; 700 } 701 }; 702 703 struct SignOpConversion : public OpConversionPattern<complex::SignOp> { 704 using OpConversionPattern<complex::SignOp>::OpConversionPattern; 705 706 LogicalResult 707 matchAndRewrite(complex::SignOp op, OpAdaptor adaptor, 708 ConversionPatternRewriter &rewriter) const override { 709 auto type = adaptor.getComplex().getType().cast<ComplexType>(); 710 auto elementType = type.getElementType().cast<FloatType>(); 711 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 712 713 Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); 714 Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); 715 Value zero = 716 b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType)); 717 Value realIsZero = 718 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero); 719 Value imagIsZero = 720 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero); 721 Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero); 722 auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex()); 723 Value realSign = b.create<arith::DivFOp>(real, abs); 724 Value imagSign = b.create<arith::DivFOp>(imag, abs); 725 Value sign = b.create<complex::CreateOp>(type, realSign, imagSign); 726 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero, 727 adaptor.getComplex(), sign); 728 return success(); 729 } 730 }; 731 732 struct TanOpConversion : public OpConversionPattern<complex::TanOp> { 733 using OpConversionPattern<complex::TanOp>::OpConversionPattern; 734 735 LogicalResult 736 matchAndRewrite(complex::TanOp op, OpAdaptor adaptor, 737 ConversionPatternRewriter &rewriter) const override { 738 auto loc = op.getLoc(); 739 Value cos = rewriter.create<complex::CosOp>(loc, adaptor.getComplex()); 740 Value sin = rewriter.create<complex::SinOp>(loc, adaptor.getComplex()); 741 rewriter.replaceOpWithNewOp<complex::DivOp>(op, sin, cos); 742 return success(); 743 } 744 }; 745 746 struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> { 747 using OpConversionPattern<complex::TanhOp>::OpConversionPattern; 748 749 LogicalResult 750 matchAndRewrite(complex::TanhOp op, OpAdaptor adaptor, 751 ConversionPatternRewriter &rewriter) const override { 752 auto loc = op.getLoc(); 753 auto type = adaptor.getComplex().getType().cast<ComplexType>(); 754 auto elementType = type.getElementType().cast<FloatType>(); 755 756 // The hyperbolic tangent for complex number can be calculated as follows. 757 // tanh(x + i * y) = (tanh(x) + i * tan(y)) / (1 + tanh(x) * tan(y)) 758 // See: https://proofwiki.org/wiki/Hyperbolic_Tangent_of_Complex_Number 759 Value real = 760 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); 761 Value imag = 762 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); 763 Value tanhA = rewriter.create<math::TanhOp>(loc, real); 764 Value cosB = rewriter.create<math::CosOp>(loc, imag); 765 Value sinB = rewriter.create<math::SinOp>(loc, imag); 766 Value tanB = rewriter.create<arith::DivFOp>(loc, sinB, cosB); 767 Value numerator = 768 rewriter.create<complex::CreateOp>(loc, type, tanhA, tanB); 769 Value one = rewriter.create<arith::ConstantOp>( 770 loc, elementType, rewriter.getFloatAttr(elementType, 1)); 771 Value mul = rewriter.create<arith::MulFOp>(loc, tanhA, tanB); 772 Value denominator = rewriter.create<complex::CreateOp>(loc, type, one, mul); 773 rewriter.replaceOpWithNewOp<complex::DivOp>(op, numerator, denominator); 774 return success(); 775 } 776 }; 777 778 } // namespace 779 780 void mlir::populateComplexToStandardConversionPatterns( 781 RewritePatternSet &patterns) { 782 // clang-format off 783 patterns.add< 784 AbsOpConversion, 785 ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>, 786 ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>, 787 BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>, 788 BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>, 789 CosOpConversion, 790 DivOpConversion, 791 ExpOpConversion, 792 Expm1OpConversion, 793 LogOpConversion, 794 Log1pOpConversion, 795 MulOpConversion, 796 NegOpConversion, 797 SignOpConversion, 798 SinOpConversion, 799 TanOpConversion, 800 TanhOpConversion>(patterns.getContext()); 801 // clang-format on 802 } 803 804 namespace { 805 struct ConvertComplexToStandardPass 806 : public ConvertComplexToStandardBase<ConvertComplexToStandardPass> { 807 void runOnOperation() override; 808 }; 809 810 void ConvertComplexToStandardPass::runOnOperation() { 811 // Convert to the Standard dialect using the converter defined above. 812 RewritePatternSet patterns(&getContext()); 813 populateComplexToStandardConversionPatterns(patterns); 814 815 ConversionTarget target(getContext()); 816 target.addLegalDialect<arith::ArithmeticDialect, math::MathDialect>(); 817 target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>(); 818 if (failed( 819 applyPartialConversion(getOperation(), target, std::move(patterns)))) 820 signalPassFailure(); 821 } 822 } // namespace 823 824 std::unique_ptr<Pass> mlir::createConvertComplexToStandardPass() { 825 return std::make_unique<ConvertComplexToStandardPass>(); 826 } 827