1 //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" 10 11 #include <memory> 12 #include <type_traits> 13 14 #include "../PassDetail.h" 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/Complex/IR/Complex.h" 17 #include "mlir/Dialect/Math/IR/Math.h" 18 #include "mlir/Dialect/StandardOps/IR/Ops.h" 19 #include "mlir/IR/ImplicitLocOpBuilder.h" 20 #include "mlir/IR/PatternMatch.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 23 using namespace mlir; 24 25 namespace { 26 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> { 27 using OpConversionPattern<complex::AbsOp>::OpConversionPattern; 28 29 LogicalResult 30 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor, 31 ConversionPatternRewriter &rewriter) const override { 32 auto loc = op.getLoc(); 33 auto type = op.getType(); 34 35 Value real = 36 rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex()); 37 Value imag = 38 rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex()); 39 Value realSqr = rewriter.create<arith::MulFOp>(loc, real, real); 40 Value imagSqr = rewriter.create<arith::MulFOp>(loc, imag, imag); 41 Value sqNorm = rewriter.create<arith::AddFOp>(loc, realSqr, imagSqr); 42 43 rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm); 44 return success(); 45 } 46 }; 47 48 template <typename ComparisonOp, arith::CmpFPredicate p> 49 struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> { 50 using OpConversionPattern<ComparisonOp>::OpConversionPattern; 51 using ResultCombiner = 52 std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value, 53 arith::AndIOp, arith::OrIOp>; 54 55 LogicalResult 56 matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor, 57 ConversionPatternRewriter &rewriter) const override { 58 auto loc = op.getLoc(); 59 auto type = adaptor.getLhs() 60 .getType() 61 .template cast<ComplexType>() 62 .getElementType(); 63 64 Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs()); 65 Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs()); 66 Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getRhs()); 67 Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getRhs()); 68 Value realComparison = 69 rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs); 70 Value imagComparison = 71 rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs); 72 73 rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison, 74 imagComparison); 75 return success(); 76 } 77 }; 78 79 // Default conversion which applies the BinaryStandardOp separately on the real 80 // and imaginary parts. Can for example be used for complex::AddOp and 81 // complex::SubOp. 82 template <typename BinaryComplexOp, typename BinaryStandardOp> 83 struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> { 84 using OpConversionPattern<BinaryComplexOp>::OpConversionPattern; 85 86 LogicalResult 87 matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor, 88 ConversionPatternRewriter &rewriter) const override { 89 auto type = adaptor.getLhs().getType().template cast<ComplexType>(); 90 auto elementType = type.getElementType().template cast<FloatType>(); 91 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 92 93 Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs()); 94 Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs()); 95 Value resultReal = 96 b.create<BinaryStandardOp>(elementType, realLhs, realRhs); 97 Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs()); 98 Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs()); 99 Value resultImag = 100 b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs); 101 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 102 resultImag); 103 return success(); 104 } 105 }; 106 107 struct DivOpConversion : public OpConversionPattern<complex::DivOp> { 108 using OpConversionPattern<complex::DivOp>::OpConversionPattern; 109 110 LogicalResult 111 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor, 112 ConversionPatternRewriter &rewriter) const override { 113 auto loc = op.getLoc(); 114 auto type = adaptor.getLhs().getType().cast<ComplexType>(); 115 auto elementType = type.getElementType().cast<FloatType>(); 116 117 Value lhsReal = 118 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs()); 119 Value lhsImag = 120 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getLhs()); 121 Value rhsReal = 122 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getRhs()); 123 Value rhsImag = 124 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getRhs()); 125 126 // Smith's algorithm to divide complex numbers. It is just a bit smarter 127 // way to compute the following formula: 128 // (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i) 129 // = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) / 130 // ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i)) 131 // = ((lhsReal * rhsReal + lhsImag * rhsImag) + 132 // (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2 133 // 134 // Depending on whether |rhsReal| < |rhsImag| we compute either 135 // rhsRealImagRatio = rhsReal / rhsImag 136 // rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio 137 // resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom 138 // resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom 139 // 140 // or 141 // 142 // rhsImagRealRatio = rhsImag / rhsReal 143 // rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio 144 // resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom 145 // resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom 146 // 147 // See https://dl.acm.org/citation.cfm?id=368661 for more details. 148 Value rhsRealImagRatio = 149 rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag); 150 Value rhsRealImagDenom = rewriter.create<arith::AddFOp>( 151 loc, rhsImag, 152 rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal)); 153 Value realNumerator1 = rewriter.create<arith::AddFOp>( 154 loc, rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio), 155 lhsImag); 156 Value resultReal1 = 157 rewriter.create<arith::DivFOp>(loc, realNumerator1, rhsRealImagDenom); 158 Value imagNumerator1 = rewriter.create<arith::SubFOp>( 159 loc, rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio), 160 lhsReal); 161 Value resultImag1 = 162 rewriter.create<arith::DivFOp>(loc, imagNumerator1, rhsRealImagDenom); 163 164 Value rhsImagRealRatio = 165 rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal); 166 Value rhsImagRealDenom = rewriter.create<arith::AddFOp>( 167 loc, rhsReal, 168 rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag)); 169 Value realNumerator2 = rewriter.create<arith::AddFOp>( 170 loc, lhsReal, 171 rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio)); 172 Value resultReal2 = 173 rewriter.create<arith::DivFOp>(loc, realNumerator2, rhsImagRealDenom); 174 Value imagNumerator2 = rewriter.create<arith::SubFOp>( 175 loc, lhsImag, 176 rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio)); 177 Value resultImag2 = 178 rewriter.create<arith::DivFOp>(loc, imagNumerator2, rhsImagRealDenom); 179 180 // Consider corner cases. 181 // Case 1. Zero denominator, numerator contains at most one NaN value. 182 Value zero = rewriter.create<arith::ConstantOp>( 183 loc, elementType, rewriter.getZeroAttr(elementType)); 184 Value rhsRealAbs = rewriter.create<math::AbsOp>(loc, rhsReal); 185 Value rhsRealIsZero = rewriter.create<arith::CmpFOp>( 186 loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero); 187 Value rhsImagAbs = rewriter.create<math::AbsOp>(loc, rhsImag); 188 Value rhsImagIsZero = rewriter.create<arith::CmpFOp>( 189 loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero); 190 Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>( 191 loc, arith::CmpFPredicate::ORD, lhsReal, zero); 192 Value lhsImagIsNotNaN = rewriter.create<arith::CmpFOp>( 193 loc, arith::CmpFPredicate::ORD, lhsImag, zero); 194 Value lhsContainsNotNaNValue = 195 rewriter.create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN); 196 Value resultIsInfinity = rewriter.create<arith::AndIOp>( 197 loc, lhsContainsNotNaNValue, 198 rewriter.create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero)); 199 Value inf = rewriter.create<arith::ConstantOp>( 200 loc, elementType, 201 rewriter.getFloatAttr( 202 elementType, APFloat::getInf(elementType.getFloatSemantics()))); 203 Value infWithSignOfRhsReal = 204 rewriter.create<math::CopySignOp>(loc, inf, rhsReal); 205 Value infinityResultReal = 206 rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal); 207 Value infinityResultImag = 208 rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag); 209 210 // Case 2. Infinite numerator, finite denominator. 211 Value rhsRealFinite = rewriter.create<arith::CmpFOp>( 212 loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf); 213 Value rhsImagFinite = rewriter.create<arith::CmpFOp>( 214 loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf); 215 Value rhsFinite = 216 rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite); 217 Value lhsRealAbs = rewriter.create<math::AbsOp>(loc, lhsReal); 218 Value lhsRealInfinite = rewriter.create<arith::CmpFOp>( 219 loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf); 220 Value lhsImagAbs = rewriter.create<math::AbsOp>(loc, lhsImag); 221 Value lhsImagInfinite = rewriter.create<arith::CmpFOp>( 222 loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf); 223 Value lhsInfinite = 224 rewriter.create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite); 225 Value infNumFiniteDenom = 226 rewriter.create<arith::AndIOp>(loc, lhsInfinite, rhsFinite); 227 Value one = rewriter.create<arith::ConstantOp>( 228 loc, elementType, rewriter.getFloatAttr(elementType, 1)); 229 Value lhsRealIsInfWithSign = rewriter.create<math::CopySignOp>( 230 loc, rewriter.create<arith::SelectOp>(loc, lhsRealInfinite, one, zero), 231 lhsReal); 232 Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>( 233 loc, rewriter.create<arith::SelectOp>(loc, lhsImagInfinite, one, zero), 234 lhsImag); 235 Value lhsRealIsInfWithSignTimesRhsReal = 236 rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal); 237 Value lhsImagIsInfWithSignTimesRhsImag = 238 rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag); 239 Value resultReal3 = rewriter.create<arith::MulFOp>( 240 loc, inf, 241 rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal, 242 lhsImagIsInfWithSignTimesRhsImag)); 243 Value lhsRealIsInfWithSignTimesRhsImag = 244 rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag); 245 Value lhsImagIsInfWithSignTimesRhsReal = 246 rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal); 247 Value resultImag3 = rewriter.create<arith::MulFOp>( 248 loc, inf, 249 rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal, 250 lhsRealIsInfWithSignTimesRhsImag)); 251 252 // Case 3: Finite numerator, infinite denominator. 253 Value lhsRealFinite = rewriter.create<arith::CmpFOp>( 254 loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf); 255 Value lhsImagFinite = rewriter.create<arith::CmpFOp>( 256 loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf); 257 Value lhsFinite = 258 rewriter.create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite); 259 Value rhsRealInfinite = rewriter.create<arith::CmpFOp>( 260 loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf); 261 Value rhsImagInfinite = rewriter.create<arith::CmpFOp>( 262 loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf); 263 Value rhsInfinite = 264 rewriter.create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite); 265 Value finiteNumInfiniteDenom = 266 rewriter.create<arith::AndIOp>(loc, lhsFinite, rhsInfinite); 267 Value rhsRealIsInfWithSign = rewriter.create<math::CopySignOp>( 268 loc, rewriter.create<arith::SelectOp>(loc, rhsRealInfinite, one, zero), 269 rhsReal); 270 Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>( 271 loc, rewriter.create<arith::SelectOp>(loc, rhsImagInfinite, one, zero), 272 rhsImag); 273 Value rhsRealIsInfWithSignTimesLhsReal = 274 rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign); 275 Value rhsImagIsInfWithSignTimesLhsImag = 276 rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign); 277 Value resultReal4 = rewriter.create<arith::MulFOp>( 278 loc, zero, 279 rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal, 280 rhsImagIsInfWithSignTimesLhsImag)); 281 Value rhsRealIsInfWithSignTimesLhsImag = 282 rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign); 283 Value rhsImagIsInfWithSignTimesLhsReal = 284 rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign); 285 Value resultImag4 = rewriter.create<arith::MulFOp>( 286 loc, zero, 287 rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag, 288 rhsImagIsInfWithSignTimesLhsReal)); 289 290 Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>( 291 loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs); 292 Value resultReal = rewriter.create<arith::SelectOp>( 293 loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2); 294 Value resultImag = rewriter.create<arith::SelectOp>( 295 loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2); 296 Value resultRealSpecialCase3 = rewriter.create<arith::SelectOp>( 297 loc, finiteNumInfiniteDenom, resultReal4, resultReal); 298 Value resultImagSpecialCase3 = rewriter.create<arith::SelectOp>( 299 loc, finiteNumInfiniteDenom, resultImag4, resultImag); 300 Value resultRealSpecialCase2 = rewriter.create<arith::SelectOp>( 301 loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3); 302 Value resultImagSpecialCase2 = rewriter.create<arith::SelectOp>( 303 loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3); 304 Value resultRealSpecialCase1 = rewriter.create<arith::SelectOp>( 305 loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2); 306 Value resultImagSpecialCase1 = rewriter.create<arith::SelectOp>( 307 loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2); 308 309 Value resultRealIsNaN = rewriter.create<arith::CmpFOp>( 310 loc, arith::CmpFPredicate::UNO, resultReal, zero); 311 Value resultImagIsNaN = rewriter.create<arith::CmpFOp>( 312 loc, arith::CmpFPredicate::UNO, resultImag, zero); 313 Value resultIsNaN = 314 rewriter.create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN); 315 Value resultRealWithSpecialCases = rewriter.create<arith::SelectOp>( 316 loc, resultIsNaN, resultRealSpecialCase1, resultReal); 317 Value resultImagWithSpecialCases = rewriter.create<arith::SelectOp>( 318 loc, resultIsNaN, resultImagSpecialCase1, resultImag); 319 320 rewriter.replaceOpWithNewOp<complex::CreateOp>( 321 op, type, resultRealWithSpecialCases, resultImagWithSpecialCases); 322 return success(); 323 } 324 }; 325 326 struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> { 327 using OpConversionPattern<complex::ExpOp>::OpConversionPattern; 328 329 LogicalResult 330 matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor, 331 ConversionPatternRewriter &rewriter) const override { 332 auto loc = op.getLoc(); 333 auto type = adaptor.getComplex().getType().cast<ComplexType>(); 334 auto elementType = type.getElementType().cast<FloatType>(); 335 336 Value real = 337 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); 338 Value imag = 339 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); 340 Value expReal = rewriter.create<math::ExpOp>(loc, real); 341 Value cosImag = rewriter.create<math::CosOp>(loc, imag); 342 Value resultReal = rewriter.create<arith::MulFOp>(loc, expReal, cosImag); 343 Value sinImag = rewriter.create<math::SinOp>(loc, imag); 344 Value resultImag = rewriter.create<arith::MulFOp>(loc, expReal, sinImag); 345 346 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 347 resultImag); 348 return success(); 349 } 350 }; 351 352 struct LogOpConversion : public OpConversionPattern<complex::LogOp> { 353 using OpConversionPattern<complex::LogOp>::OpConversionPattern; 354 355 LogicalResult 356 matchAndRewrite(complex::LogOp op, OpAdaptor adaptor, 357 ConversionPatternRewriter &rewriter) const override { 358 auto type = adaptor.getComplex().getType().cast<ComplexType>(); 359 auto elementType = type.getElementType().cast<FloatType>(); 360 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 361 362 Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex()); 363 Value resultReal = b.create<math::LogOp>(elementType, abs); 364 Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); 365 Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); 366 Value resultImag = b.create<math::Atan2Op>(elementType, imag, real); 367 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, 368 resultImag); 369 return success(); 370 } 371 }; 372 373 struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> { 374 using OpConversionPattern<complex::Log1pOp>::OpConversionPattern; 375 376 LogicalResult 377 matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor, 378 ConversionPatternRewriter &rewriter) const override { 379 auto type = adaptor.getComplex().getType().cast<ComplexType>(); 380 auto elementType = type.getElementType().cast<FloatType>(); 381 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 382 383 Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); 384 Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); 385 Value one = b.create<arith::ConstantOp>(elementType, 386 b.getFloatAttr(elementType, 1)); 387 Value realPlusOne = b.create<arith::AddFOp>(real, one); 388 Value newComplex = b.create<complex::CreateOp>(type, realPlusOne, imag); 389 rewriter.replaceOpWithNewOp<complex::LogOp>(op, type, newComplex); 390 return success(); 391 } 392 }; 393 394 struct MulOpConversion : public OpConversionPattern<complex::MulOp> { 395 using OpConversionPattern<complex::MulOp>::OpConversionPattern; 396 397 LogicalResult 398 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor, 399 ConversionPatternRewriter &rewriter) const override { 400 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 401 auto type = adaptor.getLhs().getType().cast<ComplexType>(); 402 auto elementType = type.getElementType().cast<FloatType>(); 403 404 Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs()); 405 Value lhsRealAbs = b.create<math::AbsOp>(lhsReal); 406 Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs()); 407 Value lhsImagAbs = b.create<math::AbsOp>(lhsImag); 408 Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs()); 409 Value rhsRealAbs = b.create<math::AbsOp>(rhsReal); 410 Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs()); 411 Value rhsImagAbs = b.create<math::AbsOp>(rhsImag); 412 413 Value lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal); 414 Value lhsRealTimesRhsRealAbs = b.create<math::AbsOp>(lhsRealTimesRhsReal); 415 Value lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag); 416 Value lhsImagTimesRhsImagAbs = b.create<math::AbsOp>(lhsImagTimesRhsImag); 417 Value real = 418 b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag); 419 420 Value lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal); 421 Value lhsImagTimesRhsRealAbs = b.create<math::AbsOp>(lhsImagTimesRhsReal); 422 Value lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag); 423 Value lhsRealTimesRhsImagAbs = b.create<math::AbsOp>(lhsRealTimesRhsImag); 424 Value imag = 425 b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag); 426 427 // Handle cases where the "naive" calculation results in NaN values. 428 Value realIsNan = 429 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real); 430 Value imagIsNan = 431 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, imag, imag); 432 Value isNan = b.create<arith::AndIOp>(realIsNan, imagIsNan); 433 434 Value inf = b.create<arith::ConstantOp>( 435 elementType, 436 b.getFloatAttr(elementType, 437 APFloat::getInf(elementType.getFloatSemantics()))); 438 439 // Case 1. `lhsReal` or `lhsImag` are infinite. 440 Value lhsRealIsInf = 441 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsRealAbs, inf); 442 Value lhsImagIsInf = 443 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsImagAbs, inf); 444 Value lhsIsInf = b.create<arith::OrIOp>(lhsRealIsInf, lhsImagIsInf); 445 Value rhsRealIsNan = 446 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsReal, rhsReal); 447 Value rhsImagIsNan = 448 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsImag, rhsImag); 449 Value zero = 450 b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType)); 451 Value one = b.create<arith::ConstantOp>(elementType, 452 b.getFloatAttr(elementType, 1)); 453 Value lhsRealIsInfFloat = 454 b.create<arith::SelectOp>(lhsRealIsInf, one, zero); 455 lhsReal = b.create<arith::SelectOp>( 456 lhsIsInf, b.create<math::CopySignOp>(lhsRealIsInfFloat, lhsReal), 457 lhsReal); 458 Value lhsImagIsInfFloat = 459 b.create<arith::SelectOp>(lhsImagIsInf, one, zero); 460 lhsImag = b.create<arith::SelectOp>( 461 lhsIsInf, b.create<math::CopySignOp>(lhsImagIsInfFloat, lhsImag), 462 lhsImag); 463 Value lhsIsInfAndRhsRealIsNan = 464 b.create<arith::AndIOp>(lhsIsInf, rhsRealIsNan); 465 rhsReal = b.create<arith::SelectOp>( 466 lhsIsInfAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal), 467 rhsReal); 468 Value lhsIsInfAndRhsImagIsNan = 469 b.create<arith::AndIOp>(lhsIsInf, rhsImagIsNan); 470 rhsImag = b.create<arith::SelectOp>( 471 lhsIsInfAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag), 472 rhsImag); 473 474 // Case 2. `rhsReal` or `rhsImag` are infinite. 475 Value rhsRealIsInf = 476 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsRealAbs, inf); 477 Value rhsImagIsInf = 478 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsImagAbs, inf); 479 Value rhsIsInf = b.create<arith::OrIOp>(rhsRealIsInf, rhsImagIsInf); 480 Value lhsRealIsNan = 481 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsReal, lhsReal); 482 Value lhsImagIsNan = 483 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsImag, lhsImag); 484 Value rhsRealIsInfFloat = 485 b.create<arith::SelectOp>(rhsRealIsInf, one, zero); 486 rhsReal = b.create<arith::SelectOp>( 487 rhsIsInf, b.create<math::CopySignOp>(rhsRealIsInfFloat, rhsReal), 488 rhsReal); 489 Value rhsImagIsInfFloat = 490 b.create<arith::SelectOp>(rhsImagIsInf, one, zero); 491 rhsImag = b.create<arith::SelectOp>( 492 rhsIsInf, b.create<math::CopySignOp>(rhsImagIsInfFloat, rhsImag), 493 rhsImag); 494 Value rhsIsInfAndLhsRealIsNan = 495 b.create<arith::AndIOp>(rhsIsInf, lhsRealIsNan); 496 lhsReal = b.create<arith::SelectOp>( 497 rhsIsInfAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal), 498 lhsReal); 499 Value rhsIsInfAndLhsImagIsNan = 500 b.create<arith::AndIOp>(rhsIsInf, lhsImagIsNan); 501 lhsImag = b.create<arith::SelectOp>( 502 rhsIsInfAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag), 503 lhsImag); 504 Value recalc = b.create<arith::OrIOp>(lhsIsInf, rhsIsInf); 505 506 // Case 3. One of the pairwise products of left hand side with right hand 507 // side is infinite. 508 Value lhsRealTimesRhsRealIsInf = b.create<arith::CmpFOp>( 509 arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf); 510 Value lhsImagTimesRhsImagIsInf = b.create<arith::CmpFOp>( 511 arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf); 512 Value isSpecialCase = b.create<arith::OrIOp>(lhsRealTimesRhsRealIsInf, 513 lhsImagTimesRhsImagIsInf); 514 Value lhsRealTimesRhsImagIsInf = b.create<arith::CmpFOp>( 515 arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf); 516 isSpecialCase = 517 b.create<arith::OrIOp>(isSpecialCase, lhsRealTimesRhsImagIsInf); 518 Value lhsImagTimesRhsRealIsInf = b.create<arith::CmpFOp>( 519 arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf); 520 isSpecialCase = 521 b.create<arith::OrIOp>(isSpecialCase, lhsImagTimesRhsRealIsInf); 522 Type i1Type = b.getI1Type(); 523 Value notRecalc = b.create<arith::XOrIOp>( 524 recalc, 525 b.create<arith::ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1))); 526 isSpecialCase = b.create<arith::AndIOp>(isSpecialCase, notRecalc); 527 Value isSpecialCaseAndLhsRealIsNan = 528 b.create<arith::AndIOp>(isSpecialCase, lhsRealIsNan); 529 lhsReal = b.create<arith::SelectOp>( 530 isSpecialCaseAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal), 531 lhsReal); 532 Value isSpecialCaseAndLhsImagIsNan = 533 b.create<arith::AndIOp>(isSpecialCase, lhsImagIsNan); 534 lhsImag = b.create<arith::SelectOp>( 535 isSpecialCaseAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag), 536 lhsImag); 537 Value isSpecialCaseAndRhsRealIsNan = 538 b.create<arith::AndIOp>(isSpecialCase, rhsRealIsNan); 539 rhsReal = b.create<arith::SelectOp>( 540 isSpecialCaseAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal), 541 rhsReal); 542 Value isSpecialCaseAndRhsImagIsNan = 543 b.create<arith::AndIOp>(isSpecialCase, rhsImagIsNan); 544 rhsImag = b.create<arith::SelectOp>( 545 isSpecialCaseAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag), 546 rhsImag); 547 recalc = b.create<arith::OrIOp>(recalc, isSpecialCase); 548 recalc = b.create<arith::AndIOp>(isNan, recalc); 549 550 // Recalculate real part. 551 lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal); 552 lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag); 553 Value newReal = 554 b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag); 555 real = b.create<arith::SelectOp>( 556 recalc, b.create<arith::MulFOp>(inf, newReal), real); 557 558 // Recalculate imag part. 559 lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal); 560 lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag); 561 Value newImag = 562 b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag); 563 imag = b.create<arith::SelectOp>( 564 recalc, b.create<arith::MulFOp>(inf, newImag), imag); 565 566 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag); 567 return success(); 568 } 569 }; 570 571 struct NegOpConversion : public OpConversionPattern<complex::NegOp> { 572 using OpConversionPattern<complex::NegOp>::OpConversionPattern; 573 574 LogicalResult 575 matchAndRewrite(complex::NegOp op, OpAdaptor adaptor, 576 ConversionPatternRewriter &rewriter) const override { 577 auto loc = op.getLoc(); 578 auto type = adaptor.getComplex().getType().cast<ComplexType>(); 579 auto elementType = type.getElementType().cast<FloatType>(); 580 581 Value real = 582 rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); 583 Value imag = 584 rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); 585 Value negReal = rewriter.create<arith::NegFOp>(loc, real); 586 Value negImag = rewriter.create<arith::NegFOp>(loc, imag); 587 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag); 588 return success(); 589 } 590 }; 591 592 struct SignOpConversion : public OpConversionPattern<complex::SignOp> { 593 using OpConversionPattern<complex::SignOp>::OpConversionPattern; 594 595 LogicalResult 596 matchAndRewrite(complex::SignOp op, OpAdaptor adaptor, 597 ConversionPatternRewriter &rewriter) const override { 598 auto type = adaptor.getComplex().getType().cast<ComplexType>(); 599 auto elementType = type.getElementType().cast<FloatType>(); 600 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); 601 602 Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); 603 Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); 604 Value zero = 605 b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType)); 606 Value realIsZero = 607 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero); 608 Value imagIsZero = 609 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero); 610 Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero); 611 auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex()); 612 Value realSign = b.create<arith::DivFOp>(real, abs); 613 Value imagSign = b.create<arith::DivFOp>(imag, abs); 614 Value sign = b.create<complex::CreateOp>(type, realSign, imagSign); 615 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero, 616 adaptor.getComplex(), sign); 617 return success(); 618 } 619 }; 620 } // namespace 621 622 void mlir::populateComplexToStandardConversionPatterns( 623 RewritePatternSet &patterns) { 624 // clang-format off 625 patterns.add< 626 AbsOpConversion, 627 ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>, 628 ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>, 629 BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>, 630 BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>, 631 DivOpConversion, 632 ExpOpConversion, 633 LogOpConversion, 634 Log1pOpConversion, 635 MulOpConversion, 636 NegOpConversion, 637 SignOpConversion>(patterns.getContext()); 638 // clang-format on 639 } 640 641 namespace { 642 struct ConvertComplexToStandardPass 643 : public ConvertComplexToStandardBase<ConvertComplexToStandardPass> { 644 void runOnOperation() override; 645 }; 646 647 void ConvertComplexToStandardPass::runOnOperation() { 648 auto function = getOperation(); 649 650 // Convert to the Standard dialect using the converter defined above. 651 RewritePatternSet patterns(&getContext()); 652 populateComplexToStandardConversionPatterns(patterns); 653 654 ConversionTarget target(getContext()); 655 target.addLegalDialect<arith::ArithmeticDialect, StandardOpsDialect, 656 math::MathDialect>(); 657 target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>(); 658 if (failed(applyPartialConversion(function, target, std::move(patterns)))) 659 signalPassFailure(); 660 } 661 } // namespace 662 663 std::unique_ptr<OperationPass<FuncOp>> 664 mlir::createConvertComplexToStandardPass() { 665 return std::make_unique<ConvertComplexToStandardPass>(); 666 } 667