1 //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" 10 11 #include <memory> 12 #include <type_traits> 13 14 #include "../PassDetail.h" 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/Complex/IR/Complex.h" 17 #include "mlir/Dialect/Math/IR/Math.h" 18 #include "mlir/IR/ImplicitLocOpBuilder.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 22 using namespace mlir; 23 24 namespace { 25 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> { 26 using OpConversionPattern<complex::AbsOp>::OpConversionPattern; 27 28 LogicalResult 29 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor, 30 ConversionPatternRewriter &rewriter) const override { 31 auto loc = op.getLoc(); 32 auto type = op.getType(); 33 34 Value real = 35 rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex()); 36 Value imag = 37 rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex()); 38 Value realSqr = rewriter.create<arith::MulFOp>(loc, real, real); 39 Value imagSqr = rewriter.create<arith::MulFOp>(loc, imag, imag); 40 Value sqNorm = rewriter.create<arith::AddFOp>(loc, realSqr, imagSqr); 41 42 rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm); 43 return success(); 44 } 45 }; 46 47 template <typename ComparisonOp, arith::CmpFPredicate p> 48 struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> { 49 using OpConversionPattern<ComparisonOp>::OpConversionPattern; 50 using ResultCombiner = 51 std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value, 52 arith::AndIOp, arith::OrIOp>; 53 54 LogicalResult 55 matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor, 56 ConversionPatternRewriter &rewriter) const override { 57 auto loc = op.getLoc(); 58 auto type = adaptor.getLhs() 59 .getType() 60 .template cast<ComplexType>() 61 .getElementType(); 62 63 Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs()); 64 Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs()); 65 Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getRhs()); 66 Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getRhs()); 67 Value realComparison = 68 rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs); 69 Value imagComparison = 70 rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs); 71 72 rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison, 73 imagComparison); 74 return success(); 75 } 76 }; 77 78 // Default conversion which applies the BinaryStandardOp separately on the real 79 // and imaginary parts. Can for example be used for complex::AddOp and 80 // complex::SubOp. 81 template <typename BinaryComplexOp, typename BinaryStandardOp> 82 struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> { 83 using OpConversionPattern<BinaryComplexOp>::OpConversionPattern; 84 85 LogicalResult 86 matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor, 87 ConversionPatternRewriter &rewriter) const override { 88 auto type = adaptor.getLhs().getType().template cast<ComplexType>(); 89 auto elementType = type.getElementType().template cast<FloatType>(); 90 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 91 92 Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs()); 93 Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs()); 94 Value resultReal = 95 b.create<BinaryStandardOp>(elementType, realLhs, realRhs); 96 Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs()); 97 Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs()); 98 Value resultImag = 99 b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs); 100 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 101 resultImag); 102 return success(); 103 } 104 }; 105 106 struct DivOpConversion : public OpConversionPattern<complex::DivOp> { 107 using OpConversionPattern<complex::DivOp>::OpConversionPattern; 108 109 LogicalResult 110 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor, 111 ConversionPatternRewriter &rewriter) const override { 112 auto loc = op.getLoc(); 113 auto type = adaptor.getLhs().getType().cast<ComplexType>(); 114 auto elementType = type.getElementType().cast<FloatType>(); 115 116 Value lhsReal = 117 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs()); 118 Value lhsImag = 119 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getLhs()); 120 Value rhsReal = 121 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getRhs()); 122 Value rhsImag = 123 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getRhs()); 124 125 // Smith's algorithm to divide complex numbers. It is just a bit smarter 126 // way to compute the following formula: 127 // (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i) 128 // = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) / 129 // ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i)) 130 // = ((lhsReal * rhsReal + lhsImag * rhsImag) + 131 // (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2 132 // 133 // Depending on whether |rhsReal| < |rhsImag| we compute either 134 // rhsRealImagRatio = rhsReal / rhsImag 135 // rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio 136 // resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom 137 // resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom 138 // 139 // or 140 // 141 // rhsImagRealRatio = rhsImag / rhsReal 142 // rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio 143 // resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom 144 // resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom 145 // 146 // See https://dl.acm.org/citation.cfm?id=368661 for more details. 147 Value rhsRealImagRatio = 148 rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag); 149 Value rhsRealImagDenom = rewriter.create<arith::AddFOp>( 150 loc, rhsImag, 151 rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal)); 152 Value realNumerator1 = rewriter.create<arith::AddFOp>( 153 loc, rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio), 154 lhsImag); 155 Value resultReal1 = 156 rewriter.create<arith::DivFOp>(loc, realNumerator1, rhsRealImagDenom); 157 Value imagNumerator1 = rewriter.create<arith::SubFOp>( 158 loc, rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio), 159 lhsReal); 160 Value resultImag1 = 161 rewriter.create<arith::DivFOp>(loc, imagNumerator1, rhsRealImagDenom); 162 163 Value rhsImagRealRatio = 164 rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal); 165 Value rhsImagRealDenom = rewriter.create<arith::AddFOp>( 166 loc, rhsReal, 167 rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag)); 168 Value realNumerator2 = rewriter.create<arith::AddFOp>( 169 loc, lhsReal, 170 rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio)); 171 Value resultReal2 = 172 rewriter.create<arith::DivFOp>(loc, realNumerator2, rhsImagRealDenom); 173 Value imagNumerator2 = rewriter.create<arith::SubFOp>( 174 loc, lhsImag, 175 rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio)); 176 Value resultImag2 = 177 rewriter.create<arith::DivFOp>(loc, imagNumerator2, rhsImagRealDenom); 178 179 // Consider corner cases. 180 // Case 1. Zero denominator, numerator contains at most one NaN value. 181 Value zero = rewriter.create<arith::ConstantOp>( 182 loc, elementType, rewriter.getZeroAttr(elementType)); 183 Value rhsRealAbs = rewriter.create<math::AbsOp>(loc, rhsReal); 184 Value rhsRealIsZero = rewriter.create<arith::CmpFOp>( 185 loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero); 186 Value rhsImagAbs = rewriter.create<math::AbsOp>(loc, rhsImag); 187 Value rhsImagIsZero = rewriter.create<arith::CmpFOp>( 188 loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero); 189 Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>( 190 loc, arith::CmpFPredicate::ORD, lhsReal, zero); 191 Value lhsImagIsNotNaN = rewriter.create<arith::CmpFOp>( 192 loc, arith::CmpFPredicate::ORD, lhsImag, zero); 193 Value lhsContainsNotNaNValue = 194 rewriter.create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN); 195 Value resultIsInfinity = rewriter.create<arith::AndIOp>( 196 loc, lhsContainsNotNaNValue, 197 rewriter.create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero)); 198 Value inf = rewriter.create<arith::ConstantOp>( 199 loc, elementType, 200 rewriter.getFloatAttr( 201 elementType, APFloat::getInf(elementType.getFloatSemantics()))); 202 Value infWithSignOfRhsReal = 203 rewriter.create<math::CopySignOp>(loc, inf, rhsReal); 204 Value infinityResultReal = 205 rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal); 206 Value infinityResultImag = 207 rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag); 208 209 // Case 2. Infinite numerator, finite denominator. 210 Value rhsRealFinite = rewriter.create<arith::CmpFOp>( 211 loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf); 212 Value rhsImagFinite = rewriter.create<arith::CmpFOp>( 213 loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf); 214 Value rhsFinite = 215 rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite); 216 Value lhsRealAbs = rewriter.create<math::AbsOp>(loc, lhsReal); 217 Value lhsRealInfinite = rewriter.create<arith::CmpFOp>( 218 loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf); 219 Value lhsImagAbs = rewriter.create<math::AbsOp>(loc, lhsImag); 220 Value lhsImagInfinite = rewriter.create<arith::CmpFOp>( 221 loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf); 222 Value lhsInfinite = 223 rewriter.create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite); 224 Value infNumFiniteDenom = 225 rewriter.create<arith::AndIOp>(loc, lhsInfinite, rhsFinite); 226 Value one = rewriter.create<arith::ConstantOp>( 227 loc, elementType, rewriter.getFloatAttr(elementType, 1)); 228 Value lhsRealIsInfWithSign = rewriter.create<math::CopySignOp>( 229 loc, rewriter.create<arith::SelectOp>(loc, lhsRealInfinite, one, zero), 230 lhsReal); 231 Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>( 232 loc, rewriter.create<arith::SelectOp>(loc, lhsImagInfinite, one, zero), 233 lhsImag); 234 Value lhsRealIsInfWithSignTimesRhsReal = 235 rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal); 236 Value lhsImagIsInfWithSignTimesRhsImag = 237 rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag); 238 Value resultReal3 = rewriter.create<arith::MulFOp>( 239 loc, inf, 240 rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal, 241 lhsImagIsInfWithSignTimesRhsImag)); 242 Value lhsRealIsInfWithSignTimesRhsImag = 243 rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag); 244 Value lhsImagIsInfWithSignTimesRhsReal = 245 rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal); 246 Value resultImag3 = rewriter.create<arith::MulFOp>( 247 loc, inf, 248 rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal, 249 lhsRealIsInfWithSignTimesRhsImag)); 250 251 // Case 3: Finite numerator, infinite denominator. 252 Value lhsRealFinite = rewriter.create<arith::CmpFOp>( 253 loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf); 254 Value lhsImagFinite = rewriter.create<arith::CmpFOp>( 255 loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf); 256 Value lhsFinite = 257 rewriter.create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite); 258 Value rhsRealInfinite = rewriter.create<arith::CmpFOp>( 259 loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf); 260 Value rhsImagInfinite = rewriter.create<arith::CmpFOp>( 261 loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf); 262 Value rhsInfinite = 263 rewriter.create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite); 264 Value finiteNumInfiniteDenom = 265 rewriter.create<arith::AndIOp>(loc, lhsFinite, rhsInfinite); 266 Value rhsRealIsInfWithSign = rewriter.create<math::CopySignOp>( 267 loc, rewriter.create<arith::SelectOp>(loc, rhsRealInfinite, one, zero), 268 rhsReal); 269 Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>( 270 loc, rewriter.create<arith::SelectOp>(loc, rhsImagInfinite, one, zero), 271 rhsImag); 272 Value rhsRealIsInfWithSignTimesLhsReal = 273 rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign); 274 Value rhsImagIsInfWithSignTimesLhsImag = 275 rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign); 276 Value resultReal4 = rewriter.create<arith::MulFOp>( 277 loc, zero, 278 rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal, 279 rhsImagIsInfWithSignTimesLhsImag)); 280 Value rhsRealIsInfWithSignTimesLhsImag = 281 rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign); 282 Value rhsImagIsInfWithSignTimesLhsReal = 283 rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign); 284 Value resultImag4 = rewriter.create<arith::MulFOp>( 285 loc, zero, 286 rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag, 287 rhsImagIsInfWithSignTimesLhsReal)); 288 289 Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>( 290 loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs); 291 Value resultReal = rewriter.create<arith::SelectOp>( 292 loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2); 293 Value resultImag = rewriter.create<arith::SelectOp>( 294 loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2); 295 Value resultRealSpecialCase3 = rewriter.create<arith::SelectOp>( 296 loc, finiteNumInfiniteDenom, resultReal4, resultReal); 297 Value resultImagSpecialCase3 = rewriter.create<arith::SelectOp>( 298 loc, finiteNumInfiniteDenom, resultImag4, resultImag); 299 Value resultRealSpecialCase2 = rewriter.create<arith::SelectOp>( 300 loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3); 301 Value resultImagSpecialCase2 = rewriter.create<arith::SelectOp>( 302 loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3); 303 Value resultRealSpecialCase1 = rewriter.create<arith::SelectOp>( 304 loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2); 305 Value resultImagSpecialCase1 = rewriter.create<arith::SelectOp>( 306 loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2); 307 308 Value resultRealIsNaN = rewriter.create<arith::CmpFOp>( 309 loc, arith::CmpFPredicate::UNO, resultReal, zero); 310 Value resultImagIsNaN = rewriter.create<arith::CmpFOp>( 311 loc, arith::CmpFPredicate::UNO, resultImag, zero); 312 Value resultIsNaN = 313 rewriter.create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN); 314 Value resultRealWithSpecialCases = rewriter.create<arith::SelectOp>( 315 loc, resultIsNaN, resultRealSpecialCase1, resultReal); 316 Value resultImagWithSpecialCases = rewriter.create<arith::SelectOp>( 317 loc, resultIsNaN, resultImagSpecialCase1, resultImag); 318 319 rewriter.replaceOpWithNewOp<complex::CreateOp>( 320 op, type, resultRealWithSpecialCases, resultImagWithSpecialCases); 321 return success(); 322 } 323 }; 324 325 struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> { 326 using OpConversionPattern<complex::ExpOp>::OpConversionPattern; 327 328 LogicalResult 329 matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor, 330 ConversionPatternRewriter &rewriter) const override { 331 auto loc = op.getLoc(); 332 auto type = adaptor.getComplex().getType().cast<ComplexType>(); 333 auto elementType = type.getElementType().cast<FloatType>(); 334 335 Value real = 336 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); 337 Value imag = 338 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); 339 Value expReal = rewriter.create<math::ExpOp>(loc, real); 340 Value cosImag = rewriter.create<math::CosOp>(loc, imag); 341 Value resultReal = rewriter.create<arith::MulFOp>(loc, expReal, cosImag); 342 Value sinImag = rewriter.create<math::SinOp>(loc, imag); 343 Value resultImag = rewriter.create<arith::MulFOp>(loc, expReal, sinImag); 344 345 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 346 resultImag); 347 return success(); 348 } 349 }; 350 351 struct LogOpConversion : public OpConversionPattern<complex::LogOp> { 352 using OpConversionPattern<complex::LogOp>::OpConversionPattern; 353 354 LogicalResult 355 matchAndRewrite(complex::LogOp op, OpAdaptor adaptor, 356 ConversionPatternRewriter &rewriter) const override { 357 auto type = adaptor.getComplex().getType().cast<ComplexType>(); 358 auto elementType = type.getElementType().cast<FloatType>(); 359 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 360 361 Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex()); 362 Value resultReal = b.create<math::LogOp>(elementType, abs); 363 Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); 364 Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); 365 Value resultImag = b.create<math::Atan2Op>(elementType, imag, real); 366 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 367 resultImag); 368 return success(); 369 } 370 }; 371 372 struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> { 373 using OpConversionPattern<complex::Log1pOp>::OpConversionPattern; 374 375 LogicalResult 376 matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor, 377 ConversionPatternRewriter &rewriter) const override { 378 auto type = adaptor.getComplex().getType().cast<ComplexType>(); 379 auto elementType = type.getElementType().cast<FloatType>(); 380 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 381 382 Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); 383 Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); 384 Value one = b.create<arith::ConstantOp>(elementType, 385 b.getFloatAttr(elementType, 1)); 386 Value realPlusOne = b.create<arith::AddFOp>(real, one); 387 Value newComplex = b.create<complex::CreateOp>(type, realPlusOne, imag); 388 rewriter.replaceOpWithNewOp<complex::LogOp>(op, type, newComplex); 389 return success(); 390 } 391 }; 392 393 struct MulOpConversion : public OpConversionPattern<complex::MulOp> { 394 using OpConversionPattern<complex::MulOp>::OpConversionPattern; 395 396 LogicalResult 397 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor, 398 ConversionPatternRewriter &rewriter) const override { 399 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 400 auto type = adaptor.getLhs().getType().cast<ComplexType>(); 401 auto elementType = type.getElementType().cast<FloatType>(); 402 403 Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs()); 404 Value lhsRealAbs = b.create<math::AbsOp>(lhsReal); 405 Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs()); 406 Value lhsImagAbs = b.create<math::AbsOp>(lhsImag); 407 Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs()); 408 Value rhsRealAbs = b.create<math::AbsOp>(rhsReal); 409 Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs()); 410 Value rhsImagAbs = b.create<math::AbsOp>(rhsImag); 411 412 Value lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal); 413 Value lhsRealTimesRhsRealAbs = b.create<math::AbsOp>(lhsRealTimesRhsReal); 414 Value lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag); 415 Value lhsImagTimesRhsImagAbs = b.create<math::AbsOp>(lhsImagTimesRhsImag); 416 Value real = 417 b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag); 418 419 Value lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal); 420 Value lhsImagTimesRhsRealAbs = b.create<math::AbsOp>(lhsImagTimesRhsReal); 421 Value lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag); 422 Value lhsRealTimesRhsImagAbs = b.create<math::AbsOp>(lhsRealTimesRhsImag); 423 Value imag = 424 b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag); 425 426 // Handle cases where the "naive" calculation results in NaN values. 427 Value realIsNan = 428 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real); 429 Value imagIsNan = 430 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, imag, imag); 431 Value isNan = b.create<arith::AndIOp>(realIsNan, imagIsNan); 432 433 Value inf = b.create<arith::ConstantOp>( 434 elementType, 435 b.getFloatAttr(elementType, 436 APFloat::getInf(elementType.getFloatSemantics()))); 437 438 // Case 1. `lhsReal` or `lhsImag` are infinite. 439 Value lhsRealIsInf = 440 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsRealAbs, inf); 441 Value lhsImagIsInf = 442 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsImagAbs, inf); 443 Value lhsIsInf = b.create<arith::OrIOp>(lhsRealIsInf, lhsImagIsInf); 444 Value rhsRealIsNan = 445 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsReal, rhsReal); 446 Value rhsImagIsNan = 447 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsImag, rhsImag); 448 Value zero = 449 b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType)); 450 Value one = b.create<arith::ConstantOp>(elementType, 451 b.getFloatAttr(elementType, 1)); 452 Value lhsRealIsInfFloat = 453 b.create<arith::SelectOp>(lhsRealIsInf, one, zero); 454 lhsReal = b.create<arith::SelectOp>( 455 lhsIsInf, b.create<math::CopySignOp>(lhsRealIsInfFloat, lhsReal), 456 lhsReal); 457 Value lhsImagIsInfFloat = 458 b.create<arith::SelectOp>(lhsImagIsInf, one, zero); 459 lhsImag = b.create<arith::SelectOp>( 460 lhsIsInf, b.create<math::CopySignOp>(lhsImagIsInfFloat, lhsImag), 461 lhsImag); 462 Value lhsIsInfAndRhsRealIsNan = 463 b.create<arith::AndIOp>(lhsIsInf, rhsRealIsNan); 464 rhsReal = b.create<arith::SelectOp>( 465 lhsIsInfAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal), 466 rhsReal); 467 Value lhsIsInfAndRhsImagIsNan = 468 b.create<arith::AndIOp>(lhsIsInf, rhsImagIsNan); 469 rhsImag = b.create<arith::SelectOp>( 470 lhsIsInfAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag), 471 rhsImag); 472 473 // Case 2. `rhsReal` or `rhsImag` are infinite. 474 Value rhsRealIsInf = 475 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsRealAbs, inf); 476 Value rhsImagIsInf = 477 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsImagAbs, inf); 478 Value rhsIsInf = b.create<arith::OrIOp>(rhsRealIsInf, rhsImagIsInf); 479 Value lhsRealIsNan = 480 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsReal, lhsReal); 481 Value lhsImagIsNan = 482 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsImag, lhsImag); 483 Value rhsRealIsInfFloat = 484 b.create<arith::SelectOp>(rhsRealIsInf, one, zero); 485 rhsReal = b.create<arith::SelectOp>( 486 rhsIsInf, b.create<math::CopySignOp>(rhsRealIsInfFloat, rhsReal), 487 rhsReal); 488 Value rhsImagIsInfFloat = 489 b.create<arith::SelectOp>(rhsImagIsInf, one, zero); 490 rhsImag = b.create<arith::SelectOp>( 491 rhsIsInf, b.create<math::CopySignOp>(rhsImagIsInfFloat, rhsImag), 492 rhsImag); 493 Value rhsIsInfAndLhsRealIsNan = 494 b.create<arith::AndIOp>(rhsIsInf, lhsRealIsNan); 495 lhsReal = b.create<arith::SelectOp>( 496 rhsIsInfAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal), 497 lhsReal); 498 Value rhsIsInfAndLhsImagIsNan = 499 b.create<arith::AndIOp>(rhsIsInf, lhsImagIsNan); 500 lhsImag = b.create<arith::SelectOp>( 501 rhsIsInfAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag), 502 lhsImag); 503 Value recalc = b.create<arith::OrIOp>(lhsIsInf, rhsIsInf); 504 505 // Case 3. One of the pairwise products of left hand side with right hand 506 // side is infinite. 507 Value lhsRealTimesRhsRealIsInf = b.create<arith::CmpFOp>( 508 arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf); 509 Value lhsImagTimesRhsImagIsInf = b.create<arith::CmpFOp>( 510 arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf); 511 Value isSpecialCase = b.create<arith::OrIOp>(lhsRealTimesRhsRealIsInf, 512 lhsImagTimesRhsImagIsInf); 513 Value lhsRealTimesRhsImagIsInf = b.create<arith::CmpFOp>( 514 arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf); 515 isSpecialCase = 516 b.create<arith::OrIOp>(isSpecialCase, lhsRealTimesRhsImagIsInf); 517 Value lhsImagTimesRhsRealIsInf = b.create<arith::CmpFOp>( 518 arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf); 519 isSpecialCase = 520 b.create<arith::OrIOp>(isSpecialCase, lhsImagTimesRhsRealIsInf); 521 Type i1Type = b.getI1Type(); 522 Value notRecalc = b.create<arith::XOrIOp>( 523 recalc, 524 b.create<arith::ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1))); 525 isSpecialCase = b.create<arith::AndIOp>(isSpecialCase, notRecalc); 526 Value isSpecialCaseAndLhsRealIsNan = 527 b.create<arith::AndIOp>(isSpecialCase, lhsRealIsNan); 528 lhsReal = b.create<arith::SelectOp>( 529 isSpecialCaseAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal), 530 lhsReal); 531 Value isSpecialCaseAndLhsImagIsNan = 532 b.create<arith::AndIOp>(isSpecialCase, lhsImagIsNan); 533 lhsImag = b.create<arith::SelectOp>( 534 isSpecialCaseAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag), 535 lhsImag); 536 Value isSpecialCaseAndRhsRealIsNan = 537 b.create<arith::AndIOp>(isSpecialCase, rhsRealIsNan); 538 rhsReal = b.create<arith::SelectOp>( 539 isSpecialCaseAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal), 540 rhsReal); 541 Value isSpecialCaseAndRhsImagIsNan = 542 b.create<arith::AndIOp>(isSpecialCase, rhsImagIsNan); 543 rhsImag = b.create<arith::SelectOp>( 544 isSpecialCaseAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag), 545 rhsImag); 546 recalc = b.create<arith::OrIOp>(recalc, isSpecialCase); 547 recalc = b.create<arith::AndIOp>(isNan, recalc); 548 549 // Recalculate real part. 550 lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal); 551 lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag); 552 Value newReal = 553 b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag); 554 real = b.create<arith::SelectOp>( 555 recalc, b.create<arith::MulFOp>(inf, newReal), real); 556 557 // Recalculate imag part. 558 lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal); 559 lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag); 560 Value newImag = 561 b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag); 562 imag = b.create<arith::SelectOp>( 563 recalc, b.create<arith::MulFOp>(inf, newImag), imag); 564 565 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag); 566 return success(); 567 } 568 }; 569 570 struct NegOpConversion : public OpConversionPattern<complex::NegOp> { 571 using OpConversionPattern<complex::NegOp>::OpConversionPattern; 572 573 LogicalResult 574 matchAndRewrite(complex::NegOp op, OpAdaptor adaptor, 575 ConversionPatternRewriter &rewriter) const override { 576 auto loc = op.getLoc(); 577 auto type = adaptor.getComplex().getType().cast<ComplexType>(); 578 auto elementType = type.getElementType().cast<FloatType>(); 579 580 Value real = 581 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); 582 Value imag = 583 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); 584 Value negReal = rewriter.create<arith::NegFOp>(loc, real); 585 Value negImag = rewriter.create<arith::NegFOp>(loc, imag); 586 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag); 587 return success(); 588 } 589 }; 590 591 struct SignOpConversion : public OpConversionPattern<complex::SignOp> { 592 using OpConversionPattern<complex::SignOp>::OpConversionPattern; 593 594 LogicalResult 595 matchAndRewrite(complex::SignOp op, OpAdaptor adaptor, 596 ConversionPatternRewriter &rewriter) const override { 597 auto type = adaptor.getComplex().getType().cast<ComplexType>(); 598 auto elementType = type.getElementType().cast<FloatType>(); 599 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 600 601 Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); 602 Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); 603 Value zero = 604 b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType)); 605 Value realIsZero = 606 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero); 607 Value imagIsZero = 608 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero); 609 Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero); 610 auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex()); 611 Value realSign = b.create<arith::DivFOp>(real, abs); 612 Value imagSign = b.create<arith::DivFOp>(imag, abs); 613 Value sign = b.create<complex::CreateOp>(type, realSign, imagSign); 614 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero, 615 adaptor.getComplex(), sign); 616 return success(); 617 } 618 }; 619 } // namespace 620 621 void mlir::populateComplexToStandardConversionPatterns( 622 RewritePatternSet &patterns) { 623 // clang-format off 624 patterns.add< 625 AbsOpConversion, 626 ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>, 627 ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>, 628 BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>, 629 BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>, 630 DivOpConversion, 631 ExpOpConversion, 632 LogOpConversion, 633 Log1pOpConversion, 634 MulOpConversion, 635 NegOpConversion, 636 SignOpConversion>(patterns.getContext()); 637 // clang-format on 638 } 639 640 namespace { 641 struct ConvertComplexToStandardPass 642 : public ConvertComplexToStandardBase<ConvertComplexToStandardPass> { 643 void runOnOperation() override; 644 }; 645 646 void ConvertComplexToStandardPass::runOnOperation() { 647 // Convert to the Standard dialect using the converter defined above. 648 RewritePatternSet patterns(&getContext()); 649 populateComplexToStandardConversionPatterns(patterns); 650 651 ConversionTarget target(getContext()); 652 target.addLegalDialect<arith::ArithmeticDialect, math::MathDialect>(); 653 target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>(); 654 if (failed( 655 applyPartialConversion(getOperation(), target, std::move(patterns)))) 656 signalPassFailure(); 657 } 658 } // namespace 659 660 std::unique_ptr<Pass> mlir::createConvertComplexToStandardPass() { 661 return std::make_unique<ConvertComplexToStandardPass>(); 662 } 663